mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-09 06:22:18 +02:00
Evaluate None to default (#3069)
* add sentinel value * update typing * clearer * update comments * ensure proper attribution
This commit is contained in:
@@ -323,16 +323,28 @@ async def get_async_session_with_tenant(
|
|||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_session_with_default_tenant() -> Generator[Session, None, None]:
|
||||||
|
"""
|
||||||
|
Get a database session using the current tenant ID from the context variable.
|
||||||
|
"""
|
||||||
|
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||||
|
with get_session_with_tenant(tenant_id) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_session_with_tenant(
|
def get_session_with_tenant(
|
||||||
tenant_id: str | None = None,
|
tenant_id: str | None = None,
|
||||||
) -> Generator[Session, None, None]:
|
) -> Generator[Session, None, None]:
|
||||||
"""
|
"""
|
||||||
Generate a database session bound to a connection with the appropriate tenant schema set.
|
Generate a database session for a specific tenant.
|
||||||
This preserves the tenant ID across the session and reverts to the previous tenant ID
|
|
||||||
after the session is closed.
|
This function:
|
||||||
If tenant ID is set, we save the previous tenant ID from the context var to set
|
1. Sets the database schema to the specified tenant's schema.
|
||||||
after the session is closed. The value `None` evaluates to the default schema.
|
2. Preserves the tenant ID across the session.
|
||||||
|
3. Reverts to the previous tenant ID after the session is closed.
|
||||||
|
4. Uses the default schema if no tenant ID is provided.
|
||||||
"""
|
"""
|
||||||
engine = get_sqlalchemy_engine()
|
engine = get_sqlalchemy_engine()
|
||||||
|
|
||||||
@@ -340,9 +352,9 @@ def get_session_with_tenant(
|
|||||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
||||||
|
|
||||||
if tenant_id is None:
|
if tenant_id is None:
|
||||||
tenant_id = previous_tenant_id
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
else:
|
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
|
|
||||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||||
|
|
||||||
|
@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
|||||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||||
from danswer.db.engine import get_session_with_tenant
|
from danswer.db.engine import get_session_with_default_tenant
|
||||||
from danswer.db.llm import fetch_embedding_provider
|
from danswer.db.llm import fetch_embedding_provider
|
||||||
from danswer.db.models import CloudEmbeddingProvider
|
from danswer.db.models import CloudEmbeddingProvider
|
||||||
from danswer.db.models import IndexAttempt
|
from danswer.db.models import IndexAttempt
|
||||||
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
|||||||
|
|
||||||
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
||||||
if db_session is None:
|
if db_session is None:
|
||||||
with get_session_with_tenant() as db_session:
|
with get_session_with_default_tenant() as db_session:
|
||||||
search_settings = get_current_search_settings(db_session)
|
search_settings = get_current_search_settings(db_session)
|
||||||
else:
|
else:
|
||||||
search_settings = get_current_search_settings(db_session)
|
search_settings = get_current_search_settings(db_session)
|
||||||
|
@@ -15,7 +15,7 @@ from langchain_core.messages import SystemMessage
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from danswer.configs.constants import FileOrigin
|
from danswer.configs.constants import FileOrigin
|
||||||
from danswer.db.engine import get_session_with_tenant
|
from danswer.db.engine import get_session_with_default_tenant
|
||||||
from danswer.file_store.file_store import get_default_file_store
|
from danswer.file_store.file_store import get_default_file_store
|
||||||
from danswer.file_store.models import ChatFileType
|
from danswer.file_store.models import ChatFileType
|
||||||
from danswer.file_store.models import InMemoryChatFile
|
from danswer.file_store.models import InMemoryChatFile
|
||||||
@@ -187,7 +187,7 @@ class CustomTool(BaseTool):
|
|||||||
def _save_and_get_file_references(
|
def _save_and_get_file_references(
|
||||||
self, file_content: bytes | str, content_type: str
|
self, file_content: bytes | str, content_type: str
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
with get_session_with_tenant() as db_session:
|
with get_session_with_default_tenant() as db_session:
|
||||||
file_store = get_default_file_store(db_session)
|
file_store = get_default_file_store(db_session)
|
||||||
|
|
||||||
file_id = str(uuid.uuid4())
|
file_id = str(uuid.uuid4())
|
||||||
@@ -299,7 +299,7 @@ class CustomTool(BaseTool):
|
|||||||
|
|
||||||
# Load files from storage
|
# Load files from storage
|
||||||
files = []
|
files = []
|
||||||
with get_session_with_tenant() as db_session:
|
with get_session_with_default_tenant() as db_session:
|
||||||
file_store = get_default_file_store(db_session)
|
file_store = get_default_file_store(db_session)
|
||||||
|
|
||||||
for file_id in response.tool_result.file_ids:
|
for file_id in response.tool_result.file_ids:
|
||||||
|
Reference in New Issue
Block a user