util: Add a utility for managing temporary files

In multiple situations we need to create temporary files or directories
that should not be preserved after compose is finished. Let's add
context managers that ensure these get cleaned up.

This fixes tests leaving garbage around in /tmp.

Signed-off-by: Lubomír Sedlář <lsedlar@redhat.com>
This commit is contained in:
Lubomír Sedlář 2017-02-17 13:44:11 +01:00
parent a57bc13e30
commit 857aee05c1
4 changed files with 57 additions and 42 deletions

View File

@ -4,12 +4,9 @@
from __future__ import print_function
import argparse
import contextlib
import kobo.conf
import os
import shutil
import sys
import tempfile
here = sys.path[0]
if here != '/usr/bin':
@ -20,6 +17,7 @@ import pungi.compose
import pungi.checks
import pungi.paths
import pungi.phases
import pungi.util
class ValidationCompose(pungi.compose.Compose):
@ -53,13 +51,6 @@ class ValidationCompose(pungi.compose.Compose):
return '0'
@contextlib.contextmanager
def in_temp_dir():
tempdir = tempfile.mkdtemp()
yield tempdir
shutil.rmtree(tempdir)
def run(config, topdir, has_old):
conf = kobo.conf.PyConfigParser()
conf.load_from_file(config)
@ -112,7 +103,7 @@ def main(args=None):
help='indicate if pungi-koji will be run with --old-composes option')
opts = parser.parse_args(args)
with in_temp_dir() as topdir:
with pungi.util.temp_dir() as topdir:
errors = run(opts.config, topdir, opts.old_composes)
for msg in errors:

View File

@ -25,6 +25,7 @@ import re
import urlparse
import contextlib
import traceback
import tempfile
from kobo.shortcuts import run, force_list
from productmd.common import get_major_version
@ -572,3 +573,22 @@ def levenshtein(a, b):
mat[j - 1][i - 1] + cost)
return mat[len(b)][len(a)]
@contextlib.contextmanager
def temp_dir(log=None, *args, **kwargs):
"""Create a temporary directory and ensure it's deleted."""
if kwargs.get('dir'):
# If we are supposed to create the temp dir in a particular location,
# ensure the location already exists.
makedirs(kwargs['dir'])
dir = tempfile.mkdtemp(*args, **kwargs)
try:
yield dir
finally:
try:
shutil.rmtree(dir)
except OSError as exc:
# Okay, we failed to delete temporary dir.
if log:
log.warning('Error removing %s: %s', dir, exc.strerror)

View File

