mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 12:30:49 +02:00
valid but janky db session handling
This commit is contained in:
parent
198f80d224
commit
a69a0333a5
@ -0,0 +1,24 @@
|
||||
"""initial persona set up
|
||||
|
||||
Revision ID: 6eb78875dbe0
|
||||
Revises: b25c363470f3
|
||||
Create Date: 2024-09-25 12:47:44.877589
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6eb78875dbe0"
|
||||
down_revision = "b25c363470f3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
@ -12,7 +12,7 @@ import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "703313b75876"
|
||||
down_revision = "fad14119fb92"
|
||||
down_revision = "fad14119fb92"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
@ -55,4 +55,4 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("token_rate_limit__user_group")
|
||||
op.drop_table("token_rate_limit")
|
||||
op.drop_table("token_rate_limit")
|
@ -185,17 +185,22 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
return contextlib.contextmanager(get_session)()
|
||||
global_tenant_id = ""
|
||||
|
||||
def get_current_tenant_id(request: Request) -> str:
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
|
||||
global global_tenant_id
|
||||
return contextlib.contextmanager(lambda: get_session(override_tenant_id=global_tenant_id))()
|
||||
|
||||
def get_current_tenant_id(request: Request) -> str | None:
|
||||
if not MULTI_TENANT:
|
||||
return DEFAULT_SCHEMA
|
||||
|
||||
token = request.cookies.get("tenant_details")
|
||||
if not token:
|
||||
logger.warning("No token found in cookies")
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return None
|
||||
# raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
try:
|
||||
logger.info(f"Attempting to decode token: {token[:10]}...") # Log only first 10 characters for security
|
||||
@ -207,6 +212,8 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing")
|
||||
logger.info(f"Valid tenant_id found: {tenant_id}")
|
||||
current_tenant_id.set(tenant_id)
|
||||
global global_tenant_id
|
||||
global_tenant_id = tenant_id
|
||||
return tenant_id
|
||||
except DecodeError as e:
|
||||
logger.error(f"JWT decode error: {str(e)}")
|
||||
@ -218,9 +225,12 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
logger.exception(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
def get_session(tenant_id: str | None= Depends(get_current_tenant_id)) -> Generator[Session, None, None]:
|
||||
def get_session(tenant_id: str | None= Depends(get_current_tenant_id), override_tenant_id: str | None = None) -> Generator[Session, None, None]:
|
||||
# try:
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False) as session:
|
||||
if override_tenant_id:
|
||||
print("OVERRIDE TENANT ID")
|
||||
print(override_tenant_id)
|
||||
with Session(get_sqlalchemy_engine(schema=override_tenant_id or tenant_id), expire_on_commit=False) as session:
|
||||
yield session
|
||||
# finally:
|
||||
# current_tenant_id.reset(tenant_id)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
@ -70,7 +71,8 @@ def get_default_llms(
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No default LLM provider found")
|
||||
|
@ -296,6 +296,7 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
|
||||
|
Loading…
x
Reference in New Issue
Block a user