Session Dependency for Chat Streaming (#1256)

This commit is contained in:
Yuhong Sun 2024-03-24 19:40:06 -07:00 committed by GitHub
parent 3107edc921
commit 7a861ecec4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 14 deletions

View File

@ -36,6 +36,7 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import SearchDoc as DbSearchDoc
@ -582,12 +583,12 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
) -> Iterator[str]:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
)
for obj in objects:
yield get_json_line(obj.dict())
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
)
for obj in objects:
yield get_json_line(obj.dict())

View File

@ -162,7 +162,6 @@ def delete_chat_session_by_id(
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
"""This endpoint is both used for all the following purposes:
- Sending a new message in the session
@ -176,11 +175,7 @@ def handle_new_chat_message(
if not chat_message_req.message and chat_message_req.prompt_id is not None:
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
packets = stream_chat_message(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
)
packets = stream_chat_message(new_msg_req=chat_message_req, user=user)
return StreamingResponse(packets, media_type="application/json")