From a69a0333a5c80988c6b352ac46c574e1e804dfbf Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 25 Sep 2024 13:09:55 -0700 Subject: [PATCH] valid but janky db session handling --- .../6eb78875dbe0_initial_persona_set_up.py | 24 +++++++++++++++++++ .../703313b75876_add_tokenratelimit_tables.py | 4 ++-- backend/danswer/db/engine.py | 22 ++++++++++++----- backend/danswer/llm/factory.py | 4 +++- .../server/query_and_chat/chat_backend.py | 1 + 5 files changed, 46 insertions(+), 9 deletions(-) create mode 100644 backend/alembic/versions/6eb78875dbe0_initial_persona_set_up.py diff --git a/backend/alembic/versions/6eb78875dbe0_initial_persona_set_up.py b/backend/alembic/versions/6eb78875dbe0_initial_persona_set_up.py new file mode 100644 index 000000000..840523f7a --- /dev/null +++ b/backend/alembic/versions/6eb78875dbe0_initial_persona_set_up.py @@ -0,0 +1,24 @@ +"""initial persona set up + +Revision ID: 6eb78875dbe0 +Revises: b25c363470f3 +Create Date: 2024-09-25 12:47:44.877589 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "6eb78875dbe0" +down_revision = "b25c363470f3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py b/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py index dacc79644..003a54994 100644 --- a/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py +++ b/backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py @@ -12,7 +12,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "703313b75876" -down_revision = "fad14119fb92" +down_revision = "fad14119fb92" branch_labels: None = None depends_on: None = None @@ -55,4 +55,4 @@ def upgrade() -> None: def downgrade() -> None: op.drop_table("token_rate_limit__user_group") - op.drop_table("token_rate_limit") + op.drop_table("token_rate_limit") \ No newline at end of file diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 71ac9f91d..f96491606 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -185,17 +185,22 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: return _ASYNC_ENGINE -def get_session_context_manager() -> ContextManager[Session]: - return contextlib.contextmanager(get_session)() +global_tenant_id = "" -def get_current_tenant_id(request: Request) -> str: +def get_session_context_manager() -> ContextManager[Session]: + + global global_tenant_id + return contextlib.contextmanager(lambda: get_session(override_tenant_id=global_tenant_id))() + +def get_current_tenant_id(request: Request) -> str | None: if not MULTI_TENANT: return DEFAULT_SCHEMA token = request.cookies.get("tenant_details") if not token: logger.warning("No token found in cookies") - raise HTTPException(status_code=401, detail="Authentication required") + return None + # raise HTTPException(status_code=401, detail="Authentication required") try: logger.info(f"Attempting to decode token: {token[:10]}...") # Log only first 10 characters for security @@ -207,6 +212,8 @@ def get_current_tenant_id(request: Request) -> str: raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing") logger.info(f"Valid tenant_id found: {tenant_id}") current_tenant_id.set(tenant_id) + global global_tenant_id + global_tenant_id = tenant_id return tenant_id except DecodeError as e: logger.error(f"JWT decode error: {str(e)}") @@ -218,9 +225,12 @@ def get_current_tenant_id(request: Request) -> str: logger.exception(f"Unexpected error in get_current_tenant_id: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") -def get_session(tenant_id: str | None= Depends(get_current_tenant_id)) -> Generator[Session, None, None]: +def get_session(tenant_id: str | None= Depends(get_current_tenant_id), override_tenant_id: str | None = None) -> Generator[Session, None, None]: # try: - with Session(get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False) as session: + if override_tenant_id: + print("OVERRIDE TENANT ID") + print(override_tenant_id) + with Session(get_sqlalchemy_engine(schema=override_tenant_id or tenant_id), expire_on_commit=False) as session: yield session # finally: # current_tenant_id.reset(tenant_id) diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index 307a081fa..c15711503 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -1,3 +1,4 @@ +from danswer.db.engine import get_session_context_manager from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.model_configs import GEN_AI_TEMPERATURE @@ -70,7 +71,8 @@ def get_default_llms( if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() - llm_provider = fetch_default_provider(db_session) + with get_session_context_manager() as db_session: + llm_provider = fetch_default_provider(db_session) if not llm_provider: raise ValueError("No default LLM provider found") diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 1c1d60aad..b41147cfb 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -296,6 +296,7 @@ async def is_disconnected(request: Request) -> Callable[[], bool]: def handle_new_chat_message( chat_message_req: CreateChatMessageRequest, request: Request, + db_session: Session = Depends(get_session), user: User | None = Depends(current_user), _: None = Depends(check_token_rate_limits), is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),