diff --git a/pungi/phases/gather/methods/method_nodeps.py b/pungi/phases/gather/methods/method_nodeps.py index bf881ae8..2979cd7d 100644 --- a/pungi/phases/gather/methods/method_nodeps.py +++ b/pungi/phases/gather/methods/method_nodeps.py @@ -16,6 +16,7 @@ import pungi.arch from pungi.util import pkg_is_rpm, pkg_is_srpm, pkg_is_debug +from pungi.wrappers.comps import CompsWrapper import pungi.phases.gather.method from kobo.pkgset import SimpleRpmWrapper, RpmWrapper @@ -24,7 +25,9 @@ from kobo.pkgset import SimpleRpmWrapper, RpmWrapper class GatherMethodNodeps(pungi.phases.gather.method.GatherMethodBase): enabled = True - def __call__(self, arch, variant, packages, groups, filter_packages, multilib_whitelist, multilib_blacklist, package_sets, path_prefix=None, fulltree_excludes=None, prepopulate=None): + def __call__(self, arch, variant, pkgs, groups, filter_packages, + multilib_whitelist, multilib_blacklist, package_sets, + path_prefix=None, fulltree_excludes=None, prepopulate=None): global_pkgset = package_sets["global"] result = { "rpm": [], @@ -32,6 +35,9 @@ class GatherMethodNodeps(pungi.phases.gather.method.GatherMethodBase): "debuginfo": [], } + group_packages = expand_groups(self.compose, arch, groups) + packages = pkgs | group_packages + seen_rpms = {} seen_srpms = {} @@ -89,3 +95,19 @@ class GatherMethodNodeps(pungi.phases.gather.method.GatherMethodBase): }) return result + + +def expand_groups(compose, arch, groups): + """Read comps file filtered for given architecture and return all packages + in given groups. + + :returns: A set of tuples (pkg_name, arch) + """ + comps_file = compose.paths.work.comps(arch, create_dir=False) + comps = CompsWrapper(comps_file) + packages = set() + + for group in groups: + packages.update([(pkg, arch) for pkg in comps.get_packages(group)]) + + return packages diff --git a/pungi/wrappers/comps.py b/pungi/wrappers/comps.py index 6c43e5fa..a392158b 100644 --- a/pungi/wrappers/comps.py +++ b/pungi/wrappers/comps.py @@ -58,6 +58,13 @@ class CompsWrapper(object): """Return a list of group IDs.""" return [group.id for group in self.comps.groups] + def get_packages(self, group): + """Return list of package names in given group.""" + for grp in self.comps.groups: + if grp.id == group: + return [pkg.name for pkg in grp.packages] + raise KeyError('No such group %r' % group) + def write_comps(self, comps_obj=None, target_file=None): if not comps_obj: comps_obj = self.generate_comps() diff --git a/tests/test_comps_wrapper.py b/tests/test_comps_wrapper.py index c433259d..a45decd4 100644 --- a/tests/test_comps_wrapper.py +++ b/tests/test_comps_wrapper.py @@ -41,6 +41,17 @@ class CompsWrapperTest(unittest.TestCase): comps.get_comps_groups(), ['core', 'standard', 'text-internet', 'firefox', 'resilient-storage', 'basic-desktop']) + def test_get_packages(self): + comps = CompsWrapper(COMPS_FILE) + self.assertItemsEqual( + comps.get_packages('text-internet'), + {'dummy-elinks', 'dummy-tftp'}) + + def test_get_packages_for_non_existing_group(self): + comps = CompsWrapper(COMPS_FILE) + with self.assertRaises(KeyError): + comps.get_packages('foo') + def test_write_comps(self): comps = CompsWrapper(COMPS_FILE) comps.write_comps(target_file=self.file.name) diff --git a/tests/test_gather_method_nodeps.py b/tests/test_gather_method_nodeps.py new file mode 100644 index 00000000..cf22476d --- /dev/null +++ b/tests/test_gather_method_nodeps.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +import mock +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from pungi.phases.gather.methods import method_nodeps as nodeps +from tests import helpers + +COMPS_FILE = os.path.join(helpers.FIXTURE_DIR, 'comps.xml') + + +class TestWritePungiConfig(helpers.PungiTestCase): + def setUp(self): + super(TestWritePungiConfig, self).setUp() + self.compose = helpers.DummyCompose(self.topdir, {}) + self.compose.DEBUG = False + self.compose.paths.work.comps = mock.Mock(return_value=COMPS_FILE) + + def test_expand_group(self): + packages = nodeps.expand_groups(self.compose, 'x86_64', ['core', 'text-internet']) + self.assertItemsEqual(packages, [('dummy-bash', 'x86_64'), + ('dummy-elinks', 'x86_64'), + ('dummy-tftp', 'x86_64')])