From 1d0fb6d012c1e58d82aed687439449452baf7736 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 7 Nov 2024 10:41:42 -0800 Subject: [PATCH] Evaluate None to default (#3069) * add sentinel value * update typing * clearer * update comments * ensure proper attribution --- backend/danswer/db/engine.py | 28 +++++++++++++------ backend/danswer/db/search_settings.py | 4 +-- .../custom/custom_tool.py | 6 ++-- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 0d085eb7c..639f6addc 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -323,16 +323,28 @@ async def get_async_session_with_tenant( yield session +@contextmanager +def get_session_with_default_tenant() -> Generator[Session, None, None]: + """ + Get a database session using the current tenant ID from the context variable. + """ + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() + with get_session_with_tenant(tenant_id) as session: + yield session + + @contextmanager def get_session_with_tenant( tenant_id: str | None = None, ) -> Generator[Session, None, None]: """ - Generate a database session bound to a connection with the appropriate tenant schema set. - This preserves the tenant ID across the session and reverts to the previous tenant ID - after the session is closed. - If tenant ID is set, we save the previous tenant ID from the context var to set - after the session is closed. The value `None` evaluates to the default schema. + Generate a database session for a specific tenant. + + This function: + 1. Sets the database schema to the specified tenant's schema. + 2. Preserves the tenant ID across the session. + 3. Reverts to the previous tenant ID after the session is closed. + 4. Uses the default schema if no tenant ID is provided. """ engine = get_sqlalchemy_engine() @@ -340,9 +352,9 @@ def get_session_with_tenant( previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA if tenant_id is None: - tenant_id = previous_tenant_id - else: - CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + tenant_id = POSTGRES_DEFAULT_SCHEMA + + CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) event.listen(engine, "checkout", set_search_path_on_checkout) diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index 5392ec234..342bb70b2 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS -from danswer.db.engine import get_session_with_tenant +from danswer.db.engine import get_session_with_default_tenant from danswer.db.llm import fetch_embedding_provider from danswer.db.models import CloudEmbeddingProvider from danswer.db.models import IndexAttempt @@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]: def get_multilingual_expansion(db_session: Session | None = None) -> list[str]: if db_session is None: - with get_session_with_tenant() as db_session: + with get_session_with_default_tenant() as db_session: search_settings = get_current_search_settings(db_session) else: search_settings = get_current_search_settings(db_session) diff --git a/backend/danswer/tools/tool_implementations/custom/custom_tool.py b/backend/danswer/tools/tool_implementations/custom/custom_tool.py index eace6d53a..9e9aa1216 100644 --- a/backend/danswer/tools/tool_implementations/custom/custom_tool.py +++ b/backend/danswer/tools/tool_implementations/custom/custom_tool.py @@ -15,7 +15,7 @@ from langchain_core.messages import SystemMessage from pydantic import BaseModel from danswer.configs.constants import FileOrigin -from danswer.db.engine import get_session_with_tenant +from danswer.db.engine import get_session_with_default_tenant from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import ChatFileType from danswer.file_store.models import InMemoryChatFile @@ -187,7 +187,7 @@ class CustomTool(BaseTool): def _save_and_get_file_references( self, file_content: bytes | str, content_type: str ) -> List[str]: - with get_session_with_tenant() as db_session: + with get_session_with_default_tenant() as db_session: file_store = get_default_file_store(db_session) file_id = str(uuid.uuid4()) @@ -299,7 +299,7 @@ class CustomTool(BaseTool): # Load files from storage files = [] - with get_session_with_tenant() as db_session: + with get_session_with_default_tenant() as db_session: file_store = get_default_file_store(db_session) for file_id in response.tool_result.file_ids: