util: Add a utility for managing temporary files
In multiple situations we need to create temporary files or directories that should not be preserved after compose is finished. Let's add context managers that ensure these get cleaned up. This fixes tests leaving garbage around in /tmp. Signed-off-by: Lubomír Sedlář <lsedlar@redhat.com>
This commit is contained in:
parent
a57bc13e30
commit
857aee05c1
@ -4,12 +4,9 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import kobo.conf
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
here = sys.path[0]
|
||||
if here != '/usr/bin':
|
||||
@ -20,6 +17,7 @@ import pungi.compose
|
||||
import pungi.checks
|
||||
import pungi.paths
|
||||
import pungi.phases
|
||||
import pungi.util
|
||||
|
||||
|
||||
class ValidationCompose(pungi.compose.Compose):
|
||||
@ -53,13 +51,6 @@ class ValidationCompose(pungi.compose.Compose):
|
||||
return '0'
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def in_temp_dir():
|
||||
tempdir = tempfile.mkdtemp()
|
||||
yield tempdir
|
||||
shutil.rmtree(tempdir)
|
||||
|
||||
|
||||
def run(config, topdir, has_old):
|
||||
conf = kobo.conf.PyConfigParser()
|
||||
conf.load_from_file(config)
|
||||
@ -112,7 +103,7 @@ def main(args=None):
|
||||
help='indicate if pungi-koji will be run with --old-composes option')
|
||||
opts = parser.parse_args(args)
|
||||
|
||||
with in_temp_dir() as topdir:
|
||||
with pungi.util.temp_dir() as topdir:
|
||||
errors = run(opts.config, topdir, opts.old_composes)
|
||||
|
||||
for msg in errors:
|
||||
|
@ -25,6 +25,7 @@ import re
|
||||
import urlparse
|
||||
import contextlib
|
||||
import traceback
|
||||
import tempfile
|
||||
|
||||
from kobo.shortcuts import run, force_list
|
||||
from productmd.common import get_major_version
|
||||
@ -572,3 +573,22 @@ def levenshtein(a, b):
|
||||
mat[j - 1][i - 1] + cost)
|
||||
|
||||
return mat[len(b)][len(a)]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_dir(log=None, *args, **kwargs):
|
||||
"""Create a temporary directory and ensure it's deleted."""
|
||||
if kwargs.get('dir'):
|
||||
# If we are supposed to create the temp dir in a particular location,
|
||||
# ensure the location already exists.
|
||||
makedirs(kwargs['dir'])
|
||||
dir = tempfile.mkdtemp(*args, **kwargs)
|
||||
try:
|
||||
yield dir
|
||||
finally:
|
||||
try:
|
||||
shutil.rmtree(dir)
|
||||
except OSError as exc:
|
||||
# Okay, we failed to delete temporary dir.
|
||||
if log:
|
||||
log.warning('Error removing %s: %s', dir, exc.strerror)
|
||||
|
@ -16,36 +16,20 @@ from __future__ import absolute_import
|
||||
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
import pipes
|
||||
import glob
|
||||
import time
|
||||
import contextlib
|
||||
|
||||
import kobo.log
|
||||
from kobo.shortcuts import run, force_list
|
||||
from pungi.util import explode_rpm_package, makedirs, copy_all
|
||||
from pungi.util import explode_rpm_package, makedirs, copy_all, temp_dir
|
||||
|
||||
|
||||
class ScmBase(kobo.log.LoggingBase):
|
||||
def __init__(self, logger=None):
|
||||
kobo.log.LoggingBase.__init__(self, logger=logger)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _temp_dir(self, tmp_dir=None):
|
||||
if tmp_dir is not None:
|
||||
makedirs(tmp_dir)
|
||||
path = tempfile.mkdtemp(prefix="cvswrapper_", dir=tmp_dir)
|
||||
|
||||
yield path
|
||||
|
||||
self.log_debug("Removing %s" % path)
|
||||
try:
|
||||
shutil.rmtree(path)
|
||||
except OSError as ex:
|
||||
self.log_warning("Error removing %s: %s" % (path, ex))
|
||||
|
||||
def retry_run(self, cmd, retries=5, timeout=60, **kwargs):
|
||||
"""
|
||||
@param cmd - cmd passed to kobo.shortcuts.run()
|
||||
@ -96,7 +80,7 @@ class CvsWrapper(ScmBase):
|
||||
def export_dir(self, scm_root, scm_dir, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
|
||||
scm_dir = scm_dir.lstrip("/")
|
||||
scm_branch = scm_branch or "HEAD"
|
||||
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
|
||||
with temp_dir(dir=tmp_dir) as tmp_dir:
|
||||
self.log_debug("Exporting directory %s from CVS %s (branch %s)..."
|
||||
% (scm_dir, scm_root, scm_branch))
|
||||
self.retry_run(["/usr/bin/cvs", "-q", "-d", scm_root, "export", "-r", scm_branch, scm_dir],
|
||||
@ -106,7 +90,7 @@ class CvsWrapper(ScmBase):
|
||||
def export_file(self, scm_root, scm_file, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
|
||||
scm_file = scm_file.lstrip("/")
|
||||
scm_branch = scm_branch or "HEAD"
|
||||
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
|
||||
with temp_dir(dir=tmp_dir) as tmp_dir:
|
||||
target_path = os.path.join(target_dir, os.path.basename(scm_file))
|
||||
self.log_debug("Exporting file %s from CVS %s (branch %s)..." % (scm_file, scm_root, scm_branch))
|
||||
self.retry_run(["/usr/bin/cvs", "-q", "-d", scm_root, "export", "-r", scm_branch, scm_file],
|
||||
@ -121,7 +105,7 @@ class GitWrapper(ScmBase):
|
||||
scm_dir = scm_dir.lstrip("/")
|
||||
scm_branch = scm_branch or "master"
|
||||
|
||||
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
|
||||
with temp_dir(dir=tmp_dir) as tmp_dir:
|
||||
if "://" not in scm_root:
|
||||
scm_root = "file://%s" % scm_root
|
||||
|
||||
@ -142,7 +126,7 @@ class GitWrapper(ScmBase):
|
||||
scm_file = scm_file.lstrip("/")
|
||||
scm_branch = scm_branch or "master"
|
||||
|
||||
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
|
||||
with temp_dir(dir=tmp_dir) as tmp_dir:
|
||||
target_path = os.path.join(target_dir, os.path.basename(scm_file))
|
||||
|
||||
if "://" not in scm_root:
|
||||
@ -172,7 +156,7 @@ class RpmScmWrapper(ScmBase):
|
||||
def export_dir(self, scm_root, scm_dir, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
|
||||
for rpm in self._list_rpms(scm_root):
|
||||
scm_dir = scm_dir.lstrip("/")
|
||||
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
|
||||
with temp_dir(dir=tmp_dir) as tmp_dir:
|
||||
self.log_debug("Extracting directory %s from RPM package %s..." % (scm_dir, rpm))
|
||||
explode_rpm_package(rpm, tmp_dir)
|
||||
|
||||
@ -187,7 +171,7 @@ class RpmScmWrapper(ScmBase):
|
||||
def export_file(self, scm_root, scm_file, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
|
||||
for rpm in self._list_rpms(scm_root):
|
||||
scm_file = scm_file.lstrip("/")
|
||||
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
|
||||
with temp_dir(dir=tmp_dir) as tmp_dir:
|
||||
self.log_debug("Exporting file %s from RPM file %s..." % (scm_file, rpm))
|
||||
explode_rpm_package(rpm, tmp_dir)
|
||||
|
||||
@ -256,10 +240,9 @@ def get_file_from_scm(scm_dict, target_path, logger=None):
|
||||
|
||||
files_copied = []
|
||||
for i in force_list(scm_file):
|
||||
tmp_dir = tempfile.mkdtemp(prefix="scm_checkout_")
|
||||
scm.export_file(scm_repo, i, scm_branch=scm_branch, target_dir=tmp_dir)
|
||||
files_copied += copy_all(tmp_dir, target_path)
|
||||
shutil.rmtree(tmp_dir)
|
||||
with temp_dir(prefix="scm_checkout_") as tmp_dir:
|
||||
scm.export_file(scm_repo, i, scm_branch=scm_branch, target_dir=tmp_dir)
|
||||
files_copied += copy_all(tmp_dir, target_path)
|
||||
return files_copied
|
||||
|
||||
|
||||
@ -306,8 +289,7 @@ def get_dir_from_scm(scm_dict, target_path, logger=None):
|
||||
|
||||
scm = _get_wrapper(scm_type, logger=logger)
|
||||
|
||||
tmp_dir = tempfile.mkdtemp(prefix="scm_checkout_")
|
||||
scm.export_dir(scm_repo, scm_dir, scm_branch=scm_branch, target_dir=tmp_dir)
|
||||
files_copied = copy_all(tmp_dir, target_path)
|
||||
shutil.rmtree(tmp_dir)
|
||||
with temp_dir(prefix="scm_checkout_") as tmp_dir:
|
||||
scm.export_dir(scm_repo, scm_dir, scm_branch=scm_branch, target_dir=tmp_dir)
|
||||
files_copied = copy_all(tmp_dir, target_path)
|
||||
return files_copied
|
||||
|
@ -406,5 +406,27 @@ class TestRecursiveFileList(unittest.TestCase):
|
||||
self.assertEqual(expected_files, actual_files)
|
||||
|
||||
|
||||
class TestTempFiles(unittest.TestCase):
|
||||
def test_temp_dir_ok(self):
|
||||
with util.temp_dir() as tmp:
|
||||
self.assertTrue(os.path.isdir(tmp))
|
||||
self.assertFalse(os.path.exists(tmp))
|
||||
|
||||
def test_temp_dir_fail(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
with util.temp_dir() as tmp:
|
||||
self.assertTrue(os.path.isdir(tmp))
|
||||
raise RuntimeError('BOOM')
|
||||
self.assertFalse(os.path.exists(tmp))
|
||||
|
||||
def test_temp_dir_in_non_existing_dir(self):
|
||||
with util.temp_dir() as playground:
|
||||
root = os.path.join(playground, 'missing')
|
||||
with util.temp_dir(dir=root) as tmp:
|
||||
self.assertTrue(os.path.isdir(tmp))
|
||||
self.assertTrue(os.path.isdir(root))
|
||||
self.assertFalse(os.path.exists(tmp))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user