mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 20:08:36 +02:00
Evaluate None to default (#3069)
* add sentinel value * update typing * clearer * update comments * ensure proper attribution
This commit is contained in:
parent
2b1dbde829
commit
1d0fb6d012
@ -323,16 +323,28 @@ async def get_async_session_with_tenant(
|
||||
yield session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_default_tenant() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Get a database session using the current tenant ID from the context variable.
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session bound to a connection with the appropriate tenant schema set.
|
||||
This preserves the tenant ID across the session and reverts to the previous tenant ID
|
||||
after the session is closed.
|
||||
If tenant ID is set, we save the previous tenant ID from the context var to set
|
||||
after the session is closed. The value `None` evaluates to the default schema.
|
||||
Generate a database session for a specific tenant.
|
||||
|
||||
This function:
|
||||
1. Sets the database schema to the specified tenant's schema.
|
||||
2. Preserves the tenant ID across the session.
|
||||
3. Reverts to the previous tenant ID after the session is closed.
|
||||
4. Uses the default schema if no tenant ID is provided.
|
||||
"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
@ -340,9 +352,9 @@ def get_session_with_tenant(
|
||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = previous_tenant_id
|
||||
else:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
|
@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import IndexAttempt
|
||||
@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
|
||||
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
||||
if db_session is None:
|
||||
with get_session_with_tenant() as db_session:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
else:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
@ -15,7 +15,7 @@ from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
@ -187,7 +187,7 @@ class CustomTool(BaseTool):
|
||||
def _save_and_get_file_references(
|
||||
self, file_content: bytes | str, content_type: str
|
||||
) -> List[str]:
|
||||
with get_session_with_tenant() as db_session:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
file_id = str(uuid.uuid4())
|
||||
@ -299,7 +299,7 @@ class CustomTool(BaseTool):
|
||||
|
||||
# Load files from storage
|
||||
files = []
|
||||
with get_session_with_tenant() as db_session:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
for file_id in response.tool_result.file_ids:
|
||||
|
Loading…
x
Reference in New Issue
Block a user