valid but janky db session handling

This commit is contained in:
pablodanswer 2024-09-25 13:09:55 -07:00
parent 198f80d224
commit a69a0333a5
5 changed files with 46 additions and 9 deletions

View File

@ -0,0 +1,24 @@
"""initial persona set up
Revision ID: 6eb78875dbe0
Revises: b25c363470f3
Create Date: 2024-09-25 12:47:44.877589
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6eb78875dbe0"
down_revision = "b25c363470f3"
branch_labels = None
depends_on = None
def upgrade() -> None:
pass
def downgrade() -> None:
pass

View File

@ -12,7 +12,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "703313b75876"
down_revision = "fad14119fb92"
down_revision = "fad14119fb92"
branch_labels: None = None
depends_on: None = None
@ -55,4 +55,4 @@ def upgrade() -> None:
def downgrade() -> None:
op.drop_table("token_rate_limit__user_group")
op.drop_table("token_rate_limit")
op.drop_table("token_rate_limit")

View File

@ -185,17 +185,22 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()
global_tenant_id = ""
def get_current_tenant_id(request: Request) -> str:
def get_session_context_manager() -> ContextManager[Session]:
global global_tenant_id
return contextlib.contextmanager(lambda: get_session(override_tenant_id=global_tenant_id))()
def get_current_tenant_id(request: Request) -> str | None:
if not MULTI_TENANT:
return DEFAULT_SCHEMA
token = request.cookies.get("tenant_details")
if not token:
logger.warning("No token found in cookies")
raise HTTPException(status_code=401, detail="Authentication required")
return None
# raise HTTPException(status_code=401, detail="Authentication required")
try:
logger.info(f"Attempting to decode token: {token[:10]}...") # Log only first 10 characters for security
@ -207,6 +212,8 @@ def get_current_tenant_id(request: Request) -> str:
raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing")
logger.info(f"Valid tenant_id found: {tenant_id}")
current_tenant_id.set(tenant_id)
global global_tenant_id
global_tenant_id = tenant_id
return tenant_id
except DecodeError as e:
logger.error(f"JWT decode error: {str(e)}")
@ -218,9 +225,12 @@ def get_current_tenant_id(request: Request) -> str:
logger.exception(f"Unexpected error in get_current_tenant_id: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
def get_session(tenant_id: str | None= Depends(get_current_tenant_id)) -> Generator[Session, None, None]:
def get_session(tenant_id: str | None= Depends(get_current_tenant_id), override_tenant_id: str | None = None) -> Generator[Session, None, None]:
# try:
with Session(get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False) as session:
if override_tenant_id:
print("OVERRIDE TENANT ID")
print(override_tenant_id)
with Session(get_sqlalchemy_engine(schema=override_tenant_id or tenant_id), expire_on_commit=False) as session:
yield session
# finally:
# current_tenant_id.reset(tenant_id)

View File

@ -1,3 +1,4 @@
from danswer.db.engine import get_session_context_manager
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
@ -70,7 +71,8 @@ def get_default_llms(
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
llm_provider = fetch_default_provider(db_session)
with get_session_context_manager() as db_session:
llm_provider = fetch_default_provider(db_session)
if not llm_provider:
raise ValueError("No default LLM provider found")

View File

@ -296,6 +296,7 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_user),
_: None = Depends(check_token_rate_limits),
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),