util: Add a utility for managing temporary files

In multiple situations we need to create temporary files or directories
that should not be preserved after compose is finished. Let's add
context managers that ensure these get cleaned up.

This fixes tests leaving garbage around in /tmp.

Signed-off-by: Lubomír Sedlář <lsedlar@redhat.com>
This commit is contained in:
Lubomír Sedlář 2017-02-17 13:44:11 +01:00
parent a57bc13e30
commit 857aee05c1
4 changed files with 57 additions and 42 deletions

View File

@ -4,12 +4,9 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import contextlib
import kobo.conf import kobo.conf
import os import os
import shutil
import sys import sys
import tempfile
here = sys.path[0] here = sys.path[0]
if here != '/usr/bin': if here != '/usr/bin':
@ -20,6 +17,7 @@ import pungi.compose
import pungi.checks import pungi.checks
import pungi.paths import pungi.paths
import pungi.phases import pungi.phases
import pungi.util
class ValidationCompose(pungi.compose.Compose): class ValidationCompose(pungi.compose.Compose):
@ -53,13 +51,6 @@ class ValidationCompose(pungi.compose.Compose):
return '0' return '0'
@contextlib.contextmanager
def in_temp_dir():
tempdir = tempfile.mkdtemp()
yield tempdir
shutil.rmtree(tempdir)
def run(config, topdir, has_old): def run(config, topdir, has_old):
conf = kobo.conf.PyConfigParser() conf = kobo.conf.PyConfigParser()
conf.load_from_file(config) conf.load_from_file(config)
@ -112,7 +103,7 @@ def main(args=None):
help='indicate if pungi-koji will be run with --old-composes option') help='indicate if pungi-koji will be run with --old-composes option')
opts = parser.parse_args(args) opts = parser.parse_args(args)
with in_temp_dir() as topdir: with pungi.util.temp_dir() as topdir:
errors = run(opts.config, topdir, opts.old_composes) errors = run(opts.config, topdir, opts.old_composes)
for msg in errors: for msg in errors:

View File

@ -25,6 +25,7 @@ import re
import urlparse import urlparse
import contextlib import contextlib
import traceback import traceback
import tempfile
from kobo.shortcuts import run, force_list from kobo.shortcuts import run, force_list
from productmd.common import get_major_version from productmd.common import get_major_version
@ -572,3 +573,22 @@ def levenshtein(a, b):
mat[j - 1][i - 1] + cost) mat[j - 1][i - 1] + cost)
return mat[len(b)][len(a)] return mat[len(b)][len(a)]
@contextlib.contextmanager
def temp_dir(log=None, *args, **kwargs):
"""Create a temporary directory and ensure it's deleted."""
if kwargs.get('dir'):
# If we are supposed to create the temp dir in a particular location,
# ensure the location already exists.
makedirs(kwargs['dir'])
dir = tempfile.mkdtemp(*args, **kwargs)
try:
yield dir
finally:
try:
shutil.rmtree(dir)
except OSError as exc:
# Okay, we failed to delete temporary dir.
if log:
log.warning('Error removing %s: %s', dir, exc.strerror)

View File

