2209 lines
106 KiB
Diff
2209 lines
106 KiB
Diff
diff --git a/keylime/cloud_verifier_tornado.py b/keylime/cloud_verifier_tornado.py
|
|
index 8ab81d1..7553ac8 100644
|
|
--- a/keylime/cloud_verifier_tornado.py
|
|
+++ b/keylime/cloud_verifier_tornado.py
|
|
@@ -7,7 +7,8 @@ import sys
|
|
import traceback
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from multiprocessing import Process
|
|
-from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
|
+from contextlib import contextmanager
|
|
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
|
|
|
|
import tornado.httpserver
|
|
import tornado.ioloop
|
|
@@ -34,7 +35,7 @@ from keylime.agentstates import AgentAttestState, AgentAttestStates
|
|
from keylime.common import retry, states, validators
|
|
from keylime.common.version import str_to_version
|
|
from keylime.da import record
|
|
-from keylime.db.keylime_db import DBEngineManager, SessionManager
|
|
+from keylime.db.keylime_db import SessionManager, make_engine
|
|
from keylime.db.verifier_db import VerfierMain, VerifierAllowlist, VerifierMbpolicy
|
|
from keylime.failure import MAX_SEVERITY_LABEL, Component, Event, Failure, set_severity_config
|
|
from keylime.ima import ima
|
|
@@ -47,7 +48,7 @@ GLOBAL_POLICY_CACHE: Dict[str, Dict[str, str]] = {}
|
|
set_severity_config(config.getlist("verifier", "severity_labels"), config.getlist("verifier", "severity_policy"))
|
|
|
|
try:
|
|
- engine = DBEngineManager().make_engine("cloud_verifier")
|
|
+ engine = make_engine("cloud_verifier")
|
|
except SQLAlchemyError as err:
|
|
logger.error("Error creating SQL engine or session: %s", err)
|
|
sys.exit(1)
|
|
@@ -61,8 +62,17 @@ except record.RecordManagementException as rme:
|
|
sys.exit(1)
|
|
|
|
|
|
-def get_session() -> Session:
|
|
- return SessionManager().make_session(engine)
|
|
+@contextmanager
|
|
+def session_context() -> Iterator[Session]:
|
|
+ """
|
|
+ Context manager for database sessions that ensures proper cleanup.
|
|
+ To use:
|
|
+ with session_context() as session:
|
|
+ # use session
|
|
+ """
|
|
+ session_manager = SessionManager()
|
|
+ with session_manager.session_context(engine) as session:
|
|
+ yield session
|
|
|
|
|
|
def get_AgentAttestStates() -> AgentAttestStates:
|
|
@@ -130,19 +140,18 @@ def _from_db_obj(agent_db_obj: VerfierMain) -> Dict[str, Any]:
|
|
return agent_dict
|
|
|
|
|
|
-def verifier_read_policy_from_cache(stored_agent: VerfierMain) -> str:
|
|
- checksum = ""
|
|
- name = "empty"
|
|
- agent_id = str(stored_agent.agent_id)
|
|
+def verifier_read_policy_from_cache(ima_policy_data: Dict[str, str]) -> str:
|
|
+ checksum = ima_policy_data.get("checksum", "")
|
|
+ name = ima_policy_data.get("name", "empty")
|
|
+ agent_id = ima_policy_data.get("agent_id", "")
|
|
+
|
|
+ if not agent_id:
|
|
+ return ""
|
|
|
|
if agent_id not in GLOBAL_POLICY_CACHE:
|
|
GLOBAL_POLICY_CACHE[agent_id] = {}
|
|
GLOBAL_POLICY_CACHE[agent_id][""] = ""
|
|
|
|
- if stored_agent.ima_policy:
|
|
- checksum = str(stored_agent.ima_policy.checksum)
|
|
- name = stored_agent.ima_policy.name
|
|
-
|
|
if checksum not in GLOBAL_POLICY_CACHE[agent_id]:
|
|
if len(GLOBAL_POLICY_CACHE[agent_id]) > 1:
|
|
# Perform a cleanup of the contents, IMA policy checksum changed
|
|
@@ -162,8 +171,9 @@ def verifier_read_policy_from_cache(stored_agent: VerfierMain) -> str:
|
|
checksum,
|
|
agent_id,
|
|
)
|
|
- # Actually contacts the database and load the (large) ima_policy column for "allowlists" table
|
|
- ima_policy = stored_agent.ima_policy.ima_policy
|
|
+
|
|
+ # Get the large ima_policy content - it's already loaded in ima_policy_data
|
|
+ ima_policy = ima_policy_data.get("ima_policy", "")
|
|
assert isinstance(ima_policy, str)
|
|
GLOBAL_POLICY_CACHE[agent_id][checksum] = ima_policy
|
|
|
|
@@ -182,22 +192,19 @@ def store_attestation_state(agentAttestState: AgentAttestState) -> None:
|
|
# Only store if IMA log was evaluated
|
|
if agentAttestState.get_ima_pcrs():
|
|
agent_id = agentAttestState.agent_id
|
|
- session = get_session()
|
|
try:
|
|
- update_agent = session.query(VerfierMain).get(agentAttestState.get_agent_id())
|
|
- assert update_agent
|
|
- update_agent.boottime = agentAttestState.get_boottime()
|
|
- update_agent.next_ima_ml_entry = agentAttestState.get_next_ima_ml_entry()
|
|
- ima_pcrs_dict = agentAttestState.get_ima_pcrs()
|
|
- update_agent.ima_pcrs = list(ima_pcrs_dict.keys())
|
|
- for pcr_num, value in ima_pcrs_dict.items():
|
|
- setattr(update_agent, f"pcr{pcr_num}", value)
|
|
- update_agent.learned_ima_keyrings = agentAttestState.get_ima_keyrings().to_json()
|
|
- try:
|
|
+ with session_context() as session:
|
|
+ update_agent = session.query(VerfierMain).get(agentAttestState.get_agent_id())
|
|
+ assert update_agent
|
|
+ update_agent.boottime = agentAttestState.get_boottime()
|
|
+ update_agent.next_ima_ml_entry = agentAttestState.get_next_ima_ml_entry()
|
|
+ ima_pcrs_dict = agentAttestState.get_ima_pcrs()
|
|
+ update_agent.ima_pcrs = list(ima_pcrs_dict.keys())
|
|
+ for pcr_num, value in ima_pcrs_dict.items():
|
|
+ setattr(update_agent, f"pcr{pcr_num}", value)
|
|
+ update_agent.learned_ima_keyrings = agentAttestState.get_ima_keyrings().to_json()
|
|
session.add(update_agent)
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error on storing attestation state for agent %s: %s", agent_id, e)
|
|
- session.commit()
|
|
+ # session.commit() is automatically called by context manager
|
|
except SQLAlchemyError as e:
|
|
logger.error("SQLAlchemy Error on storing attestation state for agent %s: %s", agent_id, e)
|
|
|
|
@@ -354,45 +361,17 @@ class AgentsHandler(BaseHandler):
|
|
was not found, it either completed successfully, or failed. If found, the agent_id is still polling
|
|
to contact the Cloud Agent.
|
|
"""
|
|
- session = get_session()
|
|
-
|
|
rest_params, agent_id = self.__validate_input("GET")
|
|
if not rest_params:
|
|
return
|
|
|
|
- if (agent_id is not None) and (agent_id != ""):
|
|
- # If the agent ID is not valid (wrong set of characters),
|
|
- # just do nothing.
|
|
- agent = None
|
|
- try:
|
|
- agent = (
|
|
- session.query(VerfierMain)
|
|
- .options( # type: ignore
|
|
- joinedload(VerfierMain.ima_policy).load_only(
|
|
- VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
|
|
- )
|
|
- )
|
|
- .options( # type: ignore
|
|
- joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore
|
|
- )
|
|
- .filter_by(agent_id=agent_id)
|
|
- .one_or_none()
|
|
- )
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
-
|
|
- if agent is not None:
|
|
- response = cloud_verifier_common.process_get_status(agent)
|
|
- web_util.echo_json_response(self, 200, "Success", response)
|
|
- else:
|
|
- web_util.echo_json_response(self, 404, "agent id not found")
|
|
- else:
|
|
- json_response = None
|
|
- if "bulk" in rest_params:
|
|
- agent_list = None
|
|
-
|
|
- if ("verifier" in rest_params) and (rest_params["verifier"] != ""):
|
|
- agent_list = (
|
|
+ with session_context() as session:
|
|
+ if (agent_id is not None) and (agent_id != ""):
|
|
+ # If the agent ID is not valid (wrong set of characters),
|
|
+ # just do nothing.
|
|
+ agent = None
|
|
+ try:
|
|
+ agent = (
|
|
session.query(VerfierMain)
|
|
.options( # type: ignore
|
|
joinedload(VerfierMain.ima_policy).load_only(
|
|
@@ -402,39 +381,70 @@ class AgentsHandler(BaseHandler):
|
|
.options( # type: ignore
|
|
joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore
|
|
)
|
|
- .filter_by(verifier_id=rest_params["verifier"])
|
|
- .all()
|
|
+ .filter_by(agent_id=agent_id)
|
|
+ .one_or_none()
|
|
)
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
+
|
|
+ if agent is not None:
|
|
+ response = cloud_verifier_common.process_get_status(agent)
|
|
+ web_util.echo_json_response(self, 200, "Success", response)
|
|
else:
|
|
- agent_list = (
|
|
- session.query(VerfierMain)
|
|
- .options( # type: ignore
|
|
- joinedload(VerfierMain.ima_policy).load_only(
|
|
- VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
|
|
+ web_util.echo_json_response(self, 404, "agent id not found")
|
|
+ else:
|
|
+ json_response = None
|
|
+ if "bulk" in rest_params:
|
|
+ agent_list = None
|
|
+
|
|
+ if ("verifier" in rest_params) and (rest_params["verifier"] != ""):
|
|
+ agent_list = (
|
|
+ session.query(VerfierMain)
|
|
+ .options( # type: ignore
|
|
+ joinedload(VerfierMain.ima_policy).load_only(
|
|
+ VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
|
|
+ )
|
|
)
|
|
+ .options( # type: ignore
|
|
+ joinedload(VerfierMain.mb_policy).load_only(
|
|
+ VerifierMbpolicy.mb_policy # type: ignore[arg-type]
|
|
+ )
|
|
+ )
|
|
+ .filter_by(verifier_id=rest_params["verifier"])
|
|
+ .all()
|
|
)
|
|
- .options( # type: ignore
|
|
- joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore
|
|
+ else:
|
|
+ agent_list = (
|
|
+ session.query(VerfierMain)
|
|
+ .options( # type: ignore
|
|
+ joinedload(VerfierMain.ima_policy).load_only(
|
|
+ VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
|
|
+ )
|
|
+ )
|
|
+ .options( # type: ignore
|
|
+ joinedload(VerfierMain.mb_policy).load_only(
|
|
+ VerifierMbpolicy.mb_policy # type: ignore[arg-type]
|
|
+ )
|
|
+ )
|
|
+ .all()
|
|
)
|
|
- .all()
|
|
- )
|
|
|
|
- json_response = {}
|
|
- for agent in agent_list:
|
|
- json_response[agent.agent_id] = cloud_verifier_common.process_get_status(agent)
|
|
+ json_response = {}
|
|
+ for agent in agent_list:
|
|
+ json_response[agent.agent_id] = cloud_verifier_common.process_get_status(agent)
|
|
|
|
- web_util.echo_json_response(self, 200, "Success", json_response)
|
|
- else:
|
|
- if ("verifier" in rest_params) and (rest_params["verifier"] != ""):
|
|
- json_response_list = (
|
|
- session.query(VerfierMain.agent_id).filter_by(verifier_id=rest_params["verifier"]).all()
|
|
- )
|
|
+ web_util.echo_json_response(self, 200, "Success", json_response)
|
|
else:
|
|
- json_response_list = session.query(VerfierMain.agent_id).all()
|
|
+ if ("verifier" in rest_params) and (rest_params["verifier"] != ""):
|
|
+ json_response_list = (
|
|
+ session.query(VerfierMain.agent_id).filter_by(verifier_id=rest_params["verifier"]).all()
|
|
+ )
|
|
+ else:
|
|
+ json_response_list = session.query(VerfierMain.agent_id).all()
|
|
|
|
- web_util.echo_json_response(self, 200, "Success", {"uuids": json_response_list})
|
|
+ web_util.echo_json_response(self, 200, "Success", {"uuids": json_response_list})
|
|
|
|
- logger.info("GET returning 200 response for agent_id list")
|
|
+ logger.info("GET returning 200 response for agent_id list")
|
|
|
|
def delete(self) -> None:
|
|
"""This method handles the DELETE requests to remove agents from the Cloud Verifier.
|
|
@@ -442,59 +452,55 @@ class AgentsHandler(BaseHandler):
|
|
Currently, only agents resources are available for DELETEing, i.e. /agents. All other DELETE uri's will return errors.
|
|
agents requests require a single agent_id parameter which identifies the agent to be deleted.
|
|
"""
|
|
- session = get_session()
|
|
-
|
|
rest_params, agent_id = self.__validate_input("DELETE")
|
|
if not rest_params or not agent_id:
|
|
return
|
|
|
|
- agent = None
|
|
- try:
|
|
- agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
+ with session_context() as session:
|
|
+ agent = None
|
|
+ try:
|
|
+ agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first()
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
|
|
- if agent is None:
|
|
- web_util.echo_json_response(self, 404, "agent id not found")
|
|
- logger.info("DELETE returning 404 response. agent id: %s not found.", agent_id)
|
|
- return
|
|
+ if agent is None:
|
|
+ web_util.echo_json_response(self, 404, "agent id not found")
|
|
+ logger.info("DELETE returning 404 response. agent id: %s not found.", agent_id)
|
|
+ return
|
|
|
|
- verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID)
|
|
- if verifier_id != agent.verifier_id:
|
|
- web_util.echo_json_response(self, 404, "agent id associated to this verifier")
|
|
- logger.info("DELETE returning 404 response. agent id: %s not associated to this verifer.", agent_id)
|
|
- return
|
|
+ verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID)
|
|
+ if verifier_id != agent.verifier_id:
|
|
+ web_util.echo_json_response(self, 404, "agent id associated to this verifier")
|
|
+ logger.info("DELETE returning 404 response. agent id: %s not associated to this verifer.", agent_id)
|
|
+ return
|
|
|
|
- # Cleanup the cache when the agent is deleted. Do it early.
|
|
- if agent_id in GLOBAL_POLICY_CACHE:
|
|
- del GLOBAL_POLICY_CACHE[agent_id]
|
|
- logger.debug(
|
|
- "Cleaned up policy cache from all entries used by agent %s",
|
|
- agent_id,
|
|
- )
|
|
+ # Cleanup the cache when the agent is deleted. Do it early.
|
|
+ if agent_id in GLOBAL_POLICY_CACHE:
|
|
+ del GLOBAL_POLICY_CACHE[agent_id]
|
|
+ logger.debug(
|
|
+ "Cleaned up policy cache from all entries used by agent %s",
|
|
+ agent_id,
|
|
+ )
|
|
|
|
- op_state = agent.operational_state
|
|
- if op_state in (states.SAVED, states.FAILED, states.TERMINATED, states.TENANT_FAILED, states.INVALID_QUOTE):
|
|
- try:
|
|
- verifier_db_delete_agent(session, agent_id)
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- web_util.echo_json_response(self, 200, "Success")
|
|
- logger.info("DELETE returning 200 response for agent id: %s", agent_id)
|
|
- else:
|
|
- try:
|
|
- update_agent = session.query(VerfierMain).get(agent_id)
|
|
- assert update_agent
|
|
- update_agent.operational_state = states.TERMINATED
|
|
+ op_state = agent.operational_state
|
|
+ if op_state in (states.SAVED, states.FAILED, states.TERMINATED, states.TENANT_FAILED, states.INVALID_QUOTE):
|
|
try:
|
|
+ verifier_db_delete_agent(session, agent_id)
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ web_util.echo_json_response(self, 200, "Success")
|
|
+ logger.info("DELETE returning 200 response for agent id: %s", agent_id)
|
|
+ else:
|
|
+ try:
|
|
+ update_agent = session.query(VerfierMain).get(agent_id)
|
|
+ assert update_agent
|
|
+ update_agent.operational_state = states.TERMINATED
|
|
session.add(update_agent)
|
|
+ # session.commit() is automatically called by context manager
|
|
+ web_util.echo_json_response(self, 202, "Accepted")
|
|
+ logger.info("DELETE returning 202 response for agent id: %s", agent_id)
|
|
except SQLAlchemyError as e:
|
|
logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
- session.commit()
|
|
- web_util.echo_json_response(self, 202, "Accepted")
|
|
- logger.info("DELETE returning 202 response for agent id: %s", agent_id)
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
|
|
def post(self) -> None:
|
|
"""This method handles the POST requests to add agents to the Cloud Verifier.
|
|
@@ -502,7 +508,6 @@ class AgentsHandler(BaseHandler):
|
|
Currently, only agents resources are available for POSTing, i.e. /agents. All other POST uri's will return errors.
|
|
agents requests require a json block sent in the body
|
|
"""
|
|
- session = get_session()
|
|
# TODO: exception handling needs fixing
|
|
# Maybe handle exceptions with if/else if/else blocks ... simple and avoids nesting
|
|
try: # pylint: disable=too-many-nested-blocks
|
|
@@ -585,201 +590,208 @@ class AgentsHandler(BaseHandler):
|
|
runtime_policy = base64.b64decode(json_body.get("runtime_policy")).decode()
|
|
runtime_policy_stored = None
|
|
|
|
- if runtime_policy_name:
|
|
+ with session_context() as session:
|
|
+ if runtime_policy_name:
|
|
+ try:
|
|
+ runtime_policy_stored = (
|
|
+ session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none()
|
|
+ )
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
+ raise
|
|
+
|
|
+ # Prevent overwriting existing IMA policies with name provided in request
|
|
+ if runtime_policy and runtime_policy_stored:
|
|
+ web_util.echo_json_response(
|
|
+ self,
|
|
+ 409,
|
|
+ f"IMA policy with name {runtime_policy_name} already exists. Please use a different name or delete the allowlist from the verifier.",
|
|
+ )
|
|
+ logger.warning("IMA policy with name %s already exists", runtime_policy_name)
|
|
+ return
|
|
+
|
|
+ # Return an error code if the named allowlist does not exist in the database
|
|
+ if not runtime_policy and not runtime_policy_stored:
|
|
+ web_util.echo_json_response(
|
|
+ self, 404, f"Could not find IMA policy with name {runtime_policy_name}!"
|
|
+ )
|
|
+ logger.warning("Could not find IMA policy with name %s", runtime_policy_name)
|
|
+ return
|
|
+
|
|
+ # Prevent overwriting existing agents with UUID provided in request
|
|
try:
|
|
- runtime_policy_stored = (
|
|
- session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none()
|
|
- )
|
|
+ new_agent_count = session.query(VerfierMain).filter_by(agent_id=agent_id).count()
|
|
except SQLAlchemyError as e:
|
|
logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
- raise
|
|
+ raise e
|
|
|
|
- # Prevent overwriting existing IMA policies with name provided in request
|
|
- if runtime_policy and runtime_policy_stored:
|
|
+ if new_agent_count > 0:
|
|
web_util.echo_json_response(
|
|
self,
|
|
409,
|
|
- f"IMA policy with name {runtime_policy_name} already exists. Please use a different name or delete the allowlist from the verifier.",
|
|
+ f"Agent of uuid {agent_id} already exists. Please use delete or update.",
|
|
)
|
|
- logger.warning("IMA policy with name %s already exists", runtime_policy_name)
|
|
+ logger.warning("Agent of uuid %s already exists", agent_id)
|
|
return
|
|
|
|
- # Return an error code if the named allowlist does not exist in the database
|
|
- if not runtime_policy and not runtime_policy_stored:
|
|
- web_util.echo_json_response(
|
|
- self, 404, f"Could not find IMA policy with name {runtime_policy_name}!"
|
|
- )
|
|
- logger.warning("Could not find IMA policy with name %s", runtime_policy_name)
|
|
- return
|
|
-
|
|
- # Prevent overwriting existing agents with UUID provided in request
|
|
- try:
|
|
- new_agent_count = session.query(VerfierMain).filter_by(agent_id=agent_id).count()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
- raise e
|
|
-
|
|
- if new_agent_count > 0:
|
|
- web_util.echo_json_response(
|
|
- self,
|
|
- 409,
|
|
- f"Agent of uuid {agent_id} already exists. Please use delete or update.",
|
|
- )
|
|
- logger.warning("Agent of uuid %s already exists", agent_id)
|
|
- return
|
|
-
|
|
- # Write IMA policy to database if needed
|
|
- if not runtime_policy_name and not runtime_policy:
|
|
- logger.info("IMA policy data not provided with request! Using default empty IMA policy.")
|
|
- runtime_policy = json.dumps(cast(Dict[str, Any], ima.EMPTY_RUNTIME_POLICY))
|
|
+ # Write IMA policy to database if needed
|
|
+ if not runtime_policy_name and not runtime_policy:
|
|
+ logger.info("IMA policy data not provided with request! Using default empty IMA policy.")
|
|
+ runtime_policy = json.dumps(cast(Dict[str, Any], ima.EMPTY_RUNTIME_POLICY))
|
|
|
|
- if runtime_policy:
|
|
- runtime_policy_key_bytes = signing.get_runtime_policy_keys(
|
|
- runtime_policy.encode(),
|
|
- json_body.get("runtime_policy_key"),
|
|
- )
|
|
-
|
|
- try:
|
|
- ima.verify_runtime_policy(
|
|
+ if runtime_policy:
|
|
+ runtime_policy_key_bytes = signing.get_runtime_policy_keys(
|
|
runtime_policy.encode(),
|
|
- runtime_policy_key_bytes,
|
|
- verify_sig=config.getboolean(
|
|
- "verifier", "require_allow_list_signatures", fallback=False
|
|
- ),
|
|
+ json_body.get("runtime_policy_key"),
|
|
)
|
|
- except ima.ImaValidationError as e:
|
|
- web_util.echo_json_response(self, e.code, e.message)
|
|
- logger.warning(e.message)
|
|
- return
|
|
|
|
- if not runtime_policy_name:
|
|
- runtime_policy_name = agent_id
|
|
-
|
|
- try:
|
|
- runtime_policy_db_format = ima.runtime_policy_db_contents(
|
|
- runtime_policy_name, runtime_policy
|
|
- )
|
|
- except ima.ImaValidationError as e:
|
|
- message = f"Runtime policy is malformatted: {e.message}"
|
|
- web_util.echo_json_response(self, e.code, message)
|
|
- logger.warning(message)
|
|
- return
|
|
-
|
|
- try:
|
|
- runtime_policy_stored = (
|
|
- session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none()
|
|
- )
|
|
- except SQLAlchemyError as e:
|
|
- logger.error(
|
|
- "SQLAlchemy Error while retrieving stored ima policy for agent ID %s: %s", agent_id, e
|
|
- )
|
|
- raise
|
|
- try:
|
|
- if runtime_policy_stored is None:
|
|
- runtime_policy_stored = VerifierAllowlist(**runtime_policy_db_format)
|
|
- session.add(runtime_policy_stored)
|
|
+ try:
|
|
+ ima.verify_runtime_policy(
|
|
+ runtime_policy.encode(),
|
|
+ runtime_policy_key_bytes,
|
|
+ verify_sig=config.getboolean(
|
|
+ "verifier", "require_allow_list_signatures", fallback=False
|
|
+ ),
|
|
+ )
|
|
+ except ima.ImaValidationError as e:
|
|
+ web_util.echo_json_response(self, e.code, e.message)
|
|
+ logger.warning(e.message)
|
|
+ return
|
|
+
|
|
+ if not runtime_policy_name:
|
|
+ runtime_policy_name = agent_id
|
|
+
|
|
+ try:
|
|
+ runtime_policy_db_format = ima.runtime_policy_db_contents(
|
|
+ runtime_policy_name, runtime_policy
|
|
+ )
|
|
+ except ima.ImaValidationError as e:
|
|
+ message = f"Runtime policy is malformatted: {e.message}"
|
|
+ web_util.echo_json_response(self, e.code, message)
|
|
+ logger.warning(message)
|
|
+ return
|
|
+
|
|
+ try:
|
|
+ runtime_policy_stored = (
|
|
+ session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).one_or_none()
|
|
+ )
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error(
|
|
+ "SQLAlchemy Error while retrieving stored ima policy for agent ID %s: %s",
|
|
+ agent_id,
|
|
+ e,
|
|
+ )
|
|
+ raise
|
|
+ try:
|
|
+ if runtime_policy_stored is None:
|
|
+ runtime_policy_stored = VerifierAllowlist(**runtime_policy_db_format)
|
|
+ session.add(runtime_policy_stored)
|
|
+ session.commit()
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error(
|
|
+ "SQLAlchemy Error while updating ima policy for agent ID %s: %s", agent_id, e
|
|
+ )
|
|
+ raise
|
|
+
|
|
+ # Handle measured boot policy
|
|
+ # - No name, mb_policy : store mb_policy using agent UUID as name
|
|
+ # - Name, no mb_policy : fetch existing mb_policy from DB
|
|
+ # - Name, mb_policy : store mb_policy using name
|
|
+
|
|
+ mb_policy_name = json_body["mb_policy_name"]
|
|
+ mb_policy = json_body["mb_policy"]
|
|
+ mb_policy_stored = None
|
|
+
|
|
+ if mb_policy_name:
|
|
+ try:
|
|
+ mb_policy_stored = (
|
|
+ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none()
|
|
+ )
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
+ raise
|
|
+
|
|
+ # Prevent overwriting existing mb_policy with name provided in request
|
|
+ if mb_policy and mb_policy_stored:
|
|
+ web_util.echo_json_response(
|
|
+ self,
|
|
+ 409,
|
|
+ f"mb_policy with name {mb_policy_name} already exists. Please use a different name or delete the mb_policy from the verifier.",
|
|
+ )
|
|
+ logger.warning("mb_policy with name %s already exists", mb_policy_name)
|
|
+ return
|
|
+
|
|
+ # Return error if the mb_policy is neither provided nor stored.
|
|
+ if not mb_policy and not mb_policy_stored:
|
|
+ web_util.echo_json_response(
|
|
+ self, 404, f"Could not find mb_policy with name {mb_policy_name}!"
|
|
+ )
|
|
+ logger.warning("Could not find mb_policy with name %s", mb_policy_name)
|
|
+ return
|
|
+
|
|
+ else:
|
|
+ # Use the UUID of the agent
|
|
+ mb_policy_name = agent_id
|
|
+ try:
|
|
+ mb_policy_stored = (
|
|
+ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none()
|
|
+ )
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
+ raise
|
|
+
|
|
+ # Prevent overwriting existing mb_policy
|
|
+ if mb_policy and mb_policy_stored:
|
|
+ web_util.echo_json_response(
|
|
+ self,
|
|
+ 409,
|
|
+ f"mb_policy with name {mb_policy_name} already exists. You can delete the mb_policy from the verifier.",
|
|
+ )
|
|
+ logger.warning("mb_policy with name %s already exists", mb_policy_name)
|
|
+ return
|
|
+
|
|
+ # Store the policy into database if not stored
|
|
+ if mb_policy_stored is None:
|
|
+ try:
|
|
+ mb_policy_db_format = mba.mb_policy_db_contents(mb_policy_name, mb_policy)
|
|
+ mb_policy_stored = VerifierMbpolicy(**mb_policy_db_format)
|
|
+ session.add(mb_policy_stored)
|
|
session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error while updating ima policy for agent ID %s: %s", agent_id, e)
|
|
- raise
|
|
-
|
|
- # Handle measured boot policy
|
|
- # - No name, mb_policy : store mb_policy using agent UUID as name
|
|
- # - Name, no mb_policy : fetch existing mb_policy from DB
|
|
- # - Name, mb_policy : store mb_policy using name
|
|
-
|
|
- mb_policy_name = json_body["mb_policy_name"]
|
|
- mb_policy = json_body["mb_policy"]
|
|
- mb_policy_stored = None
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error(
|
|
+ "SQLAlchemy Error while updating mb_policy for agent ID %s: %s", agent_id, e
|
|
+ )
|
|
+ raise
|
|
|
|
- if mb_policy_name:
|
|
+ # Write the agent to the database, attaching associated stored ima_policy and mb_policy
|
|
try:
|
|
- mb_policy_stored = (
|
|
- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none()
|
|
+ assert runtime_policy_stored
|
|
+ assert mb_policy_stored
|
|
+ session.add(
|
|
+ VerfierMain(**agent_data, ima_policy=runtime_policy_stored, mb_policy=mb_policy_stored)
|
|
)
|
|
+ session.commit()
|
|
except SQLAlchemyError as e:
|
|
logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
- raise
|
|
+ raise e
|
|
|
|
- # Prevent overwriting existing mb_policy with name provided in request
|
|
- if mb_policy and mb_policy_stored:
|
|
- web_util.echo_json_response(
|
|
- self,
|
|
- 409,
|
|
- f"mb_policy with name {mb_policy_name} already exists. Please use a different name or delete the mb_policy from the verifier.",
|
|
- )
|
|
- logger.warning("mb_policy with name %s already exists", mb_policy_name)
|
|
- return
|
|
+ # add default fields that are ephemeral
|
|
+ for key, val in exclude_db.items():
|
|
+ agent_data[key] = val
|
|
|
|
- # Return error if the mb_policy is neither provided nor stored.
|
|
- if not mb_policy and not mb_policy_stored:
|
|
- web_util.echo_json_response(
|
|
- self, 404, f"Could not find mb_policy with name {mb_policy_name}!"
|
|
+ # Prepare SSLContext for mTLS connections
|
|
+ agent_data["ssl_context"] = None
|
|
+ if agent_mtls_cert_enabled:
|
|
+ agent_data["ssl_context"] = web_util.generate_agent_tls_context(
|
|
+ "verifier", agent_data["mtls_cert"], logger=logger
|
|
)
|
|
- logger.warning("Could not find mb_policy with name %s", mb_policy_name)
|
|
- return
|
|
|
|
- else:
|
|
- # Use the UUID of the agent
|
|
- mb_policy_name = agent_id
|
|
- try:
|
|
- mb_policy_stored = (
|
|
- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one_or_none()
|
|
- )
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
- raise
|
|
-
|
|
- # Prevent overwriting existing mb_policy
|
|
- if mb_policy and mb_policy_stored:
|
|
- web_util.echo_json_response(
|
|
- self,
|
|
- 409,
|
|
- f"mb_policy with name {mb_policy_name} already exists. You can delete the mb_policy from the verifier.",
|
|
- )
|
|
- logger.warning("mb_policy with name %s already exists", mb_policy_name)
|
|
- return
|
|
-
|
|
- # Store the policy into database if not stored
|
|
- if mb_policy_stored is None:
|
|
- try:
|
|
- mb_policy_db_format = mba.mb_policy_db_contents(mb_policy_name, mb_policy)
|
|
- mb_policy_stored = VerifierMbpolicy(**mb_policy_db_format)
|
|
- session.add(mb_policy_stored)
|
|
- session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error while updating mb_policy for agent ID %s: %s", agent_id, e)
|
|
- raise
|
|
+ if agent_data["ssl_context"] is None:
|
|
+ logger.warning("Connecting to agent without mTLS: %s", agent_id)
|
|
|
|
- # Write the agent to the database, attaching associated stored ima_policy and mb_policy
|
|
- try:
|
|
- assert runtime_policy_stored
|
|
- assert mb_policy_stored
|
|
- session.add(
|
|
- VerfierMain(**agent_data, ima_policy=runtime_policy_stored, mb_policy=mb_policy_stored)
|
|
- )
|
|
- session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
- raise e
|
|
-
|
|
- # add default fields that are ephemeral
|
|
- for key, val in exclude_db.items():
|
|
- agent_data[key] = val
|
|
-
|
|
- # Prepare SSLContext for mTLS connections
|
|
- agent_data["ssl_context"] = None
|
|
- if agent_mtls_cert_enabled:
|
|
- agent_data["ssl_context"] = web_util.generate_agent_tls_context(
|
|
- "verifier", agent_data["mtls_cert"], logger=logger
|
|
- )
|
|
-
|
|
- if agent_data["ssl_context"] is None:
|
|
- logger.warning("Connecting to agent without mTLS: %s", agent_id)
|
|
-
|
|
- asyncio.ensure_future(process_agent(agent_data, states.GET_QUOTE))
|
|
- web_util.echo_json_response(self, 200, "Success")
|
|
- logger.info("POST returning 200 response for adding agent id: %s", agent_id)
|
|
+ asyncio.ensure_future(process_agent(agent_data, states.GET_QUOTE))
|
|
+ web_util.echo_json_response(self, 200, "Success")
|
|
+ logger.info("POST returning 200 response for adding agent id: %s", agent_id)
|
|
else:
|
|
web_util.echo_json_response(self, 400, "uri not supported")
|
|
logger.warning("POST returning 400 response. uri not supported")
|
|
@@ -794,54 +806,54 @@ class AgentsHandler(BaseHandler):
|
|
Currently, only agents resources are available for PUTing, i.e. /agents. All other PUT uri's will return errors.
|
|
agents requests require a json block sent in the body
|
|
"""
|
|
- session = get_session()
|
|
try:
|
|
rest_params, agent_id = self.__validate_input("PUT")
|
|
if not rest_params:
|
|
return
|
|
|
|
- try:
|
|
- verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID)
|
|
- db_agent = session.query(VerfierMain).filter_by(agent_id=agent_id, verifier_id=verifier_id).one()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
- raise e
|
|
+ with session_context() as session:
|
|
+ try:
|
|
+ verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID)
|
|
+ db_agent = session.query(VerfierMain).filter_by(agent_id=agent_id, verifier_id=verifier_id).one()
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
+ raise e
|
|
|
|
- if db_agent is None:
|
|
- web_util.echo_json_response(self, 404, "agent id not found")
|
|
- logger.info("PUT returning 404 response. agent id: %s not found.", agent_id)
|
|
- return
|
|
+ if db_agent is None:
|
|
+ web_util.echo_json_response(self, 404, "agent id not found")
|
|
+ logger.info("PUT returning 404 response. agent id: %s not found.", agent_id)
|
|
+ return
|
|
|
|
- if "reactivate" in rest_params:
|
|
- agent = _from_db_obj(db_agent)
|
|
+ if "reactivate" in rest_params:
|
|
+ agent = _from_db_obj(db_agent)
|
|
|
|
- if agent["mtls_cert"] and agent["mtls_cert"] != "disabled":
|
|
- agent["ssl_context"] = web_util.generate_agent_tls_context(
|
|
- "verifier", agent["mtls_cert"], logger=logger
|
|
- )
|
|
- if agent["ssl_context"] is None:
|
|
- logger.warning("Connecting to agent without mTLS: %s", agent_id)
|
|
+ if agent["mtls_cert"] and agent["mtls_cert"] != "disabled":
|
|
+ agent["ssl_context"] = web_util.generate_agent_tls_context(
|
|
+ "verifier", agent["mtls_cert"], logger=logger
|
|
+ )
|
|
+ if agent["ssl_context"] is None:
|
|
+ logger.warning("Connecting to agent without mTLS: %s", agent_id)
|
|
|
|
- agent["operational_state"] = states.START
|
|
- asyncio.ensure_future(process_agent(agent, states.GET_QUOTE))
|
|
- web_util.echo_json_response(self, 200, "Success")
|
|
- logger.info("PUT returning 200 response for agent id: %s", agent_id)
|
|
- elif "stop" in rest_params:
|
|
- # do stuff for terminate
|
|
- logger.debug("Stopping polling on %s", agent_id)
|
|
- try:
|
|
- session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).update( # pyright: ignore
|
|
- {"operational_state": states.TENANT_FAILED}
|
|
- )
|
|
- session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
+ agent["operational_state"] = states.START
|
|
+ asyncio.ensure_future(process_agent(agent, states.GET_QUOTE))
|
|
+ web_util.echo_json_response(self, 200, "Success")
|
|
+ logger.info("PUT returning 200 response for agent id: %s", agent_id)
|
|
+ elif "stop" in rest_params:
|
|
+ # do stuff for terminate
|
|
+ logger.debug("Stopping polling on %s", agent_id)
|
|
+ try:
|
|
+ session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).update( # pyright: ignore
|
|
+ {"operational_state": states.TENANT_FAILED}
|
|
+ )
|
|
+ # session.commit() is automatically called by context manager
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
|
|
- web_util.echo_json_response(self, 200, "Success")
|
|
- logger.info("PUT returning 200 response for agent id: %s", agent_id)
|
|
- else:
|
|
- web_util.echo_json_response(self, 400, "uri not supported")
|
|
- logger.warning("PUT returning 400 response. uri not supported")
|
|
+ web_util.echo_json_response(self, 200, "Success")
|
|
+ logger.info("PUT returning 200 response for agent id: %s", agent_id)
|
|
+ else:
|
|
+ web_util.echo_json_response(self, 400, "uri not supported")
|
|
+ logger.warning("PUT returning 400 response. uri not supported")
|
|
|
|
except Exception as e:
|
|
web_util.echo_json_response(self, 400, f"Exception error: {str(e)}")
|
|
@@ -887,36 +899,36 @@ class AllowlistHandler(BaseHandler):
|
|
if not params_valid:
|
|
return
|
|
|
|
- session = get_session()
|
|
- if allowlist_name is None:
|
|
- try:
|
|
- names_allowlists = session.query(VerifierAllowlist.name).all()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- web_util.echo_json_response(self, 500, "Failed to get names of allowlists")
|
|
- raise
|
|
+ with session_context() as session:
|
|
+ if allowlist_name is None:
|
|
+ try:
|
|
+ names_allowlists = session.query(VerifierAllowlist.name).all()
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ web_util.echo_json_response(self, 500, "Failed to get names of allowlists")
|
|
+ raise
|
|
|
|
- names_response = []
|
|
- for name in names_allowlists:
|
|
- names_response.append(name[0])
|
|
- web_util.echo_json_response(self, 200, "Success", {"runtimepolicy names": names_response})
|
|
+ names_response = []
|
|
+ for name in names_allowlists:
|
|
+ names_response.append(name[0])
|
|
+ web_util.echo_json_response(self, 200, "Success", {"runtimepolicy names": names_response})
|
|
|
|
- else:
|
|
- try:
|
|
- allowlist = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one()
|
|
- except NoResultFound:
|
|
- web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found")
|
|
- return
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- web_util.echo_json_response(self, 500, "Failed to get allowlist")
|
|
- raise
|
|
+ else:
|
|
+ try:
|
|
+ allowlist = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one()
|
|
+ except NoResultFound:
|
|
+ web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found")
|
|
+ return
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ web_util.echo_json_response(self, 500, "Failed to get allowlist")
|
|
+ raise
|
|
|
|
- response = {}
|
|
- for field in ("name", "tpm_policy"):
|
|
- response[field] = getattr(allowlist, field, None)
|
|
- response["runtime_policy"] = getattr(allowlist, "ima_policy", None)
|
|
- web_util.echo_json_response(self, 200, "Success", response)
|
|
+ response = {}
|
|
+ for field in ("name", "tmp_policy"):
|
|
+ response[field] = getattr(allowlist, field, None)
|
|
+ response["runtime_policy"] = getattr(allowlist, "ima_policy", None)
|
|
+ web_util.echo_json_response(self, 200, "Success", response)
|
|
|
|
def delete(self) -> None:
|
|
"""Delete an allowlist
|
|
@@ -928,45 +940,44 @@ class AllowlistHandler(BaseHandler):
|
|
if not params_valid or allowlist_name is None:
|
|
return
|
|
|
|
- session = get_session()
|
|
- try:
|
|
- runtime_policy = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one()
|
|
- except NoResultFound:
|
|
- web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found")
|
|
- return
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- web_util.echo_json_response(self, 500, "Failed to get allowlist")
|
|
- raise
|
|
+ with session_context() as session:
|
|
+ try:
|
|
+ runtime_policy = session.query(VerifierAllowlist).filter_by(name=allowlist_name).one()
|
|
+ except NoResultFound:
|
|
+ web_util.echo_json_response(self, 404, f"Runtime policy {allowlist_name} not found")
|
|
+ return
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ web_util.echo_json_response(self, 500, "Failed to get allowlist")
|
|
+ raise
|
|
|
|
- try:
|
|
- agent = session.query(VerfierMain).filter_by(ima_policy_id=runtime_policy.id).one_or_none()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- raise
|
|
- if agent is not None:
|
|
- web_util.echo_json_response(
|
|
- self,
|
|
- 409,
|
|
- f"Can't delete allowlist as it's currently in use by agent {agent.agent_id}",
|
|
- )
|
|
- return
|
|
+ try:
|
|
+ agent = session.query(VerfierMain).filter_by(ima_policy_id=runtime_policy.id).one_or_none()
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ raise
|
|
+ if agent is not None:
|
|
+ web_util.echo_json_response(
|
|
+ self,
|
|
+ 409,
|
|
+ f"Can't delete allowlist as it's currently in use by agent {agent.agent_id}",
|
|
+ )
|
|
+ return
|
|
|
|
- try:
|
|
- session.query(VerifierAllowlist).filter_by(name=allowlist_name).delete()
|
|
- session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- session.close()
|
|
- web_util.echo_json_response(self, 500, f"Database error: {e}")
|
|
- raise
|
|
+ try:
|
|
+ session.query(VerifierAllowlist).filter_by(name=allowlist_name).delete()
|
|
+ # session.commit() is automatically called by context manager
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ web_util.echo_json_response(self, 500, f"Database error: {e}")
|
|
+ raise
|
|
|
|
- # NOTE(kaifeng) 204 Can not have response body, but current helper
|
|
- # doesn't support this case.
|
|
- self.set_status(204)
|
|
- self.set_header("Content-Type", "application/json")
|
|
- self.finish()
|
|
- logger.info("DELETE returning 204 response for allowlist: %s", allowlist_name)
|
|
+ # NOTE(kaifeng) 204 Can not have response body, but current helper
|
|
+ # doesn't support this case.
|
|
+ self.set_status(204)
|
|
+ self.set_header("Content-Type", "application/json")
|
|
+ self.finish()
|
|
+ logger.info("DELETE returning 204 response for allowlist: %s", allowlist_name)
|
|
|
|
def __get_runtime_policy_db_format(self, runtime_policy_name: str) -> Dict[str, Any]:
|
|
"""Get the IMA policy from the request and return it in Db format"""
|
|
@@ -1022,28 +1033,30 @@ class AllowlistHandler(BaseHandler):
|
|
if not runtime_policy_db_format:
|
|
return
|
|
|
|
- session = get_session()
|
|
- # don't allow overwritting
|
|
- try:
|
|
- runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count()
|
|
- if runtime_policy_count > 0:
|
|
- web_util.echo_json_response(self, 409, f"Runtime policy with name {runtime_policy_name} already exists")
|
|
- logger.warning("Runtime policy with name %s already exists", runtime_policy_name)
|
|
- return
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- raise
|
|
+ with session_context() as session:
|
|
+ # don't allow overwritting
|
|
+ try:
|
|
+ runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count()
|
|
+ if runtime_policy_count > 0:
|
|
+ web_util.echo_json_response(
|
|
+ self, 409, f"Runtime policy with name {runtime_policy_name} already exists"
|
|
+ )
|
|
+ logger.warning("Runtime policy with name %s already exists", runtime_policy_name)
|
|
+ return
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ raise
|
|
|
|
- try:
|
|
- # Add the agent and data
|
|
- session.add(VerifierAllowlist(**runtime_policy_db_format))
|
|
- session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- raise
|
|
+ try:
|
|
+ # Add the agent and data
|
|
+ session.add(VerifierAllowlist(**runtime_policy_db_format))
|
|
+ # session.commit() is automatically called by context manager
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ raise
|
|
|
|
- web_util.echo_json_response(self, 201)
|
|
- logger.info("POST returning 201")
|
|
+ web_util.echo_json_response(self, 201)
|
|
+ logger.info("POST returning 201")
|
|
|
|
def put(self) -> None:
|
|
"""Update an allowlist
|
|
@@ -1060,32 +1073,34 @@ class AllowlistHandler(BaseHandler):
|
|
if not runtime_policy_db_format:
|
|
return
|
|
|
|
- session = get_session()
|
|
- # don't allow creating a new policy
|
|
- try:
|
|
- runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count()
|
|
- if runtime_policy_count != 1:
|
|
- web_util.echo_json_response(
|
|
- self, 409, f"Runtime policy with name {runtime_policy_name} does not already exist"
|
|
- )
|
|
- logger.warning("Runtime policy with name %s does not already exist", runtime_policy_name)
|
|
- return
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- raise
|
|
+ with session_context() as session:
|
|
+ # don't allow creating a new policy
|
|
+ try:
|
|
+ runtime_policy_count = session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).count()
|
|
+ if runtime_policy_count != 1:
|
|
+ web_util.echo_json_response(
|
|
+ self,
|
|
+ 404,
|
|
+ f"Runtime policy with name {runtime_policy_name} does not already exist, use POST to create",
|
|
+ )
|
|
+ logger.warning("Runtime policy with name %s does not already exist", runtime_policy_name)
|
|
+ return
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ raise
|
|
|
|
- try:
|
|
- # Update the named runtime policy
|
|
- session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).update(
|
|
- runtime_policy_db_format # pyright: ignore
|
|
- )
|
|
- session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- raise
|
|
+ try:
|
|
+ # Update the named runtime policy
|
|
+ session.query(VerifierAllowlist).filter_by(name=runtime_policy_name).update(
|
|
+ runtime_policy_db_format # pyright: ignore
|
|
+ )
|
|
+ # session.commit() is automatically called by context manager
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ raise
|
|
|
|
- web_util.echo_json_response(self, 201)
|
|
- logger.info("PUT returning 201")
|
|
+ web_util.echo_json_response(self, 201)
|
|
+ logger.info("PUT returning 201")
|
|
|
|
def data_received(self, chunk: Any) -> None:
|
|
raise NotImplementedError()
|
|
@@ -1113,8 +1128,6 @@ class VerifyIdentityHandler(BaseHandler):
|
|
|
|
This is useful for 3rd party tools and integrations to independently verify the state of an agent.
|
|
"""
|
|
- session = get_session()
|
|
-
|
|
# validate the parameters of our request
|
|
if self.request.uri is None:
|
|
web_util.echo_json_response(self, 400, "URI not specified")
|
|
@@ -1159,36 +1172,37 @@ class VerifyIdentityHandler(BaseHandler):
|
|
return
|
|
|
|
# get the agent information from the DB
|
|
- agent = None
|
|
- try:
|
|
- agent = (
|
|
- session.query(VerfierMain)
|
|
- .options( # type: ignore
|
|
- joinedload(VerfierMain.ima_policy).load_only(
|
|
- VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
|
|
+ with session_context() as session:
|
|
+ agent = None
|
|
+ try:
|
|
+ agent = (
|
|
+ session.query(VerfierMain)
|
|
+ .options( # type: ignore
|
|
+ joinedload(VerfierMain.ima_policy).load_only(
|
|
+ VerifierAllowlist.checksum, VerifierAllowlist.generator # pyright: ignore
|
|
+ )
|
|
)
|
|
+ .filter_by(agent_id=agent_id)
|
|
+ .one_or_none()
|
|
)
|
|
- .filter_by(agent_id=agent_id)
|
|
- .one_or_none()
|
|
- )
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent_id, e)
|
|
|
|
- if agent is not None:
|
|
- agentAttestState = get_AgentAttestStates().get_by_agent_id(agent_id)
|
|
- failure = cloud_verifier_common.process_verify_identity_quote(
|
|
- agent, quote, nonce, hash_alg, agentAttestState
|
|
- )
|
|
- if failure:
|
|
- failure_contexts = "; ".join(x.context for x in failure.events)
|
|
- web_util.echo_json_response(self, 200, "Success", {"valid": 0, "reason": failure_contexts})
|
|
- logger.info("GET returning 200, but validation failed")
|
|
+ if agent is not None:
|
|
+ agentAttestState = get_AgentAttestStates().get_by_agent_id(agent_id)
|
|
+ failure = cloud_verifier_common.process_verify_identity_quote(
|
|
+ agent, quote, nonce, hash_alg, agentAttestState
|
|
+ )
|
|
+ if failure:
|
|
+ failure_contexts = "; ".join(x.context for x in failure.events)
|
|
+ web_util.echo_json_response(self, 200, "Success", {"valid": 0, "reason": failure_contexts})
|
|
+ logger.info("GET returning 200, but validation failed")
|
|
+ else:
|
|
+ web_util.echo_json_response(self, 200, "Success", {"valid": 1})
|
|
+ logger.info("GET returning 200, validation successful")
|
|
else:
|
|
- web_util.echo_json_response(self, 200, "Success", {"valid": 1})
|
|
- logger.info("GET returning 200, validation successful")
|
|
- else:
|
|
- web_util.echo_json_response(self, 404, "agent id not found")
|
|
- logger.info("GET returning 404, agaent not found")
|
|
+ web_util.echo_json_response(self, 404, "agent id not found")
|
|
+ logger.info("GET returning 404, agaent not found")
|
|
|
|
def data_received(self, chunk: Any) -> None:
|
|
raise NotImplementedError()
|
|
@@ -1231,35 +1245,35 @@ class MbpolicyHandler(BaseHandler):
|
|
if not params_valid:
|
|
return
|
|
|
|
- session = get_session()
|
|
- if mb_policy_name is None:
|
|
- try:
|
|
- names_mbpolicies = session.query(VerifierMbpolicy.name).all()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- web_util.echo_json_response(self, 500, "Failed to get names of mbpolicies")
|
|
- raise
|
|
+ with session_context() as session:
|
|
+ if mb_policy_name is None:
|
|
+ try:
|
|
+ names_mbpolicies = session.query(VerifierMbpolicy.name).all()
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ web_util.echo_json_response(self, 500, "Failed to get names of mbpolicies")
|
|
+ raise
|
|
|
|
- names_response = []
|
|
- for name in names_mbpolicies:
|
|
- names_response.append(name[0])
|
|
- web_util.echo_json_response(self, 200, "Success", {"mbpolicy names": names_response})
|
|
+ names_response = []
|
|
+ for name in names_mbpolicies:
|
|
+ names_response.append(name[0])
|
|
+ web_util.echo_json_response(self, 200, "Success", {"mbpolicy names": names_response})
|
|
|
|
- else:
|
|
- try:
|
|
- mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one()
|
|
- except NoResultFound:
|
|
- web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found")
|
|
- return
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- web_util.echo_json_response(self, 500, "Failed to get mb_policy")
|
|
- raise
|
|
+ else:
|
|
+ try:
|
|
+ mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one()
|
|
+ except NoResultFound:
|
|
+ web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found")
|
|
+ return
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ web_util.echo_json_response(self, 500, "Failed to get mb_policy")
|
|
+ raise
|
|
|
|
- response = {}
|
|
- response["name"] = getattr(mbpolicy, "name", None)
|
|
- response["mb_policy"] = getattr(mbpolicy, "mb_policy", None)
|
|
- web_util.echo_json_response(self, 200, "Success", response)
|
|
+ response = {}
|
|
+ response["name"] = getattr(mbpolicy, "name", None)
|
|
+ response["mb_policy"] = getattr(mbpolicy, "mb_policy", None)
|
|
+ web_util.echo_json_response(self, 200, "Success", response)
|
|
|
|
def delete(self) -> None:
|
|
"""Delete a mb_policy
|
|
@@ -1271,45 +1285,44 @@ class MbpolicyHandler(BaseHandler):
|
|
if not params_valid or mb_policy_name is None:
|
|
return
|
|
|
|
- session = get_session()
|
|
- try:
|
|
- mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one()
|
|
- except NoResultFound:
|
|
- web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found")
|
|
- return
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- web_util.echo_json_response(self, 500, "Failed to get mb_policy")
|
|
- raise
|
|
+ with session_context() as session:
|
|
+ try:
|
|
+ mbpolicy = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).one()
|
|
+ except NoResultFound:
|
|
+ web_util.echo_json_response(self, 404, f"Measured boot policy {mb_policy_name} not found")
|
|
+ return
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ web_util.echo_json_response(self, 500, "Failed to get mb_policy")
|
|
+ raise
|
|
|
|
- try:
|
|
- agent = session.query(VerfierMain).filter_by(mb_policy_id=mbpolicy.id).one_or_none()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- raise
|
|
- if agent is not None:
|
|
- web_util.echo_json_response(
|
|
- self,
|
|
- 409,
|
|
- f"Can't delete mb_policy as it's currently in use by agent {agent.agent_id}",
|
|
- )
|
|
- return
|
|
+ try:
|
|
+ agent = session.query(VerfierMain).filter_by(mb_policy_id=mbpolicy.id).one_or_none()
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ raise
|
|
+ if agent is not None:
|
|
+ web_util.echo_json_response(
|
|
+ self,
|
|
+ 409,
|
|
+ f"Can't delete mb_policy as it's currently in use by agent {agent.agent_id}",
|
|
+ )
|
|
+ return
|
|
|
|
- try:
|
|
- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).delete()
|
|
- session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- session.close()
|
|
- web_util.echo_json_response(self, 500, f"Database error: {e}")
|
|
- raise
|
|
+ try:
|
|
+ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).delete()
|
|
+ # session.commit() is automatically called by context manager
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ web_util.echo_json_response(self, 500, f"Database error: {e}")
|
|
+ raise
|
|
|
|
- # NOTE(kaifeng) 204 Can not have response body, but current helper
|
|
- # doesn't support this case.
|
|
- self.set_status(204)
|
|
- self.set_header("Content-Type", "application/json")
|
|
- self.finish()
|
|
- logger.info("DELETE returning 204 response for mb_policy: %s", mb_policy_name)
|
|
+ # NOTE(kaifeng) 204 Can not have response body, but current helper
|
|
+ # doesn't support this case.
|
|
+ self.set_status(204)
|
|
+ self.set_header("Content-Type", "application/json")
|
|
+ self.finish()
|
|
+ logger.info("DELETE returning 204 response for mb_policy: %s", mb_policy_name)
|
|
|
|
def __get_mb_policy_db_format(self, mb_policy_name: str) -> Dict[str, Any]:
|
|
"""Get the measured boot policy from the request and return it in Db format"""
|
|
@@ -1341,30 +1354,30 @@ class MbpolicyHandler(BaseHandler):
|
|
if not mb_policy_db_format:
|
|
return
|
|
|
|
- session = get_session()
|
|
- # don't allow overwritting
|
|
- try:
|
|
- mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count()
|
|
- if mbpolicy_count > 0:
|
|
- web_util.echo_json_response(
|
|
- self, 409, f"Measured boot policy with name {mb_policy_name} already exists"
|
|
- )
|
|
- logger.warning("Measured boot policy with name %s already exists", mb_policy_name)
|
|
- return
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- raise
|
|
+ with session_context() as session:
|
|
+ # don't allow overwritting
|
|
+ try:
|
|
+ mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count()
|
|
+ if mbpolicy_count > 0:
|
|
+ web_util.echo_json_response(
|
|
+ self, 409, f"Measured boot policy with name {mb_policy_name} already exists"
|
|
+ )
|
|
+ logger.warning("Measured boot policy with name %s already exists", mb_policy_name)
|
|
+ return
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ raise
|
|
|
|
- try:
|
|
- # Add the data
|
|
- session.add(VerifierMbpolicy(**mb_policy_db_format))
|
|
- session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- raise
|
|
+ try:
|
|
+ # Add the data
|
|
+ session.add(VerifierMbpolicy(**mb_policy_db_format))
|
|
+ # session.commit() is automatically called by context manager
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ raise
|
|
|
|
- web_util.echo_json_response(self, 201)
|
|
- logger.info("POST returning 201")
|
|
+ web_util.echo_json_response(self, 201)
|
|
+ logger.info("POST returning 201")
|
|
|
|
def put(self) -> None:
|
|
"""Update an mb_policy
|
|
@@ -1381,32 +1394,32 @@ class MbpolicyHandler(BaseHandler):
|
|
if not mb_policy_db_format:
|
|
return
|
|
|
|
- session = get_session()
|
|
- # don't allow creating a new policy
|
|
- try:
|
|
- mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count()
|
|
- if mbpolicy_count != 1:
|
|
- web_util.echo_json_response(
|
|
- self, 409, f"Measured boot policy with name {mb_policy_name} does not already exist"
|
|
- )
|
|
- logger.warning("Measured boot policy with name %s does not already exist", mb_policy_name)
|
|
- return
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- raise
|
|
+ with session_context() as session:
|
|
+ # don't allow creating a new policy
|
|
+ try:
|
|
+ mbpolicy_count = session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).count()
|
|
+ if mbpolicy_count != 1:
|
|
+ web_util.echo_json_response(
|
|
+ self, 409, f"Measured boot policy with name {mb_policy_name} does not already exist"
|
|
+ )
|
|
+ logger.warning("Measured boot policy with name %s does not already exist", mb_policy_name)
|
|
+ return
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ raise
|
|
|
|
- try:
|
|
- # Update the named mb_policy
|
|
- session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).update(
|
|
- mb_policy_db_format # pyright: ignore
|
|
- )
|
|
- session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
- raise
|
|
+ try:
|
|
+ # Update the named mb_policy
|
|
+ session.query(VerifierMbpolicy).filter_by(name=mb_policy_name).update(
|
|
+ mb_policy_db_format # pyright: ignore
|
|
+ )
|
|
+ # session.commit() is automatically called by context manager
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
+ raise
|
|
|
|
- web_util.echo_json_response(self, 201)
|
|
- logger.info("PUT returning 201")
|
|
+ web_util.echo_json_response(self, 201)
|
|
+ logger.info("PUT returning 201")
|
|
|
|
def data_received(self, chunk: Any) -> None:
|
|
raise NotImplementedError()
|
|
@@ -1460,17 +1473,18 @@ async def update_agent_api_version(agent: Dict[str, Any], timeout: float = 60.0)
|
|
return None
|
|
|
|
logger.info("Agent %s new API version %s is supported", agent_id, new_version)
|
|
- session = get_session()
|
|
- agent["supported_version"] = new_version
|
|
|
|
- # Remove keys that should not go to the DB
|
|
- agent_db = dict(agent)
|
|
- for key in exclude_db:
|
|
- if key in agent_db:
|
|
- del agent_db[key]
|
|
+ with session_context() as session:
|
|
+ agent["supported_version"] = new_version
|
|
|
|
- session.query(VerfierMain).filter_by(agent_id=agent_id).update(agent_db) # pyright: ignore
|
|
- session.commit()
|
|
+ # Remove keys that should not go to the DB
|
|
+ agent_db = dict(agent)
|
|
+ for key in exclude_db:
|
|
+ if key in agent_db:
|
|
+ del agent_db[key]
|
|
+
|
|
+ session.query(VerfierMain).filter_by(agent_id=agent_id).update(agent_db) # pyright: ignore
|
|
+ # session.commit() is automatically called by context manager
|
|
else:
|
|
logger.warning("Agent %s new API version %s is not supported", agent_id, new_version)
|
|
return None
|
|
@@ -1718,50 +1732,68 @@ async def notify_error(
|
|
revocation_notifier.notify(tosend)
|
|
if "agent" in notifiers:
|
|
verifier_id = config.get("verifier", "uuid", fallback=cloud_verifier_common.DEFAULT_VERIFIER_ID)
|
|
- session = get_session()
|
|
- agents = session.query(VerfierMain).filter_by(verifier_id=verifier_id).all()
|
|
- futures = []
|
|
- loop = asyncio.get_event_loop()
|
|
- # Notify all agents asynchronously through a thread pool
|
|
- with ThreadPoolExecutor() as pool:
|
|
- for agent_db_obj in agents:
|
|
- if agent_db_obj.agent_id != agent["agent_id"]:
|
|
- agent = _from_db_obj(agent_db_obj)
|
|
- if agent["mtls_cert"] and agent["mtls_cert"] != "disabled":
|
|
- agent["ssl_context"] = web_util.generate_agent_tls_context(
|
|
- "verifier", agent["mtls_cert"], logger=logger
|
|
- )
|
|
- func = functools.partial(invoke_notify_error, agent, tosend, timeout=timeout)
|
|
- futures.append(await loop.run_in_executor(pool, func))
|
|
- # Wait for all tasks complete in 60 seconds
|
|
- try:
|
|
- for f in asyncio.as_completed(futures, timeout=60):
|
|
- await f
|
|
- except asyncio.TimeoutError as e:
|
|
- logger.error("Timeout during notifying error to agents: %s", e)
|
|
+ with session_context() as session:
|
|
+ agents = session.query(VerfierMain).filter_by(verifier_id=verifier_id).all()
|
|
+ futures = []
|
|
+ loop = asyncio.get_event_loop()
|
|
+ # Notify all agents asynchronously through a thread pool
|
|
+ with ThreadPoolExecutor() as pool:
|
|
+ for agent_db_obj in agents:
|
|
+ if agent_db_obj.agent_id != agent["agent_id"]:
|
|
+ agent = _from_db_obj(agent_db_obj)
|
|
+ if agent["mtls_cert"] and agent["mtls_cert"] != "disabled":
|
|
+ agent["ssl_context"] = web_util.generate_agent_tls_context(
|
|
+ "verifier", agent["mtls_cert"], logger=logger
|
|
+ )
|
|
+ func = functools.partial(invoke_notify_error, agent, tosend, timeout=timeout)
|
|
+ futures.append(await loop.run_in_executor(pool, func))
|
|
+ # Wait for all tasks complete in 60 seconds
|
|
+ try:
|
|
+ for f in asyncio.as_completed(futures, timeout=60):
|
|
+ await f
|
|
+ except asyncio.TimeoutError as e:
|
|
+ logger.error("Timeout during notifying error to agents: %s", e)
|
|
|
|
|
|
async def process_agent(
|
|
agent: Dict[str, Any], new_operational_state: int, failure: Failure = Failure(Component.INTERNAL, ["verifier"])
|
|
) -> None:
|
|
- session = get_session()
|
|
try: # pylint: disable=R1702
|
|
main_agent_operational_state = agent["operational_state"]
|
|
stored_agent = None
|
|
- try:
|
|
- stored_agent = (
|
|
- session.query(VerfierMain)
|
|
- .options( # type: ignore
|
|
- joinedload(VerfierMain.ima_policy).load_only(VerifierAllowlist.checksum) # pyright: ignore
|
|
- )
|
|
- .options( # type: ignore
|
|
- joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore
|
|
+
|
|
+ # First database operation - read agent data and extract all needed data within session context
|
|
+ ima_policy_data = {}
|
|
+ mb_policy_data = None
|
|
+ with session_context() as session:
|
|
+ try:
|
|
+ stored_agent = (
|
|
+ session.query(VerfierMain)
|
|
+ .options( # type: ignore
|
|
+ joinedload(VerfierMain.ima_policy) # Load full IMA policy object including content
|
|
+ )
|
|
+ .options( # type: ignore
|
|
+ joinedload(VerfierMain.mb_policy).load_only(VerifierMbpolicy.mb_policy) # pyright: ignore
|
|
+ )
|
|
+ .filter_by(agent_id=str(agent["agent_id"]))
|
|
+ .first()
|
|
)
|
|
- .filter_by(agent_id=str(agent["agent_id"]))
|
|
- .first()
|
|
- )
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e)
|
|
+
|
|
+ # Extract IMA policy data within session context to avoid DetachedInstanceError
|
|
+ if stored_agent and stored_agent.ima_policy:
|
|
+ ima_policy_data = {
|
|
+ "checksum": str(stored_agent.ima_policy.checksum),
|
|
+ "name": stored_agent.ima_policy.name,
|
|
+ "agent_id": str(stored_agent.agent_id),
|
|
+ "ima_policy": stored_agent.ima_policy.ima_policy, # Extract the large content too
|
|
+ }
|
|
+
|
|
+ # Extract MB policy data within session context
|
|
+ if stored_agent and stored_agent.mb_policy:
|
|
+ mb_policy_data = stored_agent.mb_policy.mb_policy
|
|
+
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e)
|
|
|
|
# if the stored agent could not be recovered from the database, stop polling
|
|
if not stored_agent:
|
|
@@ -1775,7 +1807,10 @@ async def process_agent(
|
|
logger.warning("Agent %s terminated by user.", agent["agent_id"])
|
|
if agent["pending_event"] is not None:
|
|
tornado.ioloop.IOLoop.current().remove_timeout(agent["pending_event"])
|
|
- verifier_db_delete_agent(session, agent["agent_id"])
|
|
+
|
|
+ # Second database operation - delete agent
|
|
+ with session_context() as session:
|
|
+ verifier_db_delete_agent(session, agent["agent_id"])
|
|
return
|
|
|
|
# if the user tells us to stop polling because the tenant quote check failed
|
|
@@ -1808,11 +1843,16 @@ async def process_agent(
|
|
if not failure.recoverable or failure.highest_severity == MAX_SEVERITY_LABEL:
|
|
if agent["pending_event"] is not None:
|
|
tornado.ioloop.IOLoop.current().remove_timeout(agent["pending_event"])
|
|
- for key in exclude_db:
|
|
- if key in agent:
|
|
- del agent[key]
|
|
- session.query(VerfierMain).filter_by(agent_id=agent["agent_id"]).update(agent) # pyright: ignore
|
|
- session.commit()
|
|
+
|
|
+ # Third database operation - update agent with failure state
|
|
+ with session_context() as session:
|
|
+ for key in exclude_db:
|
|
+ if key in agent:
|
|
+ del agent[key]
|
|
+ session.query(VerfierMain).filter_by(agent_id=agent["agent_id"]).update(
|
|
+ agent # type: ignore[arg-type]
|
|
+ )
|
|
+ # session.commit() is automatically called by context manager
|
|
|
|
# propagate all state, but remove none DB keys first (using exclude_db)
|
|
try:
|
|
@@ -1821,18 +1861,18 @@ async def process_agent(
|
|
if key in agent_db:
|
|
del agent_db[key]
|
|
|
|
- session.query(VerfierMain).filter_by(agent_id=agent_db["agent_id"]).update(agent_db) # pyright: ignore
|
|
- session.commit()
|
|
+ # Fourth database operation - update agent state
|
|
+ with session_context() as session:
|
|
+ session.query(VerfierMain).filter_by(agent_id=agent_db["agent_id"]).update(agent_db) # pyright: ignore
|
|
+ # session.commit() is automatically called by context manager
|
|
except SQLAlchemyError as e:
|
|
logger.error("SQLAlchemy Error for agent ID %s: %s", agent["agent_id"], e)
|
|
|
|
# Load agent's IMA policy
|
|
- runtime_policy = verifier_read_policy_from_cache(stored_agent)
|
|
+ runtime_policy = verifier_read_policy_from_cache(ima_policy_data)
|
|
|
|
# Get agent's measured boot policy
|
|
- mb_policy = None
|
|
- if stored_agent.mb_policy is not None:
|
|
- mb_policy = stored_agent.mb_policy.mb_policy
|
|
+ mb_policy = mb_policy_data
|
|
|
|
# If agent was in a failed state we check if we either stop polling
|
|
# or just add it again to the event loop
|
|
@@ -1876,7 +1916,14 @@ async def process_agent(
|
|
)
|
|
|
|
pending = tornado.ioloop.IOLoop.current().call_later(
|
|
- interval, invoke_get_quote, agent, mb_policy, runtime_policy, False, timeout=timeout # type: ignore # due to python <3.9
|
|
+ # type: ignore # due to python <3.9
|
|
+ interval,
|
|
+ invoke_get_quote,
|
|
+ agent,
|
|
+ mb_policy,
|
|
+ runtime_policy,
|
|
+ False,
|
|
+ timeout=timeout,
|
|
)
|
|
agent["pending_event"] = pending
|
|
return
|
|
@@ -1911,7 +1958,14 @@ async def process_agent(
|
|
next_retry,
|
|
)
|
|
tornado.ioloop.IOLoop.current().call_later(
|
|
- next_retry, invoke_get_quote, agent, mb_policy, runtime_policy, True, timeout=timeout # type: ignore # due to python <3.9
|
|
+ # type: ignore # due to python <3.9
|
|
+ next_retry,
|
|
+ invoke_get_quote,
|
|
+ agent,
|
|
+ mb_policy,
|
|
+ runtime_policy,
|
|
+ True,
|
|
+ timeout=timeout,
|
|
)
|
|
return
|
|
|
|
@@ -1980,9 +2034,9 @@ async def activate_agents(agents: List[VerfierMain], verifier_ip: str, verifier_
|
|
|
|
|
|
def get_agents_by_verifier_id(verifier_id: str) -> List[VerfierMain]:
|
|
- session = get_session()
|
|
try:
|
|
- return session.query(VerfierMain).filter_by(verifier_id=verifier_id).all()
|
|
+ with session_context() as session:
|
|
+ return session.query(VerfierMain).filter_by(verifier_id=verifier_id).all()
|
|
except SQLAlchemyError as e:
|
|
logger.error("SQLAlchemy Error: %s", e)
|
|
return []
|
|
@@ -2007,20 +2061,20 @@ def main() -> None:
|
|
os.umask(0o077)
|
|
|
|
VerfierMain.metadata.create_all(engine, checkfirst=True) # pyright: ignore
|
|
- session = get_session()
|
|
- try:
|
|
- query_all = session.query(VerfierMain).all()
|
|
- for row in query_all:
|
|
- if row.operational_state in states.APPROVED_REACTIVATE_STATES:
|
|
- row.operational_state = states.START # pyright: ignore
|
|
- session.commit()
|
|
- except SQLAlchemyError as e:
|
|
- logger.error("SQLAlchemy Error: %s", e)
|
|
+ with session_context() as session:
|
|
+ try:
|
|
+ query_all = session.query(VerfierMain).all()
|
|
+ for row in query_all:
|
|
+ if row.operational_state in states.APPROVED_REACTIVATE_STATES:
|
|
+ row.operational_state = states.START # pyright: ignore
|
|
+ # session.commit() is automatically called by context manager
|
|
+ except SQLAlchemyError as e:
|
|
+ logger.error("SQLAlchemy Error: %s", e)
|
|
|
|
- num = session.query(VerfierMain.agent_id).count()
|
|
- if num > 0:
|
|
- agent_ids = session.query(VerfierMain.agent_id).all()
|
|
- logger.info("Agent ids in db loaded from file: %s", agent_ids)
|
|
+ num = session.query(VerfierMain.agent_id).count()
|
|
+ if num > 0:
|
|
+ agent_ids = session.query(VerfierMain.agent_id).all()
|
|
+ logger.info("Agent ids in db loaded from file: %s", agent_ids)
|
|
|
|
logger.info("Starting Cloud Verifier (tornado) on port %s, use <Ctrl-C> to stop", verifier_port)
|
|
|
|
diff --git a/keylime/da/examples/sqldb.py b/keylime/da/examples/sqldb.py
|
|
index 8efc84e..04a8afb 100644
|
|
--- a/keylime/da/examples/sqldb.py
|
|
+++ b/keylime/da/examples/sqldb.py
|
|
@@ -1,7 +1,10 @@
|
|
import time
|
|
+from contextlib import contextmanager
|
|
+from typing import Iterator
|
|
|
|
import sqlalchemy
|
|
import sqlalchemy.ext.declarative
|
|
+from sqlalchemy.orm import sessionmaker
|
|
|
|
from keylime import keylime_logging
|
|
from keylime.da.record import BaseRecordManagement, base_build_key_list
|
|
@@ -45,23 +48,23 @@ class RecordManagement(BaseRecordManagement):
|
|
BaseRecordManagement.__init__(self, service)
|
|
|
|
self.engine = sqlalchemy.create_engine(self.ps_url._replace(fragment="").geturl(), pool_recycle=1800)
|
|
- sm = sqlalchemy.orm.sessionmaker()
|
|
- self.session = sqlalchemy.orm.scoped_session(sm)
|
|
- self.session.configure(bind=self.engine)
|
|
- TableBase.metadata.create_all(self.engine)
|
|
-
|
|
- def agent_list_retrieval(self, record_prefix="auto", service="auto"):
|
|
- if record_prefix == "auto":
|
|
- record_prefix = ""
|
|
-
|
|
- agent_list = []
|
|
+ self.SessionLocal = sessionmaker(bind=self.engine)
|
|
|
|
- recordtype = self.get_record_type(service)
|
|
- tbl = type2table(recordtype)
|
|
- for agentid in self.session.query(tbl.agentid).distinct(): # pylint: disable=no-member
|
|
- agent_list.append(agentid[0])
|
|
+ # Create tables if they don't exist
|
|
+ TableBase.metadata.create_all(self.engine)
|
|
|
|
- return agent_list
|
|
+ @contextmanager
|
|
+ def session_context(self) -> Iterator:
|
|
+ """Context manager for database sessions that ensures proper cleanup."""
|
|
+ session = self.SessionLocal()
|
|
+ try:
|
|
+ yield session
|
|
+ session.commit()
|
|
+ except Exception:
|
|
+ session.rollback()
|
|
+ raise
|
|
+ finally:
|
|
+ session.close()
|
|
|
|
def record_create(
|
|
self,
|
|
@@ -84,8 +87,9 @@ class RecordManagement(BaseRecordManagement):
|
|
d = {"time": recordtime, "agentid": agentid, "record": rcrd}
|
|
|
|
try:
|
|
- self.session.add((type2table(recordtype))(**d)) # pylint: disable=no-member
|
|
- self.session.commit() # pylint: disable=no-member
|
|
+ with self.session_context() as session:
|
|
+ session.add((type2table(recordtype))(**d))
|
|
+ # session.commit() is automatically called by context manager
|
|
except Exception as e:
|
|
logger.error("Failed to create attestation record: %s", e)
|
|
|
|
@@ -106,23 +110,23 @@ class RecordManagement(BaseRecordManagement):
|
|
if f"{end_date}" == "auto":
|
|
end_date = self.end_of_times
|
|
|
|
- if self.only_last_record_wanted(start_date, end_date):
|
|
- attestion_record_rows = (
|
|
- self.session.query(tbl) # pylint: disable=no-member
|
|
- .filter(tbl.agentid == record_identifier)
|
|
- .order_by(sqlalchemy.desc(tbl.time))
|
|
- .limit(1)
|
|
- )
|
|
-
|
|
- else:
|
|
- attestion_record_rows = self.session.query(tbl).filter( # pylint: disable=no-member
|
|
- tbl.agentid == record_identifier
|
|
- )
|
|
-
|
|
- for row in attestion_record_rows:
|
|
- decoded_record_object = self.record_deserialize(row.record)
|
|
- self.record_signature_check(decoded_record_object, record_identifier)
|
|
- record_list.append(decoded_record_object)
|
|
+ with self.session_context() as session:
|
|
+ if self.only_last_record_wanted(start_date, end_date):
|
|
+ attestion_record_rows = (
|
|
+ session.query(tbl)
|
|
+ .filter(tbl.agentid == record_identifier)
|
|
+ .order_by(sqlalchemy.desc(tbl.time))
|
|
+ .limit(1)
|
|
+ )
|
|
+
|
|
+ else:
|
|
+ attestion_record_rows = session.query(tbl).filter(tbl.agentid == record_identifier)
|
|
+
|
|
+ for row in attestion_record_rows:
|
|
+ decoded_record_object = self.record_deserialize(row.record)
|
|
+ self.record_signature_check(decoded_record_object, record_identifier)
|
|
+ record_list.append(decoded_record_object)
|
|
+
|
|
return record_list
|
|
|
|
def build_key_list(self, agent_identifier, service="auto"):
|
|
diff --git a/keylime/db/keylime_db.py b/keylime/db/keylime_db.py
|
|
index 5620a28..aa49e51 100644
|
|
--- a/keylime/db/keylime_db.py
|
|
+++ b/keylime/db/keylime_db.py
|
|
@@ -1,7 +1,8 @@
|
|
import os
|
|
from configparser import NoOptionError
|
|
+from contextlib import contextmanager
|
|
from sqlite3 import Connection as SQLite3Connection
|
|
-from typing import Any, Dict, Optional, cast
|
|
+from typing import Any, Iterator, Optional, cast
|
|
|
|
from sqlalchemy import create_engine, event
|
|
from sqlalchemy.engine import Engine
|
|
@@ -22,90 +23,108 @@ def _set_sqlite_pragma(dbapi_connection: SQLite3Connection, _) -> None:
|
|
cursor.close()
|
|
|
|
|
|
-class DBEngineManager:
|
|
- service: Optional[str]
|
|
-
|
|
- def __init__(self) -> None:
|
|
- self.service = None
|
|
-
|
|
- def make_engine(self, service: str) -> Engine:
|
|
- """
|
|
- To use: engine = self.make_engine('cloud_verifier')
|
|
- """
|
|
-
|
|
- # Keep DB related stuff as it is, but read configuration from new
|
|
- # configs
|
|
- if service == "cloud_verifier":
|
|
- config_service = "verifier"
|
|
- else:
|
|
- config_service = service
|
|
-
|
|
- self.service = service
|
|
-
|
|
- try:
|
|
- p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl")
|
|
- p_sz, m_ovfl = p_sz_m_ovfl.split(",")
|
|
- except NoOptionError:
|
|
- p_sz = "5"
|
|
- m_ovfl = "10"
|
|
-
|
|
- engine_args: Dict[str, Any] = {}
|
|
-
|
|
- url = config.get(config_service, "database_url")
|
|
- if url:
|
|
- logger.info("database_url is set, using it to establish database connection")
|
|
-
|
|
- # If the keyword sqlite is provided as the database url, use the
|
|
- # cv_data.sqlite for the verifier or the file reg_data.sqlite for
|
|
- # the registrar, located at the config.WORK_DIR directory
|
|
- if url == "sqlite":
|
|
+def make_engine(service: str, **engine_args: Any) -> Engine:
|
|
+ """Create a database engine for a keylime service."""
|
|
+ # Keep DB related stuff as it is, but read configuration from new
|
|
+ # configs
|
|
+ if service == "cloud_verifier":
|
|
+ config_service = "verifier"
|
|
+ else:
|
|
+ config_service = service
|
|
+
|
|
+ url = config.get(config_service, "database_url")
|
|
+ if url:
|
|
+ logger.info("database_url is set, using it to establish database connection")
|
|
+
|
|
+ # If the keyword sqlite is provided as the database url, use the
|
|
+ # cv_data.sqlite for the verifier or the file reg_data.sqlite for
|
|
+ # the registrar, located at the config.WORK_DIR directory
|
|
+ if url == "sqlite":
|
|
+ logger.info(
|
|
+ "database_url is set as 'sqlite' keyword, using default values to establish database connection"
|
|
+ )
|
|
+ if service == "cloud_verifier":
|
|
+ database = "cv_data.sqlite"
|
|
+ elif service == "registrar":
|
|
+ database = "reg_data.sqlite"
|
|
+ else:
|
|
+ logger.error("Tried to setup database access for unknown service '%s'", service)
|
|
+ raise Exception(f"Unknown service '{service}' for database setup")
|
|
+
|
|
+ database_file = os.path.abspath(os.path.join(config.WORK_DIR, database))
|
|
+ url = f"sqlite:///{database_file}"
|
|
+
|
|
+ kl_dir = os.path.dirname(os.path.abspath(database_file))
|
|
+ if not os.path.exists(kl_dir):
|
|
+ os.makedirs(kl_dir, 0o700)
|
|
+
|
|
+ engine_args["connect_args"] = {"check_same_thread": False}
|
|
+
|
|
+ if not url.count("sqlite:"):
|
|
+ # sqlite does not support setting pool size and max overflow, only
|
|
+ # read from the config when it is going to be used
|
|
+ try:
|
|
+ p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl")
|
|
+ p_sz, m_ovfl = p_sz_m_ovfl.split(",")
|
|
+ logger.info("database_pool_sz_ovfl is set, pool size = %s, max overflow = %s", p_sz, m_ovfl)
|
|
+ except NoOptionError:
|
|
+ p_sz = "5"
|
|
+ m_ovfl = "10"
|
|
logger.info(
|
|
- "database_url is set as 'sqlite' keyword, using default values to establish database connection"
|
|
+ "database_pool_sz_ovfl is not set, using default pool size = %s, max overflow = %s", p_sz, m_ovfl
|
|
)
|
|
- if service == "cloud_verifier":
|
|
- database = "cv_data.sqlite"
|
|
- elif service == "registrar":
|
|
- database = "reg_data.sqlite"
|
|
- else:
|
|
- logger.error("Tried to setup database access for unknown service '%s'", service)
|
|
- raise Exception(f"Unknown service '{service}' for database setup")
|
|
-
|
|
- database_file = os.path.abspath(os.path.join(config.WORK_DIR, database))
|
|
- url = f"sqlite:///{database_file}"
|
|
-
|
|
- kl_dir = os.path.dirname(os.path.abspath(database_file))
|
|
- if not os.path.exists(kl_dir):
|
|
- os.makedirs(kl_dir, 0o700)
|
|
-
|
|
- engine_args["connect_args"] = {"check_same_thread": False}
|
|
|
|
- if not url.count("sqlite:"):
|
|
- engine_args["pool_size"] = int(p_sz)
|
|
- engine_args["max_overflow"] = int(m_ovfl)
|
|
- engine_args["pool_pre_ping"] = True
|
|
+ engine_args["pool_size"] = int(p_sz)
|
|
+ engine_args["max_overflow"] = int(m_ovfl)
|
|
+ engine_args["pool_pre_ping"] = True
|
|
|
|
- # Enable DB debugging
|
|
- if config.DEBUG_DB and config.INSECURE_DEBUG:
|
|
- engine_args["echo"] = True
|
|
+ # Enable DB debugging
|
|
+ if config.DEBUG_DB and config.INSECURE_DEBUG:
|
|
+ engine_args["echo"] = True
|
|
|
|
- engine = create_engine(url, **engine_args)
|
|
- return engine
|
|
+ engine = create_engine(url, **engine_args)
|
|
+ return engine
|
|
|
|
|
|
class SessionManager:
|
|
engine: Optional[Engine]
|
|
+ _scoped_session: Optional[scoped_session]
|
|
|
|
def __init__(self) -> None:
|
|
self.engine = None
|
|
+ self._scoped_session = None
|
|
|
|
def make_session(self, engine: Engine) -> Session:
|
|
"""
|
|
To use: session = self.make_session(engine)
|
|
"""
|
|
self.engine = engine
|
|
- my_session = scoped_session(sessionmaker())
|
|
+ if self._scoped_session is None:
|
|
+ self._scoped_session = scoped_session(sessionmaker())
|
|
try:
|
|
- my_session.configure(bind=self.engine) # type: ignore
|
|
+ self._scoped_session.configure(bind=self.engine) # type: ignore
|
|
+ self._scoped_session.configure(expire_on_commit=False) # type: ignore
|
|
except SQLAlchemyError as err:
|
|
logger.error("Error creating SQL session manager %s", err)
|
|
- return cast(Session, my_session())
|
|
+ return cast(Session, self._scoped_session())
|
|
+
|
|
+ @contextmanager
|
|
+ def session_context(self, engine: Engine) -> Iterator[Session]:
|
|
+ """
|
|
+ Context manager for database sessions that ensures proper cleanup.
|
|
+ To use:
|
|
+ with session_manager.session_context(engine) as session:
|
|
+ # use session
|
|
+ """
|
|
+ session = self.make_session(engine)
|
|
+ try:
|
|
+ yield session
|
|
+ session.commit()
|
|
+ except Exception:
|
|
+ session.rollback()
|
|
+ raise
|
|
+ finally:
|
|
+ # Important: remove the session from the scoped session registry
|
|
+ # to prevent connection leaks with scoped_session
|
|
+ if self._scoped_session is not None:
|
|
+ self._scoped_session.remove() # type: ignore[no-untyped-call]
|
|
diff --git a/keylime/migrations/env.py b/keylime/migrations/env.py
|
|
index ac98349..a1881f2 100644
|
|
--- a/keylime/migrations/env.py
|
|
+++ b/keylime/migrations/env.py
|
|
@@ -8,7 +8,7 @@ import sys
|
|
|
|
from alembic import context
|
|
|
|
-from keylime.db.keylime_db import DBEngineManager
|
|
+from keylime.db.keylime_db import make_engine
|
|
from keylime.db.registrar_db import Base as RegistrarBase
|
|
from keylime.db.verifier_db import Base as VerifierBase
|
|
|
|
@@ -74,7 +74,7 @@ def run_migrations_offline():
|
|
logger.info("Writing output to %s", file_)
|
|
|
|
with open(file_, "w", encoding="utf-8") as buffer:
|
|
- engine = DBEngineManager().make_engine(name)
|
|
+ engine = make_engine(name)
|
|
connection = engine.connect()
|
|
context.configure(
|
|
connection=connection,
|
|
@@ -102,7 +102,7 @@ def run_migrations_online():
|
|
engines = {}
|
|
for name in re.split(r",\s*", db_names):
|
|
engines[name] = rec = {}
|
|
- rec["engine"] = DBEngineManager().make_engine(name)
|
|
+ rec["engine"] = make_engine(name)
|
|
|
|
for name, rec in engines.items():
|
|
engine = rec["engine"]
|
|
diff --git a/keylime/models/base/db.py b/keylime/models/base/db.py
|
|
index dd47d63..0229765 100644
|
|
--- a/keylime/models/base/db.py
|
|
+++ b/keylime/models/base/db.py
|
|
@@ -41,13 +41,6 @@ class DBManager:
|
|
|
|
self._service = service
|
|
|
|
- try:
|
|
- p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl")
|
|
- p_sz, m_ovfl = p_sz_m_ovfl.split(",")
|
|
- except NoOptionError:
|
|
- p_sz = "5"
|
|
- m_ovfl = "10"
|
|
-
|
|
engine_args: Dict[str, Any] = {}
|
|
|
|
url = config.get(config_service, "database_url")
|
|
@@ -79,6 +72,21 @@ class DBManager:
|
|
engine_args["connect_args"] = {"check_same_thread": False}
|
|
|
|
if not url.count("sqlite:"):
|
|
+ # sqlite does not support setting pool size and max overflow, only
|
|
+ # read from the config when it is going to be used
|
|
+ try:
|
|
+ p_sz_m_ovfl = config.get(config_service, "database_pool_sz_ovfl")
|
|
+ p_sz, m_ovfl = p_sz_m_ovfl.split(",")
|
|
+ logger.info("database_pool_sz_ovfl is set, pool size = %s, max overflow = %s", p_sz, m_ovfl)
|
|
+ except NoOptionError:
|
|
+ p_sz = "5"
|
|
+ m_ovfl = "10"
|
|
+ logger.info(
|
|
+ "database_pool_sz_ovfl is not set, using default pool size = %s, max overflow = %s",
|
|
+ p_sz,
|
|
+ m_ovfl,
|
|
+ )
|
|
+
|
|
engine_args["pool_size"] = int(p_sz)
|
|
engine_args["max_overflow"] = int(m_ovfl)
|
|
engine_args["pool_pre_ping"] = True
|
|
diff --git a/keylime/models/base/persistable_model.py b/keylime/models/base/persistable_model.py
|
|
index 18f7d0d..a779f0b 100644
|
|
--- a/keylime/models/base/persistable_model.py
|
|
+++ b/keylime/models/base/persistable_model.py
|
|
@@ -207,10 +207,16 @@ class PersistableModel(BasicModel, metaclass=PersistableModelMeta):
|
|
setattr(self._db_mapping_inst, name, field.data_type.db_dump(value, db_manager.engine.dialect))
|
|
|
|
with db_manager.session_context() as session:
|
|
- session.add(self._db_mapping_inst)
|
|
+ # Merge the potentially detached object into the new session
|
|
+ merged_instance = session.merge(self._db_mapping_inst)
|
|
+ session.add(merged_instance)
|
|
+ # Update our reference to the merged instance
|
|
+ self._db_mapping_inst = merged_instance # pylint: disable=attribute-defined-outside-init
|
|
|
|
self.clear_changes()
|
|
|
|
def delete(self) -> None:
|
|
with db_manager.session_context() as session:
|
|
- session.delete(self._db_mapping_inst) # type: ignore[no-untyped-call]
|
|
+ # Merge the potentially detached object into the new session before deleting
|
|
+ merged_instance = session.merge(self._db_mapping_inst)
|
|
+ session.delete(merged_instance) # type: ignore[no-untyped-call]
|
|
diff --git a/packit-ci.fmf b/packit-ci.fmf
|
|
index 2d1e5e5..cb64faf 100644
|
|
--- a/packit-ci.fmf
|
|
+++ b/packit-ci.fmf
|
|
@@ -101,6 +101,7 @@ adjust:
|
|
- /regression/CVE-2023-3674
|
|
- /regression/issue-1380-agent-removed-and-re-added
|
|
- /regression/keylime-agent-option-override-through-envvar
|
|
+ - /regression/db-connection-leak-reproducer
|
|
- /sanity/keylime-secure_mount
|
|
- /sanity/opened-conf-files
|
|
- /upstream/run_keylime_tests
|
|
diff --git a/test/test_verifier_db.py b/test/test_verifier_db.py
|
|
index ad72fa6..aae8f8a 100644
|
|
--- a/test/test_verifier_db.py
|
|
+++ b/test/test_verifier_db.py
|
|
@@ -172,3 +172,102 @@ class TestVerfierDB(unittest.TestCase):
|
|
|
|
def tearDown(self):
|
|
self.session.close()
|
|
+
|
|
+ def test_11_relationship_access_after_session_commit(self):
|
|
+ """Test that relationships can be accessed after session commits (DetachedInstanceError fix)"""
|
|
+ # This test reproduces the problematic pattern from cloud_verifier_tornado.py
|
|
+ # where objects are loaded with joinedload and then accessed after session closes
|
|
+
|
|
+ # Create a new session manager and context (like in cloud_verifier_tornado.py)
|
|
+ session_manager = SessionManager()
|
|
+
|
|
+ # First, load the agent with eager loading for relationships
|
|
+ stored_agent = None
|
|
+ with session_manager.session_context(self.engine) as session:
|
|
+ stored_agent = (
|
|
+ session.query(VerfierMain)
|
|
+ .options(joinedload(VerfierMain.ima_policy))
|
|
+ .options(joinedload(VerfierMain.mb_policy))
|
|
+ .filter_by(agent_id=agent_id)
|
|
+ .first()
|
|
+ )
|
|
+ # Verify agent was loaded correctly
|
|
+ self.assertIsNotNone(stored_agent)
|
|
+ # session.commit() is automatically called by context manager when exiting
|
|
+
|
|
+ # Now verify we can access relationships AFTER the session has been closed
|
|
+ # This would previously trigger DetachedInstanceError
|
|
+
|
|
+ # Ensure stored_agent is not None before proceeding
|
|
+ assert stored_agent is not None
|
|
+
|
|
+ # Test accessing ima_policy relationship
|
|
+ self.assertIsNotNone(stored_agent.ima_policy)
|
|
+ assert stored_agent.ima_policy is not None # Type narrowing for linter
|
|
+ self.assertEqual(stored_agent.ima_policy.name, "test-allowlist")
|
|
+ # checksum is not set in test data
|
|
+ self.assertEqual(stored_agent.ima_policy.checksum, None)
|
|
+
|
|
+ # Test accessing the ima_policy.ima_policy attribute (similar to verifier_read_policy_from_cache)
|
|
+ ima_policy_content = stored_agent.ima_policy.ima_policy
|
|
+ self.assertEqual(ima_policy_content, test_allowlist_data["ima_policy"])
|
|
+
|
|
+ # Test accessing mb_policy relationship
|
|
+ self.assertIsNotNone(stored_agent.mb_policy)
|
|
+ assert stored_agent.mb_policy is not None # Type narrowing for linter
|
|
+ self.assertEqual(stored_agent.mb_policy.name, "test-mbpolicy")
|
|
+
|
|
+ # Test accessing the mb_policy.mb_policy attribute (similar to process_agent function)
|
|
+ mb_policy_content = stored_agent.mb_policy.mb_policy
|
|
+ self.assertEqual(mb_policy_content, test_mbpolicy_data["mb_policy"])
|
|
+
|
|
+ # Test that we can access these relationships multiple times without issues
|
|
+ for _ in range(3):
|
|
+ self.assertIsNotNone(stored_agent.ima_policy.ima_policy)
|
|
+ self.assertIsNotNone(stored_agent.mb_policy.mb_policy)
|
|
+
|
|
+ def test_12_persistable_model_cross_session_fix(self):
|
|
+ """Test that PersistableModel can handle cross-session operations safely"""
|
|
+ # This test would previously fail with DetachedInstanceError before the fix
|
|
+ # Note: This is a conceptual test since we don't have actual PersistableModel
|
|
+ # subclasses in the test environment, but demonstrates the pattern
|
|
+
|
|
+ # Simulate creating a SQLAlchemy object in one session
|
|
+ session_manager = SessionManager()
|
|
+
|
|
+ # Load an object in one session context
|
|
+ test_agent = None
|
|
+ with session_manager.session_context(self.engine) as session:
|
|
+ test_agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first()
|
|
+ self.assertIsNotNone(test_agent)
|
|
+ # Session closes here
|
|
+
|
|
+ # Ensure test_agent is not None before proceeding
|
|
+ assert test_agent is not None
|
|
+
|
|
+ # Now simulate using this object in a different session context
|
|
+ # This tests the pattern where PersistableModel would use session.add() or session.delete()
|
|
+ # on a cross-session object
|
|
+ with session_manager.session_context(self.engine) as session:
|
|
+ # Before the fix, this would cause DetachedInstanceError
|
|
+ # The fix uses session.merge() to handle detached objects safely
|
|
+ merged_agent = session.merge(test_agent)
|
|
+ assert merged_agent is not None # Type narrowing for linter
|
|
+
|
|
+ # Test that we can modify and save the merged object
|
|
+ original_port = merged_agent.port
|
|
+ # Use setattr to avoid linter issues with Column assignment
|
|
+ setattr(merged_agent, "port", 9999)
|
|
+ session.add(merged_agent)
|
|
+ # session.commit() called automatically by context manager
|
|
+
|
|
+ # Verify the change was persisted
|
|
+ with session_manager.session_context(self.engine) as session:
|
|
+ updated_agent = session.query(VerfierMain).filter_by(agent_id=agent_id).first()
|
|
+ assert updated_agent is not None # Type narrowing for linter
|
|
+ self.assertEqual(updated_agent.port, 9999)
|
|
+
|
|
+ # Restore original value
|
|
+ # Use setattr to avoid linter issues
|
|
+ setattr(updated_agent, "port", original_port)
|
|
+ session.add(updated_agent)
|