diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 93ad3bdd3..2e79e2006 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -271,6 +271,7 @@ def stream_chat_message_objects( use_existing_user_message: bool = False, litellm_additional_headers: dict[str, str] | None = None, is_connected: Callable[[], bool] | None = None, + enforce_chat_session_id_for_search_docs: bool = True, ) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run @@ -442,6 +443,7 @@ def stream_chat_message_objects( chat_session=chat_session, user_id=user_id, db_session=db_session, + enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs, ) # Generates full documents currently diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 8599714ce..feb2e2b4b 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -598,6 +598,7 @@ def get_doc_query_identifiers_from_model( chat_session: ChatSession, user_id: UUID | None, db_session: Session, + enforce_chat_session_id_for_search_docs: bool, ) -> list[tuple[str, int]]: """Given a list of search_doc_ids""" search_docs = ( @@ -617,7 +618,8 @@ def get_doc_query_identifiers_from_model( for doc in search_docs ] ): - raise ValueError("Invalid reference doc, not from this chat session.") + if enforce_chat_session_id_for_search_docs: + raise ValueError("Invalid reference doc, not from this chat session.") except IndexError: # This happens when the doc has no chat_messages associated with it. # which happens as an edge case where the chat message failed to save diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index f5b56b7d4..dd637dcf0 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -182,6 +182,7 @@ def handle_simplified_chat_message( new_msg_req=full_chat_msg_info, user=user, db_session=db_session, + enforce_chat_session_id_for_search_docs=False, ) return _convert_packet_stream_to_response(packets) @@ -301,6 +302,7 @@ def handle_send_message_simple_with_history( new_msg_req=full_chat_msg_info, user=user, db_session=db_session, + enforce_chat_session_id_for_search_docs=False, ) return _convert_packet_stream_to_response(packets)