@ -16,36 +16,20 @@ from __future__ import absolute_import
import os import os
import tempfile
import shutil import shutil
import pipes import pipes
import glob import glob
import time import time
import contextlib
import kobo.log import kobo.log
from kobo.shortcuts import run, force_list from kobo.shortcuts import run, force_list
from pungi.util import explode_rpm_package, makedirs, copy_all from pungi.util import explode_rpm_package, makedirs, copy_all, temp_dir
class ScmBase(kobo.log.LoggingBase): class ScmBase(kobo.log.LoggingBase):
def __init__(self, logger=None): def __init__(self, logger=None):
kobo.log.LoggingBase.__init__(self, logger=logger) kobo.log.LoggingBase.__init__(self, logger=logger)
@contextlib.contextmanager
def _temp_dir(self, tmp_dir=None):
if tmp_dir is not None:
makedirs(tmp_dir)
path = tempfile.mkdtemp(prefix="cvswrapper_", dir=tmp_dir)
yield path
self.log_debug("Removing %s" % path)
try:
shutil.rmtree(path)
except OSError as ex:
self.log_warning("Error removing %s: %s" % (path, ex))
def retry_run(self, cmd, retries=5, timeout=60, **kwargs): def retry_run(self, cmd, retries=5, timeout=60, **kwargs):
""" """
@param cmd - cmd passed to kobo.shortcuts.run() @param cmd - cmd passed to kobo.shortcuts.run()
@ -96,7 +80,7 @@ class CvsWrapper(ScmBase):
def export_dir(self, scm_root, scm_dir, target_dir, scm_branch=None, tmp_dir=None, log_file=None): def export_dir(self, scm_root, scm_dir, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
scm_dir = scm_dir.lstrip("/") scm_dir = scm_dir.lstrip("/")
scm_branch = scm_branch or "HEAD" scm_branch = scm_branch or "HEAD"
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: with temp_dir(dir=tmp_dir) as tmp_dir:
self.log_debug("Exporting directory %s from CVS %s (branch %s)..." self.log_debug("Exporting directory %s from CVS %s (branch %s)..."
% (scm_dir, scm_root, scm_branch)) % (scm_dir, scm_root, scm_branch))
self.retry_run(["/usr/bin/cvs", "-q", "-d", scm_root, "export", "-r", scm_branch, scm_dir], self.retry_run(["/usr/bin/cvs", "-q", "-d", scm_root, "export", "-r", scm_branch, scm_dir],
@ -106,7 +90,7 @@ class CvsWrapper(ScmBase):
def export_file(self, scm_root, scm_file, target_dir, scm_branch=None, tmp_dir=None, log_file=None): def export_file(self, scm_root, scm_file, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
scm_file = scm_file.lstrip("/") scm_file = scm_file.lstrip("/")
scm_branch = scm_branch or "HEAD" scm_branch = scm_branch or "HEAD"
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: with temp_dir(dir=tmp_dir) as tmp_dir:
target_path = os.path.join(target_dir, os.path.basename(scm_file)) target_path = os.path.join(target_dir, os.path.basename(scm_file))
self.log_debug("Exporting file %s from CVS %s (branch %s)..." % (scm_file, scm_root, scm_branch)) self.log_debug("Exporting file %s from CVS %s (branch %s)..." % (scm_file, scm_root, scm_branch))
self.retry_run(["/usr/bin/cvs", "-q", "-d", scm_root, "export", "-r", scm_branch, scm_file], self.retry_run(["/usr/bin/cvs", "-q", "-d", scm_root, "export", "-r", scm_branch, scm_file],
@ -121,7 +105,7 @@ class GitWrapper(ScmBase):
scm_dir = scm_dir.lstrip("/") scm_dir = scm_dir.lstrip("/")
scm_branch = scm_branch or "master" scm_branch = scm_branch or "master"
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: with temp_dir(dir=tmp_dir) as tmp_dir:
if "://" not in scm_root: if "://" not in scm_root:
scm_root = "file://%s" % scm_root scm_root = "file://%s" % scm_root
@ -142,7 +126,7 @@ class GitWrapper(ScmBase):
scm_file = scm_file.lstrip("/") scm_file = scm_file.lstrip("/")
scm_branch = scm_branch or "master" scm_branch = scm_branch or "master"
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: with temp_dir(dir=tmp_dir) as tmp_dir:
target_path = os.path.join(target_dir, os.path.basename(scm_file)) target_path = os.path.join(target_dir, os.path.basename(scm_file))
if "://" not in scm_root: if "://" not in scm_root:
@ -172,7 +156,7 @@ class RpmScmWrapper(ScmBase):
def export_dir(self, scm_root, scm_dir, target_dir, scm_branch=None, tmp_dir=None, log_file=None): def export_dir(self, scm_root, scm_dir, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
for rpm in self._list_rpms(scm_root): for rpm in self._list_rpms(scm_root):
scm_dir = scm_dir.lstrip("/") scm_dir = scm_dir.lstrip("/")
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: with temp_dir(dir=tmp_dir) as tmp_dir:
self.log_debug("Extracting directory %s from RPM package %s..." % (scm_dir, rpm)) self.log_debug("Extracting directory %s from RPM package %s..." % (scm_dir, rpm))
explode_rpm_package(rpm, tmp_dir) explode_rpm_package(rpm, tmp_dir)
@ -187,7 +171,7 @@ class RpmScmWrapper(ScmBase):
def export_file(self, scm_root, scm_file, target_dir, scm_branch=None, tmp_dir=None, log_file=None): def export_file(self, scm_root, scm_file, target_dir, scm_branch=None, tmp_dir=None, log_file=None):
for rpm in self._list_rpms(scm_root): for rpm in self._list_rpms(scm_root):
scm_file = scm_file.lstrip("/") scm_file = scm_file.lstrip("/")
with self._temp_dir(tmp_dir=tmp_dir) as tmp_dir: with temp_dir(dir=tmp_dir) as tmp_dir:
self.log_debug("Exporting file %s from RPM file %s..." % (scm_file, rpm)) self.log_debug("Exporting file %s from RPM file %s..." % (scm_file, rpm))
explode_rpm_package(rpm, tmp_dir) explode_rpm_package(rpm, tmp_dir)
@ -256,10 +240,9 @@ def get_file_from_scm(scm_dict, target_path, logger=None):
files_copied = [] files_copied = []
for i in force_list(scm_file): for i in force_list(scm_file):
tmp_dir = tempfile.mkdtemp(prefix="scm_checkout_") with temp_dir(prefix="scm_checkout_") as tmp_dir:
scm.export_file(scm_repo, i, scm_branch=scm_branch, target_dir=tmp_dir) scm.export_file(scm_repo, i, scm_branch=scm_branch, target_dir=tmp_dir)
files_copied += copy_all(tmp_dir, target_path) files_copied += copy_all(tmp_dir, target_path)
shutil.rmtree(tmp_dir)
return files_copied return files_copied
@ -306,8 +289,7 @@ def get_dir_from_scm(scm_dict, target_path, logger=None):
scm = _get_wrapper(scm_type, logger=logger) scm = _get_wrapper(scm_type, logger=logger)
tmp_dir = tempfile.mkdtemp(prefix="scm_checkout_") with temp_dir(prefix="scm_checkout_") as tmp_dir:
scm.export_dir(scm_repo, scm_dir, scm_branch=scm_branch, target_dir=tmp_dir) scm.export_dir(scm_repo, scm_dir, scm_branch=scm_branch, target_dir=tmp_dir)
files_copied = copy_all(tmp_dir, target_path) files_copied = copy_all(tmp_dir, target_path)
shutil.rmtree(tmp_dir)
return files_copied return files_copied

View File

@ -406,5 +406,27 @@ class TestRecursiveFileList(unittest.TestCase):
self.assertEqual(expected_files, actual_files) self.assertEqual(expected_files, actual_files)
class TestTempFiles(unittest.TestCase):
def test_temp_dir_ok(self):
with util.temp_dir() as tmp:
self.assertTrue(os.path.isdir(tmp))
self.assertFalse(os.path.exists(tmp))
def test_temp_dir_fail(self):
with self.assertRaises(RuntimeError):
with util.temp_dir() as tmp:
self.assertTrue(os.path.isdir(tmp))
raise RuntimeError('BOOM')
self.assertFalse(os.path.exists(tmp))
def test_temp_dir_in_non_existing_dir(self):
with util.temp_dir() as playground:
root = os.path.join(playground, 'missing')
with util.temp_dir(dir=root) as tmp:
self.assertTrue(os.path.isdir(tmp))
self.assertTrue(os.path.isdir(root))
self.assertFalse(os.path.exists(tmp))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()