Session Dependency for Chat Streaming (#1256)

This commit is contained in:
Yuhong Sun
2024-03-24 19:40:06 -07:00
committed by GitHub
parent 3107edc921
commit 7a861ecec4
2 changed files with 10 additions and 14 deletions

View File

@ -36,6 +36,7 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import SearchDoc as DbSearchDoc
@ -582,12 +583,12 @@ def stream_chat_message_objects(
def stream_chat_message( def stream_chat_message(
new_msg_req: CreateChatMessageRequest, new_msg_req: CreateChatMessageRequest,
user: User | None, user: User | None,
db_session: Session,
) -> Iterator[str]: ) -> Iterator[str]:
objects = stream_chat_message_objects( with get_session_context_manager() as db_session:
new_msg_req=new_msg_req, objects = stream_chat_message_objects(
user=user, new_msg_req=new_msg_req,
db_session=db_session, user=user,
) db_session=db_session,
for obj in objects: )
yield get_json_line(obj.dict()) for obj in objects:
yield get_json_line(obj.dict())

View File

@ -162,7 +162,6 @@ def delete_chat_session_by_id(
def handle_new_chat_message( def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest, chat_message_req: CreateChatMessageRequest,
user: User | None = Depends(current_user), user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse: ) -> StreamingResponse:
"""This endpoint is both used for all the following purposes: """This endpoint is both used for all the following purposes:
- Sending a new message in the session - Sending a new message in the session
@ -176,11 +175,7 @@ def handle_new_chat_message(
if not chat_message_req.message and chat_message_req.prompt_id is not None: if not chat_message_req.message and chat_message_req.prompt_id is not None:
raise HTTPException(status_code=400, detail="Empty chat message is invalid") raise HTTPException(status_code=400, detail="Empty chat message is invalid")
packets = stream_chat_message( packets = stream_chat_message(new_msg_req=chat_message_req, user=user)
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
)
return StreamingResponse(packets, media_type="application/json") return StreamingResponse(packets, media_type="application/json")