diff --git a/pungi/phases/init.py b/pungi/phases/init.py index c347222f..19631447 100644 --- a/pungi/phases/init.py +++ b/pungi/phases/init.py @@ -14,17 +14,20 @@ # along with this program; if not, see . +import collections +import glob import os import shutil from kobo.shortcuts import run +from pungi import Modulemd from pungi.phases.base import PhaseBase from pungi.phases.gather import write_prepopulate_file -from pungi.wrappers.createrepo import CreaterepoWrapper -from pungi.wrappers.comps import CompsWrapper -from pungi.wrappers.scm import get_file_from_scm, get_dir_from_scm from pungi.util import temp_dir +from pungi.wrappers.comps import CompsWrapper +from pungi.wrappers.createrepo import CreaterepoWrapper +from pungi.wrappers.scm import get_dir_from_scm, get_file_from_scm class InitPhase(PhaseBase): @@ -55,6 +58,9 @@ class InitPhase(PhaseBase): # download module defaults if self.compose.has_module_defaults: write_module_defaults(self.compose) + validate_module_defaults( + self.compose.paths.work.module_defaults_dir(create_dir=False) + ) # write prepopulate file write_prepopulate_file(self.compose) @@ -170,3 +176,30 @@ def write_module_defaults(compose): get_dir_from_scm(scm_dict, tmp_dir, logger=compose._logger) compose.log_debug("Writing module defaults") shutil.copytree(tmp_dir, compose.paths.work.module_defaults_dir(create_dir=False)) + + +def validate_module_defaults(path): + """Make sure there are no conflicting defaults. Each module name can only + have one default stream. + + :param str path: directory with cloned module defaults + """ + seen_defaults = collections.defaultdict(set) + for file in glob.glob(os.path.join(path, "*.yaml")): + for mmddef in Modulemd.objects_from_file(file): + if not isinstance(mmddef, Modulemd.Defaults): + continue + seen_defaults[mmddef.peek_module_name()].add(mmddef.peek_default_stream()) + + errors = [] + for module_name, defaults in seen_defaults.items(): + if len(defaults) > 1: + errors.append( + "Module %s has multiple defaults: %s" + % (module_name, ", ".join(sorted(defaults))) + ) + + if errors: + raise RuntimeError( + "There are duplicated module defaults:\n%s" % "\n".join(errors) + ) diff --git a/tests/test_initphase.py b/tests/test_initphase.py index 5eaa3658..769e68fe 100644 --- a/tests/test_initphase.py +++ b/tests/test_initphase.py @@ -10,6 +10,7 @@ import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from pungi import Modulemd from pungi.phases import init from tests.helpers import DummyCompose, PungiTestCase, touch @@ -333,5 +334,66 @@ class TestWriteVariantComps(PungiTestCase): self.assertEqual(comps.write_comps.mock_calls, []) +@unittest.skipUnless(Modulemd, "Skipped test, no module support.") +class TestValidateModuleDefaults(PungiTestCase): + + def _write_defaults(self, defs): + for mod_name, streams in defs.items(): + for stream in streams: + mmddef = Modulemd.Defaults.new() + mmddef.set_version(1) + mmddef.set_module_name(mod_name) + mmddef.set_default_stream(stream) + mmddef.dump( + os.path.join(self.topdir, "%s-%s.yaml" % (mod_name, stream)) + ) + + def test_valid_files(self): + self._write_defaults({"httpd": ["1"], "python": ["3.6"]}) + + init.validate_module_defaults(self.topdir) + + def test_duplicated_stream(self): + self._write_defaults({"httpd": ["1"], "python": ["3.6", "3.5"]}) + + with self.assertRaises(RuntimeError) as ctx: + init.validate_module_defaults(self.topdir) + + self.assertIn( + "Module python has multiple defaults: 3.5, 3.6", str(ctx.exception) + ) + + def test_reports_all(self): + self._write_defaults({"httpd": ["1", "2"], "python": ["3.6", "3.5"]}) + + with self.assertRaises(RuntimeError) as ctx: + init.validate_module_defaults(self.topdir) + + self.assertIn("Module httpd has multiple defaults: 1, 2", str(ctx.exception)) + self.assertIn( + "Module python has multiple defaults: 3.5, 3.6", str(ctx.exception) + ) + + def test_handles_non_defaults_file(self): + self._write_defaults({"httpd": ["1"], "python": ["3.6"]}) + touch( + os.path.join(self.topdir, "boom.yaml"), + "\n".join( + [ + "document: modulemd", + "version: 2", + "data:", + " summary: dummy module", + " description: dummy module", + " license:", + " module: [GPL]", + " content: [GPL]", + ] + ), + ) + + init.validate_module_defaults(self.topdir) + + if __name__ == "__main__": unittest.main()