Do not obtain DB session via Depends() (#1238)

Endpoints that use Depends(get_session) with a StreamingResponse have
the problem that Depends() releases the session again after the endpoint
function returns. At that point, the streaming response is not
finished yet but still holds a reference to the session and uses it.
However, there is no cleanup of the session after the answer stream
finishes which leads to the connections accumulating in state "idle in
transaction".

This was due to a breaking change in FastAPI 0.106.0
https://fastapi.tiangolo.com/release-notes/#01060

Co-authored-by: Johannes Vass <johannes.vass@cloudflight.io>
This commit is contained in:
Johannes Vass 2024-03-25 03:31:07 +01:00 committed by GitHub
parent 49263ed146
commit 3107edc921
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 12 deletions

View File

@ -1,6 +1,8 @@
import contextlib
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import datetime
from typing import ContextManager
from ddtrace import tracer
from sqlalchemy import text
@ -70,6 +72,10 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
def get_session_context_manager() -> ContextManager:
return contextlib.contextmanager(get_session)()
def get_session() -> Generator[Session, None, None]:
with tracer.trace("db.get_session"):
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:

View File

@ -34,6 +34,7 @@ from danswer.db.chat import get_persona_by_id
from danswer.db.chat import get_prompt_by_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.models import Prompt
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
@ -418,17 +419,17 @@ def stream_search_answer(
user: User | None,
max_document_tokens: int | None,
max_history_tokens: int | None,
db_session: Session,
) -> Iterator[str]:
objects = stream_answer_objects(
query_req=query_req,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
)
for obj in objects:
yield get_json_line(obj.dict())
with get_session_context_manager() as session:
objects = stream_answer_objects(
query_req=query_req,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=session,
)
for obj in objects:
yield get_json_line(obj.dict())
def get_search_answer(

View File

@ -148,7 +148,6 @@ def stream_query_validation(
def get_answer_with_quote(
query_request: DirectQARequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
query = query_request.messages[0].message
logger.info(f"Received query for one shot answer with quotes: {query}")
@ -157,6 +156,5 @@ def get_answer_with_quote(
user=user,
max_document_tokens=None,
max_history_tokens=0,
db_session=db_session,
)
return StreamingResponse(packets, media_type="application/json")