From bc28d5d228d005702f72e98646c8cad73196ccfb Mon Sep 17 00:00:00 2001 From: Sergio Arroutbi Date: Tue, 10 Mar 2026 13:22:04 +0100 Subject: [PATCH 4/6] Include thread-safe session management Replace open-ended SQLAlchemy sessions with a context manager that guarantees connection release, preventing QueuePool exhaustion under multi-host push attestation load. Key changes: - Add double-checked locking for lazy engine initialization to prevent race conditions in multi-threaded verifier - Delegate session lifecycle to SessionManager.session_context() which provides proper rollback on exception and scoped_session.remove() cleanup, eliminating thread-local connection leaks - Use session.expunge(agent) before exiting context manager scope so VerfierMain instances safely cross session boundaries without DetachedInstanceError - Scope with-blocks narrowly: connection is returned to pool before any subsequent DB calls (e.g. AuthSession.get_by_token) to prevent connection hoarding across separate pool boundaries Co-Authored-By: Claude Opus 4.6 Signed-off-by: Sergio Arroutbi --- keylime/models/verifier/auth_session.py | 15 +++--- keylime/web/verifier/session_controller.py | 6 +++ test/test_auth_session.py | 60 ++++++++++++++++------ 3 files changed, 59 insertions(+), 22 deletions(-) diff --git a/keylime/models/verifier/auth_session.py b/keylime/models/verifier/auth_session.py index 545995f..918dfb4 100644 --- a/keylime/models/verifier/auth_session.py +++ b/keylime/models/verifier/auth_session.py @@ -1,5 +1,6 @@ import base64 import hmac +import threading import uuid from contextlib import contextmanager from datetime import timedelta @@ -31,19 +32,19 @@ from keylime.tpm.tpm_main import Tpm logger = keylime_logging.init_logging("verifier") _engine = None +_engine_lock = threading.Lock() +_session_manager = SessionManager() @contextmanager def get_session_context() -> Iterator[Session]: global _engine if _engine is None: - _engine = make_engine("cloud_verifier") - session_manager = SessionManager() - session = session_manager.make_session(_engine) - try: + with _engine_lock: + if _engine is None: + _engine = make_engine("cloud_verifier") + with _session_manager.session_context(_engine) as session: yield session - finally: - session.close() class AuthSession(PersistableModel): @@ -283,6 +284,8 @@ class AuthSession(PersistableModel): .filter(VerfierMain.agent_id == auth_session.agent_id) # type: ignore[attr-defined] .one_or_none() ) + if agent: + session.expunge(agent) # type: ignore[no-untyped-call] return agent diff --git a/keylime/web/verifier/session_controller.py b/keylime/web/verifier/session_controller.py index 49cd758..3faa310 100644 --- a/keylime/web/verifier/session_controller.py +++ b/keylime/web/verifier/session_controller.py @@ -187,6 +187,8 @@ class SessionController(Controller): # Check if agent exists - this is where we validate enrollment with get_session_context() as session: agent = session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).one_or_none() + if agent: + session.expunge(agent) # type: ignore[no-untyped-call] if not agent: # Agent not enrolled - return 200 with evaluation:fail @@ -382,6 +384,8 @@ class SessionController(Controller): def create(self, agent_id, **params): with get_session_context() as session: agent = session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).one_or_none() + if agent: + session.expunge(agent) # type: ignore[no-untyped-call] if not agent: self.respond(404, "here") @@ -405,6 +409,8 @@ class SessionController(Controller): def update(self, agent_id, token, **params): with get_session_context() as session: agent = session.query(VerfierMain).filter(VerfierMain.agent_id == agent_id).one_or_none() + if agent: + session.expunge(agent) # type: ignore[no-untyped-call] # Look up session by token hash (tokens are never stored in plaintext) auth_session = AuthSession.get_by_token(token) diff --git a/test/test_auth_session.py b/test/test_auth_session.py index 8e9ec98..2c78547 100644 --- a/test/test_auth_session.py +++ b/test/test_auth_session.py @@ -2,6 +2,7 @@ import base64 import unittest +from contextlib import contextmanager from datetime import timedelta from unittest.mock import MagicMock, PropertyMock, patch @@ -14,32 +15,59 @@ from keylime.shared_data import cleanup_global_shared_memory, get_shared_memory class TestGetSessionContext(unittest.TestCase): """Test cases for get_session_context context manager.""" + def _make_mock_session_manager(self, mock_session): + """Create a mock SessionManager whose session_context() mirrors real lifecycle.""" + mock_scoped = MagicMock() + mock_session_manager = MagicMock() + mock_session_manager.make_session.return_value = mock_session + mock_session_manager._scoped_session = mock_scoped # pylint: disable=protected-access + + @contextmanager + def fake_session_context(engine): # pylint: disable=unused-argument + session = mock_session_manager.make_session(engine) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + scoped = mock_session_manager._scoped_session # pylint: disable=protected-access + if scoped is not None: + scoped.remove() + + mock_session_manager.session_context = fake_session_context + return mock_session_manager, mock_scoped + @patch("keylime.models.verifier.auth_session.make_engine") - @patch("keylime.models.verifier.auth_session.SessionManager") - def test_session_closed_on_normal_exit(self, mock_session_manager_cls, _mock_make_engine): - """Test that session.close() is called when context manager exits normally.""" + def test_session_cleanup_on_normal_exit(self, _mock_make_engine): + """Test that session is committed and cleaned up when context manager exits normally.""" mock_session = MagicMock() - mock_session_manager_cls.return_value.make_session.return_value = mock_session + mock_session_manager, mock_scoped = self._make_mock_session_manager(mock_session) with patch("keylime.models.verifier.auth_session._engine", None): - with get_session_context() as session: - self.assertIs(session, mock_session) + with patch("keylime.models.verifier.auth_session._session_manager", mock_session_manager): + with get_session_context() as session: + self.assertIs(session, mock_session) - mock_session.close.assert_called_once() + mock_session.commit.assert_called_once() + mock_scoped.remove.assert_called_once() @patch("keylime.models.verifier.auth_session.make_engine") - @patch("keylime.models.verifier.auth_session.SessionManager") - def test_session_closed_on_exception(self, mock_session_manager_cls, _mock_make_engine): - """Test that session.close() is called even when an exception occurs.""" + def test_session_rollback_on_exception(self, _mock_make_engine): + """Test that session is rolled back and cleaned up when an exception occurs.""" mock_session = MagicMock() - mock_session_manager_cls.return_value.make_session.return_value = mock_session + mock_session_manager, mock_scoped = self._make_mock_session_manager(mock_session) with patch("keylime.models.verifier.auth_session._engine", None): - with self.assertRaises(RuntimeError): - with get_session_context(): - raise RuntimeError("simulated error") - - mock_session.close.assert_called_once() + with patch("keylime.models.verifier.auth_session._session_manager", mock_session_manager): + with self.assertRaises(RuntimeError): + with get_session_context(): + raise RuntimeError("simulated error") + + mock_session.rollback.assert_called_once() + mock_session.commit.assert_not_called() + mock_scoped.remove.assert_called_once() class TestAuthSessionHelpers(unittest.TestCase): -- 2.53.0