From 1eaad216e290d5935f59e9137a233ac8516a8afb Mon Sep 17 00:00:00 2001 From: Sergio Correia Date: Tue, 9 Dec 2025 11:11:43 +0000 Subject: [PATCH 13/14] Add shared memory infrastructure for multiprocess communication Backport of upstream https://github.com/keylime/keylime/pull/1817/commits/1024e19d Signed-off-by: Sergio Correia --- keylime-selinux-42.1.2/keylime.te | 2 + keylime/cloud_verifier_tornado.py | 89 ++--- keylime/cmd/verifier.py | 6 + keylime/config.py | 87 +++++ keylime/shared_data.py | 513 +++++++++++++++++++++++++ keylime/tpm/tpm_main.py | 17 +- keylime/web/base/default_controller.py | 6 + test/test_shared_data.py | 199 ++++++++++ 8 files changed, 868 insertions(+), 51 deletions(-) create mode 100644 keylime/shared_data.py create mode 100644 test/test_shared_data.py diff --git a/keylime-selinux-42.1.2/keylime.te b/keylime-selinux-42.1.2/keylime.te index 2c6a59e..8b8a615 100644 --- a/keylime-selinux-42.1.2/keylime.te +++ b/keylime-selinux-42.1.2/keylime.te @@ -77,6 +77,8 @@ optional_policy(` allow keylime_server_t self:key { create read setattr view write }; allow keylime_server_t self:netlink_route_socket { create_stream_socket_perms nlmsg_read }; allow keylime_server_t self:udp_socket create_stream_socket_perms; +allow keylime_server_t keylime_tmp_t:sock_file { create write }; +allow keylime_server_t self:unix_stream_socket connectto; fs_dontaudit_search_cgroup_dirs(keylime_server_t) diff --git a/keylime/cloud_verifier_tornado.py b/keylime/cloud_verifier_tornado.py index 89aa703..67ba8af 100644 --- a/keylime/cloud_verifier_tornado.py +++ b/keylime/cloud_verifier_tornado.py @@ -6,8 +6,8 @@ import signal import sys import traceback from concurrent.futures import ThreadPoolExecutor -from multiprocessing import Process from contextlib import contextmanager +from multiprocessing import Process from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast import tornado.httpserver @@ -27,6 +27,7 @@ from keylime import ( json, keylime_logging, revocation_notifier, + shared_data, signing, tornado_requests, web_util, @@ -43,7 +44,6 @@ from keylime.mba import mba logger = keylime_logging.init_logging("verifier") -GLOBAL_POLICY_CACHE: Dict[str, Dict[str, str]] = {} set_severity_config(config.getlist("verifier", "severity_labels"), config.getlist("verifier", "severity_policy")) @@ -140,44 +140,41 @@ def _from_db_obj(agent_db_obj: VerfierMain) -> Dict[str, Any]: return agent_dict -def verifier_read_policy_from_cache(ima_policy_data: Dict[str, str]) -> str: - checksum = ima_policy_data.get("checksum", "") - name = ima_policy_data.get("name", "empty") - agent_id = ima_policy_data.get("agent_id", "") +def verifier_read_policy_from_cache(stored_agent: VerfierMain) -> str: + checksum = "" + name = "empty" + agent_id = str(stored_agent.agent_id) - if not agent_id: - return "" + # Initialize agent policy cache if it doesn't exist + shared_data.initialize_agent_policy_cache(agent_id) - if agent_id not in GLOBAL_POLICY_CACHE: - GLOBAL_POLICY_CACHE[agent_id] = {} - GLOBAL_POLICY_CACHE[agent_id][""] = "" + if stored_agent.ima_policy: + checksum = str(stored_agent.ima_policy.checksum) + name = stored_agent.ima_policy.name - if checksum not in GLOBAL_POLICY_CACHE[agent_id]: - if len(GLOBAL_POLICY_CACHE[agent_id]) > 1: - # Perform a cleanup of the contents, IMA policy checksum changed - logger.debug( - "Cleaning up policy cache for policy named %s, with checksum %s, used by agent %s", - name, - checksum, - agent_id, - ) + # Check if policy is already cached + cached_policy = shared_data.get_cached_policy(agent_id, checksum) + if cached_policy is not None: + return cached_policy - GLOBAL_POLICY_CACHE[agent_id] = {} - GLOBAL_POLICY_CACHE[agent_id][""] = "" + # Policy not cached, need to clean up and load from database + shared_data.cleanup_agent_policy_cache(agent_id, checksum) - logger.debug( - "IMA policy named %s, with checksum %s, used by agent %s is not present on policy cache on this verifier, performing SQLAlchemy load", - name, - checksum, - agent_id, - ) + logger.debug( + "IMA policy named %s, with checksum %s, used by agent %s is not present on policy cache on this verifier, performing SQLAlchemy load", + name, + checksum, + agent_id, + ) - # Get the large ima_policy content - it's already loaded in ima_policy_data - ima_policy = ima_policy_data.get("ima_policy", "") - assert isinstance(ima_policy, str) - GLOBAL_POLICY_CACHE[agent_id][checksum] = ima_policy + # Actually contacts the database and load the (large) ima_policy column for "allowlists" table + ima_policy = stored_agent.ima_policy.ima_policy + assert isinstance(ima_policy, str) - return GLOBAL_POLICY_CACHE[agent_id][checksum] + # Cache the policy for future use + shared_data.cache_policy(agent_id, checksum, ima_policy) + + return ima_policy def verifier_db_delete_agent(session: Session, agent_id: str) -> None: @@ -475,12 +472,11 @@ class AgentsHandler(BaseHandler): return # Cleanup the cache when the agent is deleted. Do it early. - if agent_id in GLOBAL_POLICY_CACHE: - del GLOBAL_POLICY_CACHE[agent_id] - logger.debug( - "Cleaned up policy cache from all entries used by agent %s", - agent_id, - ) + shared_data.clear_agent_policy_cache(agent_id) + logger.debug( + "Cleaned up policy cache from all entries used by agent %s", + agent_id, + ) op_state = agent.operational_state if op_state in (states.SAVED, states.FAILED, states.TERMINATED, states.TENANT_FAILED, states.INVALID_QUOTE): @@ -1763,7 +1759,6 @@ async def process_agent( stored_agent = None # First database operation - read agent data and extract all needed data within session context - ima_policy_data = {} mb_policy_data = None with session_context() as session: try: @@ -1779,15 +1774,6 @@ async def process_agent( .first() ) - # Extract IMA policy data within session context to avoid DetachedInstanceError - if stored_agent and stored_agent.ima_policy: - ima_policy_data = { - "checksum": str(stored_agent.ima_policy.checksum), - "name": stored_agent.ima_policy.name, - "agent_id": str(stored_agent.agent_id), - "ima_policy": stored_agent.ima_policy.ima_policy, # Extract the large content too - } - # Extract MB policy data within session context if stored_agent and stored_agent.mb_policy: mb_policy_data = stored_agent.mb_policy.mb_policy @@ -1869,7 +1855,10 @@ async def process_agent( logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e) # Load agent's IMA policy - runtime_policy = verifier_read_policy_from_cache(ima_policy_data) + if stored_agent: + runtime_policy = verifier_read_policy_from_cache(stored_agent) + else: + runtime_policy = "" # Get agent's measured boot policy mb_policy = mb_policy_data diff --git a/keylime/cmd/verifier.py b/keylime/cmd/verifier.py index f3e1a86..1f9f4e5 100644 --- a/keylime/cmd/verifier.py +++ b/keylime/cmd/verifier.py @@ -1,6 +1,7 @@ from keylime import cloud_verifier_tornado, config, keylime_logging from keylime.common.migrations import apply from keylime.mba import mba +from keylime.shared_data import initialize_shared_memory logger = keylime_logging.init_logging("verifier") @@ -10,6 +11,11 @@ def main() -> None: if config.has_option("verifier", "auto_migrate_db") and config.getboolean("verifier", "auto_migrate_db"): apply("cloud_verifier") + # Initialize shared memory BEFORE creating server instance + # This MUST happen before verifier instantiation and worker forking + logger.info("Initializing shared memory manager in main process before server creation") + initialize_shared_memory() + # Explicitly load and initialize measured boot components mba.load_imports() cloud_verifier_tornado.main() diff --git a/keylime/config.py b/keylime/config.py index e7ac634..b5cd546 100644 --- a/keylime/config.py +++ b/keylime/config.py @@ -114,6 +114,85 @@ if "KEYLIME_LOGGING_CONFIG" in os.environ: _config: Optional[Dict[str, RawConfigParser]] = None +def _check_file_permissions(component: str, file_path: str) -> bool: + """Check if a config file has correct permissions and is readable. + + Args: + component: The component name (e.g., 'verifier', 'agent') + file_path: Path to the config file + + Returns: + True if file is readable, False otherwise + """ + if not os.path.exists(file_path): + return False + + if not os.access(file_path, os.R_OK): + import grp # pylint: disable=import-outside-toplevel + import pwd # pylint: disable=import-outside-toplevel + import stat # pylint: disable=import-outside-toplevel + + try: + file_stat = os.stat(file_path) + owner = pwd.getpwuid(file_stat.st_uid).pw_name + group = grp.getgrgid(file_stat.st_gid).gr_name + mode = stat.filemode(file_stat.st_mode) + except Exception: + owner = group = mode = "unknown" + + base_logger.error( # pylint: disable=logging-not-lazy + "=" * 80 + + "\n" + + "CRITICAL CONFIG ERROR: Config file %s exists but is not readable!\n" + + "File permissions: %s (owner: %s, group: %s)\n" + + "The keylime_%s service needs read access to this file.\n" + + "Fix with: chown keylime:keylime %s && chmod 440 %s\n" + + "=" * 80, + file_path, + mode, + owner, + group, + component, + file_path, + file_path, + ) + return False + + return True + + +def _validate_config_files(component: str, file_paths: List[str], files_read: List[str]) -> None: + """Validate that config files were successfully parsed. + + Args: + component: The component name (e.g., 'verifier', 'agent') + file_paths: List of file paths that were attempted to be read + files_read: List of files that ConfigParser successfully read + """ + for file_path in file_paths: + # Check file permissions first + if not _check_file_permissions(component, file_path): + continue + + if file_path not in files_read: + base_logger.error( # pylint: disable=logging-not-lazy + "=" * 80 + + "\n" + + "CRITICAL CONFIG ERROR: Config file %s exists but failed to parse!\n" + + "This usually indicates duplicate keys within the same file.\n" + + "Common issues:\n" + + " - Same option appears multiple times in the same [%s] section\n" + + " - Empty values (key = ) conflicting with defined values\n" + + " - Invalid INI file syntax\n" + + "Please check the file for duplicate entries.\n" + + "You can validate the file with: python3 -c \"import configparser; c = configparser.RawConfigParser(); print(c.read('%s'))\"\n" + + "=" * 80, + file_path, + component, + file_path, + ) + + def get_config(component: str) -> RawConfigParser: """Find the configuration file to use for the given component and apply the overrides defined by configuration snippets. @@ -216,6 +295,10 @@ def get_config(component: str) -> RawConfigParser: # Validate that at least one config file is present config_file = _config[component].read(c) + + # Validate the config file was parsed successfully + _validate_config_files(component, [c], config_file) + if config_file: base_logger.info("Reading configuration from %s", config_file) @@ -230,6 +313,10 @@ def get_config(component: str) -> RawConfigParser: [os.path.join(d, f) for f in os.listdir(d) if f and os.path.isfile(os.path.join(d, f))] ) applied_snippets = _config[component].read(snippets) + + # Validate all snippet files were parsed successfully + _validate_config_files(component, snippets, applied_snippets) + if applied_snippets: base_logger.info("Applied configuration snippets from %s", d) diff --git a/keylime/shared_data.py b/keylime/shared_data.py new file mode 100644 index 0000000..23a3d81 --- /dev/null +++ b/keylime/shared_data.py @@ -0,0 +1,513 @@ +"""Shared memory management for keylime multiprocess applications. + +This module provides thread-safe shared data management between processes +using multiprocessing.Manager(). +""" + +import atexit +import multiprocessing as mp +import threading +import time +from typing import Any, Dict, List, Optional + +from keylime import keylime_logging + +logger = keylime_logging.init_logging("shared_data") + + +class FlatDictView: + """A dictionary-like view over a flat key-value store. + + This class provides dict-like access to a subset of keys in a flat store, + identified by a namespace prefix. This avoids the nested DictProxy issues. + + Example: + store = manager.dict() # Flat store + view = FlatDictView(store, lock, "sessions") + view["123"] = "data" # Stores as "dict:sessions:123" in flat store + val = view["123"] # Retrieves from "dict:sessions:123" + """ + + def __init__(self, store: Any, lock: Any, namespace: str) -> None: + self._store = store + self._lock = lock + self._namespace = namespace + + def _make_key(self, key: Any) -> str: + """Convert user key to internal flat key with namespace prefix.""" + return f"dict:{self._namespace}:{key}" + + def __getitem__(self, key: Any) -> Any: + with self._lock: + return self._store[self._make_key(key)] + + def __setitem__(self, key: Any, value: Any) -> None: + flat_key = self._make_key(key) + with self._lock: + self._store[flat_key] = value + + def __delitem__(self, key: Any) -> None: + flat_key = self._make_key(key) + with self._lock: + del self._store[flat_key] + + def __contains__(self, key: Any) -> bool: + return self._make_key(key) in self._store + + def get(self, key: Any, default: Any = None) -> Any: + with self._lock: + return self._store.get(self._make_key(key), default) + + def keys(self) -> List[Any]: + """Return keys in this namespace.""" + prefix = f"dict:{self._namespace}:" + all_store_keys = list(self._store.keys()) + matching_keys = [k[len(prefix) :] for k in all_store_keys if k.startswith(prefix)] + return matching_keys + + def values(self) -> List[Any]: + """Return values in this namespace.""" + prefix = f"dict:{self._namespace}:" + with self._lock: + return [v for k, v in self._store.items() if k.startswith(prefix)] + + def items(self) -> List[tuple[Any, Any]]: + """Return (key, value) pairs in this namespace.""" + prefix = f"dict:{self._namespace}:" + with self._lock: + result = [(k[len(prefix) :], v) for k, v in self._store.items() if k.startswith(prefix)] + return result + + def __len__(self) -> int: + """Return number of items in this namespace.""" + return len(self.keys()) + + def __repr__(self) -> str: + return f"FlatDictView({self._namespace}, {len(self)} items)" + + +class SharedDataManager: + """Thread-safe shared data manager for multiprocess applications. + + This class uses multiprocessing.Manager() to create proxy objects that can + be safely accessed from multiple processes. All data stored must be pickleable. + + Example: + manager = SharedDataManager() + + # Store simple data + manager.set_data("config_value", "some_config") + value = manager.get_data("config_value") + + # Work with shared dictionaries + agent_cache = manager.get_or_create_dict("agent_cache") + agent_cache["agent_123"] = {"last_seen": time.time()} + + # Work with shared lists + event_log = manager.get_or_create_list("events") + event_log.append({"type": "attestation", "agent": "agent_123"}) + """ + + def __init__(self) -> None: + """Initialize the shared data manager. + + This must be called before any process forking occurs to ensure + all child processes inherit access to the shared data. + """ + logger.debug("Initializing SharedDataManager") + + # Use explicit context to ensure fork compatibility + # The Manager must be started BEFORE any fork() calls + ctx = mp.get_context("fork") + self._manager = ctx.Manager() + + # CRITICAL FIX: Use a SINGLE flat dict instead of nested dicts + # Nested DictProxy objects have synchronization issues + # We'll use key prefixes like "dict:auth_sessions:session_id" instead + self._store = self._manager.dict() # Single flat store for all data + self._lock = self._manager.Lock() + self._initialized_at = time.time() + + # Register handler to reinitialize manager connection after fork + # This is needed because Manager uses network connections that don't survive fork + try: + import os # pylint: disable=import-outside-toplevel + + self._parent_pid = os.getpid() + logger.debug("SharedDataManager initialized in process %d", self._parent_pid) + except Exception as e: + logger.warning("Could not register PID tracking: %s", e) + + # Ensure cleanup on exit + atexit.register(self.cleanup) + + logger.info("SharedDataManager initialized successfully") + + def set_data(self, key: str, value: Any) -> None: + """Store arbitrary pickleable data by key. + + Args: + key: Unique identifier for the data + value: Any pickleable Python object + + Raises: + TypeError: If value is not pickleable + """ + with self._lock: + try: + self._store[key] = value + logger.debug("Stored data for key: %s", key) + except Exception as e: + logger.error("Failed to store data for key '%s': %s", key, e) + raise + + def get_data(self, key: str, default: Any = None) -> Any: + """Retrieve data by key. + + Args: + key: The key to retrieve + default: Value to return if key doesn't exist + + Returns: + The stored value or default if key doesn't exist + """ + with self._lock: + value = self._store.get(key, default) + logger.debug("Retrieved data for key: %s (found: %s)", key, value is not default) + return value + + def get_or_create_dict(self, key: str) -> Dict[str, Any]: + """Get or create a shared dictionary. + + Args: + key: Unique identifier for the dictionary + + Returns: + A shared dictionary-like object that syncs across processes + + Note: + Returns a FlatDictView that uses key prefixes in the flat store + instead of actual nested dicts, to avoid DictProxy nesting issues. + """ + # Mark that this namespace exists + namespace_key = f"__namespace__{key}" + if namespace_key not in self._store: + with self._lock: + self._store[namespace_key] = True + + # Return a view that operates on the flat store with key prefix + return FlatDictView(self._store, self._lock, key) # type: ignore[return-value,no-untyped-call] + + def get_or_create_list(self, key: str) -> List[Any]: + """Get or create a shared list. + + Args: + key: Unique identifier for the list + + Returns: + A shared list (proxy object) that syncs across processes + """ + with self._lock: + if key not in self._store: + self._store[key] = self._manager.list() + logger.debug("Created new shared list for key: %s", key) + else: + logger.debug("Retrieved existing shared list for key: %s", key) + return self._store[key] # type: ignore[no-any-return] + + def delete_data(self, key: str) -> bool: + """Delete data by key. + + Args: + key: The key to delete + + Returns: + True if the key existed and was deleted, False otherwise + """ + with self._lock: + if key in self._store: + del self._store[key] + logger.debug("Deleted data for key: %s", key) + return True + logger.debug("Key not found for deletion: %s", key) + return False + + def has_key(self, key: str) -> bool: + """Check if a key exists. + + Args: + key: The key to check + + Returns: + True if key exists, False otherwise + """ + with self._lock: + return key in self._store + + def get_keys(self) -> List[str]: + """Get all stored keys. + + Returns: + List of all keys in the store + """ + with self._lock: + return list(self._store.keys()) + + def clear_all(self) -> None: + """Clear all stored data. Use with caution!""" + with self._lock: + key_count = len(self._store) + self._store.clear() + logger.warning("Cleared all shared data (%d keys)", key_count) + + def get_stats(self) -> Dict[str, Any]: + """Get statistics about stored data. + + Returns: + Dictionary containing storage statistics + """ + with self._lock: + return { + "total_keys": len(self._store), + "initialized_at": self._initialized_at, + "uptime_seconds": time.time() - self._initialized_at, + } + + def cleanup(self) -> None: + """Cleanup shared resources. + + This is automatically called on exit but can be called manually + for explicit cleanup. + """ + if hasattr(self, "_manager"): + logger.debug("Shutting down SharedDataManager") + try: + self._manager.shutdown() + logger.info("SharedDataManager shutdown complete") + except Exception as e: + logger.error("Error during SharedDataManager shutdown: %s", e) + + def __repr__(self) -> str: + stats = self.get_stats() + return f"SharedDataManager(keys={stats['total_keys']}, " f"uptime={stats['uptime_seconds']:.1f}s)" + + @property + def manager(self) -> Any: # type: ignore[misc] + """Access to the underlying multiprocessing Manager for advanced usage.""" + return self._manager + + +# Global shared memory manager instance +_global_shared_manager: Optional[SharedDataManager] = None +_manager_lock = threading.Lock() + + +def initialize_shared_memory() -> SharedDataManager: + """Initialize the global shared memory manager. + + This function MUST be called before any process forking occurs to ensure + all child processes share the same manager instance. + + For tornado/multiprocess servers, call this before starting workers. + + Returns: + SharedDataManager: The global shared memory manager instance + + Raises: + RuntimeError: If called after manager is already initialized + """ + global _global_shared_manager + + with _manager_lock: + if _global_shared_manager is not None: + logger.warning("Shared memory manager already initialized, returning existing instance") + return _global_shared_manager + + logger.info("Initializing global shared memory manager") + _global_shared_manager = SharedDataManager() + logger.info("Global shared memory manager initialized") + + return _global_shared_manager + + +def get_shared_memory() -> SharedDataManager: + """Get the global shared memory manager instance. + + This function returns a singleton SharedDataManager that can be used + throughout keylime for caching and inter-process communication. + + The manager is automatically initialized on first access and cleaned up + on process exit. + + IMPORTANT: In multiprocess applications (like tornado with workers), + you MUST call initialize_shared_memory() BEFORE forking workers. + Otherwise each worker will get its own separate manager. + + Returns: + SharedDataManager: The global shared memory manager instance + """ + global _global_shared_manager + + if _global_shared_manager is None: + with _manager_lock: + if _global_shared_manager is None: + logger.info("Initializing global shared memory manager") + _global_shared_manager = SharedDataManager() # type: ignore[no-untyped-call] + logger.info("Global shared memory manager initialized") + + return _global_shared_manager + + +def cleanup_global_shared_memory() -> None: + """Cleanup the global shared memory manager. + + This is automatically called on exit but can be called manually. + """ + global _global_shared_manager + + if _global_shared_manager is not None: + logger.info("Cleaning up global shared memory manager") + _global_shared_manager.cleanup() + _global_shared_manager = None + + +# Convenience functions for common keylime patterns + + +def cache_policy(agent_id: str, checksum: str, policy: str) -> None: + """Cache a policy in shared memory. + + Args: + agent_id: The agent identifier + checksum: The policy checksum + policy: The policy content to cache + """ + manager = get_shared_memory() + policy_cache = manager.get_or_create_dict("policy_cache") + + if agent_id not in policy_cache: + policy_cache[agent_id] = manager.manager.dict() # type: ignore[attr-defined] + + policy_cache[agent_id][checksum] = policy + logger.debug("Cached policy for agent %s with checksum %s", agent_id, checksum) + + +def get_cached_policy(agent_id: str, checksum: str) -> Optional[str]: + """Retrieve cached policy. + + Args: + agent_id: The agent identifier + checksum: The policy checksum + + Returns: + The cached policy content or None if not found + """ + manager = get_shared_memory() + policy_cache = manager.get_or_create_dict("policy_cache") + agent_policies = policy_cache.get(agent_id, {}) + + result = agent_policies.get(checksum) + if result: + logger.debug("Found cached policy for agent %s with checksum %s", agent_id, checksum) + else: + logger.debug("No cached policy found for agent %s with checksum %s", agent_id, checksum) + + return result # type: ignore[no-any-return] + + +def clear_agent_policy_cache(agent_id: str) -> None: + """Clear all cached policies for an agent. + + Args: + agent_id: The agent identifier + """ + manager = get_shared_memory() + policy_cache = manager.get_or_create_dict("policy_cache") + + if agent_id in policy_cache: + del policy_cache[agent_id] + logger.debug("Cleared policy cache for agent %s", agent_id) + + +def cleanup_agent_policy_cache(agent_id: str, keep_checksum: str = "") -> None: + """Clean up agent policy cache, keeping only the specified checksum. + + This mimics the cleanup behavior from GLOBAL_POLICY_CACHE where when + a new policy checksum is encountered, old cached policies are removed. + + Args: + agent_id: The agent identifier + keep_checksum: The checksum to keep in the cache (empty string by default) + """ + manager = get_shared_memory() + policy_cache = manager.get_or_create_dict("policy_cache") + + if agent_id in policy_cache and len(policy_cache[agent_id]) > 1: + # Keep only the empty entry and the specified checksum + old_policies = dict(policy_cache[agent_id]) + policy_cache[agent_id] = manager.manager.dict() + + # Always keep the empty entry + policy_cache[agent_id][""] = old_policies.get("", "") + + # Keep the specified checksum if it exists and is not empty + if keep_checksum and keep_checksum in old_policies: + policy_cache[agent_id][keep_checksum] = old_policies[keep_checksum] + + logger.debug("Cleaned up policy cache for agent %s, keeping checksum %s", agent_id, keep_checksum) + + +def initialize_agent_policy_cache(agent_id: str) -> Dict[str, Any]: + """Initialize policy cache for an agent if it doesn't exist. + + Args: + agent_id: The agent identifier + + Returns: + The agent's policy cache dictionary + """ + manager = get_shared_memory() + policy_cache = manager.get_or_create_dict("policy_cache") + + if agent_id not in policy_cache: + policy_cache[agent_id] = manager.manager.dict() # type: ignore[attr-defined] + policy_cache[agent_id][""] = "" + logger.debug("Initialized policy cache for agent %s", agent_id) + + return policy_cache[agent_id] # type: ignore[no-any-return] + + +def get_agent_cache(agent_id: str) -> Dict[str, Any]: + """Get shared cache for a specific agent. + + Args: + agent_id: The agent identifier + + Returns: + A shared dictionary for caching agent-specific data + """ + manager = get_shared_memory() + return manager.get_or_create_dict(f"agent_cache:{agent_id}") + + +def get_verification_queue(agent_id: str) -> List[Any]: + """Get verification queue for batching database operations. + + Args: + agent_id: The agent identifier + + Returns: + A shared list for queuing verification operations + """ + manager = get_shared_memory() + return manager.get_or_create_list(f"verification_queue:{agent_id}") + + +def get_shared_stats() -> Dict[str, Any]: + """Get statistics about shared memory usage. + + Returns: + Dictionary containing storage statistics + """ + manager = get_shared_memory() + return manager.get_stats() diff --git a/keylime/tpm/tpm_main.py b/keylime/tpm/tpm_main.py index 6f2e89f..9b54fc3 100644 --- a/keylime/tpm/tpm_main.py +++ b/keylime/tpm/tpm_main.py @@ -10,7 +10,7 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from keylime import cert_utils, config, json, keylime_logging from keylime.agentstates import AgentAttestState, TPMClockInfo -from keylime.common.algorithms import Hash +from keylime.common.algorithms import Hash, Sign from keylime.failure import Component, Failure from keylime.ima import ima from keylime.ima.file_signatures import ImaKeyrings @@ -50,6 +50,21 @@ class Tpm: return (keyblob, key) + # Mapping from keylime.common.algorithms enums to TPM algorithm constants + # Used for validating that TPM attestations use expected cryptographic algorithms + HASH_ALG_TO_TPM = { + Hash.SHA1: tpm2_objects.TPM_ALG_SHA1, + Hash.SHA256: tpm2_objects.TPM_ALG_SHA256, + Hash.SHA384: tpm2_objects.TPM_ALG_SHA384, + Hash.SHA512: tpm2_objects.TPM_ALG_SHA512, + } + + SIGN_ALG_TO_TPM = { + Sign.RSASSA: tpm2_objects.TPM_ALG_RSASSA, + Sign.RSAPSS: tpm2_objects.TPM_ALG_RSAPSS, + Sign.ECDSA: tpm2_objects.TPM_ALG_ECDSA, + } + @staticmethod def verify_aik_with_iak(uuid: str, aik_tpm: bytes, iak_tpm: bytes, iak_attest: bytes, iak_sign: bytes) -> bool: attest_body = iak_attest.split(b"\x00$")[1] diff --git a/keylime/web/base/default_controller.py b/keylime/web/base/default_controller.py index 971ed06..ba0782e 100644 --- a/keylime/web/base/default_controller.py +++ b/keylime/web/base/default_controller.py @@ -19,6 +19,12 @@ class DefaultController(Controller): self.send_response(400, "Bad Request") def malformed_params(self, **_params: Any) -> None: + import traceback # pylint: disable=import-outside-toplevel + + from keylime import keylime_logging # pylint: disable=import-outside-toplevel + + logger = keylime_logging.init_logging("web") + logger.error("Malformed params error. Traceback: %s", traceback.format_exc()) self.send_response(400, "Malformed Request Parameter") def action_dispatch_error(self, **_param: Any) -> None: diff --git a/test/test_shared_data.py b/test/test_shared_data.py new file mode 100644 index 0000000..8de7e64 --- /dev/null +++ b/test/test_shared_data.py @@ -0,0 +1,199 @@ +"""Unit tests for shared memory infrastructure.""" + +import unittest + +from keylime.shared_data import ( + SharedDataManager, + cache_policy, + cleanup_agent_policy_cache, + cleanup_global_shared_memory, + clear_agent_policy_cache, + get_cached_policy, + get_shared_memory, + initialize_agent_policy_cache, +) + + +class TestSharedDataManager(unittest.TestCase): + """Test cases for SharedDataManager class.""" + + def setUp(self): + """Set up test fixtures.""" + self.manager = SharedDataManager() + + def tearDown(self): + """Clean up after tests.""" + if self.manager: + self.manager.cleanup() + + def test_set_and_get_data(self): + """Test basic set and get operations.""" + self.manager.set_data("test_key", "test_value") + result = self.manager.get_data("test_key") + self.assertEqual(result, "test_value") + + def test_get_nonexistent_data(self): + """Test getting data that doesn't exist returns None.""" + result = self.manager.get_data("nonexistent_key") + self.assertIsNone(result) + + def test_get_data_with_default(self): + """Test getting data with default value.""" + result = self.manager.get_data("nonexistent_key", default="default_value") + self.assertEqual(result, "default_value") + + def test_delete_data(self): + """Test deleting data.""" + self.manager.set_data("test_key", "test_value") + result = self.manager.delete_data("test_key") + self.assertTrue(result) + + # Verify it's actually deleted + self.assertIsNone(self.manager.get_data("test_key")) + + def test_delete_nonexistent_data(self): + """Test deleting data that doesn't exist returns False.""" + result = self.manager.delete_data("nonexistent_key") + self.assertFalse(result) + + def test_has_key(self): + """Test checking if key exists.""" + self.manager.set_data("test_key", "test_value") + self.assertTrue(self.manager.has_key("test_key")) + self.assertFalse(self.manager.has_key("nonexistent_key")) + + def test_get_or_create_dict(self): + """Test getting or creating a shared dictionary.""" + shared_dict = self.manager.get_or_create_dict("test_dict") + shared_dict["key1"] = "value1" + shared_dict["key2"] = "value2" + + # Retrieve the same dict + retrieved_dict = self.manager.get_or_create_dict("test_dict") + self.assertEqual(retrieved_dict["key1"], "value1") + self.assertEqual(retrieved_dict["key2"], "value2") + + def test_get_or_create_list(self): + """Test getting or creating a shared list.""" + shared_list = self.manager.get_or_create_list("test_list") + shared_list.append("item1") + shared_list.append("item2") + + # Retrieve the same list + retrieved_list = self.manager.get_or_create_list("test_list") + self.assertEqual(len(retrieved_list), 2) + self.assertEqual(retrieved_list[0], "item1") + self.assertEqual(retrieved_list[1], "item2") + + def test_get_stats(self): + """Test getting manager statistics.""" + self.manager.set_data("key1", "value1") + self.manager.set_data("key2", "value2") + + stats = self.manager.get_stats() + self.assertIn("total_keys", stats) + self.assertIn("uptime_seconds", stats) + self.assertEqual(stats["total_keys"], 2) + self.assertGreaterEqual(stats["uptime_seconds"], 0) + + +class TestPolicyCacheFunctions(unittest.TestCase): + """Test cases for policy cache functions.""" + + def setUp(self): + """Set up test fixtures.""" + # Get the global shared memory manager + self.manager = get_shared_memory() + + def tearDown(self): + """Clean up after tests.""" + # Clean up global shared memory + cleanup_global_shared_memory() + + def test_initialize_agent_policy_cache(self): + """Test initializing agent policy cache.""" + agent_id = "test_agent_123" + initialize_agent_policy_cache(agent_id) + + # Verify the cache was initialized + policy_cache = self.manager.get_or_create_dict("policy_cache") + self.assertIn(agent_id, policy_cache) + + def test_cache_and_get_policy(self): + """Test caching and retrieving a policy.""" + agent_id = "test_agent_123" + checksum = "abc123def456" + policy_content = '{"policy": "test_policy_content"}' + + # Initialize and cache policy + initialize_agent_policy_cache(agent_id) + cache_policy(agent_id, checksum, policy_content) + + # Retrieve cached policy + cached = get_cached_policy(agent_id, checksum) + self.assertEqual(cached, policy_content) + + def test_get_nonexistent_cached_policy(self): + """Test getting a policy that hasn't been cached.""" + agent_id = "test_agent_123" + checksum = "nonexistent_checksum" + + initialize_agent_policy_cache(agent_id) + cached = get_cached_policy(agent_id, checksum) + self.assertIsNone(cached) + + def test_clear_agent_policy_cache(self): + """Test clearing an agent's policy cache.""" + agent_id = "test_agent_123" + checksum = "abc123def456" + policy_content = '{"policy": "test_policy_content"}' + + # Initialize, cache, and then clear + initialize_agent_policy_cache(agent_id) + cache_policy(agent_id, checksum, policy_content) + clear_agent_policy_cache(agent_id) + + # Verify it's cleared + cached = get_cached_policy(agent_id, checksum) + self.assertIsNone(cached) + + def test_cleanup_agent_policy_cache(self): + """Test cleaning up old policy checksums.""" + agent_id = "test_agent_123" + old_checksum = "old_checksum" + new_checksum = "new_checksum" + policy_content = '{"policy": "test"}' + + # Initialize and cache multiple policies + initialize_agent_policy_cache(agent_id) + cache_policy(agent_id, old_checksum, policy_content) + cache_policy(agent_id, new_checksum, policy_content) + + # Cleanup old checksums (keeping only new_checksum) + cleanup_agent_policy_cache(agent_id, new_checksum) + + # Verify old checksum is removed but new one remains + self.assertIsNone(get_cached_policy(agent_id, old_checksum)) + self.assertEqual(get_cached_policy(agent_id, new_checksum), policy_content) + + def test_cache_multiple_agents(self): + """Test caching policies for multiple agents.""" + agent1 = "agent_1" + agent2 = "agent_2" + checksum = "same_checksum" + policy1 = '{"policy": "agent1_policy"}' + policy2 = '{"policy": "agent2_policy"}' + + # Cache policies for different agents + initialize_agent_policy_cache(agent1) + initialize_agent_policy_cache(agent2) + cache_policy(agent1, checksum, policy1) + cache_policy(agent2, checksum, policy2) + + # Verify each agent has its own policy + self.assertEqual(get_cached_policy(agent1, checksum), policy1) + self.assertEqual(get_cached_policy(agent2, checksum), policy2) + + +if __name__ == "__main__": + unittest.main() -- 2.47.3