From 857aee05c10eb04a4f0d4c8fda9ef5b6b8c3d925 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lubom=C3=ADr=20Sedl=C3=A1=C5=99?= Date: Fri, 17 Feb 2017 13:44:11 +0100 Subject: [PATCH] util: Add a utility for managing temporary files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In multiple situations we need to create temporary files or directories that should not be preserved after compose is finished. Let's add context managers that ensure these get cleaned up. This fixes tests leaving garbage around in /tmp. Signed-off-by: Lubomír Sedlář --- bin/pungi-config-validate | 13 ++---------- pungi/util.py | 20 ++++++++++++++++++ pungi/wrappers/scm.py | 44 ++++++++++++--------------------------- tests/test_util.py | 22 ++++++++++++++++++++ 4 files changed, 57 insertions(+), 42 deletions(-) diff --git a/bin/pungi-config-validate b/bin/pungi-config-validate index 8dedf44f..ce3f5cda 100755 --- a/bin/pungi-config-validate +++ b/bin/pungi-config-validate @@ -4,12 +4,9 @@ from __future__ import print_function import argparse -import contextlib import kobo.conf import os -import shutil import sys -import tempfile here = sys.path[0] if here != '/usr/bin': @@ -20,6 +17,7 @@ import pungi.compose import pungi.checks import pungi.paths import pungi.phases +import pungi.util class ValidationCompose(pungi.compose.Compose): @@ -53,13 +51,6 @@ class ValidationCompose(pungi.compose.Compose): return '0' -@contextlib.contextmanager -def in_temp_dir(): - tempdir = tempfile.mkdtemp() - yield tempdir - shutil.rmtree(tempdir) - - def run(config, topdir, has_old): conf = kobo.conf.PyConfigParser() conf.load_from_file(config) @@ -112,7 +103,7 @@ def main(args=None): help='indicate if pungi-koji will be run with --old-composes option') opts = parser.parse_args(args) - with in_temp_dir() as topdir: + with pungi.util.temp_dir() as topdir: errors = run(opts.config, topdir, opts.old_composes) for msg in errors: diff --git a/pungi/util.py b/pungi/util.py index 2b785e03..01cab53c 100644 --- a/pungi/util.py +++ b/pungi/util.py @@ -25,6 +25,7 @@ import re import urlparse import contextlib import traceback +import tempfile from kobo.shortcuts import run, force_list from productmd.common import get_major_version @@ -572,3 +573,22 @@ def levenshtein(a, b): mat[j - 1][i - 1] + cost) return mat[len(b)][len(a)] + + +@contextlib.contextmanager +def temp_dir(log=None, *args, **kwargs): + """Create a temporary directory and ensure it's deleted.""" + if kwargs.get('dir'): + # If we are supposed to create the temp dir in a particular location, + # ensure the location already exists. + makedirs(kwargs['dir']) + dir = tempfile.mkdtemp(*args, **kwargs) + try: + yield dir + finally: + try: + shutil.rmtree(dir) + except OSError as exc: + # Okay, we failed to delete temporary dir. + if log: + log.warning('Error removing %s: %s', dir, exc.strerror) diff --git a/pungi/wrappers/scm.py b/pungi/wrappers/scm.py index 474af16e..6cb88953 100644 --- a/pungi/wrappers/scm.py +++ b/pungi/wrappers/scm.py @@ -16,36 +16,20 @@ from __future__ import absolute_import import os -import tempfile import shutil import pipes import glob import time -import contextlib import kobo.log from kobo.shortcuts import run, force_list -from pungi.util import explode_rpm_package, makedirs, copy_all +from pungi.util import explode_rpm_package, makedirs, copy_all, temp_dir class ScmBase(kobo.log.LoggingBase): def __init__(self, logger=None): kobo.log.LoggingBase.__init__(self, logger=logger) - @contextlib.contextmanager - def _temp_dir(self, tmp_dir=None): - if tmp_dir is not None: - makedirs(tmp_dir) - path = tempfile.mkdtemp(prefix="cvswrapper_", dir=tmp_dir) - - yield path - - self.log_debug("Removing %s" % path) - try: - shutil.rmtree(path) - except OSError as ex: - self.log_warning("Error removing %s: %s" % (path, ex)) - def retry_run(self, cmd, retries=5, timeout=60, **kwargs): """ @param cmd - cmd passed to kobo.shortcuts.run() @@ -96,7 +80,7 @@ class CvsWrapper(ScmBase): def export_dir(self, scm_root, scm_dir, target_dir, scm_branch=None, tmp_dir=None, log_file=None): scm_dir = scm_dir.lstrip("/") scm_branch = scm_branch or "HEAD" - with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: + with temp_dir(dir=tmp_dir) as tmp_dir: self.log_debug("Exporting directory %s from CVS %s (branch %s)..." % (scm_dir, scm_root, scm_branch)) self.retry_run(["/usr/bin/cvs", "-q", "-d", scm_root, "export", "-r", scm_branch, scm_dir], @@ -106,7 +90,7 @@ class CvsWrapper(ScmBase): def export_file(self, scm_root, scm_file, target_dir, scm_branch=None, tmp_dir=None, log_file=None): scm_file = scm_file.lstrip("/") scm_branch = scm_branch or "HEAD" - with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: + with temp_dir(dir=tmp_dir) as tmp_dir: target_path = os.path.join(target_dir, os.path.basename(scm_file)) self.log_debug("Exporting file %s from CVS %s (branch %s)..." % (scm_file, scm_root, scm_branch)) self.retry_run(["/usr/bin/cvs", "-q", "-d", scm_root, "export", "-r", scm_branch, scm_file], @@ -121,7 +105,7 @@ class GitWrapper(ScmBase): scm_dir = scm_dir.lstrip("/") scm_branch = scm_branch or "master" - with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: + with temp_dir(dir=tmp_dir) as tmp_dir: if "://" not in scm_root: scm_root = "file://%s" % scm_root @@ -142,7 +126,7 @@ class GitWrapper(ScmBase): scm_file = scm_file.lstrip("/") scm_branch = scm_branch or "master" - with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: + with temp_dir(dir=tmp_dir) as tmp_dir: target_path = os.path.join(target_dir, os.path.basename(scm_file)) if "://" not in scm_root: @@ -172,7 +156,7 @@ class RpmScmWrapper(ScmBase): def export_dir(self, scm_root, scm_dir, target_dir, scm_branch=None, tmp_dir=None, log_file=None): for rpm in self._list_rpms(scm_root): scm_dir = scm_dir.lstrip("/") - with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: + with temp_dir(dir=tmp_dir) as tmp_dir: self.log_debug("Extracting directory %s from RPM package %s..." % (scm_dir, rpm)) explode_rpm_package(rpm, tmp_dir) @@ -187,7 +171,7 @@ class RpmScmWrapper(ScmBase): def export_file(self, scm_root, scm_file, target_dir, scm_branch=None, tmp_dir=None, log_file=None): for rpm in self._list_rpms(scm_root): scm_file = scm_file.lstrip("/") - with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: + with temp_dir(dir=tmp_dir) as tmp_dir: self.log_debug("Exporting file %s from RPM file %s..." % (scm_file, rpm)) explode_rpm_package(rpm, tmp_dir) @@ -256,10 +240,9 @@ def get_file_from_scm(scm_dict, target_path, logger=None): files_copied = [] for i in force_list(scm_file): - tmp_dir = tempfile.mkdtemp(prefix="scm_checkout_") - scm.export_file(scm_repo, i, scm_branch=scm_branch, target_dir=tmp_dir) - files_copied += copy_all(tmp_dir, target_path) - shutil.rmtree(tmp_dir) + with temp_dir(prefix="scm_checkout_") as tmp_dir: + scm.export_file(scm_repo, i, scm_branch=scm_branch, target_dir=tmp_dir) + files_copied += copy_all(tmp_dir, target_path) return files_copied @@ -306,8 +289,7 @@ def get_dir_from_scm(scm_dict, target_path, logger=None): scm = _get_wrapper(scm_type, logger=logger) - tmp_dir = tempfile.mkdtemp(prefix="scm_checkout_") - scm.export_dir(scm_repo, scm_dir, scm_branch=scm_branch, target_dir=tmp_dir) - files_copied = copy_all(tmp_dir, target_path) - shutil.rmtree(tmp_dir) + with temp_dir(prefix="scm_checkout_") as tmp_dir: + scm.export_dir(scm_repo, scm_dir, scm_branch=scm_branch, target_dir=tmp_dir) + files_copied = copy_all(tmp_dir, target_path) return files_copied diff --git a/tests/test_util.py b/tests/test_util.py index 8e21cbc0..3052aab5 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -406,5 +406,27 @@ class TestRecursiveFileList(unittest.TestCase): self.assertEqual(expected_files, actual_files) +class TestTempFiles(unittest.TestCase): + def test_temp_dir_ok(self): + with util.temp_dir() as tmp: + self.assertTrue(os.path.isdir(tmp)) + self.assertFalse(os.path.exists(tmp)) + + def test_temp_dir_fail(self): + with self.assertRaises(RuntimeError): + with util.temp_dir() as tmp: + self.assertTrue(os.path.isdir(tmp)) + raise RuntimeError('BOOM') + self.assertFalse(os.path.exists(tmp)) + + def test_temp_dir_in_non_existing_dir(self): + with util.temp_dir() as playground: + root = os.path.join(playground, 'missing') + with util.temp_dir(dir=root) as tmp: + self.assertTrue(os.path.isdir(tmp)) + self.assertTrue(os.path.isdir(root)) + self.assertFalse(os.path.exists(tmp)) + + if __name__ == "__main__": unittest.main()