From 1f0f824cc12f0c905eded5db886272a035977262 Mon Sep 17 00:00:00 2001 From: Anderson Toshiyuki Sasaki Date: Fri, 8 Aug 2025 17:30:28 +0200 Subject: [PATCH] Fix DB connection leaks Resolves: RHEL-108263 Signed-off-by: Anderson Toshiyuki Sasaki --- 0007-fix_db_connection_leaks.patch | 2208 ++++++++++++++++++++++++++++ keylime.spec | 8 +- 2 files changed, 2215 insertions(+), 1 deletion(-) create mode 100644 0007-fix_db_connection_leaks.patch diff --git a/0007-fix_db_connection_leaks.patch b/0007-fix_db_connection_leaks.patch new file mode 100644 index 0000000..64be967 --- /dev/null +++ b/0007-fix_db_connection_leaks.patch @@ -0,0 +1,2208 @@ +diff --git a/keylime/cloud_verifier_tornado.py b/keylime/cloud_verifier_tornado.py +index 8ab81d1..7553ac8 100644 +--- a/keylime/cloud_verifier_tornado.py ++++ b/keylime/cloud_verifier_tornado.py +@@ -7,7 +7,8 @@ import sys + import traceback + from concurrent.futures import ThreadPoolExecutor + from multiprocessing import Process +-from typing import Any, Dict, List, Optional, Tuple, Union, cast ++from contextlib import contextmanager ++from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast + + import tornado.httpserver + import tornado.ioloop +@@ -34,7 +35,7 @@ from keylime.agentstates import AgentAttestState, AgentAttestStates + from keylime.common import retry, states, validators + from keylime.common.version import str_to_version + from keylime.da import record +-from keylime.db.keylime_db import DBEngineManager, SessionManager ++from keylime.db.keylime_db import SessionManager, make_engine + from keylime.db.verifier_db import VerfierMain, VerifierAllowlist, VerifierMbpolicy + from keylime.failure import MAX_SEVERITY_LABEL, Component, Event, Failure, set_severity_config + from keylime.ima import ima +@@ -47,7 +48,7 @@ GLOBAL_POLICY_CACHE: Dict[str, Dict[str, str]] = {} + set_severity_config(config.getlist("verifier", "severity_labels"), config.getlist("verifier", "severity_policy")) + + try: +- engine = DBEngineManager().make_engine("cloud_verifier") ++ engine = make_engine("cloud_verifier") + except SQLAlchemyError as err: + logger.error("Error creating SQL engine or session: %s", err) + sys.exit(1) +@@ -61,8 +62,17 @@ except record.RecordManagementException as rme: + sys.exit(1) + + +-def get_session() -> Session: +- return SessionManager().make_session(engine) ++@contextmanager ++def session_context() -> Iterator[Session]: ++ """ ++ Context manager for database sessions that ensures proper cleanup. ++ To use: ++ with session_context() as session: ++ # use session ++ """ ++ session_manager = SessionManager() ++ with session_manager.session_context(engine) as session: ++ yield session + + + def get_AgentAttestStates() -> AgentAttestStates: +@@ -130,19 +140,18 @@ def _from_db_obj(agent_db_obj: VerfierMain) -> Dict[str, Any]: + return agent_dict + + +-def verifier_read_policy_from_cache(stored_agent: VerfierMain) -> str: +- checksum = "" +- name = "empty" +- agent_id = str(stored_agent.agent_id) ++def verifier_read_policy_from_cache(ima_policy_data: Dict[str, str]) -> str: ++ checksum = ima_policy_data.get("checksum", "") ++ name = ima_policy_data.get("name", "empty") ++ agent_id = ima_policy_data.get("agent_id", "") ++ ++ if not agent_id: ++ return "" + + if agent_id not in GLOBAL_POLICY_CACHE: + GLOBAL_POLICY_CACHE[agent_id] = {} + GLOBAL_POLICY_CACHE[agent_id][""] = "" + +- if stored_agent.ima_policy: +- checksum = str(stored_agent.ima_policy.checksum) +- name = stored_agent.ima_policy.name +- + if checksum not in GLOBAL_POLICY_CACHE[agent_id]: + if len(GLOBAL_POLICY_CACHE[agent_id]) > 1: + # Perform a cleanup of the contents, IMA policy checksum changed +@@ -162,8 +171,9 @@ def verifier_read_policy_from_cache(stored_agent: VerfierMain) -> str: + checksum, + agent_id, + ) +- # Actually contacts the database and load the (large) ima_policy column for "allowlists" table +- ima_policy = stored_agent.ima_policy.ima_policy ++ ++ # Get the large ima_policy content - it's already loaded in ima_policy_data ++ ima_policy = ima_policy_data.get("ima_policy", "") + assert isinstance(ima_policy, str) + GLOBAL_POLICY_CACHE[agent_id][checksum] = ima_policy + +@@ -182,22 +192,19 @@ def store_attestation_state(agentAttestState: AgentAttestState) -> None: + # Only store if IMA log was evaluated + if agentAttestState.get_ima_pcrs(): + agent_id = agentAttestState.agent_id +- session = get_session() + try: +- update_agent = session.query(VerfierMain).get(agentAttestState.get_agent_id()) +- assert update_agent +- update_agent.boottime = agentAttestState.get_boottime() +- update_agent.next_ima_ml_entry = agentAttestState.get_next_ima_ml_entry() +- ima_pcrs_dict = agentAttestState.get_ima_pcrs() +- update_agent.ima_pcrs = list(ima_pcrs_dict.keys()) +- for pcr_num, value in ima_pcrs_dict.items(): +- setattr(update_agent, f"pcr{pcr_num}", value) +- update_agent.learned_ima_keyrings = agentAttestState.get_ima_keyrings().to_json() +- try: ++ with session_context() as session: ++ update_agent = session.query(VerfierMain).get(agentAttestState.get_agent_id()) ++ assert update_agent ++ update_agent.boottime = agentAttestState.get_boottime() ++ update_agent.next_ima_ml_entry = agentAttestState.get_next_ima_ml_entry() ++ ima_pcrs_dict = agentAttestState.get_ima_pcrs() ++ update_agent.ima_pcrs = list(ima_pcrs_dict.keys()) ++ for pcr_num, value in ima_pcrs_dict.items(): ++ setattr(update_agent, f"pcr{pcr_num}", value) ++ update_agent.learned_ima_keyrings = agentAttestState.get_ima_keyrings().to_json() + session.add(update_agent) +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error on storing attestation state for agent %s: %s", agent_id, e) +- session.commit() ++ # session.commit() is automatically called by context manager + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error on storing attestation state for agent %s: %s", agent_id, e) + +@@ -354,45 +361,17 @@ class AgentsHandler(BaseHandler): + was not found, it either completed successfully, or failed. If found, the agent_id is still polling + to contact the Cloud Agent. + """ +- session = get_session() +- + rest_params, agent_id = self.__validate_input("GET") + if not rest_params: + return + +- if (agent_id is not None) and (agent_id != ""): +- # If the agent ID is not valid (wrong set of characters), +- # just do nothing. +- agent = None +- try: +- agent = ( +- session.query(VerfierMain) +- .options( # type: ignore +- joinedload(VerfierMain.ima_policy).load_only( +- VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore +- ) +- ) +- .options( # type: ignore +- joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore +- ) +- .filter_by(agent_id=agent_id) +- .one_or_none() +- ) +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) +- +- if agent is not None: +- response = cloud_verifier_common.process_get_status(agent) +- web_util.echo_json_response(self, 200, "Success", response) +- else: +- web_util.echo_json_response(self, 404, "agent id not found") +- else: +- json_response = None +- if "bulk" in rest_params: +- agent_list = None +- +- if ("verifier" in rest_params) and (rest_params["verifier"] != ""): +- agent_list = ( ++ with session_context() as session: ++ if (agent_id is not None) and (agent_id != ""): ++ # If the agent ID is not valid (wrong set of characters), ++ # just do nothing. ++ agent = None ++ try: ++ agent = ( + session.query(VerfierMain) + .options( # type: ignore + joinedload(VerfierMain.ima_policy).load_only( +@@ -402,39 +381,70 @@ class AgentsHandler(BaseHandler): + .options( # type: ignore + joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore + ) +- .filter_by(verifier_id=rest_params["verifier"]) +- .all() ++ .filter_by(agent_id=agent_id) ++ .one_or_none() + ) ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) ++ ++ if agent is not None: ++ response = cloud_verifier_common.process_get_status(agent) ++ web_util.echo_json_response(self, 200, "Success", response) + else: +- agent_list = ( +- session.query(VerfierMain) +- .options( # type: ignore +- joinedload(VerfierMain.ima_policy).load_only( +- VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore ++ web_util.echo_json_response(self, 404, "agent id not found") ++ else: ++ json_response = None ++ if "bulk" in rest_params: ++ agent_list = None ++ ++ if ("verifier" in rest_params) and (rest_params["verifier"] != ""): ++ agent_list = ( ++ session.query(VerfierMain) ++ .options( # type: ignore ++ joinedload(VerfierMain.ima_policy).load_only( ++ VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore ++ ) + ) ++ .options( # type: ignore ++ joinedload(VerfierMain.mb_policy).load_only( ++ VerifierMbpolicy.mb_policy # type: ignore[arg-type] ++ ) ++ ) ++ .filter_by(verifier_id=rest_params["verifier"]) ++ .all() + ) +- .options( # type: ignore +- joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore ++ else: ++ agent_list = ( ++ session.query(VerfierMain) ++ .options( # type: ignore ++ joinedload(VerfierMain.ima_policy).load_only( ++ VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore ++ ) ++ ) ++ .options( # type: ignore ++ joinedload(VerfierMain.mb_policy).load_only( ++ VerifierMbpolicy.mb_policy # type: ignore[arg-type] ++ ) ++ ) ++ .all() + ) +- .all() +- ) + +- json_response = {} +- for agent in agent_list: +- json_response[agent.agent_id] = cloud_verifier_common.process_get_status(agent) ++ json_response = {} ++ for agent in agent_list: ++ json_response[agent.agent_id] = cloud_verifier_common.process_get_status(agent) + +- web_util.echo_json_response(self, 200, "Success", json_response) +- else: +- if ("verifier" in rest_params) and (rest_params["verifier"] != ""): +- json_response_list = ( +- session.query(VerfierMain.agent_id).filter_by(verifier_id=rest_params["verifier"]).all() +- ) ++ web_util.echo_json_response(self, 200, "Success", json_response) + else: +- json_response_list = session.query(VerfierMain.agent_id).all() ++ if ("verifier" in rest_params) and (rest_params["verifier"] != ""): ++ json_response_list = ( ++ session.query(VerfierMain.agent_id).filter_by(verifier_id=rest_params["verifier"]).all() ++ ) ++ else: ++ json_response_list = session.query(VerfierMain.agent_id).all() + +- web_util.echo_json_response(self, 200, "Success", {"uuids": json_response_list}) ++ web_util.echo_json_response(self, 200, "Success", {"uuids": json_response_list}) + +- logger.info("GET returning 200 response for agent_id list") ++ logger.info("GET returning 200 response for agent_id list") + + def delete(self) -> None: + """This method handles the DELETE requests to remove agents from the Cloud Verifier. +@@ -442,59 +452,55 @@ class AgentsHandler(BaseHandler): + Currently, only agents resources are available for DELETEing, i.e. /agents. All other DELETE uri's will return errors. + agents requests require a single agent_id parameter which identifies the agent to be deleted. + """ +- session = get_session() +- + rest_params, agent_id = self.__validate_input("DELETE") + if not rest_params or not agent_id: + return + +- agent = None +- try: +- agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) ++ with session_context() as session: ++ agent = None ++ try: ++ agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first() ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) + +- if agent is None: +- web_util.echo_json_response(self, 404, "agent id not found") +- logger.info("DELETE returning 404 response. agent id: %s not found.", agent_id) +- return ++ if agent is None: ++ web_util.echo_json_response(self, 404, "agent id not found") ++ logger.info("DELETE returning 404 response. agent id: %s not found.", agent_id) ++ return + +- verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID) +- if verifier_id != agent.verifier_id: +- web_util.echo_json_response(self, 404, "agent id associated to this verifier") +- logger.info("DELETE returning 404 response. agent id: %s not associated to this verifer.", agent_id) +- return ++ verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID) ++ if verifier_id != agent.verifier_id: ++ web_util.echo_json_response(self, 404, "agent id associated to this verifier") ++ logger.info("DELETE returning 404 response. agent id: %s not associated to this verifer.", agent_id) ++ return + +- # Cleanup the cache when the agent is deleted. Do it early. +- if agent_id in GLOBAL_POLICY_CACHE: +- del GLOBAL_POLICY_CACHE[agent_id] +- logger.debug( +- "Cleaned up policy cache from all entries used by agent %s", +- agent_id, +- ) ++ # Cleanup the cache when the agent is deleted. Do it early. ++ if agent_id in GLOBAL_POLICY_CACHE: ++ del GLOBAL_POLICY_CACHE[agent_id] ++ logger.debug( ++ "Cleaned up policy cache from all entries used by agent %s", ++ agent_id, ++ ) + +- op_state = agent.operational_state +- if op_state in (states.SAVED, states.FAILED, states.TERMINATED, states.TENANT_FAILED, states.INVALID_QUOTE): +- try: +- verifier_db_delete_agent(session, agent_id) +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- web_util.echo_json_response(self, 200, "Success") +- logger.info("DELETE returning 200 response for agent id: %s", agent_id) +- else: +- try: +- update_agent = session.query(VerfierMain).get(agent_id) +- assert update_agent +- update_agent.operational_state = states.TERMINATED ++ op_state = agent.operational_state ++ if op_state in (states.SAVED, states.FAILED, states.TERMINATED, states.TENANT_FAILED, states.INVALID_QUOTE): + try: ++ verifier_db_delete_agent(session, agent_id) ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ web_util.echo_json_response(self, 200, "Success") ++ logger.info("DELETE returning 200 response for agent id: %s", agent_id) ++ else: ++ try: ++ update_agent = session.query(VerfierMain).get(agent_id) ++ assert update_agent ++ update_agent.operational_state = states.TERMINATED + session.add(update_agent) ++ # session.commit() is automatically called by context manager ++ web_util.echo_json_response(self, 202, "Accepted") ++ logger.info("DELETE returning 202 response for agent id: %s", agent_id) + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) +- session.commit() +- web_util.echo_json_response(self, 202, "Accepted") +- logger.info("DELETE returning 202 response for agent id: %s", agent_id) +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) + + def post(self) -> None: + """This method handles the POST requests to add agents to the Cloud Verifier. +@@ -502,7 +508,6 @@ class AgentsHandler(BaseHandler): + Currently, only agents resources are available for POSTing, i.e. /agents. All other POST uri's will return errors. + agents requests require a json block sent in the body + """ +- session = get_session() + # TODO: exception handling needs fixing + # Maybe handle exceptions with if/else if/else blocks ... simple and avoids nesting + try: # pylint: disable=too-many-nested-blocks +@@ -585,201 +590,208 @@ class AgentsHandler(BaseHandler): + runtime_policy = base64.b64decode(json_body.get("runtime_policy")).decode() + runtime_policy_stored = None + +- if runtime_policy_name: ++ with session_context() as session: ++ if runtime_policy_name: ++ try: ++ runtime_policy_stored = ( ++ session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none() ++ ) ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) ++ raise ++ ++ # Prevent overwriting existing IMA policies with name provided in request ++ if runtime_policy and runtime_policy_stored: ++ web_util.echo_json_response( ++ self, ++ 409, ++ f"IMA policy with name {runtime_policy_name} already exists. Please use a different name or delete the allowlist from the verifier.", ++ ) ++ logger.warning("IMA policy with name %s already exists", runtime_policy_name) ++ return ++ ++ # Return an error code if the named allowlist does not exist in the database ++ if not runtime_policy and not runtime_policy_stored: ++ web_util.echo_json_response( ++ self, 404, f"Could not find IMA policy with name {runtime_policy_name}!" ++ ) ++ logger.warning("Could not find IMA policy with name %s", runtime_policy_name) ++ return ++ ++ # Prevent overwriting existing agents with UUID provided in request + try: +- runtime_policy_stored = ( +- session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none() +- ) ++ new_agent_count = session.query(VerfierMain).filter_by(agent_id=agent_id).count() + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) +- raise ++ raise e + +- # Prevent overwriting existing IMA policies with name provided in request +- if runtime_policy and runtime_policy_stored: ++ if new_agent_count > 0: + web_util.echo_json_response( + self, + 409, +- f"IMA policy with name {runtime_policy_name} already exists. Please use a different name or delete the allowlist from the verifier.", ++ f"Agent of uuid {agent_id} already exists. Please use delete or update.", + ) +- logger.warning("IMA policy with name %s already exists", runtime_policy_name) ++ logger.warning("Agent of uuid %s already exists", agent_id) + return + +- # Return an error code if the named allowlist does not exist in the database +- if not runtime_policy and not runtime_policy_stored: +- web_util.echo_json_response( +- self, 404, f"Could not find IMA policy with name {runtime_policy_name}!" +- ) +- logger.warning("Could not find IMA policy with name %s", runtime_policy_name) +- return +- +- # Prevent overwriting existing agents with UUID provided in request +- try: +- new_agent_count = session.query(VerfierMain).filter_by(agent_id=agent_id).count() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) +- raise e +- +- if new_agent_count > 0: +- web_util.echo_json_response( +- self, +- 409, +- f"Agent of uuid {agent_id} already exists. Please use delete or update.", +- ) +- logger.warning("Agent of uuid %s already exists", agent_id) +- return +- +- # Write IMA policy to database if needed +- if not runtime_policy_name and not runtime_policy: +- logger.info("IMA policy data not provided with request! Using default empty IMA policy.") +- runtime_policy = json.dumps(cast(Dict[str, Any], ima.EMPTY_RUNTIME_POLICY)) ++ # Write IMA policy to database if needed ++ if not runtime_policy_name and not runtime_policy: ++ logger.info("IMA policy data not provided with request! Using default empty IMA policy.") ++ runtime_policy = json.dumps(cast(Dict[str, Any], ima.EMPTY_RUNTIME_POLICY)) + +- if runtime_policy: +- runtime_policy_key_bytes = signing.get_runtime_policy_keys( +- runtime_policy.encode(), +- json_body.get("runtime_policy_key"), +- ) +- +- try: +- ima.verify_runtime_policy( ++ if runtime_policy: ++ runtime_policy_key_bytes = signing.get_runtime_policy_keys( + runtime_policy.encode(), +- runtime_policy_key_bytes, +- verify_sig=config.getboolean( +- "verifier", "require_allow_list_signatures", fallback=False +- ), ++ json_body.get("runtime_policy_key"), + ) +- except ima.ImaValidationError as e: +- web_util.echo_json_response(self, e.code, e.message) +- logger.warning(e.message) +- return + +- if not runtime_policy_name: +- runtime_policy_name = agent_id +- +- try: +- runtime_policy_db_format = ima.runtime_policy_db_contents( +- runtime_policy_name, runtime_policy +- ) +- except ima.ImaValidationError as e: +- message = f"Runtime policy is malformatted: {e.message}" +- web_util.echo_json_response(self, e.code, message) +- logger.warning(message) +- return +- +- try: +- runtime_policy_stored = ( +- session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none() +- ) +- except SQLAlchemyError as e: +- logger.error( +- "SQLAlchemy Error while retrieving stored ima policy for agent ID %s: %s", agent_id, e +- ) +- raise +- try: +- if runtime_policy_stored is None: +- runtime_policy_stored = VerifierAllowlist(**runtime_policy_db_format) +- session.add(runtime_policy_stored) ++ try: ++ ima.verify_runtime_policy( ++ runtime_policy.encode(), ++ runtime_policy_key_bytes, ++ verify_sig=config.getboolean( ++ "verifier", "require_allow_list_signatures", fallback=False ++ ), ++ ) ++ except ima.ImaValidationError as e: ++ web_util.echo_json_response(self, e.code, e.message) ++ logger.warning(e.message) ++ return ++ ++ if not runtime_policy_name: ++ runtime_policy_name = agent_id ++ ++ try: ++ runtime_policy_db_format = ima.runtime_policy_db_contents( ++ runtime_policy_name, runtime_policy ++ ) ++ except ima.ImaValidationError as e: ++ message = f"Runtime policy is malformatted: {e.message}" ++ web_util.echo_json_response(self, e.code, message) ++ logger.warning(message) ++ return ++ ++ try: ++ runtime_policy_stored = ( ++ session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none() ++ ) ++ except SQLAlchemyError as e: ++ logger.error( ++ "SQLAlchemy Error while retrieving stored ima policy for agent ID %s: %s", ++ agent_id, ++ e, ++ ) ++ raise ++ try: ++ if runtime_policy_stored is None: ++ runtime_policy_stored = VerifierAllowlist(**runtime_policy_db_format) ++ session.add(runtime_policy_stored) ++ session.commit() ++ except SQLAlchemyError as e: ++ logger.error( ++ "SQLAlchemy Error while updating ima policy for agent ID %s: %s", agent_id, e ++ ) ++ raise ++ ++ # Handle measured boot policy ++ # - No name, mb_policy : store mb_policy using agent UUID as name ++ # - Name, no mb_policy : fetch existing mb_policy from DB ++ # - Name, mb_policy : store mb_policy using name ++ ++ mb_policy_name = json_body["mb_policy_name"] ++ mb_policy = json_body["mb_policy"] ++ mb_policy_stored = None ++ ++ if mb_policy_name: ++ try: ++ mb_policy_stored = ( ++ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none() ++ ) ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) ++ raise ++ ++ # Prevent overwriting existing mb_policy with name provided in request ++ if mb_policy and mb_policy_stored: ++ web_util.echo_json_response( ++ self, ++ 409, ++ f"mb_policy with name {mb_policy_name} already exists. Please use a different name or delete the mb_policy from the verifier.", ++ ) ++ logger.warning("mb_policy with name %s already exists", mb_policy_name) ++ return ++ ++ # Return error if the mb_policy is neither provided nor stored. ++ if not mb_policy and not mb_policy_stored: ++ web_util.echo_json_response( ++ self, 404, f"Could not find mb_policy with name {mb_policy_name}!" ++ ) ++ logger.warning("Could not find mb_policy with name %s", mb_policy_name) ++ return ++ ++ else: ++ # Use the UUID of the agent ++ mb_policy_name = agent_id ++ try: ++ mb_policy_stored = ( ++ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none() ++ ) ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) ++ raise ++ ++ # Prevent overwriting existing mb_policy ++ if mb_policy and mb_policy_stored: ++ web_util.echo_json_response( ++ self, ++ 409, ++ f"mb_policy with name {mb_policy_name} already exists. You can delete the mb_policy from the verifier.", ++ ) ++ logger.warning("mb_policy with name %s already exists", mb_policy_name) ++ return ++ ++ # Store the policy into database if not stored ++ if mb_policy_stored is None: ++ try: ++ mb_policy_db_format = mba.mb_policy_db_contents(mb_policy_name, mb_policy) ++ mb_policy_stored = VerifierMbpolicy(**mb_policy_db_format) ++ session.add(mb_policy_stored) + session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error while updating ima policy for agent ID %s: %s", agent_id, e) +- raise +- +- # Handle measured boot policy +- # - No name, mb_policy : store mb_policy using agent UUID as name +- # - Name, no mb_policy : fetch existing mb_policy from DB +- # - Name, mb_policy : store mb_policy using name +- +- mb_policy_name = json_body["mb_policy_name"] +- mb_policy = json_body["mb_policy"] +- mb_policy_stored = None ++ except SQLAlchemyError as e: ++ logger.error( ++ "SQLAlchemy Error while updating mb_policy for agent ID %s: %s", agent_id, e ++ ) ++ raise + +- if mb_policy_name: ++ # Write the agent to the database, attaching associated stored ima_policy and mb_policy + try: +- mb_policy_stored = ( +- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none() ++ assert runtime_policy_stored ++ assert mb_policy_stored ++ session.add( ++ VerfierMain(**agent_data, ima_policy=runtime_policy_stored, mb_policy=mb_policy_stored) + ) ++ session.commit() + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) +- raise ++ raise e + +- # Prevent overwriting existing mb_policy with name provided in request +- if mb_policy and mb_policy_stored: +- web_util.echo_json_response( +- self, +- 409, +- f"mb_policy with name {mb_policy_name} already exists. Please use a different name or delete the mb_policy from the verifier.", +- ) +- logger.warning("mb_policy with name %s already exists", mb_policy_name) +- return ++ # add default fields that are ephemeral ++ for key, val in exclude_db.items(): ++ agent_data[key] = val + +- # Return error if the mb_policy is neither provided nor stored. +- if not mb_policy and not mb_policy_stored: +- web_util.echo_json_response( +- self, 404, f"Could not find mb_policy with name {mb_policy_name}!" ++ # Prepare SSLContext for mTLS connections ++ agent_data["ssl_context"] = None ++ if agent_mtls_cert_enabled: ++ agent_data["ssl_context"] = web_util.generate_agent_tls_context( ++ "verifier", agent_data["mtls_cert"], logger=logger + ) +- logger.warning("Could not find mb_policy with name %s", mb_policy_name) +- return + +- else: +- # Use the UUID of the agent +- mb_policy_name = agent_id +- try: +- mb_policy_stored = ( +- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none() +- ) +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) +- raise +- +- # Prevent overwriting existing mb_policy +- if mb_policy and mb_policy_stored: +- web_util.echo_json_response( +- self, +- 409, +- f"mb_policy with name {mb_policy_name} already exists. You can delete the mb_policy from the verifier.", +- ) +- logger.warning("mb_policy with name %s already exists", mb_policy_name) +- return +- +- # Store the policy into database if not stored +- if mb_policy_stored is None: +- try: +- mb_policy_db_format = mba.mb_policy_db_contents(mb_policy_name, mb_policy) +- mb_policy_stored = VerifierMbpolicy(**mb_policy_db_format) +- session.add(mb_policy_stored) +- session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error while updating mb_policy for agent ID %s: %s", agent_id, e) +- raise ++ if agent_data["ssl_context"] is None: ++ logger.warning("Connecting to agent without mTLS: %s", agent_id) + +- # Write the agent to the database, attaching associated stored ima_policy and mb_policy +- try: +- assert runtime_policy_stored +- assert mb_policy_stored +- session.add( +- VerfierMain(**agent_data, ima_policy=runtime_policy_stored, mb_policy=mb_policy_stored) +- ) +- session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) +- raise e +- +- # add default fields that are ephemeral +- for key, val in exclude_db.items(): +- agent_data[key] = val +- +- # Prepare SSLContext for mTLS connections +- agent_data["ssl_context"] = None +- if agent_mtls_cert_enabled: +- agent_data["ssl_context"] = web_util.generate_agent_tls_context( +- "verifier", agent_data["mtls_cert"], logger=logger +- ) +- +- if agent_data["ssl_context"] is None: +- logger.warning("Connecting to agent without mTLS: %s", agent_id) +- +- asyncio.ensure_future(process_agent(agent_data, states.GET_QUOTE)) +- web_util.echo_json_response(self, 200, "Success") +- logger.info("POST returning 200 response for adding agent id: %s", agent_id) ++ asyncio.ensure_future(process_agent(agent_data, states.GET_QUOTE)) ++ web_util.echo_json_response(self, 200, "Success") ++ logger.info("POST returning 200 response for adding agent id: %s", agent_id) + else: + web_util.echo_json_response(self, 400, "uri not supported") + logger.warning("POST returning 400 response. uri not supported") +@@ -794,54 +806,54 @@ class AgentsHandler(BaseHandler): + Currently, only agents resources are available for PUTing, i.e. /agents. All other PUT uri's will return errors. + agents requests require a json block sent in the body + """ +- session = get_session() + try: + rest_params, agent_id = self.__validate_input("PUT") + if not rest_params: + return + +- try: +- verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID) +- db_agent = session.query(VerfierMain).filter_by(agent_id=agent_id, verifier_id=verifier_id).one() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) +- raise e ++ with session_context() as session: ++ try: ++ verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID) ++ db_agent = session.query(VerfierMain).filter_by(agent_id=agent_id, verifier_id=verifier_id).one() ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) ++ raise e + +- if db_agent is None: +- web_util.echo_json_response(self, 404, "agent id not found") +- logger.info("PUT returning 404 response. agent id: %s not found.", agent_id) +- return ++ if db_agent is None: ++ web_util.echo_json_response(self, 404, "agent id not found") ++ logger.info("PUT returning 404 response. agent id: %s not found.", agent_id) ++ return + +- if "reactivate" in rest_params: +- agent = _from_db_obj(db_agent) ++ if "reactivate" in rest_params: ++ agent = _from_db_obj(db_agent) + +- if agent["mtls_cert"] and agent["mtls_cert"] != "disabled": +- agent["ssl_context"] = web_util.generate_agent_tls_context( +- "verifier", agent["mtls_cert"], logger=logger +- ) +- if agent["ssl_context"] is None: +- logger.warning("Connecting to agent without mTLS: %s", agent_id) ++ if agent["mtls_cert"] and agent["mtls_cert"] != "disabled": ++ agent["ssl_context"] = web_util.generate_agent_tls_context( ++ "verifier", agent["mtls_cert"], logger=logger ++ ) ++ if agent["ssl_context"] is None: ++ logger.warning("Connecting to agent without mTLS: %s", agent_id) + +- agent["operational_state"] = states.START +- asyncio.ensure_future(process_agent(agent, states.GET_QUOTE)) +- web_util.echo_json_response(self, 200, "Success") +- logger.info("PUT returning 200 response for agent id: %s", agent_id) +- elif "stop" in rest_params: +- # do stuff for terminate +- logger.debug("Stopping polling on %s", agent_id) +- try: +- session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).update( # pyright: ignore +- {"operational_state": states.TENANT_FAILED} +- ) +- session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) ++ agent["operational_state"] = states.START ++ asyncio.ensure_future(process_agent(agent, states.GET_QUOTE)) ++ web_util.echo_json_response(self, 200, "Success") ++ logger.info("PUT returning 200 response for agent id: %s", agent_id) ++ elif "stop" in rest_params: ++ # do stuff for terminate ++ logger.debug("Stopping polling on %s", agent_id) ++ try: ++ session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).update( # pyright: ignore ++ {"operational_state": states.TENANT_FAILED} ++ ) ++ # session.commit() is automatically called by context manager ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) + +- web_util.echo_json_response(self, 200, "Success") +- logger.info("PUT returning 200 response for agent id: %s", agent_id) +- else: +- web_util.echo_json_response(self, 400, "uri not supported") +- logger.warning("PUT returning 400 response. uri not supported") ++ web_util.echo_json_response(self, 200, "Success") ++ logger.info("PUT returning 200 response for agent id: %s", agent_id) ++ else: ++ web_util.echo_json_response(self, 400, "uri not supported") ++ logger.warning("PUT returning 400 response. uri not supported") + + except Exception as e: + web_util.echo_json_response(self, 400, f"Exception error: {str(e)}") +@@ -887,36 +899,36 @@ class AllowlistHandler(BaseHandler): + if not params_valid: + return + +- session = get_session() +- if allowlist_name is None: +- try: +- names_allowlists = session.query(VerifierAllowlist.name).all() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- web_util.echo_json_response(self, 500, "Failed to get names of allowlists") +- raise ++ with session_context() as session: ++ if allowlist_name is None: ++ try: ++ names_allowlists = session.query(VerifierAllowlist.name).all() ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ web_util.echo_json_response(self, 500, "Failed to get names of allowlists") ++ raise + +- names_response = [] +- for name in names_allowlists: +- names_response.append(name[0]) +- web_util.echo_json_response(self, 200, "Success", {"runtimepolicy names": names_response}) ++ names_response = [] ++ for name in names_allowlists: ++ names_response.append(name[0]) ++ web_util.echo_json_response(self, 200, "Success", {"runtimepolicy names": names_response}) + +- else: +- try: +- allowlist = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one() +- except NoResultFound: +- web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found") +- return +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- web_util.echo_json_response(self, 500, "Failed to get allowlist") +- raise ++ else: ++ try: ++ allowlist = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one() ++ except NoResultFound: ++ web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found") ++ return ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ web_util.echo_json_response(self, 500, "Failed to get allowlist") ++ raise + +- response = {} +- for field in ("name", "tpm_policy"): +- response[field] = getattr(allowlist, field, None) +- response["runtime_policy"] = getattr(allowlist, "ima_policy", None) +- web_util.echo_json_response(self, 200, "Success", response) ++ response = {} ++ for field in ("name", "tmp_policy"): ++ response[field] = getattr(allowlist, field, None) ++ response["runtime_policy"] = getattr(allowlist, "ima_policy", None) ++ web_util.echo_json_response(self, 200, "Success", response) + + def delete(self) -> None: + """Delete an allowlist +@@ -928,45 +940,44 @@ class AllowlistHandler(BaseHandler): + if not params_valid or allowlist_name is None: + return + +- session = get_session() +- try: +- runtime_policy = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one() +- except NoResultFound: +- web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found") +- return +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- web_util.echo_json_response(self, 500, "Failed to get allowlist") +- raise ++ with session_context() as session: ++ try: ++ runtime_policy = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one() ++ except NoResultFound: ++ web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found") ++ return ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ web_util.echo_json_response(self, 500, "Failed to get allowlist") ++ raise + +- try: +- agent = session.query(VerfierMain).filter_by(ima_policy_id=runtime_policy.id).one_or_none() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- raise +- if agent is not None: +- web_util.echo_json_response( +- self, +- 409, +- f"Can't delete allowlist as it's currently in use by agent {agent.agent_id}", +- ) +- return ++ try: ++ agent = session.query(VerfierMain).filter_by(ima_policy_id=runtime_policy.id).one_or_none() ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ raise ++ if agent is not None: ++ web_util.echo_json_response( ++ self, ++ 409, ++ f"Can't delete allowlist as it's currently in use by agent {agent.agent_id}", ++ ) ++ return + +- try: +- session.query(VerifierAllowlist).filter_by(name=allowlist_name).delete() +- session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- session.close() +- web_util.echo_json_response(self, 500, f"Database error: {e}") +- raise ++ try: ++ session.query(VerifierAllowlist).filter_by(name=allowlist_name).delete() ++ # session.commit() is automatically called by context manager ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ web_util.echo_json_response(self, 500, f"Database error: {e}") ++ raise + +- # NOTE(kaifeng) 204 Can not have response body, but current helper +- # doesn't support this case. +- self.set_status(204) +- self.set_header("Content-Type", "application/json") +- self.finish() +- logger.info("DELETE returning 204 response for allowlist: %s", allowlist_name) ++ # NOTE(kaifeng) 204 Can not have response body, but current helper ++ # doesn't support this case. ++ self.set_status(204) ++ self.set_header("Content-Type", "application/json") ++ self.finish() ++ logger.info("DELETE returning 204 response for allowlist: %s", allowlist_name) + + def __get_runtime_policy_db_format(self, runtime_policy_name: str) -> Dict[str, Any]: + """Get the IMA policy from the request and return it in Db format""" +@@ -1022,28 +1033,30 @@ class AllowlistHandler(BaseHandler): + if not runtime_policy_db_format: + return + +- session = get_session() +- # don't allow overwritting +- try: +- runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count() +- if runtime_policy_count > 0: +- web_util.echo_json_response(self, 409, f"Runtime policy with name {runtime_policy_name} already exists") +- logger.warning("Runtime policy with name %s already exists", runtime_policy_name) +- return +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- raise ++ with session_context() as session: ++ # don't allow overwritting ++ try: ++ runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count() ++ if runtime_policy_count > 0: ++ web_util.echo_json_response( ++ self, 409, f"Runtime policy with name {runtime_policy_name} already exists" ++ ) ++ logger.warning("Runtime policy with name %s already exists", runtime_policy_name) ++ return ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ raise + +- try: +- # Add the agent and data +- session.add(VerifierAllowlist(**runtime_policy_db_format)) +- session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- raise ++ try: ++ # Add the agent and data ++ session.add(VerifierAllowlist(**runtime_policy_db_format)) ++ # session.commit() is automatically called by context manager ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ raise + +- web_util.echo_json_response(self, 201) +- logger.info("POST returning 201") ++ web_util.echo_json_response(self, 201) ++ logger.info("POST returning 201") + + def put(self) -> None: + """Update an allowlist +@@ -1060,32 +1073,34 @@ class AllowlistHandler(BaseHandler): + if not runtime_policy_db_format: + return + +- session = get_session() +- # don't allow creating a new policy +- try: +- runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count() +- if runtime_policy_count != 1: +- web_util.echo_json_response( +- self, 409, f"Runtime policy with name {runtime_policy_name} does not already exist" +- ) +- logger.warning("Runtime policy with name %s does not already exist", runtime_policy_name) +- return +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- raise ++ with session_context() as session: ++ # don't allow creating a new policy ++ try: ++ runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count() ++ if runtime_policy_count != 1: ++ web_util.echo_json_response( ++ self, ++ 404, ++ f"Runtime policy with name {runtime_policy_name} does not already exist, use POST to create", ++ ) ++ logger.warning("Runtime policy with name %s does not already exist", runtime_policy_name) ++ return ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ raise + +- try: +- # Update the named runtime policy +- session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).update( +- runtime_policy_db_format # pyright: ignore +- ) +- session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- raise ++ try: ++ # Update the named runtime policy ++ session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).update( ++ runtime_policy_db_format # pyright: ignore ++ ) ++ # session.commit() is automatically called by context manager ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ raise + +- web_util.echo_json_response(self, 201) +- logger.info("PUT returning 201") ++ web_util.echo_json_response(self, 201) ++ logger.info("PUT returning 201") + + def data_received(self, chunk: Any) -> None: + raise NotImplementedError() +@@ -1113,8 +1128,6 @@ class VerifyIdentityHandler(BaseHandler): + + This is useful for 3rd party tools and integrations to independently verify the state of an agent. + """ +- session = get_session() +- + # validate the parameters of our request + if self.request.uri is None: + web_util.echo_json_response(self, 400, "URI not specified") +@@ -1159,36 +1172,37 @@ class VerifyIdentityHandler(BaseHandler): + return + + # get the agent information from the DB +- agent = None +- try: +- agent = ( +- session.query(VerfierMain) +- .options( # type: ignore +- joinedload(VerfierMain.ima_policy).load_only( +- VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore ++ with session_context() as session: ++ agent = None ++ try: ++ agent = ( ++ session.query(VerfierMain) ++ .options( # type: ignore ++ joinedload(VerfierMain.ima_policy).load_only( ++ VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore ++ ) + ) ++ .filter_by(agent_id=agent_id) ++ .one_or_none() + ) +- .filter_by(agent_id=agent_id) +- .one_or_none() +- ) +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) + +- if agent is not None: +- agentAttestState = get_AgentAttestStates().get_by_agent_id(agent_id) +- failure = cloud_verifier_common.process_verify_identity_quote( +- agent, quote, nonce, hash_alg, agentAttestState +- ) +- if failure: +- failure_contexts = "; ".join(x.context for x in failure.events) +- web_util.echo_json_response(self, 200, "Success", {"valid": 0, "reason": failure_contexts}) +- logger.info("GET returning 200, but validation failed") ++ if agent is not None: ++ agentAttestState = get_AgentAttestStates().get_by_agent_id(agent_id) ++ failure = cloud_verifier_common.process_verify_identity_quote( ++ agent, quote, nonce, hash_alg, agentAttestState ++ ) ++ if failure: ++ failure_contexts = "; ".join(x.context for x in failure.events) ++ web_util.echo_json_response(self, 200, "Success", {"valid": 0, "reason": failure_contexts}) ++ logger.info("GET returning 200, but validation failed") ++ else: ++ web_util.echo_json_response(self, 200, "Success", {"valid": 1}) ++ logger.info("GET returning 200, validation successful") + else: +- web_util.echo_json_response(self, 200, "Success", {"valid": 1}) +- logger.info("GET returning 200, validation successful") +- else: +- web_util.echo_json_response(self, 404, "agent id not found") +- logger.info("GET returning 404, agaent not found") ++ web_util.echo_json_response(self, 404, "agent id not found") ++ logger.info("GET returning 404, agaent not found") + + def data_received(self, chunk: Any) -> None: + raise NotImplementedError() +@@ -1231,35 +1245,35 @@ class MbpolicyHandler(BaseHandler): + if not params_valid: + return + +- session = get_session() +- if mb_policy_name is None: +- try: +- names_mbpolicies = session.query(VerifierMbpolicy.name).all() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- web_util.echo_json_response(self, 500, "Failed to get names of mbpolicies") +- raise ++ with session_context() as session: ++ if mb_policy_name is None: ++ try: ++ names_mbpolicies = session.query(VerifierMbpolicy.name).all() ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ web_util.echo_json_response(self, 500, "Failed to get names of mbpolicies") ++ raise + +- names_response = [] +- for name in names_mbpolicies: +- names_response.append(name[0]) +- web_util.echo_json_response(self, 200, "Success", {"mbpolicy names": names_response}) ++ names_response = [] ++ for name in names_mbpolicies: ++ names_response.append(name[0]) ++ web_util.echo_json_response(self, 200, "Success", {"mbpolicy names": names_response}) + +- else: +- try: +- mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one() +- except NoResultFound: +- web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found") +- return +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- web_util.echo_json_response(self, 500, "Failed to get mb_policy") +- raise ++ else: ++ try: ++ mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one() ++ except NoResultFound: ++ web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found") ++ return ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ web_util.echo_json_response(self, 500, "Failed to get mb_policy") ++ raise + +- response = {} +- response["name"] = getattr(mbpolicy, "name", None) +- response["mb_policy"] = getattr(mbpolicy, "mb_policy", None) +- web_util.echo_json_response(self, 200, "Success", response) ++ response = {} ++ response["name"] = getattr(mbpolicy, "name", None) ++ response["mb_policy"] = getattr(mbpolicy, "mb_policy", None) ++ web_util.echo_json_response(self, 200, "Success", response) + + def delete(self) -> None: + """Delete a mb_policy +@@ -1271,45 +1285,44 @@ class MbpolicyHandler(BaseHandler): + if not params_valid or mb_policy_name is None: + return + +- session = get_session() +- try: +- mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one() +- except NoResultFound: +- web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found") +- return +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- web_util.echo_json_response(self, 500, "Failed to get mb_policy") +- raise ++ with session_context() as session: ++ try: ++ mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one() ++ except NoResultFound: ++ web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found") ++ return ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ web_util.echo_json_response(self, 500, "Failed to get mb_policy") ++ raise + +- try: +- agent = session.query(VerfierMain).filter_by(mb_policy_id=mbpolicy.id).one_or_none() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- raise +- if agent is not None: +- web_util.echo_json_response( +- self, +- 409, +- f"Can't delete mb_policy as it's currently in use by agent {agent.agent_id}", +- ) +- return ++ try: ++ agent = session.query(VerfierMain).filter_by(mb_policy_id=mbpolicy.id).one_or_none() ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ raise ++ if agent is not None: ++ web_util.echo_json_response( ++ self, ++ 409, ++ f"Can't delete mb_policy as it's currently in use by agent {agent.agent_id}", ++ ) ++ return + +- try: +- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).delete() +- session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- session.close() +- web_util.echo_json_response(self, 500, f"Database error: {e}") +- raise ++ try: ++ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).delete() ++ # session.commit() is automatically called by context manager ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ web_util.echo_json_response(self, 500, f"Database error: {e}") ++ raise + +- # NOTE(kaifeng) 204 Can not have response body, but current helper +- # doesn't support this case. +- self.set_status(204) +- self.set_header("Content-Type", "application/json") +- self.finish() +- logger.info("DELETE returning 204 response for mb_policy: %s", mb_policy_name) ++ # NOTE(kaifeng) 204 Can not have response body, but current helper ++ # doesn't support this case. ++ self.set_status(204) ++ self.set_header("Content-Type", "application/json") ++ self.finish() ++ logger.info("DELETE returning 204 response for mb_policy: %s", mb_policy_name) + + def __get_mb_policy_db_format(self, mb_policy_name: str) -> Dict[str, Any]: + """Get the measured boot policy from the request and return it in Db format""" +@@ -1341,30 +1354,30 @@ class MbpolicyHandler(BaseHandler): + if not mb_policy_db_format: + return + +- session = get_session() +- # don't allow overwritting +- try: +- mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count() +- if mbpolicy_count > 0: +- web_util.echo_json_response( +- self, 409, f"Measured boot policy with name {mb_policy_name} already exists" +- ) +- logger.warning("Measured boot policy with name %s already exists", mb_policy_name) +- return +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- raise ++ with session_context() as session: ++ # don't allow overwritting ++ try: ++ mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count() ++ if mbpolicy_count > 0: ++ web_util.echo_json_response( ++ self, 409, f"Measured boot policy with name {mb_policy_name} already exists" ++ ) ++ logger.warning("Measured boot policy with name %s already exists", mb_policy_name) ++ return ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ raise + +- try: +- # Add the data +- session.add(VerifierMbpolicy(**mb_policy_db_format)) +- session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- raise ++ try: ++ # Add the data ++ session.add(VerifierMbpolicy(**mb_policy_db_format)) ++ # session.commit() is automatically called by context manager ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ raise + +- web_util.echo_json_response(self, 201) +- logger.info("POST returning 201") ++ web_util.echo_json_response(self, 201) ++ logger.info("POST returning 201") + + def put(self) -> None: + """Update an mb_policy +@@ -1381,32 +1394,32 @@ class MbpolicyHandler(BaseHandler): + if not mb_policy_db_format: + return + +- session = get_session() +- # don't allow creating a new policy +- try: +- mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count() +- if mbpolicy_count != 1: +- web_util.echo_json_response( +- self, 409, f"Measured boot policy with name {mb_policy_name} does not already exist" +- ) +- logger.warning("Measured boot policy with name %s does not already exist", mb_policy_name) +- return +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- raise ++ with session_context() as session: ++ # don't allow creating a new policy ++ try: ++ mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count() ++ if mbpolicy_count != 1: ++ web_util.echo_json_response( ++ self, 409, f"Measured boot policy with name {mb_policy_name} does not already exist" ++ ) ++ logger.warning("Measured boot policy with name %s does not already exist", mb_policy_name) ++ return ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ raise + +- try: +- # Update the named mb_policy +- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).update( +- mb_policy_db_format # pyright: ignore +- ) +- session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) +- raise ++ try: ++ # Update the named mb_policy ++ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).update( ++ mb_policy_db_format # pyright: ignore ++ ) ++ # session.commit() is automatically called by context manager ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) ++ raise + +- web_util.echo_json_response(self, 201) +- logger.info("PUT returning 201") ++ web_util.echo_json_response(self, 201) ++ logger.info("PUT returning 201") + + def data_received(self, chunk: Any) -> None: + raise NotImplementedError() +@@ -1460,17 +1473,18 @@ async def update_agent_api_version(agent: Dict[str, Any], timeout: float = 60.0) + return None + + logger.info("Agent %s new API version %s is supported", agent_id, new_version) +- session = get_session() +- agent["supported_version"] = new_version + +- # Remove keys that should not go to the DB +- agent_db = dict(agent) +- for key in exclude_db: +- if key in agent_db: +- del agent_db[key] ++ with session_context() as session: ++ agent["supported_version"] = new_version + +- session.query(VerfierMain).filter_by(agent_id=agent_id).update(agent_db) # pyright: ignore +- session.commit() ++ # Remove keys that should not go to the DB ++ agent_db = dict(agent) ++ for key in exclude_db: ++ if key in agent_db: ++ del agent_db[key] ++ ++ session.query(VerfierMain).filter_by(agent_id=agent_id).update(agent_db) # pyright: ignore ++ # session.commit() is automatically called by context manager + else: + logger.warning("Agent %s new API version %s is not supported", agent_id, new_version) + return None +@@ -1718,50 +1732,68 @@ async def notify_error( + revocation_notifier.notify(tosend) + if "agent" in notifiers: + verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID) +- session = get_session() +- agents = session.query(VerfierMain).filter_by(verifier_id=verifier_id).all() +- futures = [] +- loop = asyncio.get_event_loop() +- # Notify all agents asynchronously through a thread pool +- with ThreadPoolExecutor() as pool: +- for agent_db_obj in agents: +- if agent_db_obj.agent_id != agent["agent_id"]: +- agent = _from_db_obj(agent_db_obj) +- if agent["mtls_cert"] and agent["mtls_cert"] != "disabled": +- agent["ssl_context"] = web_util.generate_agent_tls_context( +- "verifier", agent["mtls_cert"], logger=logger +- ) +- func = functools.partial(invoke_notify_error, agent, tosend, timeout=timeout) +- futures.append(await loop.run_in_executor(pool, func)) +- # Wait for all tasks complete in 60 seconds +- try: +- for f in asyncio.as_completed(futures, timeout=60): +- await f +- except asyncio.TimeoutError as e: +- logger.error("Timeout during notifying error to agents: %s", e) ++ with session_context() as session: ++ agents = session.query(VerfierMain).filter_by(verifier_id=verifier_id).all() ++ futures = [] ++ loop = asyncio.get_event_loop() ++ # Notify all agents asynchronously through a thread pool ++ with ThreadPoolExecutor() as pool: ++ for agent_db_obj in agents: ++ if agent_db_obj.agent_id != agent["agent_id"]: ++ agent = _from_db_obj(agent_db_obj) ++ if agent["mtls_cert"] and agent["mtls_cert"] != "disabled": ++ agent["ssl_context"] = web_util.generate_agent_tls_context( ++ "verifier", agent["mtls_cert"], logger=logger ++ ) ++ func = functools.partial(invoke_notify_error, agent, tosend, timeout=timeout) ++ futures.append(await loop.run_in_executor(pool, func)) ++ # Wait for all tasks complete in 60 seconds ++ try: ++ for f in asyncio.as_completed(futures, timeout=60): ++ await f ++ except asyncio.TimeoutError as e: ++ logger.error("Timeout during notifying error to agents: %s", e) + + + async def process_agent( + agent: Dict[str, Any], new_operational_state: int, failure: Failure = Failure(Component.INTERNAL, ["verifier"]) + ) -> None: +- session = get_session() + try: # pylint: disable=R1702 + main_agent_operational_state = agent["operational_state"] + stored_agent = None +- try: +- stored_agent = ( +- session.query(VerfierMain) +- .options( # type: ignore +- joinedload(VerfierMain.ima_policy).load_only(VerifierAllowlist.checksum) # pyright: ignore +- ) +- .options( # type: ignore +- joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore ++ ++ # First database operation - read agent data and extract all needed data within session context ++ ima_policy_data = {} ++ mb_policy_data = None ++ with session_context() as session: ++ try: ++ stored_agent = ( ++ session.query(VerfierMain) ++ .options( # type: ignore ++ joinedload(VerfierMain.ima_policy) # Load full IMA policy object including content ++ ) ++ .options( # type: ignore ++ joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore ++ ) ++ .filter_by(agent_id=str(agent["agent_id"])) ++ .first() + ) +- .filter_by(agent_id=str(agent["agent_id"])) +- .first() +- ) +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e) ++ ++ # Extract IMA policy data within session context to avoid DetachedInstanceError ++ if stored_agent and stored_agent.ima_policy: ++ ima_policy_data = { ++ "checksum": str(stored_agent.ima_policy.checksum), ++ "name": stored_agent.ima_policy.name, ++ "agent_id": str(stored_agent.agent_id), ++ "ima_policy": stored_agent.ima_policy.ima_policy, # Extract the large content too ++ } ++ ++ # Extract MB policy data within session context ++ if stored_agent and stored_agent.mb_policy: ++ mb_policy_data = stored_agent.mb_policy.mb_policy ++ ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e) + + # if the stored agent could not be recovered from the database, stop polling + if not stored_agent: +@@ -1775,7 +1807,10 @@ async def process_agent( + logger.warning("Agent %s terminated by user.", agent["agent_id"]) + if agent["pending_event"] is not None: + tornado.ioloop.IOLoop.current().remove_timeout(agent["pending_event"]) +- verifier_db_delete_agent(session, agent["agent_id"]) ++ ++ # Second database operation - delete agent ++ with session_context() as session: ++ verifier_db_delete_agent(session, agent["agent_id"]) + return + + # if the user tells us to stop polling because the tenant quote check failed +@@ -1808,11 +1843,16 @@ async def process_agent( + if not failure.recoverable or failure.highest_severity == MAX_SEVERITY_LABEL: + if agent["pending_event"] is not None: + tornado.ioloop.IOLoop.current().remove_timeout(agent["pending_event"]) +- for key in exclude_db: +- if key in agent: +- del agent[key] +- session.query(VerfierMain).filter_by(agent_id=agent["agent_id"]).update(agent) # pyright: ignore +- session.commit() ++ ++ # Third database operation - update agent with failure state ++ with session_context() as session: ++ for key in exclude_db: ++ if key in agent: ++ del agent[key] ++ session.query(VerfierMain).filter_by(agent_id=agent["agent_id"]).update( ++ agent # type: ignore[arg-type] ++ ) ++ # session.commit() is automatically called by context manager + + # propagate all state, but remove none DB keys first (using exclude_db) + try: +@@ -1821,18 +1861,18 @@ async def process_agent( + if key in agent_db: + del agent_db[key] + +- session.query(VerfierMain).filter_by(agent_id=agent_db["agent_id"]).update(agent_db) # pyright: ignore +- session.commit() ++ # Fourth database operation - update agent state ++ with session_context() as session: ++ session.query(VerfierMain).filter_by(agent_id=agent_db["agent_id"]).update(agent_db) # pyright: ignore ++ # session.commit() is automatically called by context manager + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e) + + # Load agent's IMA policy +- runtime_policy = verifier_read_policy_from_cache(stored_agent) ++ runtime_policy = verifier_read_policy_from_cache(ima_policy_data) + + # Get agent's measured boot policy +- mb_policy = None +- if stored_agent.mb_policy is not None: +- mb_policy = stored_agent.mb_policy.mb_policy ++ mb_policy = mb_policy_data + + # If agent was in a failed state we check if we either stop polling + # or just add it again to the event loop +@@ -1876,7 +1916,14 @@ async def process_agent( + ) + + pending = tornado.ioloop.IOLoop.current().call_later( +- interval, invoke_get_quote, agent, mb_policy, runtime_policy, False, timeout=timeout # type: ignore # due to python <3.9 ++ # type: ignore # due to python <3.9 ++ interval, ++ invoke_get_quote, ++ agent, ++ mb_policy, ++ runtime_policy, ++ False, ++ timeout=timeout, + ) + agent["pending_event"] = pending + return +@@ -1911,7 +1958,14 @@ async def process_agent( + next_retry, + ) + tornado.ioloop.IOLoop.current().call_later( +- next_retry, invoke_get_quote, agent, mb_policy, runtime_policy, True, timeout=timeout # type: ignore # due to python <3.9 ++ # type: ignore # due to python <3.9 ++ next_retry, ++ invoke_get_quote, ++ agent, ++ mb_policy, ++ runtime_policy, ++ True, ++ timeout=timeout, + ) + return + +@@ -1980,9 +2034,9 @@ async def activate_agents(agents: List[VerfierMain], verifier_ip: str, verifier_ + + + def get_agents_by_verifier_id(verifier_id: str) -> List[VerfierMain]: +- session = get_session() + try: +- return session.query(VerfierMain).filter_by(verifier_id=verifier_id).all() ++ with session_context() as session: ++ return session.query(VerfierMain).filter_by(verifier_id=verifier_id).all() + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + return [] +@@ -2007,20 +2061,20 @@ def main() -> None: + os.umask(0o077) + + VerfierMain.metadata.create_all(engine, checkfirst=True) # pyright: ignore +- session = get_session() +- try: +- query_all = session.query(VerfierMain).all() +- for row in query_all: +- if row.operational_state in states.APPROVED_REACTIVATE_STATES: +- row.operational_state = states.START # pyright: ignore +- session.commit() +- except SQLAlchemyError as e: +- logger.error("SQLAlchemy Error: %s", e) ++ with session_context() as session: ++ try: ++ query_all = session.query(VerfierMain).all() ++ for row in query_all: ++ if row.operational_state in states.APPROVED_REACTIVATE_STATES: ++ row.operational_state = states.START # pyright: ignore ++ # session.commit() is automatically called by context manager ++ except SQLAlchemyError as e: ++ logger.error("SQLAlchemy Error: %s", e) + +- num = session.query(VerfierMain.agent_id).count() +- if num > 0: +- agent_ids = session.query(VerfierMain.agent_id).all() +- logger.info("Agent ids in db loaded from file: %s", agent_ids) ++ num = session.query(VerfierMain.agent_id).count() ++ if num > 0: ++ agent_ids = session.query(VerfierMain.agent_id).all() ++ logger.info("Agent ids in db loaded from file: %s", agent_ids) + + logger.info("Starting Cloud Verifier (tornado) on port %s, use to stop", verifier_port) + +diff --git a/keylime/da/examples/sqldb.py b/keylime/da/examples/sqldb.py +index 8efc84e..04a8afb 100644 +--- a/keylime/da/examples/sqldb.py ++++ b/keylime/da/examples/sqldb.py +@@ -1,7 +1,10 @@ + import time ++from contextlib import contextmanager ++from typing import Iterator + + import sqlalchemy + import sqlalchemy.ext.declarative ++from sqlalchemy.orm import sessionmaker + + from keylime import keylime_logging + from keylime.da.record import BaseRecordManagement, base_build_key_list +@@ -45,23 +48,23 @@ class RecordManagement(BaseRecordManagement): + BaseRecordManagement.__init__(self, service) + + self.engine = sqlalchemy.create_engine(self.ps_url._replace(fragment="").geturl(), pool_recycle=1800) +- sm = sqlalchemy.orm.sessionmaker() +- self.session = sqlalchemy.orm.scoped_session(sm) +- self.session.configure(bind=self.engine) +- TableBase.metadata.create_all(self.engine) +- +- def agent_list_retrieval(self, record_prefix="auto", service="auto"): +- if record_prefix == "auto": +- record_prefix = "" +- +- agent_list = [] ++ self.SessionLocal = sessionmaker(bind=self.engine) + +- recordtype = self.get_record_type(service) +- tbl = type2table(recordtype) +- for agentid in self.session.query(tbl.agentid).distinct(): # pylint: disable=no-member +- agent_list.append(agentid[0]) ++ # Create tables if they don't exist ++ TableBase.metadata.create_all(self.engine) + +- return agent_list ++ @contextmanager ++ def session_context(self) -> Iterator: ++ """Context manager for database sessions that ensures proper cleanup.""" ++ session = self.SessionLocal() ++ try: ++ yield session ++ session.commit() ++ except Exception: ++ session.rollback() ++ raise ++ finally: ++ session.close() + + def record_create( + self, +@@ -84,8 +87,9 @@ class RecordManagement(BaseRecordManagement): + d = {"time": recordtime, "agentid": agentid, "record": rcrd} + + try: +- self.session.add((type2table(recordtype))(**d)) # pylint: disable=no-member +- self.session.commit() # pylint: disable=no-member ++ with self.session_context() as session: ++ session.add((type2table(recordtype))(**d)) ++ # session.commit() is automatically called by context manager + except Exception as e: + logger.error("Failed to create attestation record: %s", e) + +@@ -106,23 +110,23 @@ class RecordManagement(BaseRecordManagement): + if f"{end_date}" == "auto": + end_date = self.end_of_times + +- if self.only_last_record_wanted(start_date, end_date): +- attestion_record_rows = ( +- self.session.query(tbl) # pylint: disable=no-member +- .filter(tbl.agentid == record_identifier) +- .order_by(sqlalchemy.desc(tbl.time)) +- .limit(1) +- ) +- +- else: +- attestion_record_rows = self.session.query(tbl).filter( # pylint: disable=no-member +- tbl.agentid == record_identifier +- ) +- +- for row in attestion_record_rows: +- decoded_record_object = self.record_deserialize(row.record) +- self.record_signature_check(decoded_record_object, record_identifier) +- record_list.append(decoded_record_object) ++ with self.session_context() as session: ++ if self.only_last_record_wanted(start_date, end_date): ++ attestion_record_rows = ( ++ session.query(tbl) ++ .filter(tbl.agentid == record_identifier) ++ .order_by(sqlalchemy.desc(tbl.time)) ++ .limit(1) ++ ) ++ ++ else: ++ attestion_record_rows = session.query(tbl).filter(tbl.agentid == record_identifier) ++ ++ for row in attestion_record_rows: ++ decoded_record_object = self.record_deserialize(row.record) ++ self.record_signature_check(decoded_record_object, record_identifier) ++ record_list.append(decoded_record_object) ++ + return record_list + + def build_key_list(self, agent_identifier, service="auto"): +diff --git a/keylime/db/keylime_db.py b/keylime/db/keylime_db.py +index 5620a28..aa49e51 100644 +--- a/keylime/db/keylime_db.py ++++ b/keylime/db/keylime_db.py +@@ -1,7 +1,8 @@ + import os + from configparser import NoOptionError ++from contextlib import contextmanager + from sqlite3 import Connection as SQLite3Connection +-from typing import Any, Dict, Optional, cast ++from typing import Any, Iterator, Optional, cast + + from sqlalchemy import create_engine, event + from sqlalchemy.engine import Engine +@@ -22,90 +23,108 @@ def _set_sqlite_pragma(dbapi_connection: SQLite3Connection, _) -> None: + cursor.close() + + +-class DBEngineManager: +- service: Optional[str] +- +- def __init__(self) -> None: +- self.service = None +- +- def make_engine(self, service: str) -> Engine: +- """ +- To use: engine = self.make_engine('cloud_verifier') +- """ +- +- # Keep DB related stuff as it is, but read configuration from new +- # configs +- if service == "cloud_verifier": +- config_service = "verifier" +- else: +- config_service = service +- +- self.service = service +- +- try: +- p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl") +- p_sz, m_ovfl = p_sz_m_ovfl.split(",") +- except NoOptionError: +- p_sz = "5" +- m_ovfl = "10" +- +- engine_args: Dict[str, Any] = {} +- +- url = config.get(config_service, "database_url") +- if url: +- logger.info("database_url is set, using it to establish database connection") +- +- # If the keyword sqlite is provided as the database url, use the +- # cv_data.sqlite for the verifier or the file reg_data.sqlite for +- # the registrar, located at the config.WORK_DIR directory +- if url == "sqlite": ++def make_engine(service: str, **engine_args: Any) -> Engine: ++ """Create a database engine for a keylime service.""" ++ # Keep DB related stuff as it is, but read configuration from new ++ # configs ++ if service == "cloud_verifier": ++ config_service = "verifier" ++ else: ++ config_service = service ++ ++ url = config.get(config_service, "database_url") ++ if url: ++ logger.info("database_url is set, using it to establish database connection") ++ ++ # If the keyword sqlite is provided as the database url, use the ++ # cv_data.sqlite for the verifier or the file reg_data.sqlite for ++ # the registrar, located at the config.WORK_DIR directory ++ if url == "sqlite": ++ logger.info( ++ "database_url is set as 'sqlite' keyword, using default values to establish database connection" ++ ) ++ if service == "cloud_verifier": ++ database = "cv_data.sqlite" ++ elif service == "registrar": ++ database = "reg_data.sqlite" ++ else: ++ logger.error("Tried to setup database access for unknown service '%s'", service) ++ raise Exception(f"Unknown service '{service}' for database setup") ++ ++ database_file = os.path.abspath(os.path.join(config.WORK_DIR, database)) ++ url = f"sqlite:///{database_file}" ++ ++ kl_dir = os.path.dirname(os.path.abspath(database_file)) ++ if not os.path.exists(kl_dir): ++ os.makedirs(kl_dir, 0o700) ++ ++ engine_args["connect_args"] = {"check_same_thread": False} ++ ++ if not url.count("sqlite:"): ++ # sqlite does not support setting pool size and max overflow, only ++ # read from the config when it is going to be used ++ try: ++ p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl") ++ p_sz, m_ovfl = p_sz_m_ovfl.split(",") ++ logger.info("database_pool_sz_ovfl is set, pool size = %s, max overflow = %s", p_sz, m_ovfl) ++ except NoOptionError: ++ p_sz = "5" ++ m_ovfl = "10" + logger.info( +- "database_url is set as 'sqlite' keyword, using default values to establish database connection" ++ "database_pool_sz_ovfl is not set, using default pool size = %s, max overflow = %s", p_sz, m_ovfl + ) +- if service == "cloud_verifier": +- database = "cv_data.sqlite" +- elif service == "registrar": +- database = "reg_data.sqlite" +- else: +- logger.error("Tried to setup database access for unknown service '%s'", service) +- raise Exception(f"Unknown service '{service}' for database setup") +- +- database_file = os.path.abspath(os.path.join(config.WORK_DIR, database)) +- url = f"sqlite:///{database_file}" +- +- kl_dir = os.path.dirname(os.path.abspath(database_file)) +- if not os.path.exists(kl_dir): +- os.makedirs(kl_dir, 0o700) +- +- engine_args["connect_args"] = {"check_same_thread": False} + +- if not url.count("sqlite:"): +- engine_args["pool_size"] = int(p_sz) +- engine_args["max_overflow"] = int(m_ovfl) +- engine_args["pool_pre_ping"] = True ++ engine_args["pool_size"] = int(p_sz) ++ engine_args["max_overflow"] = int(m_ovfl) ++ engine_args["pool_pre_ping"] = True + +- # Enable DB debugging +- if config.DEBUG_DB and config.INSECURE_DEBUG: +- engine_args["echo"] = True ++ # Enable DB debugging ++ if config.DEBUG_DB and config.INSECURE_DEBUG: ++ engine_args["echo"] = True + +- engine = create_engine(url, **engine_args) +- return engine ++ engine = create_engine(url, **engine_args) ++ return engine + + + class SessionManager: + engine: Optional[Engine] ++ _scoped_session: Optional[scoped_session] + + def __init__(self) -> None: + self.engine = None ++ self._scoped_session = None + + def make_session(self, engine: Engine) -> Session: + """ + To use: session = self.make_session(engine) + """ + self.engine = engine +- my_session = scoped_session(sessionmaker()) ++ if self._scoped_session is None: ++ self._scoped_session = scoped_session(sessionmaker()) + try: +- my_session.configure(bind=self.engine) # type: ignore ++ self._scoped_session.configure(bind=self.engine) # type: ignore ++ self._scoped_session.configure(expire_on_commit=False) # type: ignore + except SQLAlchemyError as err: + logger.error("Error creating SQL session manager %s", err) +- return cast(Session, my_session()) ++ return cast(Session, self._scoped_session()) ++ ++ @contextmanager ++ def session_context(self, engine: Engine) -> Iterator[Session]: ++ """ ++ Context manager for database sessions that ensures proper cleanup. ++ To use: ++ with session_manager.session_context(engine) as session: ++ # use session ++ """ ++ session = self.make_session(engine) ++ try: ++ yield session ++ session.commit() ++ except Exception: ++ session.rollback() ++ raise ++ finally: ++ # Important: remove the session from the scoped session registry ++ # to prevent connection leaks with scoped_session ++ if self._scoped_session is not None: ++ self._scoped_session.remove() # type: ignore[no-untyped-call] +diff --git a/keylime/migrations/env.py b/keylime/migrations/env.py +index ac98349..a1881f2 100644 +--- a/keylime/migrations/env.py ++++ b/keylime/migrations/env.py +@@ -8,7 +8,7 @@ import sys + + from alembic import context + +-from keylime.db.keylime_db import DBEngineManager ++from keylime.db.keylime_db import make_engine + from keylime.db.registrar_db import Base as RegistrarBase + from keylime.db.verifier_db import Base as VerifierBase + +@@ -74,7 +74,7 @@ def run_migrations_offline(): + logger.info("Writing output to %s", file_) + + with open(file_, "w", encoding="utf-8") as buffer: +- engine = DBEngineManager().make_engine(name) ++ engine = make_engine(name) + connection = engine.connect() + context.configure( + connection=connection, +@@ -102,7 +102,7 @@ def run_migrations_online(): + engines = {} + for name in re.split(r",\s*", db_names): + engines[name] = rec = {} +- rec["engine"] = DBEngineManager().make_engine(name) ++ rec["engine"] = make_engine(name) + + for name, rec in engines.items(): + engine = rec["engine"] +diff --git a/keylime/models/base/db.py b/keylime/models/base/db.py +index dd47d63..0229765 100644 +--- a/keylime/models/base/db.py ++++ b/keylime/models/base/db.py +@@ -41,13 +41,6 @@ class DBManager: + + self._service = service + +- try: +- p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl") +- p_sz, m_ovfl = p_sz_m_ovfl.split(",") +- except NoOptionError: +- p_sz = "5" +- m_ovfl = "10" +- + engine_args: Dict[str, Any] = {} + + url = config.get(config_service, "database_url") +@@ -79,6 +72,21 @@ class DBManager: + engine_args["connect_args"] = {"check_same_thread": False} + + if not url.count("sqlite:"): ++ # sqlite does not support setting pool size and max overflow, only ++ # read from the config when it is going to be used ++ try: ++ p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl") ++ p_sz, m_ovfl = p_sz_m_ovfl.split(",") ++ logger.info("database_pool_sz_ovfl is set, pool size = %s, max overflow = %s", p_sz, m_ovfl) ++ except NoOptionError: ++ p_sz = "5" ++ m_ovfl = "10" ++ logger.info( ++ "database_pool_sz_ovfl is not set, using default pool size = %s, max overflow = %s", ++ p_sz, ++ m_ovfl, ++ ) ++ + engine_args["pool_size"] = int(p_sz) + engine_args["max_overflow"] = int(m_ovfl) + engine_args["pool_pre_ping"] = True +diff --git a/keylime/models/base/persistable_model.py b/keylime/models/base/persistable_model.py +index 18f7d0d..a779f0b 100644 +--- a/keylime/models/base/persistable_model.py ++++ b/keylime/models/base/persistable_model.py +@@ -207,10 +207,16 @@ class PersistableModel(BasicModel, metaclass=PersistableModelMeta): + setattr(self._db_mapping_inst, name, field.data_type.db_dump(value, db_manager.engine.dialect)) + + with db_manager.session_context() as session: +- session.add(self._db_mapping_inst) ++ # Merge the potentially detached object into the new session ++ merged_instance = session.merge(self._db_mapping_inst) ++ session.add(merged_instance) ++ # Update our reference to the merged instance ++ self._db_mapping_inst = merged_instance # pylint: disable=attribute-defined-outside-init + + self.clear_changes() + + def delete(self) -> None: + with db_manager.session_context() as session: +- session.delete(self._db_mapping_inst) # type: ignore[no-untyped-call] ++ # Merge the potentially detached object into the new session before deleting ++ merged_instance = session.merge(self._db_mapping_inst) ++ session.delete(merged_instance) # type: ignore[no-untyped-call] +diff --git a/packit-ci.fmf b/packit-ci.fmf +index 2d1e5e5..cb64faf 100644 +--- a/packit-ci.fmf ++++ b/packit-ci.fmf +@@ -101,6 +101,7 @@ adjust: + - /regression/CVE-2023-3674 + - /regression/issue-1380-agent-removed-and-re-added + - /regression/keylime-agent-option-override-through-envvar ++ - /regression/db-connection-leak-reproducer + - /sanity/keylime-secure_mount + - /sanity/opened-conf-files + - /upstream/run_keylime_tests +diff --git a/test/test_verifier_db.py b/test/test_verifier_db.py +index ad72fa6..aae8f8a 100644 +--- a/test/test_verifier_db.py ++++ b/test/test_verifier_db.py +@@ -172,3 +172,102 @@ class TestVerfierDB(unittest.TestCase): + + def tearDown(self): + self.session.close() ++ ++ def test_11_relationship_access_after_session_commit(self): ++ """Test that relationships can be accessed after session commits (DetachedInstanceError fix)""" ++ # This test reproduces the problematic pattern from cloud_verifier_tornado.py ++ # where objects are loaded with joinedload and then accessed after session closes ++ ++ # Create a new session manager and context (like in cloud_verifier_tornado.py) ++ session_manager = SessionManager() ++ ++ # First, load the agent with eager loading for relationships ++ stored_agent = None ++ with session_manager.session_context(self.engine) as session: ++ stored_agent = ( ++ session.query(VerfierMain) ++ .options(joinedload(VerfierMain.ima_policy)) ++ .options(joinedload(VerfierMain.mb_policy)) ++ .filter_by(agent_id=agent_id) ++ .first() ++ ) ++ # Verify agent was loaded correctly ++ self.assertIsNotNone(stored_agent) ++ # session.commit() is automatically called by context manager when exiting ++ ++ # Now verify we can access relationships AFTER the session has been closed ++ # This would previously trigger DetachedInstanceError ++ ++ # Ensure stored_agent is not None before proceeding ++ assert stored_agent is not None ++ ++ # Test accessing ima_policy relationship ++ self.assertIsNotNone(stored_agent.ima_policy) ++ assert stored_agent.ima_policy is not None # Type narrowing for linter ++ self.assertEqual(stored_agent.ima_policy.name, "test-allowlist") ++ # checksum is not set in test data ++ self.assertEqual(stored_agent.ima_policy.checksum, None) ++ ++ # Test accessing the ima_policy.ima_policy attribute (similar to verifier_read_policy_from_cache) ++ ima_policy_content = stored_agent.ima_policy.ima_policy ++ self.assertEqual(ima_policy_content, test_allowlist_data["ima_policy"]) ++ ++ # Test accessing mb_policy relationship ++ self.assertIsNotNone(stored_agent.mb_policy) ++ assert stored_agent.mb_policy is not None # Type narrowing for linter ++ self.assertEqual(stored_agent.mb_policy.name, "test-mbpolicy") ++ ++ # Test accessing the mb_policy.mb_policy attribute (similar to process_agent function) ++ mb_policy_content = stored_agent.mb_policy.mb_policy ++ self.assertEqual(mb_policy_content, test_mbpolicy_data["mb_policy"]) ++ ++ # Test that we can access these relationships multiple times without issues ++ for _ in range(3): ++ self.assertIsNotNone(stored_agent.ima_policy.ima_policy) ++ self.assertIsNotNone(stored_agent.mb_policy.mb_policy) ++ ++ def test_12_persistable_model_cross_session_fix(self): ++ """Test that PersistableModel can handle cross-session operations safely""" ++ # This test would previously fail with DetachedInstanceError before the fix ++ # Note: This is a conceptual test since we don't have actual PersistableModel ++ # subclasses in the test environment, but demonstrates the pattern ++ ++ # Simulate creating a SQLAlchemy object in one session ++ session_manager = SessionManager() ++ ++ # Load an object in one session context ++ test_agent = None ++ with session_manager.session_context(self.engine) as session: ++ test_agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first() ++ self.assertIsNotNone(test_agent) ++ # Session closes here ++ ++ # Ensure test_agent is not None before proceeding ++ assert test_agent is not None ++ ++ # Now simulate using this object in a different session context ++ # This tests the pattern where PersistableModel would use session.add() or session.delete() ++ # on a cross-session object ++ with session_manager.session_context(self.engine) as session: ++ # Before the fix, this would cause DetachedInstanceError ++ # The fix uses session.merge() to handle detached objects safely ++ merged_agent = session.merge(test_agent) ++ assert merged_agent is not None # Type narrowing for linter ++ ++ # Test that we can modify and save the merged object ++ original_port = merged_agent.port ++ # Use setattr to avoid linter issues with Column assignment ++ setattr(merged_agent, "port", 9999) ++ session.add(merged_agent) ++ # session.commit() called automatically by context manager ++ ++ # Verify the change was persisted ++ with session_manager.session_context(self.engine) as session: ++ updated_agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first() ++ assert updated_agent is not None # Type narrowing for linter ++ self.assertEqual(updated_agent.port, 9999) ++ ++ # Restore original value ++ # Use setattr to avoid linter issues ++ setattr(updated_agent, "port", original_port) ++ session.add(updated_agent) diff --git a/keylime.spec b/keylime.spec index 78360fc..3e52794 100644 --- a/keylime.spec +++ b/keylime.spec @@ -9,7 +9,7 @@ Name: keylime Version: 7.12.1 -Release: 7%{?dist} +Release: 8%{?dist} Summary: Open source TPM software for Bootstrapping and Maintaining Trust URL: https://github.com/keylime/keylime @@ -27,6 +27,8 @@ Patch: 0004-templates-duplicate-str_to_version-in-the-adjust-scr.patch # DO NOT REMOVE THE FOLLOWING TWO PATCHES IN FOLLOWING RHEL-9.x REBASES. Patch: 0005-Restore-RHEL-9-version-of-create_allowlist.sh.patch Patch: 0006-Revert-default-server_key_password-for-verifier-regi.patch +# Backported from https://github.com/keylime/keylime/pull/1782 +Patch: 0007-fix_db_connection_leaks.patch License: ASL 2.0 and MIT @@ -421,6 +423,10 @@ fi %license LICENSE %changelog +* Fri Aug 08 2025 Anderson Toshiyuki Sasaki - 7.12.1-8 +- Fix DB connection leaks + Resolves: RHEL-108263 + * Tue Jul 22 2025 Sergio Correia - 7.12.1-7 - Fix tmpfiles.d configuration related to the cert store Resolves: RHEL-104572