diff --git a/keylime/cloud_verifier_tornado.py b/keylime/cloud_verifier_tornado.py index 8ab81d1..7553ac8 100644 --- a/keylime/cloud_verifier_tornado.py +++ b/keylime/cloud_verifier_tornado.py @@ -7,7 +7,8 @@ import sys import traceback from concurrent.futures import ThreadPoolExecutor from multiprocessing import Process -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast import tornado.httpserver import tornado.ioloop @@ -34,7 +35,7 @@ from keylime.agentstates import AgentAttestState, AgentAttestStates from keylime.common import retry, states, validators from keylime.common.version import str_to_version from keylime.da import record -from keylime.db.keylime_db import DBEngineManager, SessionManager +from keylime.db.keylime_db import SessionManager, make_engine from keylime.db.verifier_db import VerfierMain, VerifierAllowlist, VerifierMbpolicy from keylime.failure import MAX_SEVERITY_LABEL, Component, Event, Failure, set_severity_config from keylime.ima import ima @@ -47,7 +48,7 @@ GLOBAL_POLICY_CACHE: Dict[str, Dict[str, str]] = {} set_severity_config(config.getlist("verifier", "severity_labels"), config.getlist("verifier", "severity_policy")) try: - engine = DBEngineManager().make_engine("cloud_verifier") + engine = make_engine("cloud_verifier") except SQLAlchemyError as err: logger.error("Error creating SQL engine or session: %s", err) sys.exit(1) @@ -61,8 +62,17 @@ except record.RecordManagementException as rme: sys.exit(1) -def get_session() -> Session: - return SessionManager().make_session(engine) +@contextmanager +def session_context() -> Iterator[Session]: + """ + Context manager for database sessions that ensures proper cleanup. + To use: + with session_context() as session: + # use session + """ + session_manager = SessionManager() + with session_manager.session_context(engine) as session: + yield session def get_AgentAttestStates() -> AgentAttestStates: @@ -130,19 +140,18 @@ def _from_db_obj(agent_db_obj: VerfierMain) -> Dict[str, Any]: return agent_dict -def verifier_read_policy_from_cache(stored_agent: VerfierMain) -> str: - checksum = "" - name = "empty" - agent_id = str(stored_agent.agent_id) +def verifier_read_policy_from_cache(ima_policy_data: Dict[str, str]) -> str: + checksum = ima_policy_data.get("checksum", "") + name = ima_policy_data.get("name", "empty") + agent_id = ima_policy_data.get("agent_id", "") + + if not agent_id: + return "" if agent_id not in GLOBAL_POLICY_CACHE: GLOBAL_POLICY_CACHE[agent_id] = {} GLOBAL_POLICY_CACHE[agent_id][""] = "" - if stored_agent.ima_policy: - checksum = str(stored_agent.ima_policy.checksum) - name = stored_agent.ima_policy.name - if checksum not in GLOBAL_POLICY_CACHE[agent_id]: if len(GLOBAL_POLICY_CACHE[agent_id]) > 1: # Perform a cleanup of the contents, IMA policy checksum changed @@ -162,8 +171,9 @@ def verifier_read_policy_from_cache(stored_agent: VerfierMain) -> str: checksum, agent_id, ) - # Actually contacts the database and load the (large) ima_policy column for "allowlists" table - ima_policy = stored_agent.ima_policy.ima_policy + + # Get the large ima_policy content - it's already loaded in ima_policy_data + ima_policy = ima_policy_data.get("ima_policy", "") assert isinstance(ima_policy, str) GLOBAL_POLICY_CACHE[agent_id][checksum] = ima_policy @@ -182,22 +192,19 @@ def store_attestation_state(agentAttestState: AgentAttestState) -> None: # Only store if IMA log was evaluated if agentAttestState.get_ima_pcrs(): agent_id = agentAttestState.agent_id - session = get_session() try: - update_agent = session.query(VerfierMain).get(agentAttestState.get_agent_id()) - assert update_agent - update_agent.boottime = agentAttestState.get_boottime() - update_agent.next_ima_ml_entry = agentAttestState.get_next_ima_ml_entry() - ima_pcrs_dict = agentAttestState.get_ima_pcrs() - update_agent.ima_pcrs = list(ima_pcrs_dict.keys()) - for pcr_num, value in ima_pcrs_dict.items(): - setattr(update_agent, f"pcr{pcr_num}", value) - update_agent.learned_ima_keyrings = agentAttestState.get_ima_keyrings().to_json() - try: + with session_context() as session: + update_agent = session.query(VerfierMain).get(agentAttestState.get_agent_id()) + assert update_agent + update_agent.boottime = agentAttestState.get_boottime() + update_agent.next_ima_ml_entry = agentAttestState.get_next_ima_ml_entry() + ima_pcrs_dict = agentAttestState.get_ima_pcrs() + update_agent.ima_pcrs = list(ima_pcrs_dict.keys()) + for pcr_num, value in ima_pcrs_dict.items(): + setattr(update_agent, f"pcr{pcr_num}", value) + update_agent.learned_ima_keyrings = agentAttestState.get_ima_keyrings().to_json() session.add(update_agent) - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error on storing attestation state for agent %s: %s", agent_id, e) - session.commit() + # session.commit() is automatically called by context manager except SQLAlchemyError as e: logger.error("SQLAlchemy Error on storing attestation state for agent %s: %s", agent_id, e) @@ -354,45 +361,17 @@ class AgentsHandler(BaseHandler): was not found, it either completed successfully, or failed. If found, the agent_id is still polling to contact the Cloud Agent. """ - session = get_session() - rest_params, agent_id = self.__validate_input("GET") if not rest_params: return - if (agent_id is not None) and (agent_id != ""): - # If the agent ID is not valid (wrong set of characters), - # just do nothing. - agent = None - try: - agent = ( - session.query(VerfierMain) - .options( # type: ignore - joinedload(VerfierMain.ima_policy).load_only( - VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore - ) - ) - .options( # type: ignore - joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore - ) - .filter_by(agent_id=agent_id) - .one_or_none() - ) - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) - - if agent is not None: - response = cloud_verifier_common.process_get_status(agent) - web_util.echo_json_response(self, 200, "Success", response) - else: - web_util.echo_json_response(self, 404, "agent id not found") - else: - json_response = None - if "bulk" in rest_params: - agent_list = None - - if ("verifier" in rest_params) and (rest_params["verifier"] != ""): - agent_list = ( + with session_context() as session: + if (agent_id is not None) and (agent_id != ""): + # If the agent ID is not valid (wrong set of characters), + # just do nothing. + agent = None + try: + agent = ( session.query(VerfierMain) .options( # type: ignore joinedload(VerfierMain.ima_policy).load_only( @@ -402,39 +381,70 @@ class AgentsHandler(BaseHandler): .options( # type: ignore joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore ) - .filter_by(verifier_id=rest_params["verifier"]) - .all() + .filter_by(agent_id=agent_id) + .one_or_none() ) + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) + + if agent is not None: + response = cloud_verifier_common.process_get_status(agent) + web_util.echo_json_response(self, 200, "Success", response) else: - agent_list = ( - session.query(VerfierMain) - .options( # type: ignore - joinedload(VerfierMain.ima_policy).load_only( - VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore + web_util.echo_json_response(self, 404, "agent id not found") + else: + json_response = None + if "bulk" in rest_params: + agent_list = None + + if ("verifier" in rest_params) and (rest_params["verifier"] != ""): + agent_list = ( + session.query(VerfierMain) + .options( # type: ignore + joinedload(VerfierMain.ima_policy).load_only( + VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore + ) ) + .options( # type: ignore + joinedload(VerfierMain.mb_policy).load_only( + VerifierMbpolicy.mb_policy # type: ignore[arg-type] + ) + ) + .filter_by(verifier_id=rest_params["verifier"]) + .all() ) - .options( # type: ignore - joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore + else: + agent_list = ( + session.query(VerfierMain) + .options( # type: ignore + joinedload(VerfierMain.ima_policy).load_only( + VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore + ) + ) + .options( # type: ignore + joinedload(VerfierMain.mb_policy).load_only( + VerifierMbpolicy.mb_policy # type: ignore[arg-type] + ) + ) + .all() ) - .all() - ) - json_response = {} - for agent in agent_list: - json_response[agent.agent_id] = cloud_verifier_common.process_get_status(agent) + json_response = {} + for agent in agent_list: + json_response[agent.agent_id] = cloud_verifier_common.process_get_status(agent) - web_util.echo_json_response(self, 200, "Success", json_response) - else: - if ("verifier" in rest_params) and (rest_params["verifier"] != ""): - json_response_list = ( - session.query(VerfierMain.agent_id).filter_by(verifier_id=rest_params["verifier"]).all() - ) + web_util.echo_json_response(self, 200, "Success", json_response) else: - json_response_list = session.query(VerfierMain.agent_id).all() + if ("verifier" in rest_params) and (rest_params["verifier"] != ""): + json_response_list = ( + session.query(VerfierMain.agent_id).filter_by(verifier_id=rest_params["verifier"]).all() + ) + else: + json_response_list = session.query(VerfierMain.agent_id).all() - web_util.echo_json_response(self, 200, "Success", {"uuids": json_response_list}) + web_util.echo_json_response(self, 200, "Success", {"uuids": json_response_list}) - logger.info("GET returning 200 response for agent_id list") + logger.info("GET returning 200 response for agent_id list") def delete(self) -> None: """This method handles the DELETE requests to remove agents from the Cloud Verifier. @@ -442,59 +452,55 @@ class AgentsHandler(BaseHandler): Currently, only agents resources are available for DELETEing, i.e. /agents. All other DELETE uri's will return errors. agents requests require a single agent_id parameter which identifies the agent to be deleted. """ - session = get_session() - rest_params, agent_id = self.__validate_input("DELETE") if not rest_params or not agent_id: return - agent = None - try: - agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) + with session_context() as session: + agent = None + try: + agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first() + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) - if agent is None: - web_util.echo_json_response(self, 404, "agent id not found") - logger.info("DELETE returning 404 response. agent id: %s not found.", agent_id) - return + if agent is None: + web_util.echo_json_response(self, 404, "agent id not found") + logger.info("DELETE returning 404 response. agent id: %s not found.", agent_id) + return - verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID) - if verifier_id != agent.verifier_id: - web_util.echo_json_response(self, 404, "agent id associated to this verifier") - logger.info("DELETE returning 404 response. agent id: %s not associated to this verifer.", agent_id) - return + verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID) + if verifier_id != agent.verifier_id: + web_util.echo_json_response(self, 404, "agent id associated to this verifier") + logger.info("DELETE returning 404 response. agent id: %s not associated to this verifer.", agent_id) + return - # Cleanup the cache when the agent is deleted. Do it early. - if agent_id in GLOBAL_POLICY_CACHE: - del GLOBAL_POLICY_CACHE[agent_id] - logger.debug( - "Cleaned up policy cache from all entries used by agent %s", - agent_id, - ) + # Cleanup the cache when the agent is deleted. Do it early. + if agent_id in GLOBAL_POLICY_CACHE: + del GLOBAL_POLICY_CACHE[agent_id] + logger.debug( + "Cleaned up policy cache from all entries used by agent %s", + agent_id, + ) - op_state = agent.operational_state - if op_state in (states.SAVED, states.FAILED, states.TERMINATED, states.TENANT_FAILED, states.INVALID_QUOTE): - try: - verifier_db_delete_agent(session, agent_id) - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - web_util.echo_json_response(self, 200, "Success") - logger.info("DELETE returning 200 response for agent id: %s", agent_id) - else: - try: - update_agent = session.query(VerfierMain).get(agent_id) - assert update_agent - update_agent.operational_state = states.TERMINATED + op_state = agent.operational_state + if op_state in (states.SAVED, states.FAILED, states.TERMINATED, states.TENANT_FAILED, states.INVALID_QUOTE): try: + verifier_db_delete_agent(session, agent_id) + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + web_util.echo_json_response(self, 200, "Success") + logger.info("DELETE returning 200 response for agent id: %s", agent_id) + else: + try: + update_agent = session.query(VerfierMain).get(agent_id) + assert update_agent + update_agent.operational_state = states.TERMINATED session.add(update_agent) + # session.commit() is automatically called by context manager + web_util.echo_json_response(self, 202, "Accepted") + logger.info("DELETE returning 202 response for agent id: %s", agent_id) except SQLAlchemyError as e: logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) - session.commit() - web_util.echo_json_response(self, 202, "Accepted") - logger.info("DELETE returning 202 response for agent id: %s", agent_id) - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) def post(self) -> None: """This method handles the POST requests to add agents to the Cloud Verifier. @@ -502,7 +508,6 @@ class AgentsHandler(BaseHandler): Currently, only agents resources are available for POSTing, i.e. /agents. All other POST uri's will return errors. agents requests require a json block sent in the body """ - session = get_session() # TODO: exception handling needs fixing # Maybe handle exceptions with if/else if/else blocks ... simple and avoids nesting try: # pylint: disable=too-many-nested-blocks @@ -585,201 +590,208 @@ class AgentsHandler(BaseHandler): runtime_policy = base64.b64decode(json_body.get("runtime_policy")).decode() runtime_policy_stored = None - if runtime_policy_name: + with session_context() as session: + if runtime_policy_name: + try: + runtime_policy_stored = ( + session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none() + ) + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) + raise + + # Prevent overwriting existing IMA policies with name provided in request + if runtime_policy and runtime_policy_stored: + web_util.echo_json_response( + self, + 409, + f"IMA policy with name {runtime_policy_name} already exists. Please use a different name or delete the allowlist from the verifier.", + ) + logger.warning("IMA policy with name %s already exists", runtime_policy_name) + return + + # Return an error code if the named allowlist does not exist in the database + if not runtime_policy and not runtime_policy_stored: + web_util.echo_json_response( + self, 404, f"Could not find IMA policy with name {runtime_policy_name}!" + ) + logger.warning("Could not find IMA policy with name %s", runtime_policy_name) + return + + # Prevent overwriting existing agents with UUID provided in request try: - runtime_policy_stored = ( - session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none() - ) + new_agent_count = session.query(VerfierMain).filter_by(agent_id=agent_id).count() except SQLAlchemyError as e: logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) - raise + raise e - # Prevent overwriting existing IMA policies with name provided in request - if runtime_policy and runtime_policy_stored: + if new_agent_count > 0: web_util.echo_json_response( self, 409, - f"IMA policy with name {runtime_policy_name} already exists. Please use a different name or delete the allowlist from the verifier.", + f"Agent of uuid {agent_id} already exists. Please use delete or update.", ) - logger.warning("IMA policy with name %s already exists", runtime_policy_name) + logger.warning("Agent of uuid %s already exists", agent_id) return - # Return an error code if the named allowlist does not exist in the database - if not runtime_policy and not runtime_policy_stored: - web_util.echo_json_response( - self, 404, f"Could not find IMA policy with name {runtime_policy_name}!" - ) - logger.warning("Could not find IMA policy with name %s", runtime_policy_name) - return - - # Prevent overwriting existing agents with UUID provided in request - try: - new_agent_count = session.query(VerfierMain).filter_by(agent_id=agent_id).count() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) - raise e - - if new_agent_count > 0: - web_util.echo_json_response( - self, - 409, - f"Agent of uuid {agent_id} already exists. Please use delete or update.", - ) - logger.warning("Agent of uuid %s already exists", agent_id) - return - - # Write IMA policy to database if needed - if not runtime_policy_name and not runtime_policy: - logger.info("IMA policy data not provided with request! Using default empty IMA policy.") - runtime_policy = json.dumps(cast(Dict[str, Any], ima.EMPTY_RUNTIME_POLICY)) + # Write IMA policy to database if needed + if not runtime_policy_name and not runtime_policy: + logger.info("IMA policy data not provided with request! Using default empty IMA policy.") + runtime_policy = json.dumps(cast(Dict[str, Any], ima.EMPTY_RUNTIME_POLICY)) - if runtime_policy: - runtime_policy_key_bytes = signing.get_runtime_policy_keys( - runtime_policy.encode(), - json_body.get("runtime_policy_key"), - ) - - try: - ima.verify_runtime_policy( + if runtime_policy: + runtime_policy_key_bytes = signing.get_runtime_policy_keys( runtime_policy.encode(), - runtime_policy_key_bytes, - verify_sig=config.getboolean( - "verifier", "require_allow_list_signatures", fallback=False - ), + json_body.get("runtime_policy_key"), ) - except ima.ImaValidationError as e: - web_util.echo_json_response(self, e.code, e.message) - logger.warning(e.message) - return - if not runtime_policy_name: - runtime_policy_name = agent_id - - try: - runtime_policy_db_format = ima.runtime_policy_db_contents( - runtime_policy_name, runtime_policy - ) - except ima.ImaValidationError as e: - message = f"Runtime policy is malformatted: {e.message}" - web_util.echo_json_response(self, e.code, message) - logger.warning(message) - return - - try: - runtime_policy_stored = ( - session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none() - ) - except SQLAlchemyError as e: - logger.error( - "SQLAlchemy Error while retrieving stored ima policy for agent ID %s: %s", agent_id, e - ) - raise - try: - if runtime_policy_stored is None: - runtime_policy_stored = VerifierAllowlist(**runtime_policy_db_format) - session.add(runtime_policy_stored) + try: + ima.verify_runtime_policy( + runtime_policy.encode(), + runtime_policy_key_bytes, + verify_sig=config.getboolean( + "verifier", "require_allow_list_signatures", fallback=False + ), + ) + except ima.ImaValidationError as e: + web_util.echo_json_response(self, e.code, e.message) + logger.warning(e.message) + return + + if not runtime_policy_name: + runtime_policy_name = agent_id + + try: + runtime_policy_db_format = ima.runtime_policy_db_contents( + runtime_policy_name, runtime_policy + ) + except ima.ImaValidationError as e: + message = f"Runtime policy is malformatted: {e.message}" + web_util.echo_json_response(self, e.code, message) + logger.warning(message) + return + + try: + runtime_policy_stored = ( + session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none() + ) + except SQLAlchemyError as e: + logger.error( + "SQLAlchemy Error while retrieving stored ima policy for agent ID %s: %s", + agent_id, + e, + ) + raise + try: + if runtime_policy_stored is None: + runtime_policy_stored = VerifierAllowlist(**runtime_policy_db_format) + session.add(runtime_policy_stored) + session.commit() + except SQLAlchemyError as e: + logger.error( + "SQLAlchemy Error while updating ima policy for agent ID %s: %s", agent_id, e + ) + raise + + # Handle measured boot policy + # - No name, mb_policy : store mb_policy using agent UUID as name + # - Name, no mb_policy : fetch existing mb_policy from DB + # - Name, mb_policy : store mb_policy using name + + mb_policy_name = json_body["mb_policy_name"] + mb_policy = json_body["mb_policy"] + mb_policy_stored = None + + if mb_policy_name: + try: + mb_policy_stored = ( + session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none() + ) + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) + raise + + # Prevent overwriting existing mb_policy with name provided in request + if mb_policy and mb_policy_stored: + web_util.echo_json_response( + self, + 409, + f"mb_policy with name {mb_policy_name} already exists. Please use a different name or delete the mb_policy from the verifier.", + ) + logger.warning("mb_policy with name %s already exists", mb_policy_name) + return + + # Return error if the mb_policy is neither provided nor stored. + if not mb_policy and not mb_policy_stored: + web_util.echo_json_response( + self, 404, f"Could not find mb_policy with name {mb_policy_name}!" + ) + logger.warning("Could not find mb_policy with name %s", mb_policy_name) + return + + else: + # Use the UUID of the agent + mb_policy_name = agent_id + try: + mb_policy_stored = ( + session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none() + ) + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) + raise + + # Prevent overwriting existing mb_policy + if mb_policy and mb_policy_stored: + web_util.echo_json_response( + self, + 409, + f"mb_policy with name {mb_policy_name} already exists. You can delete the mb_policy from the verifier.", + ) + logger.warning("mb_policy with name %s already exists", mb_policy_name) + return + + # Store the policy into database if not stored + if mb_policy_stored is None: + try: + mb_policy_db_format = mba.mb_policy_db_contents(mb_policy_name, mb_policy) + mb_policy_stored = VerifierMbpolicy(**mb_policy_db_format) + session.add(mb_policy_stored) session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error while updating ima policy for agent ID %s: %s", agent_id, e) - raise - - # Handle measured boot policy - # - No name, mb_policy : store mb_policy using agent UUID as name - # - Name, no mb_policy : fetch existing mb_policy from DB - # - Name, mb_policy : store mb_policy using name - - mb_policy_name = json_body["mb_policy_name"] - mb_policy = json_body["mb_policy"] - mb_policy_stored = None + except SQLAlchemyError as e: + logger.error( + "SQLAlchemy Error while updating mb_policy for agent ID %s: %s", agent_id, e + ) + raise - if mb_policy_name: + # Write the agent to the database, attaching associated stored ima_policy and mb_policy try: - mb_policy_stored = ( - session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none() + assert runtime_policy_stored + assert mb_policy_stored + session.add( + VerfierMain(**agent_data, ima_policy=runtime_policy_stored, mb_policy=mb_policy_stored) ) + session.commit() except SQLAlchemyError as e: logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) - raise + raise e - # Prevent overwriting existing mb_policy with name provided in request - if mb_policy and mb_policy_stored: - web_util.echo_json_response( - self, - 409, - f"mb_policy with name {mb_policy_name} already exists. Please use a different name or delete the mb_policy from the verifier.", - ) - logger.warning("mb_policy with name %s already exists", mb_policy_name) - return + # add default fields that are ephemeral + for key, val in exclude_db.items(): + agent_data[key] = val - # Return error if the mb_policy is neither provided nor stored. - if not mb_policy and not mb_policy_stored: - web_util.echo_json_response( - self, 404, f"Could not find mb_policy with name {mb_policy_name}!" + # Prepare SSLContext for mTLS connections + agent_data["ssl_context"] = None + if agent_mtls_cert_enabled: + agent_data["ssl_context"] = web_util.generate_agent_tls_context( + "verifier", agent_data["mtls_cert"], logger=logger ) - logger.warning("Could not find mb_policy with name %s", mb_policy_name) - return - else: - # Use the UUID of the agent - mb_policy_name = agent_id - try: - mb_policy_stored = ( - session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none() - ) - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) - raise - - # Prevent overwriting existing mb_policy - if mb_policy and mb_policy_stored: - web_util.echo_json_response( - self, - 409, - f"mb_policy with name {mb_policy_name} already exists. You can delete the mb_policy from the verifier.", - ) - logger.warning("mb_policy with name %s already exists", mb_policy_name) - return - - # Store the policy into database if not stored - if mb_policy_stored is None: - try: - mb_policy_db_format = mba.mb_policy_db_contents(mb_policy_name, mb_policy) - mb_policy_stored = VerifierMbpolicy(**mb_policy_db_format) - session.add(mb_policy_stored) - session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error while updating mb_policy for agent ID %s: %s", agent_id, e) - raise + if agent_data["ssl_context"] is None: + logger.warning("Connecting to agent without mTLS: %s", agent_id) - # Write the agent to the database, attaching associated stored ima_policy and mb_policy - try: - assert runtime_policy_stored - assert mb_policy_stored - session.add( - VerfierMain(**agent_data, ima_policy=runtime_policy_stored, mb_policy=mb_policy_stored) - ) - session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) - raise e - - # add default fields that are ephemeral - for key, val in exclude_db.items(): - agent_data[key] = val - - # Prepare SSLContext for mTLS connections - agent_data["ssl_context"] = None - if agent_mtls_cert_enabled: - agent_data["ssl_context"] = web_util.generate_agent_tls_context( - "verifier", agent_data["mtls_cert"], logger=logger - ) - - if agent_data["ssl_context"] is None: - logger.warning("Connecting to agent without mTLS: %s", agent_id) - - asyncio.ensure_future(process_agent(agent_data, states.GET_QUOTE)) - web_util.echo_json_response(self, 200, "Success") - logger.info("POST returning 200 response for adding agent id: %s", agent_id) + asyncio.ensure_future(process_agent(agent_data, states.GET_QUOTE)) + web_util.echo_json_response(self, 200, "Success") + logger.info("POST returning 200 response for adding agent id: %s", agent_id) else: web_util.echo_json_response(self, 400, "uri not supported") logger.warning("POST returning 400 response. uri not supported") @@ -794,54 +806,54 @@ class AgentsHandler(BaseHandler): Currently, only agents resources are available for PUTing, i.e. /agents. All other PUT uri's will return errors. agents requests require a json block sent in the body """ - session = get_session() try: rest_params, agent_id = self.__validate_input("PUT") if not rest_params: return - try: - verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID) - db_agent = session.query(VerfierMain).filter_by(agent_id=agent_id, verifier_id=verifier_id).one() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) - raise e + with session_context() as session: + try: + verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID) + db_agent = session.query(VerfierMain).filter_by(agent_id=agent_id, verifier_id=verifier_id).one() + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) + raise e - if db_agent is None: - web_util.echo_json_response(self, 404, "agent id not found") - logger.info("PUT returning 404 response. agent id: %s not found.", agent_id) - return + if db_agent is None: + web_util.echo_json_response(self, 404, "agent id not found") + logger.info("PUT returning 404 response. agent id: %s not found.", agent_id) + return - if "reactivate" in rest_params: - agent = _from_db_obj(db_agent) + if "reactivate" in rest_params: + agent = _from_db_obj(db_agent) - if agent["mtls_cert"] and agent["mtls_cert"] != "disabled": - agent["ssl_context"] = web_util.generate_agent_tls_context( - "verifier", agent["mtls_cert"], logger=logger - ) - if agent["ssl_context"] is None: - logger.warning("Connecting to agent without mTLS: %s", agent_id) + if agent["mtls_cert"] and agent["mtls_cert"] != "disabled": + agent["ssl_context"] = web_util.generate_agent_tls_context( + "verifier", agent["mtls_cert"], logger=logger + ) + if agent["ssl_context"] is None: + logger.warning("Connecting to agent without mTLS: %s", agent_id) - agent["operational_state"] = states.START - asyncio.ensure_future(process_agent(agent, states.GET_QUOTE)) - web_util.echo_json_response(self, 200, "Success") - logger.info("PUT returning 200 response for agent id: %s", agent_id) - elif "stop" in rest_params: - # do stuff for terminate - logger.debug("Stopping polling on %s", agent_id) - try: - session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).update( # pyright: ignore - {"operational_state": states.TENANT_FAILED} - ) - session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) + agent["operational_state"] = states.START + asyncio.ensure_future(process_agent(agent, states.GET_QUOTE)) + web_util.echo_json_response(self, 200, "Success") + logger.info("PUT returning 200 response for agent id: %s", agent_id) + elif "stop" in rest_params: + # do stuff for terminate + logger.debug("Stopping polling on %s", agent_id) + try: + session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).update( # pyright: ignore + {"operational_state": states.TENANT_FAILED} + ) + # session.commit() is automatically called by context manager + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) - web_util.echo_json_response(self, 200, "Success") - logger.info("PUT returning 200 response for agent id: %s", agent_id) - else: - web_util.echo_json_response(self, 400, "uri not supported") - logger.warning("PUT returning 400 response. uri not supported") + web_util.echo_json_response(self, 200, "Success") + logger.info("PUT returning 200 response for agent id: %s", agent_id) + else: + web_util.echo_json_response(self, 400, "uri not supported") + logger.warning("PUT returning 400 response. uri not supported") except Exception as e: web_util.echo_json_response(self, 400, f"Exception error: {str(e)}") @@ -887,36 +899,36 @@ class AllowlistHandler(BaseHandler): if not params_valid: return - session = get_session() - if allowlist_name is None: - try: - names_allowlists = session.query(VerifierAllowlist.name).all() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - web_util.echo_json_response(self, 500, "Failed to get names of allowlists") - raise + with session_context() as session: + if allowlist_name is None: + try: + names_allowlists = session.query(VerifierAllowlist.name).all() + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + web_util.echo_json_response(self, 500, "Failed to get names of allowlists") + raise - names_response = [] - for name in names_allowlists: - names_response.append(name[0]) - web_util.echo_json_response(self, 200, "Success", {"runtimepolicy names": names_response}) + names_response = [] + for name in names_allowlists: + names_response.append(name[0]) + web_util.echo_json_response(self, 200, "Success", {"runtimepolicy names": names_response}) - else: - try: - allowlist = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one() - except NoResultFound: - web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found") - return - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - web_util.echo_json_response(self, 500, "Failed to get allowlist") - raise + else: + try: + allowlist = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one() + except NoResultFound: + web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found") + return + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + web_util.echo_json_response(self, 500, "Failed to get allowlist") + raise - response = {} - for field in ("name", "tpm_policy"): - response[field] = getattr(allowlist, field, None) - response["runtime_policy"] = getattr(allowlist, "ima_policy", None) - web_util.echo_json_response(self, 200, "Success", response) + response = {} + for field in ("name", "tmp_policy"): + response[field] = getattr(allowlist, field, None) + response["runtime_policy"] = getattr(allowlist, "ima_policy", None) + web_util.echo_json_response(self, 200, "Success", response) def delete(self) -> None: """Delete an allowlist @@ -928,45 +940,44 @@ class AllowlistHandler(BaseHandler): if not params_valid or allowlist_name is None: return - session = get_session() - try: - runtime_policy = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one() - except NoResultFound: - web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found") - return - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - web_util.echo_json_response(self, 500, "Failed to get allowlist") - raise + with session_context() as session: + try: + runtime_policy = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one() + except NoResultFound: + web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found") + return + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + web_util.echo_json_response(self, 500, "Failed to get allowlist") + raise - try: - agent = session.query(VerfierMain).filter_by(ima_policy_id=runtime_policy.id).one_or_none() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - raise - if agent is not None: - web_util.echo_json_response( - self, - 409, - f"Can't delete allowlist as it's currently in use by agent {agent.agent_id}", - ) - return + try: + agent = session.query(VerfierMain).filter_by(ima_policy_id=runtime_policy.id).one_or_none() + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + raise + if agent is not None: + web_util.echo_json_response( + self, + 409, + f"Can't delete allowlist as it's currently in use by agent {agent.agent_id}", + ) + return - try: - session.query(VerifierAllowlist).filter_by(name=allowlist_name).delete() - session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - session.close() - web_util.echo_json_response(self, 500, f"Database error: {e}") - raise + try: + session.query(VerifierAllowlist).filter_by(name=allowlist_name).delete() + # session.commit() is automatically called by context manager + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + web_util.echo_json_response(self, 500, f"Database error: {e}") + raise - # NOTE(kaifeng) 204 Can not have response body, but current helper - # doesn't support this case. - self.set_status(204) - self.set_header("Content-Type", "application/json") - self.finish() - logger.info("DELETE returning 204 response for allowlist: %s", allowlist_name) + # NOTE(kaifeng) 204 Can not have response body, but current helper + # doesn't support this case. + self.set_status(204) + self.set_header("Content-Type", "application/json") + self.finish() + logger.info("DELETE returning 204 response for allowlist: %s", allowlist_name) def __get_runtime_policy_db_format(self, runtime_policy_name: str) -> Dict[str, Any]: """Get the IMA policy from the request and return it in Db format""" @@ -1022,28 +1033,30 @@ class AllowlistHandler(BaseHandler): if not runtime_policy_db_format: return - session = get_session() - # don't allow overwritting - try: - runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count() - if runtime_policy_count > 0: - web_util.echo_json_response(self, 409, f"Runtime policy with name {runtime_policy_name} already exists") - logger.warning("Runtime policy with name %s already exists", runtime_policy_name) - return - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - raise + with session_context() as session: + # don't allow overwritting + try: + runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count() + if runtime_policy_count > 0: + web_util.echo_json_response( + self, 409, f"Runtime policy with name {runtime_policy_name} already exists" + ) + logger.warning("Runtime policy with name %s already exists", runtime_policy_name) + return + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + raise - try: - # Add the agent and data - session.add(VerifierAllowlist(**runtime_policy_db_format)) - session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - raise + try: + # Add the agent and data + session.add(VerifierAllowlist(**runtime_policy_db_format)) + # session.commit() is automatically called by context manager + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + raise - web_util.echo_json_response(self, 201) - logger.info("POST returning 201") + web_util.echo_json_response(self, 201) + logger.info("POST returning 201") def put(self) -> None: """Update an allowlist @@ -1060,32 +1073,34 @@ class AllowlistHandler(BaseHandler): if not runtime_policy_db_format: return - session = get_session() - # don't allow creating a new policy - try: - runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count() - if runtime_policy_count != 1: - web_util.echo_json_response( - self, 409, f"Runtime policy with name {runtime_policy_name} does not already exist" - ) - logger.warning("Runtime policy with name %s does not already exist", runtime_policy_name) - return - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - raise + with session_context() as session: + # don't allow creating a new policy + try: + runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count() + if runtime_policy_count != 1: + web_util.echo_json_response( + self, + 404, + f"Runtime policy with name {runtime_policy_name} does not already exist, use POST to create", + ) + logger.warning("Runtime policy with name %s does not already exist", runtime_policy_name) + return + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + raise - try: - # Update the named runtime policy - session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).update( - runtime_policy_db_format # pyright: ignore - ) - session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - raise + try: + # Update the named runtime policy + session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).update( + runtime_policy_db_format # pyright: ignore + ) + # session.commit() is automatically called by context manager + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + raise - web_util.echo_json_response(self, 201) - logger.info("PUT returning 201") + web_util.echo_json_response(self, 201) + logger.info("PUT returning 201") def data_received(self, chunk: Any) -> None: raise NotImplementedError() @@ -1113,8 +1128,6 @@ class VerifyIdentityHandler(BaseHandler): This is useful for 3rd party tools and integrations to independently verify the state of an agent. """ - session = get_session() - # validate the parameters of our request if self.request.uri is None: web_util.echo_json_response(self, 400, "URI not specified") @@ -1159,36 +1172,37 @@ class VerifyIdentityHandler(BaseHandler): return # get the agent information from the DB - agent = None - try: - agent = ( - session.query(VerfierMain) - .options( # type: ignore - joinedload(VerfierMain.ima_policy).load_only( - VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore + with session_context() as session: + agent = None + try: + agent = ( + session.query(VerfierMain) + .options( # type: ignore + joinedload(VerfierMain.ima_policy).load_only( + VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore + ) ) + .filter_by(agent_id=agent_id) + .one_or_none() ) - .filter_by(agent_id=agent_id) - .one_or_none() - ) - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e) - if agent is not None: - agentAttestState = get_AgentAttestStates().get_by_agent_id(agent_id) - failure = cloud_verifier_common.process_verify_identity_quote( - agent, quote, nonce, hash_alg, agentAttestState - ) - if failure: - failure_contexts = "; ".join(x.context for x in failure.events) - web_util.echo_json_response(self, 200, "Success", {"valid": 0, "reason": failure_contexts}) - logger.info("GET returning 200, but validation failed") + if agent is not None: + agentAttestState = get_AgentAttestStates().get_by_agent_id(agent_id) + failure = cloud_verifier_common.process_verify_identity_quote( + agent, quote, nonce, hash_alg, agentAttestState + ) + if failure: + failure_contexts = "; ".join(x.context for x in failure.events) + web_util.echo_json_response(self, 200, "Success", {"valid": 0, "reason": failure_contexts}) + logger.info("GET returning 200, but validation failed") + else: + web_util.echo_json_response(self, 200, "Success", {"valid": 1}) + logger.info("GET returning 200, validation successful") else: - web_util.echo_json_response(self, 200, "Success", {"valid": 1}) - logger.info("GET returning 200, validation successful") - else: - web_util.echo_json_response(self, 404, "agent id not found") - logger.info("GET returning 404, agaent not found") + web_util.echo_json_response(self, 404, "agent id not found") + logger.info("GET returning 404, agaent not found") def data_received(self, chunk: Any) -> None: raise NotImplementedError() @@ -1231,35 +1245,35 @@ class MbpolicyHandler(BaseHandler): if not params_valid: return - session = get_session() - if mb_policy_name is None: - try: - names_mbpolicies = session.query(VerifierMbpolicy.name).all() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - web_util.echo_json_response(self, 500, "Failed to get names of mbpolicies") - raise + with session_context() as session: + if mb_policy_name is None: + try: + names_mbpolicies = session.query(VerifierMbpolicy.name).all() + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + web_util.echo_json_response(self, 500, "Failed to get names of mbpolicies") + raise - names_response = [] - for name in names_mbpolicies: - names_response.append(name[0]) - web_util.echo_json_response(self, 200, "Success", {"mbpolicy names": names_response}) + names_response = [] + for name in names_mbpolicies: + names_response.append(name[0]) + web_util.echo_json_response(self, 200, "Success", {"mbpolicy names": names_response}) - else: - try: - mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one() - except NoResultFound: - web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found") - return - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - web_util.echo_json_response(self, 500, "Failed to get mb_policy") - raise + else: + try: + mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one() + except NoResultFound: + web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found") + return + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + web_util.echo_json_response(self, 500, "Failed to get mb_policy") + raise - response = {} - response["name"] = getattr(mbpolicy, "name", None) - response["mb_policy"] = getattr(mbpolicy, "mb_policy", None) - web_util.echo_json_response(self, 200, "Success", response) + response = {} + response["name"] = getattr(mbpolicy, "name", None) + response["mb_policy"] = getattr(mbpolicy, "mb_policy", None) + web_util.echo_json_response(self, 200, "Success", response) def delete(self) -> None: """Delete a mb_policy @@ -1271,45 +1285,44 @@ class MbpolicyHandler(BaseHandler): if not params_valid or mb_policy_name is None: return - session = get_session() - try: - mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one() - except NoResultFound: - web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found") - return - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - web_util.echo_json_response(self, 500, "Failed to get mb_policy") - raise + with session_context() as session: + try: + mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one() + except NoResultFound: + web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found") + return + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + web_util.echo_json_response(self, 500, "Failed to get mb_policy") + raise - try: - agent = session.query(VerfierMain).filter_by(mb_policy_id=mbpolicy.id).one_or_none() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - raise - if agent is not None: - web_util.echo_json_response( - self, - 409, - f"Can't delete mb_policy as it's currently in use by agent {agent.agent_id}", - ) - return + try: + agent = session.query(VerfierMain).filter_by(mb_policy_id=mbpolicy.id).one_or_none() + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + raise + if agent is not None: + web_util.echo_json_response( + self, + 409, + f"Can't delete mb_policy as it's currently in use by agent {agent.agent_id}", + ) + return - try: - session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).delete() - session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - session.close() - web_util.echo_json_response(self, 500, f"Database error: {e}") - raise + try: + session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).delete() + # session.commit() is automatically called by context manager + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + web_util.echo_json_response(self, 500, f"Database error: {e}") + raise - # NOTE(kaifeng) 204 Can not have response body, but current helper - # doesn't support this case. - self.set_status(204) - self.set_header("Content-Type", "application/json") - self.finish() - logger.info("DELETE returning 204 response for mb_policy: %s", mb_policy_name) + # NOTE(kaifeng) 204 Can not have response body, but current helper + # doesn't support this case. + self.set_status(204) + self.set_header("Content-Type", "application/json") + self.finish() + logger.info("DELETE returning 204 response for mb_policy: %s", mb_policy_name) def __get_mb_policy_db_format(self, mb_policy_name: str) -> Dict[str, Any]: """Get the measured boot policy from the request and return it in Db format""" @@ -1341,30 +1354,30 @@ class MbpolicyHandler(BaseHandler): if not mb_policy_db_format: return - session = get_session() - # don't allow overwritting - try: - mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count() - if mbpolicy_count > 0: - web_util.echo_json_response( - self, 409, f"Measured boot policy with name {mb_policy_name} already exists" - ) - logger.warning("Measured boot policy with name %s already exists", mb_policy_name) - return - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - raise + with session_context() as session: + # don't allow overwritting + try: + mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count() + if mbpolicy_count > 0: + web_util.echo_json_response( + self, 409, f"Measured boot policy with name {mb_policy_name} already exists" + ) + logger.warning("Measured boot policy with name %s already exists", mb_policy_name) + return + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + raise - try: - # Add the data - session.add(VerifierMbpolicy(**mb_policy_db_format)) - session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - raise + try: + # Add the data + session.add(VerifierMbpolicy(**mb_policy_db_format)) + # session.commit() is automatically called by context manager + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + raise - web_util.echo_json_response(self, 201) - logger.info("POST returning 201") + web_util.echo_json_response(self, 201) + logger.info("POST returning 201") def put(self) -> None: """Update an mb_policy @@ -1381,32 +1394,32 @@ class MbpolicyHandler(BaseHandler): if not mb_policy_db_format: return - session = get_session() - # don't allow creating a new policy - try: - mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count() - if mbpolicy_count != 1: - web_util.echo_json_response( - self, 409, f"Measured boot policy with name {mb_policy_name} does not already exist" - ) - logger.warning("Measured boot policy with name %s does not already exist", mb_policy_name) - return - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - raise + with session_context() as session: + # don't allow creating a new policy + try: + mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count() + if mbpolicy_count != 1: + web_util.echo_json_response( + self, 409, f"Measured boot policy with name {mb_policy_name} does not already exist" + ) + logger.warning("Measured boot policy with name %s does not already exist", mb_policy_name) + return + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + raise - try: - # Update the named mb_policy - session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).update( - mb_policy_db_format # pyright: ignore - ) - session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) - raise + try: + # Update the named mb_policy + session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).update( + mb_policy_db_format # pyright: ignore + ) + # session.commit() is automatically called by context manager + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) + raise - web_util.echo_json_response(self, 201) - logger.info("PUT returning 201") + web_util.echo_json_response(self, 201) + logger.info("PUT returning 201") def data_received(self, chunk: Any) -> None: raise NotImplementedError() @@ -1460,17 +1473,18 @@ async def update_agent_api_version(agent: Dict[str, Any], timeout: float = 60.0) return None logger.info("Agent %s new API version %s is supported", agent_id, new_version) - session = get_session() - agent["supported_version"] = new_version - # Remove keys that should not go to the DB - agent_db = dict(agent) - for key in exclude_db: - if key in agent_db: - del agent_db[key] + with session_context() as session: + agent["supported_version"] = new_version - session.query(VerfierMain).filter_by(agent_id=agent_id).update(agent_db) # pyright: ignore - session.commit() + # Remove keys that should not go to the DB + agent_db = dict(agent) + for key in exclude_db: + if key in agent_db: + del agent_db[key] + + session.query(VerfierMain).filter_by(agent_id=agent_id).update(agent_db) # pyright: ignore + # session.commit() is automatically called by context manager else: logger.warning("Agent %s new API version %s is not supported", agent_id, new_version) return None @@ -1718,50 +1732,68 @@ async def notify_error( revocation_notifier.notify(tosend) if "agent" in notifiers: verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID) - session = get_session() - agents = session.query(VerfierMain).filter_by(verifier_id=verifier_id).all() - futures = [] - loop = asyncio.get_event_loop() - # Notify all agents asynchronously through a thread pool - with ThreadPoolExecutor() as pool: - for agent_db_obj in agents: - if agent_db_obj.agent_id != agent["agent_id"]: - agent = _from_db_obj(agent_db_obj) - if agent["mtls_cert"] and agent["mtls_cert"] != "disabled": - agent["ssl_context"] = web_util.generate_agent_tls_context( - "verifier", agent["mtls_cert"], logger=logger - ) - func = functools.partial(invoke_notify_error, agent, tosend, timeout=timeout) - futures.append(await loop.run_in_executor(pool, func)) - # Wait for all tasks complete in 60 seconds - try: - for f in asyncio.as_completed(futures, timeout=60): - await f - except asyncio.TimeoutError as e: - logger.error("Timeout during notifying error to agents: %s", e) + with session_context() as session: + agents = session.query(VerfierMain).filter_by(verifier_id=verifier_id).all() + futures = [] + loop = asyncio.get_event_loop() + # Notify all agents asynchronously through a thread pool + with ThreadPoolExecutor() as pool: + for agent_db_obj in agents: + if agent_db_obj.agent_id != agent["agent_id"]: + agent = _from_db_obj(agent_db_obj) + if agent["mtls_cert"] and agent["mtls_cert"] != "disabled": + agent["ssl_context"] = web_util.generate_agent_tls_context( + "verifier", agent["mtls_cert"], logger=logger + ) + func = functools.partial(invoke_notify_error, agent, tosend, timeout=timeout) + futures.append(await loop.run_in_executor(pool, func)) + # Wait for all tasks complete in 60 seconds + try: + for f in asyncio.as_completed(futures, timeout=60): + await f + except asyncio.TimeoutError as e: + logger.error("Timeout during notifying error to agents: %s", e) async def process_agent( agent: Dict[str, Any], new_operational_state: int, failure: Failure = Failure(Component.INTERNAL, ["verifier"]) ) -> None: - session = get_session() try: # pylint: disable=R1702 main_agent_operational_state = agent["operational_state"] stored_agent = None - try: - stored_agent = ( - session.query(VerfierMain) - .options( # type: ignore - joinedload(VerfierMain.ima_policy).load_only(VerifierAllowlist.checksum) # pyright: ignore - ) - .options( # type: ignore - joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore + + # First database operation - read agent data and extract all needed data within session context + ima_policy_data = {} + mb_policy_data = None + with session_context() as session: + try: + stored_agent = ( + session.query(VerfierMain) + .options( # type: ignore + joinedload(VerfierMain.ima_policy) # Load full IMA policy object including content + ) + .options( # type: ignore + joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore + ) + .filter_by(agent_id=str(agent["agent_id"])) + .first() ) - .filter_by(agent_id=str(agent["agent_id"])) - .first() - ) - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e) + + # Extract IMA policy data within session context to avoid DetachedInstanceError + if stored_agent and stored_agent.ima_policy: + ima_policy_data = { + "checksum": str(stored_agent.ima_policy.checksum), + "name": stored_agent.ima_policy.name, + "agent_id": str(stored_agent.agent_id), + "ima_policy": stored_agent.ima_policy.ima_policy, # Extract the large content too + } + + # Extract MB policy data within session context + if stored_agent and stored_agent.mb_policy: + mb_policy_data = stored_agent.mb_policy.mb_policy + + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e) # if the stored agent could not be recovered from the database, stop polling if not stored_agent: @@ -1775,7 +1807,10 @@ async def process_agent( logger.warning("Agent %s terminated by user.", agent["agent_id"]) if agent["pending_event"] is not None: tornado.ioloop.IOLoop.current().remove_timeout(agent["pending_event"]) - verifier_db_delete_agent(session, agent["agent_id"]) + + # Second database operation - delete agent + with session_context() as session: + verifier_db_delete_agent(session, agent["agent_id"]) return # if the user tells us to stop polling because the tenant quote check failed @@ -1808,11 +1843,16 @@ async def process_agent( if not failure.recoverable or failure.highest_severity == MAX_SEVERITY_LABEL: if agent["pending_event"] is not None: tornado.ioloop.IOLoop.current().remove_timeout(agent["pending_event"]) - for key in exclude_db: - if key in agent: - del agent[key] - session.query(VerfierMain).filter_by(agent_id=agent["agent_id"]).update(agent) # pyright: ignore - session.commit() + + # Third database operation - update agent with failure state + with session_context() as session: + for key in exclude_db: + if key in agent: + del agent[key] + session.query(VerfierMain).filter_by(agent_id=agent["agent_id"]).update( + agent # type: ignore[arg-type] + ) + # session.commit() is automatically called by context manager # propagate all state, but remove none DB keys first (using exclude_db) try: @@ -1821,18 +1861,18 @@ async def process_agent( if key in agent_db: del agent_db[key] - session.query(VerfierMain).filter_by(agent_id=agent_db["agent_id"]).update(agent_db) # pyright: ignore - session.commit() + # Fourth database operation - update agent state + with session_context() as session: + session.query(VerfierMain).filter_by(agent_id=agent_db["agent_id"]).update(agent_db) # pyright: ignore + # session.commit() is automatically called by context manager except SQLAlchemyError as e: logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e) # Load agent's IMA policy - runtime_policy = verifier_read_policy_from_cache(stored_agent) + runtime_policy = verifier_read_policy_from_cache(ima_policy_data) # Get agent's measured boot policy - mb_policy = None - if stored_agent.mb_policy is not None: - mb_policy = stored_agent.mb_policy.mb_policy + mb_policy = mb_policy_data # If agent was in a failed state we check if we either stop polling # or just add it again to the event loop @@ -1876,7 +1916,14 @@ async def process_agent( ) pending = tornado.ioloop.IOLoop.current().call_later( - interval, invoke_get_quote, agent, mb_policy, runtime_policy, False, timeout=timeout # type: ignore # due to python <3.9 + # type: ignore # due to python <3.9 + interval, + invoke_get_quote, + agent, + mb_policy, + runtime_policy, + False, + timeout=timeout, ) agent["pending_event"] = pending return @@ -1911,7 +1958,14 @@ async def process_agent( next_retry, ) tornado.ioloop.IOLoop.current().call_later( - next_retry, invoke_get_quote, agent, mb_policy, runtime_policy, True, timeout=timeout # type: ignore # due to python <3.9 + # type: ignore # due to python <3.9 + next_retry, + invoke_get_quote, + agent, + mb_policy, + runtime_policy, + True, + timeout=timeout, ) return @@ -1980,9 +2034,9 @@ async def activate_agents(agents: List[VerfierMain], verifier_ip: str, verifier_ def get_agents_by_verifier_id(verifier_id: str) -> List[VerfierMain]: - session = get_session() try: - return session.query(VerfierMain).filter_by(verifier_id=verifier_id).all() + with session_context() as session: + return session.query(VerfierMain).filter_by(verifier_id=verifier_id).all() except SQLAlchemyError as e: logger.error("SQLAlchemy Error: %s", e) return [] @@ -2007,20 +2061,20 @@ def main() -> None: os.umask(0o077) VerfierMain.metadata.create_all(engine, checkfirst=True) # pyright: ignore - session = get_session() - try: - query_all = session.query(VerfierMain).all() - for row in query_all: - if row.operational_state in states.APPROVED_REACTIVATE_STATES: - row.operational_state = states.START # pyright: ignore - session.commit() - except SQLAlchemyError as e: - logger.error("SQLAlchemy Error: %s", e) + with session_context() as session: + try: + query_all = session.query(VerfierMain).all() + for row in query_all: + if row.operational_state in states.APPROVED_REACTIVATE_STATES: + row.operational_state = states.START # pyright: ignore + # session.commit() is automatically called by context manager + except SQLAlchemyError as e: + logger.error("SQLAlchemy Error: %s", e) - num = session.query(VerfierMain.agent_id).count() - if num > 0: - agent_ids = session.query(VerfierMain.agent_id).all() - logger.info("Agent ids in db loaded from file: %s", agent_ids) + num = session.query(VerfierMain.agent_id).count() + if num > 0: + agent_ids = session.query(VerfierMain.agent_id).all() + logger.info("Agent ids in db loaded from file: %s", agent_ids) logger.info("Starting Cloud Verifier (tornado) on port %s, use to stop", verifier_port) diff --git a/keylime/da/examples/sqldb.py b/keylime/da/examples/sqldb.py index 8efc84e..04a8afb 100644 --- a/keylime/da/examples/sqldb.py +++ b/keylime/da/examples/sqldb.py @@ -1,7 +1,10 @@ import time +from contextlib import contextmanager +from typing import Iterator import sqlalchemy import sqlalchemy.ext.declarative +from sqlalchemy.orm import sessionmaker from keylime import keylime_logging from keylime.da.record import BaseRecordManagement, base_build_key_list @@ -45,23 +48,23 @@ class RecordManagement(BaseRecordManagement): BaseRecordManagement.__init__(self, service) self.engine = sqlalchemy.create_engine(self.ps_url._replace(fragment="").geturl(), pool_recycle=1800) - sm = sqlalchemy.orm.sessionmaker() - self.session = sqlalchemy.orm.scoped_session(sm) - self.session.configure(bind=self.engine) - TableBase.metadata.create_all(self.engine) - - def agent_list_retrieval(self, record_prefix="auto", service="auto"): - if record_prefix == "auto": - record_prefix = "" - - agent_list = [] + self.SessionLocal = sessionmaker(bind=self.engine) - recordtype = self.get_record_type(service) - tbl = type2table(recordtype) - for agentid in self.session.query(tbl.agentid).distinct(): # pylint: disable=no-member - agent_list.append(agentid[0]) + # Create tables if they don't exist + TableBase.metadata.create_all(self.engine) - return agent_list + @contextmanager + def session_context(self) -> Iterator: + """Context manager for database sessions that ensures proper cleanup.""" + session = self.SessionLocal() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() def record_create( self, @@ -84,8 +87,9 @@ class RecordManagement(BaseRecordManagement): d = {"time": recordtime, "agentid": agentid, "record": rcrd} try: - self.session.add((type2table(recordtype))(**d)) # pylint: disable=no-member - self.session.commit() # pylint: disable=no-member + with self.session_context() as session: + session.add((type2table(recordtype))(**d)) + # session.commit() is automatically called by context manager except Exception as e: logger.error("Failed to create attestation record: %s", e) @@ -106,23 +110,23 @@ class RecordManagement(BaseRecordManagement): if f"{end_date}" == "auto": end_date = self.end_of_times - if self.only_last_record_wanted(start_date, end_date): - attestion_record_rows = ( - self.session.query(tbl) # pylint: disable=no-member - .filter(tbl.agentid == record_identifier) - .order_by(sqlalchemy.desc(tbl.time)) - .limit(1) - ) - - else: - attestion_record_rows = self.session.query(tbl).filter( # pylint: disable=no-member - tbl.agentid == record_identifier - ) - - for row in attestion_record_rows: - decoded_record_object = self.record_deserialize(row.record) - self.record_signature_check(decoded_record_object, record_identifier) - record_list.append(decoded_record_object) + with self.session_context() as session: + if self.only_last_record_wanted(start_date, end_date): + attestion_record_rows = ( + session.query(tbl) + .filter(tbl.agentid == record_identifier) + .order_by(sqlalchemy.desc(tbl.time)) + .limit(1) + ) + + else: + attestion_record_rows = session.query(tbl).filter(tbl.agentid == record_identifier) + + for row in attestion_record_rows: + decoded_record_object = self.record_deserialize(row.record) + self.record_signature_check(decoded_record_object, record_identifier) + record_list.append(decoded_record_object) + return record_list def build_key_list(self, agent_identifier, service="auto"): diff --git a/keylime/db/keylime_db.py b/keylime/db/keylime_db.py index 5620a28..aa49e51 100644 --- a/keylime/db/keylime_db.py +++ b/keylime/db/keylime_db.py @@ -1,7 +1,8 @@ import os from configparser import NoOptionError +from contextlib import contextmanager from sqlite3 import Connection as SQLite3Connection -from typing import Any, Dict, Optional, cast +from typing import Any, Iterator, Optional, cast from sqlalchemy import create_engine, event from sqlalchemy.engine import Engine @@ -22,90 +23,108 @@ def _set_sqlite_pragma(dbapi_connection: SQLite3Connection, _) -> None: cursor.close() -class DBEngineManager: - service: Optional[str] - - def __init__(self) -> None: - self.service = None - - def make_engine(self, service: str) -> Engine: - """ - To use: engine = self.make_engine('cloud_verifier') - """ - - # Keep DB related stuff as it is, but read configuration from new - # configs - if service == "cloud_verifier": - config_service = "verifier" - else: - config_service = service - - self.service = service - - try: - p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl") - p_sz, m_ovfl = p_sz_m_ovfl.split(",") - except NoOptionError: - p_sz = "5" - m_ovfl = "10" - - engine_args: Dict[str, Any] = {} - - url = config.get(config_service, "database_url") - if url: - logger.info("database_url is set, using it to establish database connection") - - # If the keyword sqlite is provided as the database url, use the - # cv_data.sqlite for the verifier or the file reg_data.sqlite for - # the registrar, located at the config.WORK_DIR directory - if url == "sqlite": +def make_engine(service: str, **engine_args: Any) -> Engine: + """Create a database engine for a keylime service.""" + # Keep DB related stuff as it is, but read configuration from new + # configs + if service == "cloud_verifier": + config_service = "verifier" + else: + config_service = service + + url = config.get(config_service, "database_url") + if url: + logger.info("database_url is set, using it to establish database connection") + + # If the keyword sqlite is provided as the database url, use the + # cv_data.sqlite for the verifier or the file reg_data.sqlite for + # the registrar, located at the config.WORK_DIR directory + if url == "sqlite": + logger.info( + "database_url is set as 'sqlite' keyword, using default values to establish database connection" + ) + if service == "cloud_verifier": + database = "cv_data.sqlite" + elif service == "registrar": + database = "reg_data.sqlite" + else: + logger.error("Tried to setup database access for unknown service '%s'", service) + raise Exception(f"Unknown service '{service}' for database setup") + + database_file = os.path.abspath(os.path.join(config.WORK_DIR, database)) + url = f"sqlite:///{database_file}" + + kl_dir = os.path.dirname(os.path.abspath(database_file)) + if not os.path.exists(kl_dir): + os.makedirs(kl_dir, 0o700) + + engine_args["connect_args"] = {"check_same_thread": False} + + if not url.count("sqlite:"): + # sqlite does not support setting pool size and max overflow, only + # read from the config when it is going to be used + try: + p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl") + p_sz, m_ovfl = p_sz_m_ovfl.split(",") + logger.info("database_pool_sz_ovfl is set, pool size = %s, max overflow = %s", p_sz, m_ovfl) + except NoOptionError: + p_sz = "5" + m_ovfl = "10" logger.info( - "database_url is set as 'sqlite' keyword, using default values to establish database connection" + "database_pool_sz_ovfl is not set, using default pool size = %s, max overflow = %s", p_sz, m_ovfl ) - if service == "cloud_verifier": - database = "cv_data.sqlite" - elif service == "registrar": - database = "reg_data.sqlite" - else: - logger.error("Tried to setup database access for unknown service '%s'", service) - raise Exception(f"Unknown service '{service}' for database setup") - - database_file = os.path.abspath(os.path.join(config.WORK_DIR, database)) - url = f"sqlite:///{database_file}" - - kl_dir = os.path.dirname(os.path.abspath(database_file)) - if not os.path.exists(kl_dir): - os.makedirs(kl_dir, 0o700) - - engine_args["connect_args"] = {"check_same_thread": False} - if not url.count("sqlite:"): - engine_args["pool_size"] = int(p_sz) - engine_args["max_overflow"] = int(m_ovfl) - engine_args["pool_pre_ping"] = True + engine_args["pool_size"] = int(p_sz) + engine_args["max_overflow"] = int(m_ovfl) + engine_args["pool_pre_ping"] = True - # Enable DB debugging - if config.DEBUG_DB and config.INSECURE_DEBUG: - engine_args["echo"] = True + # Enable DB debugging + if config.DEBUG_DB and config.INSECURE_DEBUG: + engine_args["echo"] = True - engine = create_engine(url, **engine_args) - return engine + engine = create_engine(url, **engine_args) + return engine class SessionManager: engine: Optional[Engine] + _scoped_session: Optional[scoped_session] def __init__(self) -> None: self.engine = None + self._scoped_session = None def make_session(self, engine: Engine) -> Session: """ To use: session = self.make_session(engine) """ self.engine = engine - my_session = scoped_session(sessionmaker()) + if self._scoped_session is None: + self._scoped_session = scoped_session(sessionmaker()) try: - my_session.configure(bind=self.engine) # type: ignore + self._scoped_session.configure(bind=self.engine) # type: ignore + self._scoped_session.configure(expire_on_commit=False) # type: ignore except SQLAlchemyError as err: logger.error("Error creating SQL session manager %s", err) - return cast(Session, my_session()) + return cast(Session, self._scoped_session()) + + @contextmanager + def session_context(self, engine: Engine) -> Iterator[Session]: + """ + Context manager for database sessions that ensures proper cleanup. + To use: + with session_manager.session_context(engine) as session: + # use session + """ + session = self.make_session(engine) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + # Important: remove the session from the scoped session registry + # to prevent connection leaks with scoped_session + if self._scoped_session is not None: + self._scoped_session.remove() # type: ignore[no-untyped-call] diff --git a/keylime/migrations/env.py b/keylime/migrations/env.py index ac98349..a1881f2 100644 --- a/keylime/migrations/env.py +++ b/keylime/migrations/env.py @@ -8,7 +8,7 @@ import sys from alembic import context -from keylime.db.keylime_db import DBEngineManager +from keylime.db.keylime_db import make_engine from keylime.db.registrar_db import Base as RegistrarBase from keylime.db.verifier_db import Base as VerifierBase @@ -74,7 +74,7 @@ def run_migrations_offline(): logger.info("Writing output to %s", file_) with open(file_, "w", encoding="utf-8") as buffer: - engine = DBEngineManager().make_engine(name) + engine = make_engine(name) connection = engine.connect() context.configure( connection=connection, @@ -102,7 +102,7 @@ def run_migrations_online(): engines = {} for name in re.split(r",\s*", db_names): engines[name] = rec = {} - rec["engine"] = DBEngineManager().make_engine(name) + rec["engine"] = make_engine(name) for name, rec in engines.items(): engine = rec["engine"] diff --git a/keylime/models/base/db.py b/keylime/models/base/db.py index dd47d63..0229765 100644 --- a/keylime/models/base/db.py +++ b/keylime/models/base/db.py @@ -41,13 +41,6 @@ class DBManager: self._service = service - try: - p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl") - p_sz, m_ovfl = p_sz_m_ovfl.split(",") - except NoOptionError: - p_sz = "5" - m_ovfl = "10" - engine_args: Dict[str, Any] = {} url = config.get(config_service, "database_url") @@ -79,6 +72,21 @@ class DBManager: engine_args["connect_args"] = {"check_same_thread": False} if not url.count("sqlite:"): + # sqlite does not support setting pool size and max overflow, only + # read from the config when it is going to be used + try: + p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl") + p_sz, m_ovfl = p_sz_m_ovfl.split(",") + logger.info("database_pool_sz_ovfl is set, pool size = %s, max overflow = %s", p_sz, m_ovfl) + except NoOptionError: + p_sz = "5" + m_ovfl = "10" + logger.info( + "database_pool_sz_ovfl is not set, using default pool size = %s, max overflow = %s", + p_sz, + m_ovfl, + ) + engine_args["pool_size"] = int(p_sz) engine_args["max_overflow"] = int(m_ovfl) engine_args["pool_pre_ping"] = True diff --git a/keylime/models/base/persistable_model.py b/keylime/models/base/persistable_model.py index 18f7d0d..a779f0b 100644 --- a/keylime/models/base/persistable_model.py +++ b/keylime/models/base/persistable_model.py @@ -207,10 +207,16 @@ class PersistableModel(BasicModel, metaclass=PersistableModelMeta): setattr(self._db_mapping_inst, name, field.data_type.db_dump(value, db_manager.engine.dialect)) with db_manager.session_context() as session: - session.add(self._db_mapping_inst) + # Merge the potentially detached object into the new session + merged_instance = session.merge(self._db_mapping_inst) + session.add(merged_instance) + # Update our reference to the merged instance + self._db_mapping_inst = merged_instance # pylint: disable=attribute-defined-outside-init self.clear_changes() def delete(self) -> None: with db_manager.session_context() as session: - session.delete(self._db_mapping_inst) # type: ignore[no-untyped-call] + # Merge the potentially detached object into the new session before deleting + merged_instance = session.merge(self._db_mapping_inst) + session.delete(merged_instance) # type: ignore[no-untyped-call] diff --git a/packit-ci.fmf b/packit-ci.fmf index 2d1e5e5..cb64faf 100644 --- a/packit-ci.fmf +++ b/packit-ci.fmf @@ -101,6 +101,7 @@ adjust: - /regression/CVE-2023-3674 - /regression/issue-1380-agent-removed-and-re-added - /regression/keylime-agent-option-override-through-envvar + - /regression/db-connection-leak-reproducer - /sanity/keylime-secure_mount - /sanity/opened-conf-files - /upstream/run_keylime_tests diff --git a/test/test_verifier_db.py b/test/test_verifier_db.py index ad72fa6..aae8f8a 100644 --- a/test/test_verifier_db.py +++ b/test/test_verifier_db.py @@ -172,3 +172,102 @@ class TestVerfierDB(unittest.TestCase): def tearDown(self): self.session.close() + + def test_11_relationship_access_after_session_commit(self): + """Test that relationships can be accessed after session commits (DetachedInstanceError fix)""" + # This test reproduces the problematic pattern from cloud_verifier_tornado.py + # where objects are loaded with joinedload and then accessed after session closes + + # Create a new session manager and context (like in cloud_verifier_tornado.py) + session_manager = SessionManager() + + # First, load the agent with eager loading for relationships + stored_agent = None + with session_manager.session_context(self.engine) as session: + stored_agent = ( + session.query(VerfierMain) + .options(joinedload(VerfierMain.ima_policy)) + .options(joinedload(VerfierMain.mb_policy)) + .filter_by(agent_id=agent_id) + .first() + ) + # Verify agent was loaded correctly + self.assertIsNotNone(stored_agent) + # session.commit() is automatically called by context manager when exiting + + # Now verify we can access relationships AFTER the session has been closed + # This would previously trigger DetachedInstanceError + + # Ensure stored_agent is not None before proceeding + assert stored_agent is not None + + # Test accessing ima_policy relationship + self.assertIsNotNone(stored_agent.ima_policy) + assert stored_agent.ima_policy is not None # Type narrowing for linter + self.assertEqual(stored_agent.ima_policy.name, "test-allowlist") + # checksum is not set in test data + self.assertEqual(stored_agent.ima_policy.checksum, None) + + # Test accessing the ima_policy.ima_policy attribute (similar to verifier_read_policy_from_cache) + ima_policy_content = stored_agent.ima_policy.ima_policy + self.assertEqual(ima_policy_content, test_allowlist_data["ima_policy"]) + + # Test accessing mb_policy relationship + self.assertIsNotNone(stored_agent.mb_policy) + assert stored_agent.mb_policy is not None # Type narrowing for linter + self.assertEqual(stored_agent.mb_policy.name, "test-mbpolicy") + + # Test accessing the mb_policy.mb_policy attribute (similar to process_agent function) + mb_policy_content = stored_agent.mb_policy.mb_policy + self.assertEqual(mb_policy_content, test_mbpolicy_data["mb_policy"]) + + # Test that we can access these relationships multiple times without issues + for _ in range(3): + self.assertIsNotNone(stored_agent.ima_policy.ima_policy) + self.assertIsNotNone(stored_agent.mb_policy.mb_policy) + + def test_12_persistable_model_cross_session_fix(self): + """Test that PersistableModel can handle cross-session operations safely""" + # This test would previously fail with DetachedInstanceError before the fix + # Note: This is a conceptual test since we don't have actual PersistableModel + # subclasses in the test environment, but demonstrates the pattern + + # Simulate creating a SQLAlchemy object in one session + session_manager = SessionManager() + + # Load an object in one session context + test_agent = None + with session_manager.session_context(self.engine) as session: + test_agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first() + self.assertIsNotNone(test_agent) + # Session closes here + + # Ensure test_agent is not None before proceeding + assert test_agent is not None + + # Now simulate using this object in a different session context + # This tests the pattern where PersistableModel would use session.add() or session.delete() + # on a cross-session object + with session_manager.session_context(self.engine) as session: + # Before the fix, this would cause DetachedInstanceError + # The fix uses session.merge() to handle detached objects safely + merged_agent = session.merge(test_agent) + assert merged_agent is not None # Type narrowing for linter + + # Test that we can modify and save the merged object + original_port = merged_agent.port + # Use setattr to avoid linter issues with Column assignment + setattr(merged_agent, "port", 9999) + session.add(merged_agent) + # session.commit() called automatically by context manager + + # Verify the change was persisted + with session_manager.session_context(self.engine) as session: + updated_agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first() + assert updated_agent is not None # Type narrowing for linter + self.assertEqual(updated_agent.port, 9999) + + # Restore original value + # Use setattr to avoid linter issues + setattr(updated_agent, "port", original_port) + session.add(updated_agent)