@ -16,36 +16,20 @@ from __future__ import absolute_import
import os
import tempfile
import shutil
import pipes
import glob
import time
import contextlib
import kobo.log
from kobo.shortcuts import run, force_list
from pungi.util import explode_rpm_package, makedirs, copy_all
from pungi.util import explode_rpm_package, makedirs, copy_all, temp_dir
class ScmBase(kobo.log.LoggingBase):
def __init__(self, logger=None):
kobo.log.LoggingBase.__init__(self, logger=logger)
@contextlib.contextmanager
def _temp_dir(self, tmp_dir=None):
if tmp_dir is not None:
makedirs(tmp_dir)
path = tempfile.mkdtemp(prefix="cvswrapper_", dir=tmp_dir)
yield path
self.log_debug("Removing %s" % path)
try:
shutil.rmtree(path)
except OSError as ex:
self.log_warning("Error removing %s: %s" % (path, ex))
def retry_run(self, cmd, retries=5, timeout=60, **kwargs):
"""
@param cmd - cmd passed to kobo.shortcuts.run()
@ -96,7 +80,7 @@ class CvsWrapper(ScmBase):
def export_dir(self, scm_root, scm_dir, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
scm_dir = scm_dir.lstrip("/")
scm_branch = scm_branch or "HEAD"
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
with temp_dir(dir=tmp_dir) as tmp_dir:
self.log_debug("Exporting directory %s from CVS %s (branch %s)..."
% (scm_dir, scm_root, scm_branch))
self.retry_run(["/usr/bin/cvs", "-q", "-d", scm_root, "export", "-r", scm_branch, scm_dir],
@ -106,7 +90,7 @@ class CvsWrapper(ScmBase):
def export_file(self, scm_root, scm_file, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
scm_file = scm_file.lstrip("/")
scm_branch = scm_branch or "HEAD"
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
with temp_dir(dir=tmp_dir) as tmp_dir:
target_path = os.path.join(target_dir, os.path.basename(scm_file))
self.log_debug("Exporting file %s from CVS %s (branch %s)..." % (scm_file, scm_root, scm_branch))
self.retry_run(["/usr/bin/cvs", "-q", "-d", scm_root, "export", "-r", scm_branch, scm_file],
@ -121,7 +105,7 @@ class GitWrapper(ScmBase):
scm_dir = scm_dir.lstrip("/")
scm_branch = scm_branch or "master"
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
with temp_dir(dir=tmp_dir) as tmp_dir:
if "://" not in scm_root:
scm_root = "file://%s" % scm_root
@ -142,7 +126,7 @@ class GitWrapper(ScmBase):
scm_file = scm_file.lstrip("/")
scm_branch = scm_branch or "master"
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
with temp_dir(dir=tmp_dir) as tmp_dir:
target_path = os.path.join(target_dir, os.path.basename(scm_file))
if "://" not in scm_root:
@ -172,7 +156,7 @@ class RpmScmWrapper(ScmBase):
def export_dir(self, scm_root, scm_dir, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
for rpm in self._list_rpms(scm_root):
scm_dir = scm_dir.lstrip("/")
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
with temp_dir(dir=tmp_dir) as tmp_dir:
self.log_debug("Extracting directory %s from RPM package %s..." % (scm_dir, rpm))
explode_rpm_package(rpm, tmp_dir)
@ -187,7 +171,7 @@ class RpmScmWrapper(ScmBase):
def export_file(self, scm_root, scm_file, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
for rpm in self._list_rpms(scm_root):
scm_file = scm_file.lstrip("/")
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir:
with temp_dir(dir=tmp_dir) as tmp_dir:
self.log_debug("Exporting file %s from RPM file %s..." % (scm_file, rpm))
explode_rpm_package(rpm, tmp_dir)
@ -256,10 +240,9 @@ def get_file_from_scm(scm_dict, target_path, logger=None):
files_copied = []
for i in force_list(scm_file):
tmp_dir = tempfile.mkdtemp(prefix="scm_checkout_")
scm.export_file(scm_repo, i, scm_branch=scm_branch, target_dir=tmp_dir)
files_copied += copy_all(tmp_dir, target_path)
shutil.rmtree(tmp_dir)
with temp_dir(prefix="scm_checkout_") as tmp_dir:
scm.export_file(scm_repo, i, scm_branch=scm_branch, target_dir=tmp_dir)
files_copied += copy_all(tmp_dir, target_path)
return files_copied
@ -306,8 +289,7 @@ def get_dir_from_scm(scm_dict, target_path, logger=None):
scm = _get_wrapper(scm_type, logger=logger)
tmp_dir = tempfile.mkdtemp(prefix="scm_checkout_")
scm.export_dir(scm_repo, scm_dir, scm_branch=scm_branch, target_dir=tmp_dir)
files_copied = copy_all(tmp_dir, target_path)
shutil.rmtree(tmp_dir)
with temp_dir(prefix="scm_checkout_") as tmp_dir:
scm.export_dir(scm_repo, scm_dir, scm_branch=scm_branch, target_dir=tmp_dir)
files_copied = copy_all(tmp_dir, target_path)
return files_copied

View File

@ -406,5 +406,27 @@ class TestRecursiveFileList(unittest.TestCase):
self.assertEqual(expected_files, actual_files)
class TestTempFiles(unittest.TestCase):
def test_temp_dir_ok(self):
with util.temp_dir() as tmp:
self.assertTrue(os.path.isdir(tmp))
self.assertFalse(os.path.exists(tmp))
def test_temp_dir_fail(self):
with self.assertRaises(RuntimeError):
with util.temp_dir() as tmp:
self.assertTrue(os.path.isdir(tmp))
raise RuntimeError('BOOM')
self.assertFalse(os.path.exists(tmp))
def test_temp_dir_in_non_existing_dir(self):
with util.temp_dir() as playground:
root = os.path.join(playground, 'missing')
with util.temp_dir(dir=root) as tmp:
self.assertTrue(os.path.isdir(tmp))
self.assertTrue(os.path.isdir(root))
self.assertFalse(os.path.exists(tmp))
if __name__ == "__main__":
unittest.main()