keylime/0013-Add-shared-memory-infrastructure-for-multiprocess-co.patch

1108 lines
40 KiB
Diff

From 1eaad216e290d5935f59e9137a233ac8516a8afb Mon Sep 17 00:00:00 2001
From: Sergio Correia <scorreia@redhat.com>
Date: Tue, 9 Dec 2025 11:11:43 +0000
Subject: [PATCH 13/14] Add shared memory infrastructure for multiprocess
communication
Backport of upstream https://github.com/keylime/keylime/pull/1817/commits/1024e19d
Signed-off-by: Sergio Correia <scorreia@redhat.com>
---
keylime-selinux-42.1.2/keylime.te | 2 +
keylime/cloud_verifier_tornado.py | 89 ++---
keylime/cmd/verifier.py | 6 +
keylime/config.py | 87 +++++
keylime/shared_data.py | 513 +++++++++++++++++++++++++
keylime/tpm/tpm_main.py | 17 +-
keylime/web/base/default_controller.py | 6 +
test/test_shared_data.py | 199 ++++++++++
8 files changed, 868 insertions(+), 51 deletions(-)
create mode 100644 keylime/shared_data.py
create mode 100644 test/test_shared_data.py
diff --git a/keylime-selinux-42.1.2/keylime.te b/keylime-selinux-42.1.2/keylime.te
index 2c6a59e..8b8a615 100644
--- a/keylime-selinux-42.1.2/keylime.te
+++ b/keylime-selinux-42.1.2/keylime.te
@@ -77,6 +77,8 @@ optional_policy(`
allow keylime_server_t self:key { create read setattr view write };
allow keylime_server_t self:netlink_route_socket { create_stream_socket_perms nlmsg_read };
allow keylime_server_t self:udp_socket create_stream_socket_perms;
+allow keylime_server_t keylime_tmp_t:sock_file { create write };
+allow keylime_server_t self:unix_stream_socket connectto;
fs_dontaudit_search_cgroup_dirs(keylime_server_t)
diff --git a/keylime/cloud_verifier_tornado.py b/keylime/cloud_verifier_tornado.py
index 89aa703..67ba8af 100644
--- a/keylime/cloud_verifier_tornado.py
+++ b/keylime/cloud_verifier_tornado.py
@@ -6,8 +6,8 @@ import signal
import sys
import traceback
from concurrent.futures import ThreadPoolExecutor
-from multiprocessing import Process
from contextlib import contextmanager
+from multiprocessing import Process
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
import tornado.httpserver
@@ -27,6 +27,7 @@ from keylime import (
json,
keylime_logging,
revocation_notifier,
+ shared_data,
signing,
tornado_requests,
web_util,
@@ -43,7 +44,6 @@ from keylime.mba import mba
logger = keylime_logging.init_logging("verifier")
-GLOBAL_POLICY_CACHE: Dict[str, Dict[str, str]] = {}
set_severity_config(config.getlist("verifier", "severity_labels"), config.getlist("verifier", "severity_policy"))
@@ -140,44 +140,41 @@ def _from_db_obj(agent_db_obj: VerfierMain) -> Dict[str, Any]:
return agent_dict
-def verifier_read_policy_from_cache(ima_policy_data: Dict[str, str]) -> str:
- checksum = ima_policy_data.get("checksum", "")
- name = ima_policy_data.get("name", "empty")
- agent_id = ima_policy_data.get("agent_id", "")
+def verifier_read_policy_from_cache(stored_agent: VerfierMain) -> str:
+ checksum = ""
+ name = "empty"
+ agent_id = str(stored_agent.agent_id)
- if not agent_id:
- return ""
+ # Initialize agent policy cache if it doesn't exist
+ shared_data.initialize_agent_policy_cache(agent_id)
- if agent_id not in GLOBAL_POLICY_CACHE:
- GLOBAL_POLICY_CACHE[agent_id] = {}
- GLOBAL_POLICY_CACHE[agent_id][""] = ""
+ if stored_agent.ima_policy:
+ checksum = str(stored_agent.ima_policy.checksum)
+ name = stored_agent.ima_policy.name
- if checksum not in GLOBAL_POLICY_CACHE[agent_id]:
- if len(GLOBAL_POLICY_CACHE[agent_id]) > 1:
- # Perform a cleanup of the contents, IMA policy checksum changed
- logger.debug(
- "Cleaning up policy cache for policy named %s, with checksum %s, used by agent %s",
- name,
- checksum,
- agent_id,
- )
+ # Check if policy is already cached
+ cached_policy = shared_data.get_cached_policy(agent_id, checksum)
+ if cached_policy is not None:
+ return cached_policy
- GLOBAL_POLICY_CACHE[agent_id] = {}
- GLOBAL_POLICY_CACHE[agent_id][""] = ""
+ # Policy not cached, need to clean up and load from database
+ shared_data.cleanup_agent_policy_cache(agent_id, checksum)
- logger.debug(
- "IMA policy named %s, with checksum %s, used by agent %s is not present on policy cache on this verifier, performing SQLAlchemy load",
- name,
- checksum,
- agent_id,
- )
+ logger.debug(
+ "IMA policy named %s, with checksum %s, used by agent %s is not present on policy cache on this verifier, performing SQLAlchemy load",
+ name,
+ checksum,
+ agent_id,
+ )
- # Get the large ima_policy content - it's already loaded in ima_policy_data
- ima_policy = ima_policy_data.get("ima_policy", "")
- assert isinstance(ima_policy, str)
- GLOBAL_POLICY_CACHE[agent_id][checksum] = ima_policy
+ # Actually contacts the database and load the (large) ima_policy column for "allowlists" table
+ ima_policy = stored_agent.ima_policy.ima_policy
+ assert isinstance(ima_policy, str)
- return GLOBAL_POLICY_CACHE[agent_id][checksum]
+ # Cache the policy for future use
+ shared_data.cache_policy(agent_id, checksum, ima_policy)
+
+ return ima_policy
def verifier_db_delete_agent(session: Session, agent_id: str) -> None:
@@ -475,12 +472,11 @@ class AgentsHandler(BaseHandler):
return
# Cleanup the cache when the agent is deleted. Do it early.
- if agent_id in GLOBAL_POLICY_CACHE:
- del GLOBAL_POLICY_CACHE[agent_id]
- logger.debug(
- "Cleaned up policy cache from all entries used by agent %s",
- agent_id,
- )
+ shared_data.clear_agent_policy_cache(agent_id)
+ logger.debug(
+ "Cleaned up policy cache from all entries used by agent %s",
+ agent_id,
+ )
op_state = agent.operational_state
if op_state in (states.SAVED, states.FAILED, states.TERMINATED, states.TENANT_FAILED, states.INVALID_QUOTE):
@@ -1763,7 +1759,6 @@ async def process_agent(
stored_agent = None
# First database operation - read agent data and extract all needed data within session context
- ima_policy_data = {}
mb_policy_data = None
with session_context() as session:
try:
@@ -1779,15 +1774,6 @@ async def process_agent(
.first()
)
- # Extract IMA policy data within session context to avoid DetachedInstanceError
- if stored_agent and stored_agent.ima_policy:
- ima_policy_data = {
- "checksum": str(stored_agent.ima_policy.checksum),
- "name": stored_agent.ima_policy.name,
- "agent_id": str(stored_agent.agent_id),
- "ima_policy": stored_agent.ima_policy.ima_policy, # Extract the large content too
- }
-
# Extract MB policy data within session context
if stored_agent and stored_agent.mb_policy:
mb_policy_data = stored_agent.mb_policy.mb_policy
@@ -1869,7 +1855,10 @@ async def process_agent(
logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e)
# Load agent's IMA policy
- runtime_policy = verifier_read_policy_from_cache(ima_policy_data)
+ if stored_agent:
+ runtime_policy = verifier_read_policy_from_cache(stored_agent)
+ else:
+ runtime_policy = ""
# Get agent's measured boot policy
mb_policy = mb_policy_data
diff --git a/keylime/cmd/verifier.py b/keylime/cmd/verifier.py
index f3e1a86..1f9f4e5 100644
--- a/keylime/cmd/verifier.py
+++ b/keylime/cmd/verifier.py
@@ -1,6 +1,7 @@
from keylime import cloud_verifier_tornado, config, keylime_logging
from keylime.common.migrations import apply
from keylime.mba import mba
+from keylime.shared_data import initialize_shared_memory
logger = keylime_logging.init_logging("verifier")
@@ -10,6 +11,11 @@ def main() -> None:
if config.has_option("verifier", "auto_migrate_db") and config.getboolean("verifier", "auto_migrate_db"):
apply("cloud_verifier")
+ # Initialize shared memory BEFORE creating server instance
+ # This MUST happen before verifier instantiation and worker forking
+ logger.info("Initializing shared memory manager in main process before server creation")
+ initialize_shared_memory()
+
# Explicitly load and initialize measured boot components
mba.load_imports()
cloud_verifier_tornado.main()
diff --git a/keylime/config.py b/keylime/config.py
index e7ac634..b5cd546 100644
--- a/keylime/config.py
+++ b/keylime/config.py
@@ -114,6 +114,85 @@ if "KEYLIME_LOGGING_CONFIG" in os.environ:
_config: Optional[Dict[str, RawConfigParser]] = None
+def _check_file_permissions(component: str, file_path: str) -> bool:
+ """Check if a config file has correct permissions and is readable.
+
+ Args:
+ component: The component name (e.g., 'verifier', 'agent')
+ file_path: Path to the config file
+
+ Returns:
+ True if file is readable, False otherwise
+ """
+ if not os.path.exists(file_path):
+ return False
+
+ if not os.access(file_path, os.R_OK):
+ import grp # pylint: disable=import-outside-toplevel
+ import pwd # pylint: disable=import-outside-toplevel
+ import stat # pylint: disable=import-outside-toplevel
+
+ try:
+ file_stat = os.stat(file_path)
+ owner = pwd.getpwuid(file_stat.st_uid).pw_name
+ group = grp.getgrgid(file_stat.st_gid).gr_name
+ mode = stat.filemode(file_stat.st_mode)
+ except Exception:
+ owner = group = mode = "unknown"
+
+ base_logger.error( # pylint: disable=logging-not-lazy
+ "=" * 80
+ + "\n"
+ + "CRITICAL CONFIG ERROR: Config file %s exists but is not readable!\n"
+ + "File permissions: %s (owner: %s, group: %s)\n"
+ + "The keylime_%s service needs read access to this file.\n"
+ + "Fix with: chown keylime:keylime %s && chmod 440 %s\n"
+ + "=" * 80,
+ file_path,
+ mode,
+ owner,
+ group,
+ component,
+ file_path,
+ file_path,
+ )
+ return False
+
+ return True
+
+
+def _validate_config_files(component: str, file_paths: List[str], files_read: List[str]) -> None:
+ """Validate that config files were successfully parsed.
+
+ Args:
+ component: The component name (e.g., 'verifier', 'agent')
+ file_paths: List of file paths that were attempted to be read
+ files_read: List of files that ConfigParser successfully read
+ """
+ for file_path in file_paths:
+ # Check file permissions first
+ if not _check_file_permissions(component, file_path):
+ continue
+
+ if file_path not in files_read:
+ base_logger.error( # pylint: disable=logging-not-lazy
+ "=" * 80
+ + "\n"
+ + "CRITICAL CONFIG ERROR: Config file %s exists but failed to parse!\n"
+ + "This usually indicates duplicate keys within the same file.\n"
+ + "Common issues:\n"
+ + " - Same option appears multiple times in the same [%s] section\n"
+ + " - Empty values (key = ) conflicting with defined values\n"
+ + " - Invalid INI file syntax\n"
+ + "Please check the file for duplicate entries.\n"
+ + "You can validate the file with: python3 -c \"import configparser; c = configparser.RawConfigParser(); print(c.read('%s'))\"\n"
+ + "=" * 80,
+ file_path,
+ component,
+ file_path,
+ )
+
+
def get_config(component: str) -> RawConfigParser:
"""Find the configuration file to use for the given component and apply the
overrides defined by configuration snippets.
@@ -216,6 +295,10 @@ def get_config(component: str) -> RawConfigParser:
# Validate that at least one config file is present
config_file = _config[component].read(c)
+
+ # Validate the config file was parsed successfully
+ _validate_config_files(component, [c], config_file)
+
if config_file:
base_logger.info("Reading configuration from %s", config_file)
@@ -230,6 +313,10 @@ def get_config(component: str) -> RawConfigParser:
[os.path.join(d, f) for f in os.listdir(d) if f and os.path.isfile(os.path.join(d, f))]
)
applied_snippets = _config[component].read(snippets)
+
+ # Validate all snippet files were parsed successfully
+ _validate_config_files(component, snippets, applied_snippets)
+
if applied_snippets:
base_logger.info("Applied configuration snippets from %s", d)
diff --git a/keylime/shared_data.py b/keylime/shared_data.py
new file mode 100644
index 0000000..23a3d81
--- /dev/null
+++ b/keylime/shared_data.py
@@ -0,0 +1,513 @@
+"""Shared memory management for keylime multiprocess applications.
+
+This module provides thread-safe shared data management between processes
+using multiprocessing.Manager().
+"""
+
+import atexit
+import multiprocessing as mp
+import threading
+import time
+from typing import Any, Dict, List, Optional
+
+from keylime import keylime_logging
+
+logger = keylime_logging.init_logging("shared_data")
+
+
+class FlatDictView:
+ """A dictionary-like view over a flat key-value store.
+
+ This class provides dict-like access to a subset of keys in a flat store,
+ identified by a namespace prefix. This avoids the nested DictProxy issues.
+
+ Example:
+ store = manager.dict() # Flat store
+ view = FlatDictView(store, lock, "sessions")
+ view["123"] = "data" # Stores as "dict:sessions:123" in flat store
+ val = view["123"] # Retrieves from "dict:sessions:123"
+ """
+
+ def __init__(self, store: Any, lock: Any, namespace: str) -> None:
+ self._store = store
+ self._lock = lock
+ self._namespace = namespace
+
+ def _make_key(self, key: Any) -> str:
+ """Convert user key to internal flat key with namespace prefix."""
+ return f"dict:{self._namespace}:{key}"
+
+ def __getitem__(self, key: Any) -> Any:
+ with self._lock:
+ return self._store[self._make_key(key)]
+
+ def __setitem__(self, key: Any, value: Any) -> None:
+ flat_key = self._make_key(key)
+ with self._lock:
+ self._store[flat_key] = value
+
+ def __delitem__(self, key: Any) -> None:
+ flat_key = self._make_key(key)
+ with self._lock:
+ del self._store[flat_key]
+
+ def __contains__(self, key: Any) -> bool:
+ return self._make_key(key) in self._store
+
+ def get(self, key: Any, default: Any = None) -> Any:
+ with self._lock:
+ return self._store.get(self._make_key(key), default)
+
+ def keys(self) -> List[Any]:
+ """Return keys in this namespace."""
+ prefix = f"dict:{self._namespace}:"
+ all_store_keys = list(self._store.keys())
+ matching_keys = [k[len(prefix) :] for k in all_store_keys if k.startswith(prefix)]
+ return matching_keys
+
+ def values(self) -> List[Any]:
+ """Return values in this namespace."""
+ prefix = f"dict:{self._namespace}:"
+ with self._lock:
+ return [v for k, v in self._store.items() if k.startswith(prefix)]
+
+ def items(self) -> List[tuple[Any, Any]]:
+ """Return (key, value) pairs in this namespace."""
+ prefix = f"dict:{self._namespace}:"
+ with self._lock:
+ result = [(k[len(prefix) :], v) for k, v in self._store.items() if k.startswith(prefix)]
+ return result
+
+ def __len__(self) -> int:
+ """Return number of items in this namespace."""
+ return len(self.keys())
+
+ def __repr__(self) -> str:
+ return f"FlatDictView({self._namespace}, {len(self)} items)"
+
+
+class SharedDataManager:
+ """Thread-safe shared data manager for multiprocess applications.
+
+ This class uses multiprocessing.Manager() to create proxy objects that can
+ be safely accessed from multiple processes. All data stored must be pickleable.
+
+ Example:
+ manager = SharedDataManager()
+
+ # Store simple data
+ manager.set_data("config_value", "some_config")
+ value = manager.get_data("config_value")
+
+ # Work with shared dictionaries
+ agent_cache = manager.get_or_create_dict("agent_cache")
+ agent_cache["agent_123"] = {"last_seen": time.time()}
+
+ # Work with shared lists
+ event_log = manager.get_or_create_list("events")
+ event_log.append({"type": "attestation", "agent": "agent_123"})
+ """
+
+ def __init__(self) -> None:
+ """Initialize the shared data manager.
+
+ This must be called before any process forking occurs to ensure
+ all child processes inherit access to the shared data.
+ """
+ logger.debug("Initializing SharedDataManager")
+
+ # Use explicit context to ensure fork compatibility
+ # The Manager must be started BEFORE any fork() calls
+ ctx = mp.get_context("fork")
+ self._manager = ctx.Manager()
+
+ # CRITICAL FIX: Use a SINGLE flat dict instead of nested dicts
+ # Nested DictProxy objects have synchronization issues
+ # We'll use key prefixes like "dict:auth_sessions:session_id" instead
+ self._store = self._manager.dict() # Single flat store for all data
+ self._lock = self._manager.Lock()
+ self._initialized_at = time.time()
+
+ # Register handler to reinitialize manager connection after fork
+ # This is needed because Manager uses network connections that don't survive fork
+ try:
+ import os # pylint: disable=import-outside-toplevel
+
+ self._parent_pid = os.getpid()
+ logger.debug("SharedDataManager initialized in process %d", self._parent_pid)
+ except Exception as e:
+ logger.warning("Could not register PID tracking: %s", e)
+
+ # Ensure cleanup on exit
+ atexit.register(self.cleanup)
+
+ logger.info("SharedDataManager initialized successfully")
+
+ def set_data(self, key: str, value: Any) -> None:
+ """Store arbitrary pickleable data by key.
+
+ Args:
+ key: Unique identifier for the data
+ value: Any pickleable Python object
+
+ Raises:
+ TypeError: If value is not pickleable
+ """
+ with self._lock:
+ try:
+ self._store[key] = value
+ logger.debug("Stored data for key: %s", key)
+ except Exception as e:
+ logger.error("Failed to store data for key '%s': %s", key, e)
+ raise
+
+ def get_data(self, key: str, default: Any = None) -> Any:
+ """Retrieve data by key.
+
+ Args:
+ key: The key to retrieve
+ default: Value to return if key doesn't exist
+
+ Returns:
+ The stored value or default if key doesn't exist
+ """
+ with self._lock:
+ value = self._store.get(key, default)
+ logger.debug("Retrieved data for key: %s (found: %s)", key, value is not default)
+ return value
+
+ def get_or_create_dict(self, key: str) -> Dict[str, Any]:
+ """Get or create a shared dictionary.
+
+ Args:
+ key: Unique identifier for the dictionary
+
+ Returns:
+ A shared dictionary-like object that syncs across processes
+
+ Note:
+ Returns a FlatDictView that uses key prefixes in the flat store
+ instead of actual nested dicts, to avoid DictProxy nesting issues.
+ """
+ # Mark that this namespace exists
+ namespace_key = f"__namespace__{key}"
+ if namespace_key not in self._store:
+ with self._lock:
+ self._store[namespace_key] = True
+
+ # Return a view that operates on the flat store with key prefix
+ return FlatDictView(self._store, self._lock, key) # type: ignore[return-value,no-untyped-call]
+
+ def get_or_create_list(self, key: str) -> List[Any]:
+ """Get or create a shared list.
+
+ Args:
+ key: Unique identifier for the list
+
+ Returns:
+ A shared list (proxy object) that syncs across processes
+ """
+ with self._lock:
+ if key not in self._store:
+ self._store[key] = self._manager.list()
+ logger.debug("Created new shared list for key: %s", key)
+ else:
+ logger.debug("Retrieved existing shared list for key: %s", key)
+ return self._store[key] # type: ignore[no-any-return]
+
+ def delete_data(self, key: str) -> bool:
+ """Delete data by key.
+
+ Args:
+ key: The key to delete
+
+ Returns:
+ True if the key existed and was deleted, False otherwise
+ """
+ with self._lock:
+ if key in self._store:
+ del self._store[key]
+ logger.debug("Deleted data for key: %s", key)
+ return True
+ logger.debug("Key not found for deletion: %s", key)
+ return False
+
+ def has_key(self, key: str) -> bool:
+ """Check if a key exists.
+
+ Args:
+ key: The key to check
+
+ Returns:
+ True if key exists, False otherwise
+ """
+ with self._lock:
+ return key in self._store
+
+ def get_keys(self) -> List[str]:
+ """Get all stored keys.
+
+ Returns:
+ List of all keys in the store
+ """
+ with self._lock:
+ return list(self._store.keys())
+
+ def clear_all(self) -> None:
+ """Clear all stored data. Use with caution!"""
+ with self._lock:
+ key_count = len(self._store)
+ self._store.clear()
+ logger.warning("Cleared all shared data (%d keys)", key_count)
+
+ def get_stats(self) -> Dict[str, Any]:
+ """Get statistics about stored data.
+
+ Returns:
+ Dictionary containing storage statistics
+ """
+ with self._lock:
+ return {
+ "total_keys": len(self._store),
+ "initialized_at": self._initialized_at,
+ "uptime_seconds": time.time() - self._initialized_at,
+ }
+
+ def cleanup(self) -> None:
+ """Cleanup shared resources.
+
+ This is automatically called on exit but can be called manually
+ for explicit cleanup.
+ """
+ if hasattr(self, "_manager"):
+ logger.debug("Shutting down SharedDataManager")
+ try:
+ self._manager.shutdown()
+ logger.info("SharedDataManager shutdown complete")
+ except Exception as e:
+ logger.error("Error during SharedDataManager shutdown: %s", e)
+
+ def __repr__(self) -> str:
+ stats = self.get_stats()
+ return f"SharedDataManager(keys={stats['total_keys']}, " f"uptime={stats['uptime_seconds']:.1f}s)"
+
+ @property
+ def manager(self) -> Any: # type: ignore[misc]
+ """Access to the underlying multiprocessing Manager for advanced usage."""
+ return self._manager
+
+
+# Global shared memory manager instance
+_global_shared_manager: Optional[SharedDataManager] = None
+_manager_lock = threading.Lock()
+
+
+def initialize_shared_memory() -> SharedDataManager:
+ """Initialize the global shared memory manager.
+
+ This function MUST be called before any process forking occurs to ensure
+ all child processes share the same manager instance.
+
+ For tornado/multiprocess servers, call this before starting workers.
+
+ Returns:
+ SharedDataManager: The global shared memory manager instance
+
+ Raises:
+ RuntimeError: If called after manager is already initialized
+ """
+ global _global_shared_manager
+
+ with _manager_lock:
+ if _global_shared_manager is not None:
+ logger.warning("Shared memory manager already initialized, returning existing instance")
+ return _global_shared_manager
+
+ logger.info("Initializing global shared memory manager")
+ _global_shared_manager = SharedDataManager()
+ logger.info("Global shared memory manager initialized")
+
+ return _global_shared_manager
+
+
+def get_shared_memory() -> SharedDataManager:
+ """Get the global shared memory manager instance.
+
+ This function returns a singleton SharedDataManager that can be used
+ throughout keylime for caching and inter-process communication.
+
+ The manager is automatically initialized on first access and cleaned up
+ on process exit.
+
+ IMPORTANT: In multiprocess applications (like tornado with workers),
+ you MUST call initialize_shared_memory() BEFORE forking workers.
+ Otherwise each worker will get its own separate manager.
+
+ Returns:
+ SharedDataManager: The global shared memory manager instance
+ """
+ global _global_shared_manager
+
+ if _global_shared_manager is None:
+ with _manager_lock:
+ if _global_shared_manager is None:
+ logger.info("Initializing global shared memory manager")
+ _global_shared_manager = SharedDataManager() # type: ignore[no-untyped-call]
+ logger.info("Global shared memory manager initialized")
+
+ return _global_shared_manager
+
+
+def cleanup_global_shared_memory() -> None:
+ """Cleanup the global shared memory manager.
+
+ This is automatically called on exit but can be called manually.
+ """
+ global _global_shared_manager
+
+ if _global_shared_manager is not None:
+ logger.info("Cleaning up global shared memory manager")
+ _global_shared_manager.cleanup()
+ _global_shared_manager = None
+
+
+# Convenience functions for common keylime patterns
+
+
+def cache_policy(agent_id: str, checksum: str, policy: str) -> None:
+ """Cache a policy in shared memory.
+
+ Args:
+ agent_id: The agent identifier
+ checksum: The policy checksum
+ policy: The policy content to cache
+ """
+ manager = get_shared_memory()
+ policy_cache = manager.get_or_create_dict("policy_cache")
+
+ if agent_id not in policy_cache:
+ policy_cache[agent_id] = manager.manager.dict() # type: ignore[attr-defined]
+
+ policy_cache[agent_id][checksum] = policy
+ logger.debug("Cached policy for agent %s with checksum %s", agent_id, checksum)
+
+
+def get_cached_policy(agent_id: str, checksum: str) -> Optional[str]:
+ """Retrieve cached policy.
+
+ Args:
+ agent_id: The agent identifier
+ checksum: The policy checksum
+
+ Returns:
+ The cached policy content or None if not found
+ """
+ manager = get_shared_memory()
+ policy_cache = manager.get_or_create_dict("policy_cache")
+ agent_policies = policy_cache.get(agent_id, {})
+
+ result = agent_policies.get(checksum)
+ if result:
+ logger.debug("Found cached policy for agent %s with checksum %s", agent_id, checksum)
+ else:
+ logger.debug("No cached policy found for agent %s with checksum %s", agent_id, checksum)
+
+ return result # type: ignore[no-any-return]
+
+
+def clear_agent_policy_cache(agent_id: str) -> None:
+ """Clear all cached policies for an agent.
+
+ Args:
+ agent_id: The agent identifier
+ """
+ manager = get_shared_memory()
+ policy_cache = manager.get_or_create_dict("policy_cache")
+
+ if agent_id in policy_cache:
+ del policy_cache[agent_id]
+ logger.debug("Cleared policy cache for agent %s", agent_id)
+
+
+def cleanup_agent_policy_cache(agent_id: str, keep_checksum: str = "") -> None:
+ """Clean up agent policy cache, keeping only the specified checksum.
+
+ This mimics the cleanup behavior from GLOBAL_POLICY_CACHE where when
+ a new policy checksum is encountered, old cached policies are removed.
+
+ Args:
+ agent_id: The agent identifier
+ keep_checksum: The checksum to keep in the cache (empty string by default)
+ """
+ manager = get_shared_memory()
+ policy_cache = manager.get_or_create_dict("policy_cache")
+
+ if agent_id in policy_cache and len(policy_cache[agent_id]) > 1:
+ # Keep only the empty entry and the specified checksum
+ old_policies = dict(policy_cache[agent_id])
+ policy_cache[agent_id] = manager.manager.dict()
+
+ # Always keep the empty entry
+ policy_cache[agent_id][""] = old_policies.get("", "")
+
+ # Keep the specified checksum if it exists and is not empty
+ if keep_checksum and keep_checksum in old_policies:
+ policy_cache[agent_id][keep_checksum] = old_policies[keep_checksum]
+
+ logger.debug("Cleaned up policy cache for agent %s, keeping checksum %s", agent_id, keep_checksum)
+
+
+def initialize_agent_policy_cache(agent_id: str) -> Dict[str, Any]:
+ """Initialize policy cache for an agent if it doesn't exist.
+
+ Args:
+ agent_id: The agent identifier
+
+ Returns:
+ The agent's policy cache dictionary
+ """
+ manager = get_shared_memory()
+ policy_cache = manager.get_or_create_dict("policy_cache")
+
+ if agent_id not in policy_cache:
+ policy_cache[agent_id] = manager.manager.dict() # type: ignore[attr-defined]
+ policy_cache[agent_id][""] = ""
+ logger.debug("Initialized policy cache for agent %s", agent_id)
+
+ return policy_cache[agent_id] # type: ignore[no-any-return]
+
+
+def get_agent_cache(agent_id: str) -> Dict[str, Any]:
+ """Get shared cache for a specific agent.
+
+ Args:
+ agent_id: The agent identifier
+
+ Returns:
+ A shared dictionary for caching agent-specific data
+ """
+ manager = get_shared_memory()
+ return manager.get_or_create_dict(f"agent_cache:{agent_id}")
+
+
+def get_verification_queue(agent_id: str) -> List[Any]:
+ """Get verification queue for batching database operations.
+
+ Args:
+ agent_id: The agent identifier
+
+ Returns:
+ A shared list for queuing verification operations
+ """
+ manager = get_shared_memory()
+ return manager.get_or_create_list(f"verification_queue:{agent_id}")
+
+
+def get_shared_stats() -> Dict[str, Any]:
+ """Get statistics about shared memory usage.
+
+ Returns:
+ Dictionary containing storage statistics
+ """
+ manager = get_shared_memory()
+ return manager.get_stats()
diff --git a/keylime/tpm/tpm_main.py b/keylime/tpm/tpm_main.py
index 6f2e89f..9b54fc3 100644
--- a/keylime/tpm/tpm_main.py
+++ b/keylime/tpm/tpm_main.py
@@ -10,7 +10,7 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from keylime import cert_utils, config, json, keylime_logging
from keylime.agentstates import AgentAttestState, TPMClockInfo
-from keylime.common.algorithms import Hash
+from keylime.common.algorithms import Hash, Sign
from keylime.failure import Component, Failure
from keylime.ima import ima
from keylime.ima.file_signatures import ImaKeyrings
@@ -50,6 +50,21 @@ class Tpm:
return (keyblob, key)
+ # Mapping from keylime.common.algorithms enums to TPM algorithm constants
+ # Used for validating that TPM attestations use expected cryptographic algorithms
+ HASH_ALG_TO_TPM = {
+ Hash.SHA1: tpm2_objects.TPM_ALG_SHA1,
+ Hash.SHA256: tpm2_objects.TPM_ALG_SHA256,
+ Hash.SHA384: tpm2_objects.TPM_ALG_SHA384,
+ Hash.SHA512: tpm2_objects.TPM_ALG_SHA512,
+ }
+
+ SIGN_ALG_TO_TPM = {
+ Sign.RSASSA: tpm2_objects.TPM_ALG_RSASSA,
+ Sign.RSAPSS: tpm2_objects.TPM_ALG_RSAPSS,
+ Sign.ECDSA: tpm2_objects.TPM_ALG_ECDSA,
+ }
+
@staticmethod
def verify_aik_with_iak(uuid: str, aik_tpm: bytes, iak_tpm: bytes, iak_attest: bytes, iak_sign: bytes) -> bool:
attest_body = iak_attest.split(b"\x00$")[1]
diff --git a/keylime/web/base/default_controller.py b/keylime/web/base/default_controller.py
index 971ed06..ba0782e 100644
--- a/keylime/web/base/default_controller.py
+++ b/keylime/web/base/default_controller.py
@@ -19,6 +19,12 @@ class DefaultController(Controller):
self.send_response(400, "Bad Request")
def malformed_params(self, **_params: Any) -> None:
+ import traceback # pylint: disable=import-outside-toplevel
+
+ from keylime import keylime_logging # pylint: disable=import-outside-toplevel
+
+ logger = keylime_logging.init_logging("web")
+ logger.error("Malformed params error. Traceback: %s", traceback.format_exc())
self.send_response(400, "Malformed Request Parameter")
def action_dispatch_error(self, **_param: Any) -> None:
diff --git a/test/test_shared_data.py b/test/test_shared_data.py
new file mode 100644
index 0000000..8de7e64
--- /dev/null
+++ b/test/test_shared_data.py
@@ -0,0 +1,199 @@
+"""Unit tests for shared memory infrastructure."""
+
+import unittest
+
+from keylime.shared_data import (
+ SharedDataManager,
+ cache_policy,
+ cleanup_agent_policy_cache,
+ cleanup_global_shared_memory,
+ clear_agent_policy_cache,
+ get_cached_policy,
+ get_shared_memory,
+ initialize_agent_policy_cache,
+)
+
+
+class TestSharedDataManager(unittest.TestCase):
+ """Test cases for SharedDataManager class."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ self.manager = SharedDataManager()
+
+ def tearDown(self):
+ """Clean up after tests."""
+ if self.manager:
+ self.manager.cleanup()
+
+ def test_set_and_get_data(self):
+ """Test basic set and get operations."""
+ self.manager.set_data("test_key", "test_value")
+ result = self.manager.get_data("test_key")
+ self.assertEqual(result, "test_value")
+
+ def test_get_nonexistent_data(self):
+ """Test getting data that doesn't exist returns None."""
+ result = self.manager.get_data("nonexistent_key")
+ self.assertIsNone(result)
+
+ def test_get_data_with_default(self):
+ """Test getting data with default value."""
+ result = self.manager.get_data("nonexistent_key", default="default_value")
+ self.assertEqual(result, "default_value")
+
+ def test_delete_data(self):
+ """Test deleting data."""
+ self.manager.set_data("test_key", "test_value")
+ result = self.manager.delete_data("test_key")
+ self.assertTrue(result)
+
+ # Verify it's actually deleted
+ self.assertIsNone(self.manager.get_data("test_key"))
+
+ def test_delete_nonexistent_data(self):
+ """Test deleting data that doesn't exist returns False."""
+ result = self.manager.delete_data("nonexistent_key")
+ self.assertFalse(result)
+
+ def test_has_key(self):
+ """Test checking if key exists."""
+ self.manager.set_data("test_key", "test_value")
+ self.assertTrue(self.manager.has_key("test_key"))
+ self.assertFalse(self.manager.has_key("nonexistent_key"))
+
+ def test_get_or_create_dict(self):
+ """Test getting or creating a shared dictionary."""
+ shared_dict = self.manager.get_or_create_dict("test_dict")
+ shared_dict["key1"] = "value1"
+ shared_dict["key2"] = "value2"
+
+ # Retrieve the same dict
+ retrieved_dict = self.manager.get_or_create_dict("test_dict")
+ self.assertEqual(retrieved_dict["key1"], "value1")
+ self.assertEqual(retrieved_dict["key2"], "value2")
+
+ def test_get_or_create_list(self):
+ """Test getting or creating a shared list."""
+ shared_list = self.manager.get_or_create_list("test_list")
+ shared_list.append("item1")
+ shared_list.append("item2")
+
+ # Retrieve the same list
+ retrieved_list = self.manager.get_or_create_list("test_list")
+ self.assertEqual(len(retrieved_list), 2)
+ self.assertEqual(retrieved_list[0], "item1")
+ self.assertEqual(retrieved_list[1], "item2")
+
+ def test_get_stats(self):
+ """Test getting manager statistics."""
+ self.manager.set_data("key1", "value1")
+ self.manager.set_data("key2", "value2")
+
+ stats = self.manager.get_stats()
+ self.assertIn("total_keys", stats)
+ self.assertIn("uptime_seconds", stats)
+ self.assertEqual(stats["total_keys"], 2)
+ self.assertGreaterEqual(stats["uptime_seconds"], 0)
+
+
+class TestPolicyCacheFunctions(unittest.TestCase):
+ """Test cases for policy cache functions."""
+
+ def setUp(self):
+ """Set up test fixtures."""
+ # Get the global shared memory manager
+ self.manager = get_shared_memory()
+
+ def tearDown(self):
+ """Clean up after tests."""
+ # Clean up global shared memory
+ cleanup_global_shared_memory()
+
+ def test_initialize_agent_policy_cache(self):
+ """Test initializing agent policy cache."""
+ agent_id = "test_agent_123"
+ initialize_agent_policy_cache(agent_id)
+
+ # Verify the cache was initialized
+ policy_cache = self.manager.get_or_create_dict("policy_cache")
+ self.assertIn(agent_id, policy_cache)
+
+ def test_cache_and_get_policy(self):
+ """Test caching and retrieving a policy."""
+ agent_id = "test_agent_123"
+ checksum = "abc123def456"
+ policy_content = '{"policy": "test_policy_content"}'
+
+ # Initialize and cache policy
+ initialize_agent_policy_cache(agent_id)
+ cache_policy(agent_id, checksum, policy_content)
+
+ # Retrieve cached policy
+ cached = get_cached_policy(agent_id, checksum)
+ self.assertEqual(cached, policy_content)
+
+ def test_get_nonexistent_cached_policy(self):
+ """Test getting a policy that hasn't been cached."""
+ agent_id = "test_agent_123"
+ checksum = "nonexistent_checksum"
+
+ initialize_agent_policy_cache(agent_id)
+ cached = get_cached_policy(agent_id, checksum)
+ self.assertIsNone(cached)
+
+ def test_clear_agent_policy_cache(self):
+ """Test clearing an agent's policy cache."""
+ agent_id = "test_agent_123"
+ checksum = "abc123def456"
+ policy_content = '{"policy": "test_policy_content"}'
+
+ # Initialize, cache, and then clear
+ initialize_agent_policy_cache(agent_id)
+ cache_policy(agent_id, checksum, policy_content)
+ clear_agent_policy_cache(agent_id)
+
+ # Verify it's cleared
+ cached = get_cached_policy(agent_id, checksum)
+ self.assertIsNone(cached)
+
+ def test_cleanup_agent_policy_cache(self):
+ """Test cleaning up old policy checksums."""
+ agent_id = "test_agent_123"
+ old_checksum = "old_checksum"
+ new_checksum = "new_checksum"
+ policy_content = '{"policy": "test"}'
+
+ # Initialize and cache multiple policies
+ initialize_agent_policy_cache(agent_id)
+ cache_policy(agent_id, old_checksum, policy_content)
+ cache_policy(agent_id, new_checksum, policy_content)
+
+ # Cleanup old checksums (keeping only new_checksum)
+ cleanup_agent_policy_cache(agent_id, new_checksum)
+
+ # Verify old checksum is removed but new one remains
+ self.assertIsNone(get_cached_policy(agent_id, old_checksum))
+ self.assertEqual(get_cached_policy(agent_id, new_checksum), policy_content)
+
+ def test_cache_multiple_agents(self):
+ """Test caching policies for multiple agents."""
+ agent1 = "agent_1"
+ agent2 = "agent_2"
+ checksum = "same_checksum"
+ policy1 = '{"policy": "agent1_policy"}'
+ policy2 = '{"policy": "agent2_policy"}'
+
+ # Cache policies for different agents
+ initialize_agent_policy_cache(agent1)
+ initialize_agent_policy_cache(agent2)
+ cache_policy(agent1, checksum, policy1)
+ cache_policy(agent2, checksum, policy2)
+
+ # Verify each agent has its own policy
+ self.assertEqual(get_cached_policy(agent1, checksum), policy1)
+ self.assertEqual(get_cached_policy(agent2, checksum), policy2)
+
+
+if __name__ == "__main__":
+ unittest.main()
--
2.47.3