Evaluate None to default (#3069)

* add sentinel value

* update typing

* clearer

* update comments

* ensure proper attribution
This commit is contained in:
pablodanswer
2024-11-07 10:41:42 -08:00
committed by GitHub
parent 2b1dbde829
commit 1d0fb6d012
3 changed files with 25 additions and 13 deletions

View File

@@ -323,16 +323,28 @@ async def get_async_session_with_tenant(
yield session yield session
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
@contextmanager @contextmanager
def get_session_with_tenant( def get_session_with_tenant(
tenant_id: str | None = None, tenant_id: str | None = None,
) -> Generator[Session, None, None]: ) -> Generator[Session, None, None]:
""" """
Generate a database session bound to a connection with the appropriate tenant schema set. Generate a database session for a specific tenant.
This preserves the tenant ID across the session and reverts to the previous tenant ID
after the session is closed. This function:
If tenant ID is set, we save the previous tenant ID from the context var to set 1. Sets the database schema to the specified tenant's schema.
after the session is closed. The value `None` evaluates to the default schema. 2. Preserves the tenant ID across the session.
3. Reverts to the previous tenant ID after the session is closed.
4. Uses the default schema if no tenant ID is provided.
""" """
engine = get_sqlalchemy_engine() engine = get_sqlalchemy_engine()
@@ -340,9 +352,9 @@ def get_session_with_tenant(
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None: if tenant_id is None:
tenant_id = previous_tenant_id tenant_id = POSTGRES_DEFAULT_SCHEMA
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout) event.listen(engine, "checkout", set_search_path_on_checkout)

View File

@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_session_with_default_tenant
from danswer.db.llm import fetch_embedding_provider from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt from danswer.db.models import IndexAttempt
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]: def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None: if db_session is None:
with get_session_with_tenant() as db_session: with get_session_with_default_tenant() as db_session:
search_settings = get_current_search_settings(db_session) search_settings = get_current_search_settings(db_session)
else: else:
search_settings = get_current_search_settings(db_session) search_settings = get_current_search_settings(db_session)

View File

@@ -15,7 +15,7 @@ from langchain_core.messages import SystemMessage
from pydantic import BaseModel from pydantic import BaseModel
from danswer.configs.constants import FileOrigin from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_session_with_default_tenant
from danswer.file_store.file_store import get_default_file_store from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile from danswer.file_store.models import InMemoryChatFile
@@ -187,7 +187,7 @@ class CustomTool(BaseTool):
def _save_and_get_file_references( def _save_and_get_file_references(
self, file_content: bytes | str, content_type: str self, file_content: bytes | str, content_type: str
) -> List[str]: ) -> List[str]:
with get_session_with_tenant() as db_session: with get_session_with_default_tenant() as db_session:
file_store = get_default_file_store(db_session) file_store = get_default_file_store(db_session)
file_id = str(uuid.uuid4()) file_id = str(uuid.uuid4())
@@ -299,7 +299,7 @@ class CustomTool(BaseTool):
# Load files from storage # Load files from storage
files = [] files = []
with get_session_with_tenant() as db_session: with get_session_with_default_tenant() as db_session:
file_store = get_default_file_store(db_session) file_store = get_default_file_store(db_session)
for file_id in response.tool_result.file_ids: for file_id in response.tool_result.file_ids: