mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-10 21:26:01 +02:00
111 lines
3.6 KiB
Python
111 lines
3.6 KiB
Python
import contextlib
|
|
from collections.abc import AsyncGenerator
|
|
from collections.abc import Generator
|
|
from contextlib import asynccontextmanager
|
|
from contextlib import contextmanager
|
|
from typing import ContextManager
|
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from onyx.configs.app_configs import POSTGRES_IDLE_SESSIONS_TIMEOUT
|
|
from onyx.db.engine import get_sqlalchemy_engine
|
|
from onyx.db.session_schema_translate_map import (
|
|
OnyxSchemaTranslateMapSession as OnyxSession,
|
|
)
|
|
from onyx.db.utils import is_valid_schema_name
|
|
from onyx.utils.logger import setup_logger
|
|
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
|
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
|
from shared_configs.contextvars import get_current_tenant_id
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
SessionFactory: sessionmaker[Session] | None = None
|
|
|
|
|
|
@contextmanager
|
|
def get_session_with_current_tenant() -> Generator[Session, None, None]:
|
|
tenant_id = get_current_tenant_id()
|
|
|
|
with OnyxSession.get_session_with_tenant(tenant_id=tenant_id) as session:
|
|
yield session
|
|
|
|
|
|
# Used in multi tenant mode when need to refer to the shared `public` schema
|
|
@contextmanager
|
|
def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
|
token = CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
|
with OnyxSession.get_session_with_tenant(
|
|
tenant_id=POSTGRES_DEFAULT_SCHEMA
|
|
) as session:
|
|
yield session
|
|
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
|
|
|
|
|
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
|
tenant_id = get_current_tenant_id()
|
|
with OnyxSession.get_session_with_tenant(tenant_id=tenant_id) as session:
|
|
yield session
|
|
|
|
|
|
def get_session_context_manager() -> ContextManager[Session]:
|
|
"""Context manager for database sessions."""
|
|
return contextlib.contextmanager(get_session_generator_with_tenant)()
|
|
|
|
|
|
def get_session_factory() -> sessionmaker[Session]:
|
|
global SessionFactory
|
|
if SessionFactory is None:
|
|
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
|
return SessionFactory
|
|
|
|
|
|
@contextmanager
|
|
def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]:
|
|
with OnyxSession.get_session_with_tenant(tenant_id=tenant_id) as session:
|
|
yield session
|
|
|
|
|
|
def get_session() -> Generator[Session, None, None]:
|
|
yield from OnyxSession.get_session()
|
|
|
|
|
|
def get_multi_tenant_session(tenant_id: str) -> Generator[Session, None, None]:
|
|
yield from OnyxSession.get_multi_tenant_session(tenant_id)
|
|
|
|
|
|
def get_single_tenant_session() -> Generator[Session, None, None]:
|
|
yield from OnyxSession.get_single_tenant_session()
|
|
|
|
|
|
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Proxy method that simply delegates to `get_async_session`."""
|
|
async for session in OnyxSession.get_async_session():
|
|
yield session
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_async_session_with_tenant(
|
|
tenant_id: str | None = None,
|
|
) -> AsyncGenerator[AsyncSession, None]:
|
|
if tenant_id is None:
|
|
tenant_id = get_current_tenant_id()
|
|
|
|
if not is_valid_schema_name(tenant_id):
|
|
logger.error(f"Invalid tenant ID: {tenant_id}")
|
|
raise ValueError("Invalid tenant ID")
|
|
|
|
async for session in OnyxSession.get_multi_tenant_async_session(tenant_id):
|
|
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
|
await session.execute(
|
|
text(
|
|
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
|
)
|
|
)
|
|
|
|
yield session
|