keylime/SOURCES/0007-fix_db_connection_leaks.patch

2209 lines
106 KiB
Diff

diff --git a/keylime/cloud_verifier_tornado.py b/keylime/cloud_verifier_tornado.py
index 8ab81d1..7553ac8 100644
--- a/keylime/cloud_verifier_tornado.py
+++ b/keylime/cloud_verifier_tornado.py
@@ -7,7 +7,8 @@ import sys
import traceback
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Process
-from typing import Any, Dict, List, Optional, Tuple, Union, cast
+from contextlib import contextmanager
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
import tornado.httpserver
import tornado.ioloop
@@ -34,7 +35,7 @@ from keylime.agentstates import AgentAttestState, AgentAttestStates
from keylime.common import retry, states, validators
from keylime.common.version import str_to_version
from keylime.da import record
-from keylime.db.keylime_db import DBEngineManager, SessionManager
+from keylime.db.keylime_db import SessionManager, make_engine
from keylime.db.verifier_db import VerfierMain, VerifierAllowlist, VerifierMbpolicy
from keylime.failure import MAX_SEVERITY_LABEL, Component, Event, Failure, set_severity_config
from keylime.ima import ima
@@ -47,7 +48,7 @@ GLOBAL_POLICY_CACHE: Dict[str, Dict[str, str]] = {}
set_severity_config(config.getlist("verifier", "severity_labels"), config.getlist("verifier", "severity_policy"))
try:
- engine = DBEngineManager().make_engine("cloud_verifier")
+ engine = make_engine("cloud_verifier")
except SQLAlchemyError as err:
logger.error("Error creating SQL engine or session: %s", err)
sys.exit(1)
@@ -61,8 +62,17 @@ except record.RecordManagementException as rme:
sys.exit(1)
-def get_session() -> Session:
- return SessionManager().make_session(engine)
+@contextmanager
+def session_context() -> Iterator[Session]:
+ """
+ Context manager for database sessions that ensures proper cleanup.
+ To use:
+ with session_context() as session:
+ # use session
+ """
+ session_manager = SessionManager()
+ with session_manager.session_context(engine) as session:
+ yield session
def get_AgentAttestStates() -> AgentAttestStates:
@@ -130,19 +140,18 @@ def _from_db_obj(agent_db_obj: VerfierMain) -> Dict[str, Any]:
return agent_dict
-def verifier_read_policy_from_cache(stored_agent: VerfierMain) -> str:
- checksum = ""
- name = "empty"
- agent_id = str(stored_agent.agent_id)
+def verifier_read_policy_from_cache(ima_policy_data: Dict[str, str]) -> str:
+ checksum = ima_policy_data.get("checksum", "")
+ name = ima_policy_data.get("name", "empty")
+ agent_id = ima_policy_data.get("agent_id", "")
+
+ if not agent_id:
+ return ""
if agent_id not in GLOBAL_POLICY_CACHE:
GLOBAL_POLICY_CACHE[agent_id] = {}
GLOBAL_POLICY_CACHE[agent_id][""] = ""
- if stored_agent.ima_policy:
- checksum = str(stored_agent.ima_policy.checksum)
- name = stored_agent.ima_policy.name
-
if checksum not in GLOBAL_POLICY_CACHE[agent_id]:
if len(GLOBAL_POLICY_CACHE[agent_id]) > 1:
# Perform a cleanup of the contents, IMA policy checksum changed
@@ -162,8 +171,9 @@ def verifier_read_policy_from_cache(stored_agent: VerfierMain) -> str:
checksum,
agent_id,
)
- # Actually contacts the database and load the (large) ima_policy column for "allowlists" table
- ima_policy = stored_agent.ima_policy.ima_policy
+
+ # Get the large ima_policy content - it's already loaded in ima_policy_data
+ ima_policy = ima_policy_data.get("ima_policy", "")
assert isinstance(ima_policy, str)
GLOBAL_POLICY_CACHE[agent_id][checksum] = ima_policy
@@ -182,22 +192,19 @@ def store_attestation_state(agentAttestState: AgentAttestState) -> None:
# Only store if IMA log was evaluated
if agentAttestState.get_ima_pcrs():
agent_id = agentAttestState.agent_id
- session = get_session()
try:
- update_agent = session.query(VerfierMain).get(agentAttestState.get_agent_id())
- assert update_agent
- update_agent.boottime = agentAttestState.get_boottime()
- update_agent.next_ima_ml_entry = agentAttestState.get_next_ima_ml_entry()
- ima_pcrs_dict = agentAttestState.get_ima_pcrs()
- update_agent.ima_pcrs = list(ima_pcrs_dict.keys())
- for pcr_num, value in ima_pcrs_dict.items():
- setattr(update_agent, f"pcr{pcr_num}", value)
- update_agent.learned_ima_keyrings = agentAttestState.get_ima_keyrings().to_json()
- try:
+ with session_context() as session:
+ update_agent = session.query(VerfierMain).get(agentAttestState.get_agent_id())
+ assert update_agent
+ update_agent.boottime = agentAttestState.get_boottime()
+ update_agent.next_ima_ml_entry = agentAttestState.get_next_ima_ml_entry()
+ ima_pcrs_dict = agentAttestState.get_ima_pcrs()
+ update_agent.ima_pcrs = list(ima_pcrs_dict.keys())
+ for pcr_num, value in ima_pcrs_dict.items():
+ setattr(update_agent, f"pcr{pcr_num}", value)
+ update_agent.learned_ima_keyrings = agentAttestState.get_ima_keyrings().to_json()
session.add(update_agent)
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error on storing attestation state for agent %s: %s", agent_id, e)
- session.commit()
+ # session.commit() is automatically called by context manager
except SQLAlchemyError as e:
logger.error("SQLAlchemy Error on storing attestation state for agent %s: %s", agent_id, e)
@@ -354,45 +361,17 @@ class AgentsHandler(BaseHandler):
was not found, it either completed successfully, or failed. If found, the agent_id is still polling
to contact the Cloud Agent.
"""
- session = get_session()
-
rest_params, agent_id = self.__validate_input("GET")
if not rest_params:
return
- if (agent_id is not None) and (agent_id != ""):
- # If the agent ID is not valid (wrong set of characters),
- # just do nothing.
- agent = None
- try:
- agent = (
- session.query(VerfierMain)
- .options( # type: ignore
- joinedload(VerfierMain.ima_policy).load_only(
- VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
- )
- )
- .options( # type: ignore
- joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore
- )
- .filter_by(agent_id=agent_id)
- .one_or_none()
- )
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
-
- if agent is not None:
- response = cloud_verifier_common.process_get_status(agent)
- web_util.echo_json_response(self, 200, "Success", response)
- else:
- web_util.echo_json_response(self, 404, "agent id not found")
- else:
- json_response = None
- if "bulk" in rest_params:
- agent_list = None
-
- if ("verifier" in rest_params) and (rest_params["verifier"] != ""):
- agent_list = (
+ with session_context() as session:
+ if (agent_id is not None) and (agent_id != ""):
+ # If the agent ID is not valid (wrong set of characters),
+ # just do nothing.
+ agent = None
+ try:
+ agent = (
session.query(VerfierMain)
.options( # type: ignore
joinedload(VerfierMain.ima_policy).load_only(
@@ -402,39 +381,70 @@ class AgentsHandler(BaseHandler):
.options( # type: ignore
joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore
)
- .filter_by(verifier_id=rest_params["verifier"])
- .all()
+ .filter_by(agent_id=agent_id)
+ .one_or_none()
)
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
+
+ if agent is not None:
+ response = cloud_verifier_common.process_get_status(agent)
+ web_util.echo_json_response(self, 200, "Success", response)
else:
- agent_list = (
- session.query(VerfierMain)
- .options( # type: ignore
- joinedload(VerfierMain.ima_policy).load_only(
- VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
+ web_util.echo_json_response(self, 404, "agent id not found")
+ else:
+ json_response = None
+ if "bulk" in rest_params:
+ agent_list = None
+
+ if ("verifier" in rest_params) and (rest_params["verifier"] != ""):
+ agent_list = (
+ session.query(VerfierMain)
+ .options( # type: ignore
+ joinedload(VerfierMain.ima_policy).load_only(
+ VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
+ )
)
+ .options( # type: ignore
+ joinedload(VerfierMain.mb_policy).load_only(
+ VerifierMbpolicy.mb_policy # type: ignore[arg-type]
+ )
+ )
+ .filter_by(verifier_id=rest_params["verifier"])
+ .all()
)
- .options( # type: ignore
- joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore
+ else:
+ agent_list = (
+ session.query(VerfierMain)
+ .options( # type: ignore
+ joinedload(VerfierMain.ima_policy).load_only(
+ VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
+ )
+ )
+ .options( # type: ignore
+ joinedload(VerfierMain.mb_policy).load_only(
+ VerifierMbpolicy.mb_policy # type: ignore[arg-type]
+ )
+ )
+ .all()
)
- .all()
- )
- json_response = {}
- for agent in agent_list:
- json_response[agent.agent_id] = cloud_verifier_common.process_get_status(agent)
+ json_response = {}
+ for agent in agent_list:
+ json_response[agent.agent_id] = cloud_verifier_common.process_get_status(agent)
- web_util.echo_json_response(self, 200, "Success", json_response)
- else:
- if ("verifier" in rest_params) and (rest_params["verifier"] != ""):
- json_response_list = (
- session.query(VerfierMain.agent_id).filter_by(verifier_id=rest_params["verifier"]).all()
- )
+ web_util.echo_json_response(self, 200, "Success", json_response)
else:
- json_response_list = session.query(VerfierMain.agent_id).all()
+ if ("verifier" in rest_params) and (rest_params["verifier"] != ""):
+ json_response_list = (
+ session.query(VerfierMain.agent_id).filter_by(verifier_id=rest_params["verifier"]).all()
+ )
+ else:
+ json_response_list = session.query(VerfierMain.agent_id).all()
- web_util.echo_json_response(self, 200, "Success", {"uuids": json_response_list})
+ web_util.echo_json_response(self, 200, "Success", {"uuids": json_response_list})
- logger.info("GET returning 200 response for agent_id list")
+ logger.info("GET returning 200 response for agent_id list")
def delete(self) -> None:
"""This method handles the DELETE requests to remove agents from the Cloud Verifier.
@@ -442,59 +452,55 @@ class AgentsHandler(BaseHandler):
Currently, only agents resources are available for DELETEing, i.e. /agents. All other DELETE uri's will return errors.
agents requests require a single agent_id parameter which identifies the agent to be deleted.
"""
- session = get_session()
-
rest_params, agent_id = self.__validate_input("DELETE")
if not rest_params or not agent_id:
return
- agent = None
- try:
- agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
+ with session_context() as session:
+ agent = None
+ try:
+ agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first()
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
- if agent is None:
- web_util.echo_json_response(self, 404, "agent id not found")
- logger.info("DELETE returning 404 response. agent id: %s not found.", agent_id)
- return
+ if agent is None:
+ web_util.echo_json_response(self, 404, "agent id not found")
+ logger.info("DELETE returning 404 response. agent id: %s not found.", agent_id)
+ return
- verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID)
- if verifier_id != agent.verifier_id:
- web_util.echo_json_response(self, 404, "agent id associated to this verifier")
- logger.info("DELETE returning 404 response. agent id: %s not associated to this verifer.", agent_id)
- return
+ verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID)
+ if verifier_id != agent.verifier_id:
+ web_util.echo_json_response(self, 404, "agent id associated to this verifier")
+ logger.info("DELETE returning 404 response. agent id: %s not associated to this verifer.", agent_id)
+ return
- # Cleanup the cache when the agent is deleted. Do it early.
- if agent_id in GLOBAL_POLICY_CACHE:
- del GLOBAL_POLICY_CACHE[agent_id]
- logger.debug(
- "Cleaned up policy cache from all entries used by agent %s",
- agent_id,
- )
+ # Cleanup the cache when the agent is deleted. Do it early.
+ if agent_id in GLOBAL_POLICY_CACHE:
+ del GLOBAL_POLICY_CACHE[agent_id]
+ logger.debug(
+ "Cleaned up policy cache from all entries used by agent %s",
+ agent_id,
+ )
- op_state = agent.operational_state
- if op_state in (states.SAVED, states.FAILED, states.TERMINATED, states.TENANT_FAILED, states.INVALID_QUOTE):
- try:
- verifier_db_delete_agent(session, agent_id)
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- web_util.echo_json_response(self, 200, "Success")
- logger.info("DELETE returning 200 response for agent id: %s", agent_id)
- else:
- try:
- update_agent = session.query(VerfierMain).get(agent_id)
- assert update_agent
- update_agent.operational_state = states.TERMINATED
+ op_state = agent.operational_state
+ if op_state in (states.SAVED, states.FAILED, states.TERMINATED, states.TENANT_FAILED, states.INVALID_QUOTE):
try:
+ verifier_db_delete_agent(session, agent_id)
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ web_util.echo_json_response(self, 200, "Success")
+ logger.info("DELETE returning 200 response for agent id: %s", agent_id)
+ else:
+ try:
+ update_agent = session.query(VerfierMain).get(agent_id)
+ assert update_agent
+ update_agent.operational_state = states.TERMINATED
session.add(update_agent)
+ # session.commit() is automatically called by context manager
+ web_util.echo_json_response(self, 202, "Accepted")
+ logger.info("DELETE returning 202 response for agent id: %s", agent_id)
except SQLAlchemyError as e:
logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
- session.commit()
- web_util.echo_json_response(self, 202, "Accepted")
- logger.info("DELETE returning 202 response for agent id: %s", agent_id)
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
def post(self) -> None:
"""This method handles the POST requests to add agents to the Cloud Verifier.
@@ -502,7 +508,6 @@ class AgentsHandler(BaseHandler):
Currently, only agents resources are available for POSTing, i.e. /agents. All other POST uri's will return errors.
agents requests require a json block sent in the body
"""
- session = get_session()
# TODO: exception handling needs fixing
# Maybe handle exceptions with if/else if/else blocks ... simple and avoids nesting
try: # pylint: disable=too-many-nested-blocks
@@ -585,201 +590,208 @@ class AgentsHandler(BaseHandler):
runtime_policy = base64.b64decode(json_body.get("runtime_policy")).decode()
runtime_policy_stored = None
- if runtime_policy_name:
+ with session_context() as session:
+ if runtime_policy_name:
+ try:
+ runtime_policy_stored = (
+ session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none()
+ )
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
+ raise
+
+ # Prevent overwriting existing IMA policies with name provided in request
+ if runtime_policy and runtime_policy_stored:
+ web_util.echo_json_response(
+ self,
+ 409,
+ f"IMA policy with name {runtime_policy_name} already exists. Please use a different name or delete the allowlist from the verifier.",
+ )
+ logger.warning("IMA policy with name %s already exists", runtime_policy_name)
+ return
+
+ # Return an error code if the named allowlist does not exist in the database
+ if not runtime_policy and not runtime_policy_stored:
+ web_util.echo_json_response(
+ self, 404, f"Could not find IMA policy with name {runtime_policy_name}!"
+ )
+ logger.warning("Could not find IMA policy with name %s", runtime_policy_name)
+ return
+
+ # Prevent overwriting existing agents with UUID provided in request
try:
- runtime_policy_stored = (
- session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none()
- )
+ new_agent_count = session.query(VerfierMain).filter_by(agent_id=agent_id).count()
except SQLAlchemyError as e:
logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
- raise
+ raise e
- # Prevent overwriting existing IMA policies with name provided in request
- if runtime_policy and runtime_policy_stored:
+ if new_agent_count > 0:
web_util.echo_json_response(
self,
409,
- f"IMA policy with name {runtime_policy_name} already exists. Please use a different name or delete the allowlist from the verifier.",
+ f"Agent of uuid {agent_id} already exists. Please use delete or update.",
)
- logger.warning("IMA policy with name %s already exists", runtime_policy_name)
+ logger.warning("Agent of uuid %s already exists", agent_id)
return
- # Return an error code if the named allowlist does not exist in the database
- if not runtime_policy and not runtime_policy_stored:
- web_util.echo_json_response(
- self, 404, f"Could not find IMA policy with name {runtime_policy_name}!"
- )
- logger.warning("Could not find IMA policy with name %s", runtime_policy_name)
- return
-
- # Prevent overwriting existing agents with UUID provided in request
- try:
- new_agent_count = session.query(VerfierMain).filter_by(agent_id=agent_id).count()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
- raise e
-
- if new_agent_count > 0:
- web_util.echo_json_response(
- self,
- 409,
- f"Agent of uuid {agent_id} already exists. Please use delete or update.",
- )
- logger.warning("Agent of uuid %s already exists", agent_id)
- return
-
- # Write IMA policy to database if needed
- if not runtime_policy_name and not runtime_policy:
- logger.info("IMA policy data not provided with request! Using default empty IMA policy.")
- runtime_policy = json.dumps(cast(Dict[str, Any], ima.EMPTY_RUNTIME_POLICY))
+ # Write IMA policy to database if needed
+ if not runtime_policy_name and not runtime_policy:
+ logger.info("IMA policy data not provided with request! Using default empty IMA policy.")
+ runtime_policy = json.dumps(cast(Dict[str, Any], ima.EMPTY_RUNTIME_POLICY))
- if runtime_policy:
- runtime_policy_key_bytes = signing.get_runtime_policy_keys(
- runtime_policy.encode(),
- json_body.get("runtime_policy_key"),
- )
-
- try:
- ima.verify_runtime_policy(
+ if runtime_policy:
+ runtime_policy_key_bytes = signing.get_runtime_policy_keys(
runtime_policy.encode(),
- runtime_policy_key_bytes,
- verify_sig=config.getboolean(
- "verifier", "require_allow_list_signatures", fallback=False
- ),
+ json_body.get("runtime_policy_key"),
)
- except ima.ImaValidationError as e:
- web_util.echo_json_response(self, e.code, e.message)
- logger.warning(e.message)
- return
- if not runtime_policy_name:
- runtime_policy_name = agent_id
-
- try:
- runtime_policy_db_format = ima.runtime_policy_db_contents(
- runtime_policy_name, runtime_policy
- )
- except ima.ImaValidationError as e:
- message = f"Runtime policy is malformatted: {e.message}"
- web_util.echo_json_response(self, e.code, message)
- logger.warning(message)
- return
-
- try:
- runtime_policy_stored = (
- session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none()
- )
- except SQLAlchemyError as e:
- logger.error(
- "SQLAlchemy Error while retrieving stored ima policy for agent ID %s: %s", agent_id, e
- )
- raise
- try:
- if runtime_policy_stored is None:
- runtime_policy_stored = VerifierAllowlist(**runtime_policy_db_format)
- session.add(runtime_policy_stored)
+ try:
+ ima.verify_runtime_policy(
+ runtime_policy.encode(),
+ runtime_policy_key_bytes,
+ verify_sig=config.getboolean(
+ "verifier", "require_allow_list_signatures", fallback=False
+ ),
+ )
+ except ima.ImaValidationError as e:
+ web_util.echo_json_response(self, e.code, e.message)
+ logger.warning(e.message)
+ return
+
+ if not runtime_policy_name:
+ runtime_policy_name = agent_id
+
+ try:
+ runtime_policy_db_format = ima.runtime_policy_db_contents(
+ runtime_policy_name, runtime_policy
+ )
+ except ima.ImaValidationError as e:
+ message = f"Runtime policy is malformatted: {e.message}"
+ web_util.echo_json_response(self, e.code, message)
+ logger.warning(message)
+ return
+
+ try:
+ runtime_policy_stored = (
+ session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none()
+ )
+ except SQLAlchemyError as e:
+ logger.error(
+ "SQLAlchemy Error while retrieving stored ima policy for agent ID %s: %s",
+ agent_id,
+ e,
+ )
+ raise
+ try:
+ if runtime_policy_stored is None:
+ runtime_policy_stored = VerifierAllowlist(**runtime_policy_db_format)
+ session.add(runtime_policy_stored)
+ session.commit()
+ except SQLAlchemyError as e:
+ logger.error(
+ "SQLAlchemy Error while updating ima policy for agent ID %s: %s", agent_id, e
+ )
+ raise
+
+ # Handle measured boot policy
+ # - No name, mb_policy : store mb_policy using agent UUID as name
+ # - Name, no mb_policy : fetch existing mb_policy from DB
+ # - Name, mb_policy : store mb_policy using name
+
+ mb_policy_name = json_body["mb_policy_name"]
+ mb_policy = json_body["mb_policy"]
+ mb_policy_stored = None
+
+ if mb_policy_name:
+ try:
+ mb_policy_stored = (
+ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none()
+ )
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
+ raise
+
+ # Prevent overwriting existing mb_policy with name provided in request
+ if mb_policy and mb_policy_stored:
+ web_util.echo_json_response(
+ self,
+ 409,
+ f"mb_policy with name {mb_policy_name} already exists. Please use a different name or delete the mb_policy from the verifier.",
+ )
+ logger.warning("mb_policy with name %s already exists", mb_policy_name)
+ return
+
+ # Return error if the mb_policy is neither provided nor stored.
+ if not mb_policy and not mb_policy_stored:
+ web_util.echo_json_response(
+ self, 404, f"Could not find mb_policy with name {mb_policy_name}!"
+ )
+ logger.warning("Could not find mb_policy with name %s", mb_policy_name)
+ return
+
+ else:
+ # Use the UUID of the agent
+ mb_policy_name = agent_id
+ try:
+ mb_policy_stored = (
+ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none()
+ )
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
+ raise
+
+ # Prevent overwriting existing mb_policy
+ if mb_policy and mb_policy_stored:
+ web_util.echo_json_response(
+ self,
+ 409,
+ f"mb_policy with name {mb_policy_name} already exists. You can delete the mb_policy from the verifier.",
+ )
+ logger.warning("mb_policy with name %s already exists", mb_policy_name)
+ return
+
+ # Store the policy into database if not stored
+ if mb_policy_stored is None:
+ try:
+ mb_policy_db_format = mba.mb_policy_db_contents(mb_policy_name, mb_policy)
+ mb_policy_stored = VerifierMbpolicy(**mb_policy_db_format)
+ session.add(mb_policy_stored)
session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error while updating ima policy for agent ID %s: %s", agent_id, e)
- raise
-
- # Handle measured boot policy
- # - No name, mb_policy : store mb_policy using agent UUID as name
- # - Name, no mb_policy : fetch existing mb_policy from DB
- # - Name, mb_policy : store mb_policy using name
-
- mb_policy_name = json_body["mb_policy_name"]
- mb_policy = json_body["mb_policy"]
- mb_policy_stored = None
+ except SQLAlchemyError as e:
+ logger.error(
+ "SQLAlchemy Error while updating mb_policy for agent ID %s: %s", agent_id, e
+ )
+ raise
- if mb_policy_name:
+ # Write the agent to the database, attaching associated stored ima_policy and mb_policy
try:
- mb_policy_stored = (
- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none()
+ assert runtime_policy_stored
+ assert mb_policy_stored
+ session.add(
+ VerfierMain(**agent_data, ima_policy=runtime_policy_stored, mb_policy=mb_policy_stored)
)
+ session.commit()
except SQLAlchemyError as e:
logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
- raise
+ raise e
- # Prevent overwriting existing mb_policy with name provided in request
- if mb_policy and mb_policy_stored:
- web_util.echo_json_response(
- self,
- 409,
- f"mb_policy with name {mb_policy_name} already exists. Please use a different name or delete the mb_policy from the verifier.",
- )
- logger.warning("mb_policy with name %s already exists", mb_policy_name)
- return
+ # add default fields that are ephemeral
+ for key, val in exclude_db.items():
+ agent_data[key] = val
- # Return error if the mb_policy is neither provided nor stored.
- if not mb_policy and not mb_policy_stored:
- web_util.echo_json_response(
- self, 404, f"Could not find mb_policy with name {mb_policy_name}!"
+ # Prepare SSLContext for mTLS connections
+ agent_data["ssl_context"] = None
+ if agent_mtls_cert_enabled:
+ agent_data["ssl_context"] = web_util.generate_agent_tls_context(
+ "verifier", agent_data["mtls_cert"], logger=logger
)
- logger.warning("Could not find mb_policy with name %s", mb_policy_name)
- return
- else:
- # Use the UUID of the agent
- mb_policy_name = agent_id
- try:
- mb_policy_stored = (
- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none()
- )
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
- raise
-
- # Prevent overwriting existing mb_policy
- if mb_policy and mb_policy_stored:
- web_util.echo_json_response(
- self,
- 409,
- f"mb_policy with name {mb_policy_name} already exists. You can delete the mb_policy from the verifier.",
- )
- logger.warning("mb_policy with name %s already exists", mb_policy_name)
- return
-
- # Store the policy into database if not stored
- if mb_policy_stored is None:
- try:
- mb_policy_db_format = mba.mb_policy_db_contents(mb_policy_name, mb_policy)
- mb_policy_stored = VerifierMbpolicy(**mb_policy_db_format)
- session.add(mb_policy_stored)
- session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error while updating mb_policy for agent ID %s: %s", agent_id, e)
- raise
+ if agent_data["ssl_context"] is None:
+ logger.warning("Connecting to agent without mTLS: %s", agent_id)
- # Write the agent to the database, attaching associated stored ima_policy and mb_policy
- try:
- assert runtime_policy_stored
- assert mb_policy_stored
- session.add(
- VerfierMain(**agent_data, ima_policy=runtime_policy_stored, mb_policy=mb_policy_stored)
- )
- session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
- raise e
-
- # add default fields that are ephemeral
- for key, val in exclude_db.items():
- agent_data[key] = val
-
- # Prepare SSLContext for mTLS connections
- agent_data["ssl_context"] = None
- if agent_mtls_cert_enabled:
- agent_data["ssl_context"] = web_util.generate_agent_tls_context(
- "verifier", agent_data["mtls_cert"], logger=logger
- )
-
- if agent_data["ssl_context"] is None:
- logger.warning("Connecting to agent without mTLS: %s", agent_id)
-
- asyncio.ensure_future(process_agent(agent_data, states.GET_QUOTE))
- web_util.echo_json_response(self, 200, "Success")
- logger.info("POST returning 200 response for adding agent id: %s", agent_id)
+ asyncio.ensure_future(process_agent(agent_data, states.GET_QUOTE))
+ web_util.echo_json_response(self, 200, "Success")
+ logger.info("POST returning 200 response for adding agent id: %s", agent_id)
else:
web_util.echo_json_response(self, 400, "uri not supported")
logger.warning("POST returning 400 response. uri not supported")
@@ -794,54 +806,54 @@ class AgentsHandler(BaseHandler):
Currently, only agents resources are available for PUTing, i.e. /agents. All other PUT uri's will return errors.
agents requests require a json block sent in the body
"""
- session = get_session()
try:
rest_params, agent_id = self.__validate_input("PUT")
if not rest_params:
return
- try:
- verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID)
- db_agent = session.query(VerfierMain).filter_by(agent_id=agent_id, verifier_id=verifier_id).one()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
- raise e
+ with session_context() as session:
+ try:
+ verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID)
+ db_agent = session.query(VerfierMain).filter_by(agent_id=agent_id, verifier_id=verifier_id).one()
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
+ raise e
- if db_agent is None:
- web_util.echo_json_response(self, 404, "agent id not found")
- logger.info("PUT returning 404 response. agent id: %s not found.", agent_id)
- return
+ if db_agent is None:
+ web_util.echo_json_response(self, 404, "agent id not found")
+ logger.info("PUT returning 404 response. agent id: %s not found.", agent_id)
+ return
- if "reactivate" in rest_params:
- agent = _from_db_obj(db_agent)
+ if "reactivate" in rest_params:
+ agent = _from_db_obj(db_agent)
- if agent["mtls_cert"] and agent["mtls_cert"] != "disabled":
- agent["ssl_context"] = web_util.generate_agent_tls_context(
- "verifier", agent["mtls_cert"], logger=logger
- )
- if agent["ssl_context"] is None:
- logger.warning("Connecting to agent without mTLS: %s", agent_id)
+ if agent["mtls_cert"] and agent["mtls_cert"] != "disabled":
+ agent["ssl_context"] = web_util.generate_agent_tls_context(
+ "verifier", agent["mtls_cert"], logger=logger
+ )
+ if agent["ssl_context"] is None:
+ logger.warning("Connecting to agent without mTLS: %s", agent_id)
- agent["operational_state"] = states.START
- asyncio.ensure_future(process_agent(agent, states.GET_QUOTE))
- web_util.echo_json_response(self, 200, "Success")
- logger.info("PUT returning 200 response for agent id: %s", agent_id)
- elif "stop" in rest_params:
- # do stuff for terminate
- logger.debug("Stopping polling on %s", agent_id)
- try:
- session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).update( # pyright: ignore
- {"operational_state": states.TENANT_FAILED}
- )
- session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
+ agent["operational_state"] = states.START
+ asyncio.ensure_future(process_agent(agent, states.GET_QUOTE))
+ web_util.echo_json_response(self, 200, "Success")
+ logger.info("PUT returning 200 response for agent id: %s", agent_id)
+ elif "stop" in rest_params:
+ # do stuff for terminate
+ logger.debug("Stopping polling on %s", agent_id)
+ try:
+ session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).update( # pyright: ignore
+ {"operational_state": states.TENANT_FAILED}
+ )
+ # session.commit() is automatically called by context manager
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
- web_util.echo_json_response(self, 200, "Success")
- logger.info("PUT returning 200 response for agent id: %s", agent_id)
- else:
- web_util.echo_json_response(self, 400, "uri not supported")
- logger.warning("PUT returning 400 response. uri not supported")
+ web_util.echo_json_response(self, 200, "Success")
+ logger.info("PUT returning 200 response for agent id: %s", agent_id)
+ else:
+ web_util.echo_json_response(self, 400, "uri not supported")
+ logger.warning("PUT returning 400 response. uri not supported")
except Exception as e:
web_util.echo_json_response(self, 400, f"Exception error: {str(e)}")
@@ -887,36 +899,36 @@ class AllowlistHandler(BaseHandler):
if not params_valid:
return
- session = get_session()
- if allowlist_name is None:
- try:
- names_allowlists = session.query(VerifierAllowlist.name).all()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- web_util.echo_json_response(self, 500, "Failed to get names of allowlists")
- raise
+ with session_context() as session:
+ if allowlist_name is None:
+ try:
+ names_allowlists = session.query(VerifierAllowlist.name).all()
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ web_util.echo_json_response(self, 500, "Failed to get names of allowlists")
+ raise
- names_response = []
- for name in names_allowlists:
- names_response.append(name[0])
- web_util.echo_json_response(self, 200, "Success", {"runtimepolicy names": names_response})
+ names_response = []
+ for name in names_allowlists:
+ names_response.append(name[0])
+ web_util.echo_json_response(self, 200, "Success", {"runtimepolicy names": names_response})
- else:
- try:
- allowlist = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one()
- except NoResultFound:
- web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found")
- return
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- web_util.echo_json_response(self, 500, "Failed to get allowlist")
- raise
+ else:
+ try:
+ allowlist = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one()
+ except NoResultFound:
+ web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found")
+ return
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ web_util.echo_json_response(self, 500, "Failed to get allowlist")
+ raise
- response = {}
- for field in ("name", "tpm_policy"):
- response[field] = getattr(allowlist, field, None)
- response["runtime_policy"] = getattr(allowlist, "ima_policy", None)
- web_util.echo_json_response(self, 200, "Success", response)
+ response = {}
+ for field in ("name", "tmp_policy"):
+ response[field] = getattr(allowlist, field, None)
+ response["runtime_policy"] = getattr(allowlist, "ima_policy", None)
+ web_util.echo_json_response(self, 200, "Success", response)
def delete(self) -> None:
"""Delete an allowlist
@@ -928,45 +940,44 @@ class AllowlistHandler(BaseHandler):
if not params_valid or allowlist_name is None:
return
- session = get_session()
- try:
- runtime_policy = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one()
- except NoResultFound:
- web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found")
- return
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- web_util.echo_json_response(self, 500, "Failed to get allowlist")
- raise
+ with session_context() as session:
+ try:
+ runtime_policy = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one()
+ except NoResultFound:
+ web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found")
+ return
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ web_util.echo_json_response(self, 500, "Failed to get allowlist")
+ raise
- try:
- agent = session.query(VerfierMain).filter_by(ima_policy_id=runtime_policy.id).one_or_none()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- raise
- if agent is not None:
- web_util.echo_json_response(
- self,
- 409,
- f"Can't delete allowlist as it's currently in use by agent {agent.agent_id}",
- )
- return
+ try:
+ agent = session.query(VerfierMain).filter_by(ima_policy_id=runtime_policy.id).one_or_none()
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ raise
+ if agent is not None:
+ web_util.echo_json_response(
+ self,
+ 409,
+ f"Can't delete allowlist as it's currently in use by agent {agent.agent_id}",
+ )
+ return
- try:
- session.query(VerifierAllowlist).filter_by(name=allowlist_name).delete()
- session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- session.close()
- web_util.echo_json_response(self, 500, f"Database error: {e}")
- raise
+ try:
+ session.query(VerifierAllowlist).filter_by(name=allowlist_name).delete()
+ # session.commit() is automatically called by context manager
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ web_util.echo_json_response(self, 500, f"Database error: {e}")
+ raise
- # NOTE(kaifeng) 204 Can not have response body, but current helper
- # doesn't support this case.
- self.set_status(204)
- self.set_header("Content-Type", "application/json")
- self.finish()
- logger.info("DELETE returning 204 response for allowlist: %s", allowlist_name)
+ # NOTE(kaifeng) 204 Can not have response body, but current helper
+ # doesn't support this case.
+ self.set_status(204)
+ self.set_header("Content-Type", "application/json")
+ self.finish()
+ logger.info("DELETE returning 204 response for allowlist: %s", allowlist_name)
def __get_runtime_policy_db_format(self, runtime_policy_name: str) -> Dict[str, Any]:
"""Get the IMA policy from the request and return it in Db format"""
@@ -1022,28 +1033,30 @@ class AllowlistHandler(BaseHandler):
if not runtime_policy_db_format:
return
- session = get_session()
- # don't allow overwritting
- try:
- runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count()
- if runtime_policy_count > 0:
- web_util.echo_json_response(self, 409, f"Runtime policy with name {runtime_policy_name} already exists")
- logger.warning("Runtime policy with name %s already exists", runtime_policy_name)
- return
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- raise
+ with session_context() as session:
+ # don't allow overwritting
+ try:
+ runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count()
+ if runtime_policy_count > 0:
+ web_util.echo_json_response(
+ self, 409, f"Runtime policy with name {runtime_policy_name} already exists"
+ )
+ logger.warning("Runtime policy with name %s already exists", runtime_policy_name)
+ return
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ raise
- try:
- # Add the agent and data
- session.add(VerifierAllowlist(**runtime_policy_db_format))
- session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- raise
+ try:
+ # Add the agent and data
+ session.add(VerifierAllowlist(**runtime_policy_db_format))
+ # session.commit() is automatically called by context manager
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ raise
- web_util.echo_json_response(self, 201)
- logger.info("POST returning 201")
+ web_util.echo_json_response(self, 201)
+ logger.info("POST returning 201")
def put(self) -> None:
"""Update an allowlist
@@ -1060,32 +1073,34 @@ class AllowlistHandler(BaseHandler):
if not runtime_policy_db_format:
return
- session = get_session()
- # don't allow creating a new policy
- try:
- runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count()
- if runtime_policy_count != 1:
- web_util.echo_json_response(
- self, 409, f"Runtime policy with name {runtime_policy_name} does not already exist"
- )
- logger.warning("Runtime policy with name %s does not already exist", runtime_policy_name)
- return
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- raise
+ with session_context() as session:
+ # don't allow creating a new policy
+ try:
+ runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count()
+ if runtime_policy_count != 1:
+ web_util.echo_json_response(
+ self,
+ 404,
+ f"Runtime policy with name {runtime_policy_name} does not already exist, use POST to create",
+ )
+ logger.warning("Runtime policy with name %s does not already exist", runtime_policy_name)
+ return
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ raise
- try:
- # Update the named runtime policy
- session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).update(
- runtime_policy_db_format # pyright: ignore
- )
- session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- raise
+ try:
+ # Update the named runtime policy
+ session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).update(
+ runtime_policy_db_format # pyright: ignore
+ )
+ # session.commit() is automatically called by context manager
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ raise
- web_util.echo_json_response(self, 201)
- logger.info("PUT returning 201")
+ web_util.echo_json_response(self, 201)
+ logger.info("PUT returning 201")
def data_received(self, chunk: Any) -> None:
raise NotImplementedError()
@@ -1113,8 +1128,6 @@ class VerifyIdentityHandler(BaseHandler):
This is useful for 3rd party tools and integrations to independently verify the state of an agent.
"""
- session = get_session()
-
# validate the parameters of our request
if self.request.uri is None:
web_util.echo_json_response(self, 400, "URI not specified")
@@ -1159,36 +1172,37 @@ class VerifyIdentityHandler(BaseHandler):
return
# get the agent information from the DB
- agent = None
- try:
- agent = (
- session.query(VerfierMain)
- .options( # type: ignore
- joinedload(VerfierMain.ima_policy).load_only(
- VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
+ with session_context() as session:
+ agent = None
+ try:
+ agent = (
+ session.query(VerfierMain)
+ .options( # type: ignore
+ joinedload(VerfierMain.ima_policy).load_only(
+ VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
+ )
)
+ .filter_by(agent_id=agent_id)
+ .one_or_none()
)
- .filter_by(agent_id=agent_id)
- .one_or_none()
- )
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
- if agent is not None:
- agentAttestState = get_AgentAttestStates().get_by_agent_id(agent_id)
- failure = cloud_verifier_common.process_verify_identity_quote(
- agent, quote, nonce, hash_alg, agentAttestState
- )
- if failure:
- failure_contexts = "; ".join(x.context for x in failure.events)
- web_util.echo_json_response(self, 200, "Success", {"valid": 0, "reason": failure_contexts})
- logger.info("GET returning 200, but validation failed")
+ if agent is not None:
+ agentAttestState = get_AgentAttestStates().get_by_agent_id(agent_id)
+ failure = cloud_verifier_common.process_verify_identity_quote(
+ agent, quote, nonce, hash_alg, agentAttestState
+ )
+ if failure:
+ failure_contexts = "; ".join(x.context for x in failure.events)
+ web_util.echo_json_response(self, 200, "Success", {"valid": 0, "reason": failure_contexts})
+ logger.info("GET returning 200, but validation failed")
+ else:
+ web_util.echo_json_response(self, 200, "Success", {"valid": 1})
+ logger.info("GET returning 200, validation successful")
else:
- web_util.echo_json_response(self, 200, "Success", {"valid": 1})
- logger.info("GET returning 200, validation successful")
- else:
- web_util.echo_json_response(self, 404, "agent id not found")
- logger.info("GET returning 404, agaent not found")
+ web_util.echo_json_response(self, 404, "agent id not found")
+ logger.info("GET returning 404, agaent not found")
def data_received(self, chunk: Any) -> None:
raise NotImplementedError()
@@ -1231,35 +1245,35 @@ class MbpolicyHandler(BaseHandler):
if not params_valid:
return
- session = get_session()
- if mb_policy_name is None:
- try:
- names_mbpolicies = session.query(VerifierMbpolicy.name).all()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- web_util.echo_json_response(self, 500, "Failed to get names of mbpolicies")
- raise
+ with session_context() as session:
+ if mb_policy_name is None:
+ try:
+ names_mbpolicies = session.query(VerifierMbpolicy.name).all()
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ web_util.echo_json_response(self, 500, "Failed to get names of mbpolicies")
+ raise
- names_response = []
- for name in names_mbpolicies:
- names_response.append(name[0])
- web_util.echo_json_response(self, 200, "Success", {"mbpolicy names": names_response})
+ names_response = []
+ for name in names_mbpolicies:
+ names_response.append(name[0])
+ web_util.echo_json_response(self, 200, "Success", {"mbpolicy names": names_response})
- else:
- try:
- mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one()
- except NoResultFound:
- web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found")
- return
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- web_util.echo_json_response(self, 500, "Failed to get mb_policy")
- raise
+ else:
+ try:
+ mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one()
+ except NoResultFound:
+ web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found")
+ return
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ web_util.echo_json_response(self, 500, "Failed to get mb_policy")
+ raise
- response = {}
- response["name"] = getattr(mbpolicy, "name", None)
- response["mb_policy"] = getattr(mbpolicy, "mb_policy", None)
- web_util.echo_json_response(self, 200, "Success", response)
+ response = {}
+ response["name"] = getattr(mbpolicy, "name", None)
+ response["mb_policy"] = getattr(mbpolicy, "mb_policy", None)
+ web_util.echo_json_response(self, 200, "Success", response)
def delete(self) -> None:
"""Delete a mb_policy
@@ -1271,45 +1285,44 @@ class MbpolicyHandler(BaseHandler):
if not params_valid or mb_policy_name is None:
return
- session = get_session()
- try:
- mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one()
- except NoResultFound:
- web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found")
- return
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- web_util.echo_json_response(self, 500, "Failed to get mb_policy")
- raise
+ with session_context() as session:
+ try:
+ mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one()
+ except NoResultFound:
+ web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found")
+ return
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ web_util.echo_json_response(self, 500, "Failed to get mb_policy")
+ raise
- try:
- agent = session.query(VerfierMain).filter_by(mb_policy_id=mbpolicy.id).one_or_none()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- raise
- if agent is not None:
- web_util.echo_json_response(
- self,
- 409,
- f"Can't delete mb_policy as it's currently in use by agent {agent.agent_id}",
- )
- return
+ try:
+ agent = session.query(VerfierMain).filter_by(mb_policy_id=mbpolicy.id).one_or_none()
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ raise
+ if agent is not None:
+ web_util.echo_json_response(
+ self,
+ 409,
+ f"Can't delete mb_policy as it's currently in use by agent {agent.agent_id}",
+ )
+ return
- try:
- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).delete()
- session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- session.close()
- web_util.echo_json_response(self, 500, f"Database error: {e}")
- raise
+ try:
+ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).delete()
+ # session.commit() is automatically called by context manager
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ web_util.echo_json_response(self, 500, f"Database error: {e}")
+ raise
- # NOTE(kaifeng) 204 Can not have response body, but current helper
- # doesn't support this case.
- self.set_status(204)
- self.set_header("Content-Type", "application/json")
- self.finish()
- logger.info("DELETE returning 204 response for mb_policy: %s", mb_policy_name)
+ # NOTE(kaifeng) 204 Can not have response body, but current helper
+ # doesn't support this case.
+ self.set_status(204)
+ self.set_header("Content-Type", "application/json")
+ self.finish()
+ logger.info("DELETE returning 204 response for mb_policy: %s", mb_policy_name)
def __get_mb_policy_db_format(self, mb_policy_name: str) -> Dict[str, Any]:
"""Get the measured boot policy from the request and return it in Db format"""
@@ -1341,30 +1354,30 @@ class MbpolicyHandler(BaseHandler):
if not mb_policy_db_format:
return
- session = get_session()
- # don't allow overwritting
- try:
- mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count()
- if mbpolicy_count > 0:
- web_util.echo_json_response(
- self, 409, f"Measured boot policy with name {mb_policy_name} already exists"
- )
- logger.warning("Measured boot policy with name %s already exists", mb_policy_name)
- return
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- raise
+ with session_context() as session:
+ # don't allow overwritting
+ try:
+ mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count()
+ if mbpolicy_count > 0:
+ web_util.echo_json_response(
+ self, 409, f"Measured boot policy with name {mb_policy_name} already exists"
+ )
+ logger.warning("Measured boot policy with name %s already exists", mb_policy_name)
+ return
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ raise
- try:
- # Add the data
- session.add(VerifierMbpolicy(**mb_policy_db_format))
- session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- raise
+ try:
+ # Add the data
+ session.add(VerifierMbpolicy(**mb_policy_db_format))
+ # session.commit() is automatically called by context manager
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ raise
- web_util.echo_json_response(self, 201)
- logger.info("POST returning 201")
+ web_util.echo_json_response(self, 201)
+ logger.info("POST returning 201")
def put(self) -> None:
"""Update an mb_policy
@@ -1381,32 +1394,32 @@ class MbpolicyHandler(BaseHandler):
if not mb_policy_db_format:
return
- session = get_session()
- # don't allow creating a new policy
- try:
- mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count()
- if mbpolicy_count != 1:
- web_util.echo_json_response(
- self, 409, f"Measured boot policy with name {mb_policy_name} does not already exist"
- )
- logger.warning("Measured boot policy with name %s does not already exist", mb_policy_name)
- return
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- raise
+ with session_context() as session:
+ # don't allow creating a new policy
+ try:
+ mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count()
+ if mbpolicy_count != 1:
+ web_util.echo_json_response(
+ self, 409, f"Measured boot policy with name {mb_policy_name} does not already exist"
+ )
+ logger.warning("Measured boot policy with name %s does not already exist", mb_policy_name)
+ return
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ raise
- try:
- # Update the named mb_policy
- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).update(
- mb_policy_db_format # pyright: ignore
- )
- session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
- raise
+ try:
+ # Update the named mb_policy
+ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).update(
+ mb_policy_db_format # pyright: ignore
+ )
+ # session.commit() is automatically called by context manager
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
+ raise
- web_util.echo_json_response(self, 201)
- logger.info("PUT returning 201")
+ web_util.echo_json_response(self, 201)
+ logger.info("PUT returning 201")
def data_received(self, chunk: Any) -> None:
raise NotImplementedError()
@@ -1460,17 +1473,18 @@ async def update_agent_api_version(agent: Dict[str, Any], timeout: float = 60.0)
return None
logger.info("Agent %s new API version %s is supported", agent_id, new_version)
- session = get_session()
- agent["supported_version"] = new_version
- # Remove keys that should not go to the DB
- agent_db = dict(agent)
- for key in exclude_db:
- if key in agent_db:
- del agent_db[key]
+ with session_context() as session:
+ agent["supported_version"] = new_version
- session.query(VerfierMain).filter_by(agent_id=agent_id).update(agent_db) # pyright: ignore
- session.commit()
+ # Remove keys that should not go to the DB
+ agent_db = dict(agent)
+ for key in exclude_db:
+ if key in agent_db:
+ del agent_db[key]
+
+ session.query(VerfierMain).filter_by(agent_id=agent_id).update(agent_db) # pyright: ignore
+ # session.commit() is automatically called by context manager
else:
logger.warning("Agent %s new API version %s is not supported", agent_id, new_version)
return None
@@ -1718,50 +1732,68 @@ async def notify_error(
revocation_notifier.notify(tosend)
if "agent" in notifiers:
verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID)
- session = get_session()
- agents = session.query(VerfierMain).filter_by(verifier_id=verifier_id).all()
- futures = []
- loop = asyncio.get_event_loop()
- # Notify all agents asynchronously through a thread pool
- with ThreadPoolExecutor() as pool:
- for agent_db_obj in agents:
- if agent_db_obj.agent_id != agent["agent_id"]:
- agent = _from_db_obj(agent_db_obj)
- if agent["mtls_cert"] and agent["mtls_cert"] != "disabled":
- agent["ssl_context"] = web_util.generate_agent_tls_context(
- "verifier", agent["mtls_cert"], logger=logger
- )
- func = functools.partial(invoke_notify_error, agent, tosend, timeout=timeout)
- futures.append(await loop.run_in_executor(pool, func))
- # Wait for all tasks complete in 60 seconds
- try:
- for f in asyncio.as_completed(futures, timeout=60):
- await f
- except asyncio.TimeoutError as e:
- logger.error("Timeout during notifying error to agents: %s", e)
+ with session_context() as session:
+ agents = session.query(VerfierMain).filter_by(verifier_id=verifier_id).all()
+ futures = []
+ loop = asyncio.get_event_loop()
+ # Notify all agents asynchronously through a thread pool
+ with ThreadPoolExecutor() as pool:
+ for agent_db_obj in agents:
+ if agent_db_obj.agent_id != agent["agent_id"]:
+ agent = _from_db_obj(agent_db_obj)
+ if agent["mtls_cert"] and agent["mtls_cert"] != "disabled":
+ agent["ssl_context"] = web_util.generate_agent_tls_context(
+ "verifier", agent["mtls_cert"], logger=logger
+ )
+ func = functools.partial(invoke_notify_error, agent, tosend, timeout=timeout)
+ futures.append(await loop.run_in_executor(pool, func))
+ # Wait for all tasks complete in 60 seconds
+ try:
+ for f in asyncio.as_completed(futures, timeout=60):
+ await f
+ except asyncio.TimeoutError as e:
+ logger.error("Timeout during notifying error to agents: %s", e)
async def process_agent(
agent: Dict[str, Any], new_operational_state: int, failure: Failure = Failure(Component.INTERNAL, ["verifier"])
) -> None:
- session = get_session()
try: # pylint: disable=R1702
main_agent_operational_state = agent["operational_state"]
stored_agent = None
- try:
- stored_agent = (
- session.query(VerfierMain)
- .options( # type: ignore
- joinedload(VerfierMain.ima_policy).load_only(VerifierAllowlist.checksum) # pyright: ignore
- )
- .options( # type: ignore
- joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore
+
+ # First database operation - read agent data and extract all needed data within session context
+ ima_policy_data = {}
+ mb_policy_data = None
+ with session_context() as session:
+ try:
+ stored_agent = (
+ session.query(VerfierMain)
+ .options( # type: ignore
+ joinedload(VerfierMain.ima_policy) # Load full IMA policy object including content
+ )
+ .options( # type: ignore
+ joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore
+ )
+ .filter_by(agent_id=str(agent["agent_id"]))
+ .first()
)
- .filter_by(agent_id=str(agent["agent_id"]))
- .first()
- )
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e)
+
+ # Extract IMA policy data within session context to avoid DetachedInstanceError
+ if stored_agent and stored_agent.ima_policy:
+ ima_policy_data = {
+ "checksum": str(stored_agent.ima_policy.checksum),
+ "name": stored_agent.ima_policy.name,
+ "agent_id": str(stored_agent.agent_id),
+ "ima_policy": stored_agent.ima_policy.ima_policy, # Extract the large content too
+ }
+
+ # Extract MB policy data within session context
+ if stored_agent and stored_agent.mb_policy:
+ mb_policy_data = stored_agent.mb_policy.mb_policy
+
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e)
# if the stored agent could not be recovered from the database, stop polling
if not stored_agent:
@@ -1775,7 +1807,10 @@ async def process_agent(
logger.warning("Agent %s terminated by user.", agent["agent_id"])
if agent["pending_event"] is not None:
tornado.ioloop.IOLoop.current().remove_timeout(agent["pending_event"])
- verifier_db_delete_agent(session, agent["agent_id"])
+
+ # Second database operation - delete agent
+ with session_context() as session:
+ verifier_db_delete_agent(session, agent["agent_id"])
return
# if the user tells us to stop polling because the tenant quote check failed
@@ -1808,11 +1843,16 @@ async def process_agent(
if not failure.recoverable or failure.highest_severity == MAX_SEVERITY_LABEL:
if agent["pending_event"] is not None:
tornado.ioloop.IOLoop.current().remove_timeout(agent["pending_event"])
- for key in exclude_db:
- if key in agent:
- del agent[key]
- session.query(VerfierMain).filter_by(agent_id=agent["agent_id"]).update(agent) # pyright: ignore
- session.commit()
+
+ # Third database operation - update agent with failure state
+ with session_context() as session:
+ for key in exclude_db:
+ if key in agent:
+ del agent[key]
+ session.query(VerfierMain).filter_by(agent_id=agent["agent_id"]).update(
+ agent # type: ignore[arg-type]
+ )
+ # session.commit() is automatically called by context manager
# propagate all state, but remove none DB keys first (using exclude_db)
try:
@@ -1821,18 +1861,18 @@ async def process_agent(
if key in agent_db:
del agent_db[key]
- session.query(VerfierMain).filter_by(agent_id=agent_db["agent_id"]).update(agent_db) # pyright: ignore
- session.commit()
+ # Fourth database operation - update agent state
+ with session_context() as session:
+ session.query(VerfierMain).filter_by(agent_id=agent_db["agent_id"]).update(agent_db) # pyright: ignore
+ # session.commit() is automatically called by context manager
except SQLAlchemyError as e:
logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e)
# Load agent's IMA policy
- runtime_policy = verifier_read_policy_from_cache(stored_agent)
+ runtime_policy = verifier_read_policy_from_cache(ima_policy_data)
# Get agent's measured boot policy
- mb_policy = None
- if stored_agent.mb_policy is not None:
- mb_policy = stored_agent.mb_policy.mb_policy
+ mb_policy = mb_policy_data
# If agent was in a failed state we check if we either stop polling
# or just add it again to the event loop
@@ -1876,7 +1916,14 @@ async def process_agent(
)
pending = tornado.ioloop.IOLoop.current().call_later(
- interval, invoke_get_quote, agent, mb_policy, runtime_policy, False, timeout=timeout # type: ignore # due to python <3.9
+ # type: ignore # due to python <3.9
+ interval,
+ invoke_get_quote,
+ agent,
+ mb_policy,
+ runtime_policy,
+ False,
+ timeout=timeout,
)
agent["pending_event"] = pending
return
@@ -1911,7 +1958,14 @@ async def process_agent(
next_retry,
)
tornado.ioloop.IOLoop.current().call_later(
- next_retry, invoke_get_quote, agent, mb_policy, runtime_policy, True, timeout=timeout # type: ignore # due to python <3.9
+ # type: ignore # due to python <3.9
+ next_retry,
+ invoke_get_quote,
+ agent,
+ mb_policy,
+ runtime_policy,
+ True,
+ timeout=timeout,
)
return
@@ -1980,9 +2034,9 @@ async def activate_agents(agents: List[VerfierMain], verifier_ip: str, verifier_
def get_agents_by_verifier_id(verifier_id: str) -> List[VerfierMain]:
- session = get_session()
try:
- return session.query(VerfierMain).filter_by(verifier_id=verifier_id).all()
+ with session_context() as session:
+ return session.query(VerfierMain).filter_by(verifier_id=verifier_id).all()
except SQLAlchemyError as e:
logger.error("SQLAlchemy Error: %s", e)
return []
@@ -2007,20 +2061,20 @@ def main() -> None:
os.umask(0o077)
VerfierMain.metadata.create_all(engine, checkfirst=True) # pyright: ignore
- session = get_session()
- try:
- query_all = session.query(VerfierMain).all()
- for row in query_all:
- if row.operational_state in states.APPROVED_REACTIVATE_STATES:
- row.operational_state = states.START # pyright: ignore
- session.commit()
- except SQLAlchemyError as e:
- logger.error("SQLAlchemy Error: %s", e)
+ with session_context() as session:
+ try:
+ query_all = session.query(VerfierMain).all()
+ for row in query_all:
+ if row.operational_state in states.APPROVED_REACTIVATE_STATES:
+ row.operational_state = states.START # pyright: ignore
+ # session.commit() is automatically called by context manager
+ except SQLAlchemyError as e:
+ logger.error("SQLAlchemy Error: %s", e)
- num = session.query(VerfierMain.agent_id).count()
- if num > 0:
- agent_ids = session.query(VerfierMain.agent_id).all()
- logger.info("Agent ids in db loaded from file: %s", agent_ids)
+ num = session.query(VerfierMain.agent_id).count()
+ if num > 0:
+ agent_ids = session.query(VerfierMain.agent_id).all()
+ logger.info("Agent ids in db loaded from file: %s", agent_ids)
logger.info("Starting Cloud Verifier (tornado) on port %s, use <Ctrl-C> to stop", verifier_port)
diff --git a/keylime/da/examples/sqldb.py b/keylime/da/examples/sqldb.py
index 8efc84e..04a8afb 100644
--- a/keylime/da/examples/sqldb.py
+++ b/keylime/da/examples/sqldb.py
@@ -1,7 +1,10 @@
import time
+from contextlib import contextmanager
+from typing import Iterator
import sqlalchemy
import sqlalchemy.ext.declarative
+from sqlalchemy.orm import sessionmaker
from keylime import keylime_logging
from keylime.da.record import BaseRecordManagement, base_build_key_list
@@ -45,23 +48,23 @@ class RecordManagement(BaseRecordManagement):
BaseRecordManagement.__init__(self, service)
self.engine = sqlalchemy.create_engine(self.ps_url._replace(fragment="").geturl(), pool_recycle=1800)
- sm = sqlalchemy.orm.sessionmaker()
- self.session = sqlalchemy.orm.scoped_session(sm)
- self.session.configure(bind=self.engine)
- TableBase.metadata.create_all(self.engine)
-
- def agent_list_retrieval(self, record_prefix="auto", service="auto"):
- if record_prefix == "auto":
- record_prefix = ""
-
- agent_list = []
+ self.SessionLocal = sessionmaker(bind=self.engine)
- recordtype = self.get_record_type(service)
- tbl = type2table(recordtype)
- for agentid in self.session.query(tbl.agentid).distinct(): # pylint: disable=no-member
- agent_list.append(agentid[0])
+ # Create tables if they don't exist
+ TableBase.metadata.create_all(self.engine)
- return agent_list
+ @contextmanager
+ def session_context(self) -> Iterator:
+ """Context manager for database sessions that ensures proper cleanup."""
+ session = self.SessionLocal()
+ try:
+ yield session
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ session.close()
def record_create(
self,
@@ -84,8 +87,9 @@ class RecordManagement(BaseRecordManagement):
d = {"time": recordtime, "agentid": agentid, "record": rcrd}
try:
- self.session.add((type2table(recordtype))(**d)) # pylint: disable=no-member
- self.session.commit() # pylint: disable=no-member
+ with self.session_context() as session:
+ session.add((type2table(recordtype))(**d))
+ # session.commit() is automatically called by context manager
except Exception as e:
logger.error("Failed to create attestation record: %s", e)
@@ -106,23 +110,23 @@ class RecordManagement(BaseRecordManagement):
if f"{end_date}" == "auto":
end_date = self.end_of_times
- if self.only_last_record_wanted(start_date, end_date):
- attestion_record_rows = (
- self.session.query(tbl) # pylint: disable=no-member
- .filter(tbl.agentid == record_identifier)
- .order_by(sqlalchemy.desc(tbl.time))
- .limit(1)
- )
-
- else:
- attestion_record_rows = self.session.query(tbl).filter( # pylint: disable=no-member
- tbl.agentid == record_identifier
- )
-
- for row in attestion_record_rows:
- decoded_record_object = self.record_deserialize(row.record)
- self.record_signature_check(decoded_record_object, record_identifier)
- record_list.append(decoded_record_object)
+ with self.session_context() as session:
+ if self.only_last_record_wanted(start_date, end_date):
+ attestion_record_rows = (
+ session.query(tbl)
+ .filter(tbl.agentid == record_identifier)
+ .order_by(sqlalchemy.desc(tbl.time))
+ .limit(1)
+ )
+
+ else:
+ attestion_record_rows = session.query(tbl).filter(tbl.agentid == record_identifier)
+
+ for row in attestion_record_rows:
+ decoded_record_object = self.record_deserialize(row.record)
+ self.record_signature_check(decoded_record_object, record_identifier)
+ record_list.append(decoded_record_object)
+
return record_list
def build_key_list(self, agent_identifier, service="auto"):
diff --git a/keylime/db/keylime_db.py b/keylime/db/keylime_db.py
index 5620a28..aa49e51 100644
--- a/keylime/db/keylime_db.py
+++ b/keylime/db/keylime_db.py
@@ -1,7 +1,8 @@
import os
from configparser import NoOptionError
+from contextlib import contextmanager
from sqlite3 import Connection as SQLite3Connection
-from typing import Any, Dict, Optional, cast
+from typing import Any, Iterator, Optional, cast
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
@@ -22,90 +23,108 @@ def _set_sqlite_pragma(dbapi_connection: SQLite3Connection, _) -> None:
cursor.close()
-class DBEngineManager:
- service: Optional[str]
-
- def __init__(self) -> None:
- self.service = None
-
- def make_engine(self, service: str) -> Engine:
- """
- To use: engine = self.make_engine('cloud_verifier')
- """
-
- # Keep DB related stuff as it is, but read configuration from new
- # configs
- if service == "cloud_verifier":
- config_service = "verifier"
- else:
- config_service = service
-
- self.service = service
-
- try:
- p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl")
- p_sz, m_ovfl = p_sz_m_ovfl.split(",")
- except NoOptionError:
- p_sz = "5"
- m_ovfl = "10"
-
- engine_args: Dict[str, Any] = {}
-
- url = config.get(config_service, "database_url")
- if url:
- logger.info("database_url is set, using it to establish database connection")
-
- # If the keyword sqlite is provided as the database url, use the
- # cv_data.sqlite for the verifier or the file reg_data.sqlite for
- # the registrar, located at the config.WORK_DIR directory
- if url == "sqlite":
+def make_engine(service: str, **engine_args: Any) -> Engine:
+ """Create a database engine for a keylime service."""
+ # Keep DB related stuff as it is, but read configuration from new
+ # configs
+ if service == "cloud_verifier":
+ config_service = "verifier"
+ else:
+ config_service = service
+
+ url = config.get(config_service, "database_url")
+ if url:
+ logger.info("database_url is set, using it to establish database connection")
+
+ # If the keyword sqlite is provided as the database url, use the
+ # cv_data.sqlite for the verifier or the file reg_data.sqlite for
+ # the registrar, located at the config.WORK_DIR directory
+ if url == "sqlite":
+ logger.info(
+ "database_url is set as 'sqlite' keyword, using default values to establish database connection"
+ )
+ if service == "cloud_verifier":
+ database = "cv_data.sqlite"
+ elif service == "registrar":
+ database = "reg_data.sqlite"
+ else:
+ logger.error("Tried to setup database access for unknown service '%s'", service)
+ raise Exception(f"Unknown service '{service}' for database setup")
+
+ database_file = os.path.abspath(os.path.join(config.WORK_DIR, database))
+ url = f"sqlite:///{database_file}"
+
+ kl_dir = os.path.dirname(os.path.abspath(database_file))
+ if not os.path.exists(kl_dir):
+ os.makedirs(kl_dir, 0o700)
+
+ engine_args["connect_args"] = {"check_same_thread": False}
+
+ if not url.count("sqlite:"):
+ # sqlite does not support setting pool size and max overflow, only
+ # read from the config when it is going to be used
+ try:
+ p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl")
+ p_sz, m_ovfl = p_sz_m_ovfl.split(",")
+ logger.info("database_pool_sz_ovfl is set, pool size = %s, max overflow = %s", p_sz, m_ovfl)
+ except NoOptionError:
+ p_sz = "5"
+ m_ovfl = "10"
logger.info(
- "database_url is set as 'sqlite' keyword, using default values to establish database connection"
+ "database_pool_sz_ovfl is not set, using default pool size = %s, max overflow = %s", p_sz, m_ovfl
)
- if service == "cloud_verifier":
- database = "cv_data.sqlite"
- elif service == "registrar":
- database = "reg_data.sqlite"
- else:
- logger.error("Tried to setup database access for unknown service '%s'", service)
- raise Exception(f"Unknown service '{service}' for database setup")
-
- database_file = os.path.abspath(os.path.join(config.WORK_DIR, database))
- url = f"sqlite:///{database_file}"
-
- kl_dir = os.path.dirname(os.path.abspath(database_file))
- if not os.path.exists(kl_dir):
- os.makedirs(kl_dir, 0o700)
-
- engine_args["connect_args"] = {"check_same_thread": False}
- if not url.count("sqlite:"):
- engine_args["pool_size"] = int(p_sz)
- engine_args["max_overflow"] = int(m_ovfl)
- engine_args["pool_pre_ping"] = True
+ engine_args["pool_size"] = int(p_sz)
+ engine_args["max_overflow"] = int(m_ovfl)
+ engine_args["pool_pre_ping"] = True
- # Enable DB debugging
- if config.DEBUG_DB and config.INSECURE_DEBUG:
- engine_args["echo"] = True
+ # Enable DB debugging
+ if config.DEBUG_DB and config.INSECURE_DEBUG:
+ engine_args["echo"] = True
- engine = create_engine(url, **engine_args)
- return engine
+ engine = create_engine(url, **engine_args)
+ return engine
class SessionManager:
engine: Optional[Engine]
+ _scoped_session: Optional[scoped_session]
def __init__(self) -> None:
self.engine = None
+ self._scoped_session = None
def make_session(self, engine: Engine) -> Session:
"""
To use: session = self.make_session(engine)
"""
self.engine = engine
- my_session = scoped_session(sessionmaker())
+ if self._scoped_session is None:
+ self._scoped_session = scoped_session(sessionmaker())
try:
- my_session.configure(bind=self.engine) # type: ignore
+ self._scoped_session.configure(bind=self.engine) # type: ignore
+ self._scoped_session.configure(expire_on_commit=False) # type: ignore
except SQLAlchemyError as err:
logger.error("Error creating SQL session manager %s", err)
- return cast(Session, my_session())
+ return cast(Session, self._scoped_session())
+
+ @contextmanager
+ def session_context(self, engine: Engine) -> Iterator[Session]:
+ """
+ Context manager for database sessions that ensures proper cleanup.
+ To use:
+ with session_manager.session_context(engine) as session:
+ # use session
+ """
+ session = self.make_session(engine)
+ try:
+ yield session
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+ finally:
+ # Important: remove the session from the scoped session registry
+ # to prevent connection leaks with scoped_session
+ if self._scoped_session is not None:
+ self._scoped_session.remove() # type: ignore[no-untyped-call]
diff --git a/keylime/migrations/env.py b/keylime/migrations/env.py
index ac98349..a1881f2 100644
--- a/keylime/migrations/env.py
+++ b/keylime/migrations/env.py
@@ -8,7 +8,7 @@ import sys
from alembic import context
-from keylime.db.keylime_db import DBEngineManager
+from keylime.db.keylime_db import make_engine
from keylime.db.registrar_db import Base as RegistrarBase
from keylime.db.verifier_db import Base as VerifierBase
@@ -74,7 +74,7 @@ def run_migrations_offline():
logger.info("Writing output to %s", file_)
with open(file_, "w", encoding="utf-8") as buffer:
- engine = DBEngineManager().make_engine(name)
+ engine = make_engine(name)
connection = engine.connect()
context.configure(
connection=connection,
@@ -102,7 +102,7 @@ def run_migrations_online():
engines = {}
for name in re.split(r",\s*", db_names):
engines[name] = rec = {}
- rec["engine"] = DBEngineManager().make_engine(name)
+ rec["engine"] = make_engine(name)
for name, rec in engines.items():
engine = rec["engine"]
diff --git a/keylime/models/base/db.py b/keylime/models/base/db.py
index dd47d63..0229765 100644
--- a/keylime/models/base/db.py
+++ b/keylime/models/base/db.py
@@ -41,13 +41,6 @@ class DBManager:
self._service = service
- try:
- p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl")
- p_sz, m_ovfl = p_sz_m_ovfl.split(",")
- except NoOptionError:
- p_sz = "5"
- m_ovfl = "10"
-
engine_args: Dict[str, Any] = {}
url = config.get(config_service, "database_url")
@@ -79,6 +72,21 @@ class DBManager:
engine_args["connect_args"] = {"check_same_thread": False}
if not url.count("sqlite:"):
+ # sqlite does not support setting pool size and max overflow, only
+ # read from the config when it is going to be used
+ try:
+ p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl")
+ p_sz, m_ovfl = p_sz_m_ovfl.split(",")
+ logger.info("database_pool_sz_ovfl is set, pool size = %s, max overflow = %s", p_sz, m_ovfl)
+ except NoOptionError:
+ p_sz = "5"
+ m_ovfl = "10"
+ logger.info(
+ "database_pool_sz_ovfl is not set, using default pool size = %s, max overflow = %s",
+ p_sz,
+ m_ovfl,
+ )
+
engine_args["pool_size"] = int(p_sz)
engine_args["max_overflow"] = int(m_ovfl)
engine_args["pool_pre_ping"] = True
diff --git a/keylime/models/base/persistable_model.py b/keylime/models/base/persistable_model.py
index 18f7d0d..a779f0b 100644
--- a/keylime/models/base/persistable_model.py
+++ b/keylime/models/base/persistable_model.py
@@ -207,10 +207,16 @@ class PersistableModel(BasicModel, metaclass=PersistableModelMeta):
setattr(self._db_mapping_inst, name, field.data_type.db_dump(value, db_manager.engine.dialect))
with db_manager.session_context() as session:
- session.add(self._db_mapping_inst)
+ # Merge the potentially detached object into the new session
+ merged_instance = session.merge(self._db_mapping_inst)
+ session.add(merged_instance)
+ # Update our reference to the merged instance
+ self._db_mapping_inst = merged_instance # pylint: disable=attribute-defined-outside-init
self.clear_changes()
def delete(self) -> None:
with db_manager.session_context() as session:
- session.delete(self._db_mapping_inst) # type: ignore[no-untyped-call]
+ # Merge the potentially detached object into the new session before deleting
+ merged_instance = session.merge(self._db_mapping_inst)
+ session.delete(merged_instance) # type: ignore[no-untyped-call]
diff --git a/packit-ci.fmf b/packit-ci.fmf
index 2d1e5e5..cb64faf 100644
--- a/packit-ci.fmf
+++ b/packit-ci.fmf
@@ -101,6 +101,7 @@ adjust:
- /regression/CVE-2023-3674
- /regression/issue-1380-agent-removed-and-re-added
- /regression/keylime-agent-option-override-through-envvar
+ - /regression/db-connection-leak-reproducer
- /sanity/keylime-secure_mount
- /sanity/opened-conf-files
- /upstream/run_keylime_tests
diff --git a/test/test_verifier_db.py b/test/test_verifier_db.py
index ad72fa6..aae8f8a 100644
--- a/test/test_verifier_db.py
+++ b/test/test_verifier_db.py
@@ -172,3 +172,102 @@ class TestVerfierDB(unittest.TestCase):
def tearDown(self):
self.session.close()
+
+ def test_11_relationship_access_after_session_commit(self):
+ """Test that relationships can be accessed after session commits (DetachedInstanceError fix)"""
+ # This test reproduces the problematic pattern from cloud_verifier_tornado.py
+ # where objects are loaded with joinedload and then accessed after session closes
+
+ # Create a new session manager and context (like in cloud_verifier_tornado.py)
+ session_manager = SessionManager()
+
+ # First, load the agent with eager loading for relationships
+ stored_agent = None
+ with session_manager.session_context(self.engine) as session:
+ stored_agent = (
+ session.query(VerfierMain)
+ .options(joinedload(VerfierMain.ima_policy))
+ .options(joinedload(VerfierMain.mb_policy))
+ .filter_by(agent_id=agent_id)
+ .first()
+ )
+ # Verify agent was loaded correctly
+ self.assertIsNotNone(stored_agent)
+ # session.commit() is automatically called by context manager when exiting
+
+ # Now verify we can access relationships AFTER the session has been closed
+ # This would previously trigger DetachedInstanceError
+
+ # Ensure stored_agent is not None before proceeding
+ assert stored_agent is not None
+
+ # Test accessing ima_policy relationship
+ self.assertIsNotNone(stored_agent.ima_policy)
+ assert stored_agent.ima_policy is not None # Type narrowing for linter
+ self.assertEqual(stored_agent.ima_policy.name, "test-allowlist")
+ # checksum is not set in test data
+ self.assertEqual(stored_agent.ima_policy.checksum, None)
+
+ # Test accessing the ima_policy.ima_policy attribute (similar to verifier_read_policy_from_cache)
+ ima_policy_content = stored_agent.ima_policy.ima_policy
+ self.assertEqual(ima_policy_content, test_allowlist_data["ima_policy"])
+
+ # Test accessing mb_policy relationship
+ self.assertIsNotNone(stored_agent.mb_policy)
+ assert stored_agent.mb_policy is not None # Type narrowing for linter
+ self.assertEqual(stored_agent.mb_policy.name, "test-mbpolicy")
+
+ # Test accessing the mb_policy.mb_policy attribute (similar to process_agent function)
+ mb_policy_content = stored_agent.mb_policy.mb_policy
+ self.assertEqual(mb_policy_content, test_mbpolicy_data["mb_policy"])
+
+ # Test that we can access these relationships multiple times without issues
+ for _ in range(3):
+ self.assertIsNotNone(stored_agent.ima_policy.ima_policy)
+ self.assertIsNotNone(stored_agent.mb_policy.mb_policy)
+
+ def test_12_persistable_model_cross_session_fix(self):
+ """Test that PersistableModel can handle cross-session operations safely"""
+ # This test would previously fail with DetachedInstanceError before the fix
+ # Note: This is a conceptual test since we don't have actual PersistableModel
+ # subclasses in the test environment, but demonstrates the pattern
+
+ # Simulate creating a SQLAlchemy object in one session
+ session_manager = SessionManager()
+
+ # Load an object in one session context
+ test_agent = None
+ with session_manager.session_context(self.engine) as session:
+ test_agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first()
+ self.assertIsNotNone(test_agent)
+ # Session closes here
+
+ # Ensure test_agent is not None before proceeding
+ assert test_agent is not None
+
+ # Now simulate using this object in a different session context
+ # This tests the pattern where PersistableModel would use session.add() or session.delete()
+ # on a cross-session object
+ with session_manager.session_context(self.engine) as session:
+ # Before the fix, this would cause DetachedInstanceError
+ # The fix uses session.merge() to handle detached objects safely
+ merged_agent = session.merge(test_agent)
+ assert merged_agent is not None # Type narrowing for linter
+
+ # Test that we can modify and save the merged object
+ original_port = merged_agent.port
+ # Use setattr to avoid linter issues with Column assignment
+ setattr(merged_agent, "port", 9999)
+ session.add(merged_agent)
+ # session.commit() called automatically by context manager
+
+ # Verify the change was persisted
+ with session_manager.session_context(self.engine) as session:
+ updated_agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first()
+ assert updated_agent is not None # Type narrowing for linter
+ self.assertEqual(updated_agent.port, 9999)
+
+ # Restore original value
+ # Use setattr to avoid linter issues
+ setattr(updated_agent, "port", original_port)
+ session.add(updated_agent)