Evaluate None to default (#3069)

* add sentinel value

* update typing

* clearer

* update comments

* ensure proper attribution
This commit is contained in:
pablodanswer 2024-11-07 10:41:42 -08:00 committed by GitHub
parent 2b1dbde829
commit 1d0fb6d012
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 13 deletions

View File

@ -323,16 +323,28 @@ async def get_async_session_with_tenant(
yield session
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""
Generate a database session bound to a connection with the appropriate tenant schema set.
This preserves the tenant ID across the session and reverts to the previous tenant ID
after the session is closed.
If tenant ID is set, we save the previous tenant ID from the context var to set
after the session is closed. The value `None` evaluates to the default schema.
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
3. Reverts to the previous tenant ID after the session is closed.
4. Uses the default schema if no tenant ID is provided.
"""
engine = get_sqlalchemy_engine()
@ -340,9 +352,9 @@ def get_session_with_tenant(
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = previous_tenant_id
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)

View File

@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)

View File

@ -15,7 +15,7 @@ from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_session_with_default_tenant
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
@ -187,7 +187,7 @@ class CustomTool(BaseTool):
def _save_and_get_file_references(
self, file_content: bytes | str, content_type: str
) -> List[str]:
with get_session_with_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_id = str(uuid.uuid4())
@ -299,7 +299,7 @@ class CustomTool(BaseTool):
# Load files from storage
files = []
with get_session_with_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
file_store = get_default_file_store(db_session)
for file_id in response.tool_result.file_ids: