From 595f61ea3a2d379d097ef974ccd6900d92abd947 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 15 Oct 2023 13:40:07 -0700 Subject: [PATCH] Add Retrieval to Chat History (#577) --- .../7ccea01261f6_store_chat_retrieval_docs.py | 31 ++++++ backend/danswer/chat/chat_llm.py | 105 +++++++++++++----- backend/danswer/db/chat.py | 3 + backend/danswer/db/models.py | 3 + backend/danswer/server/chat_backend.py | 12 ++ backend/danswer/server/models.py | 1 + backend/scripts/simulate_chat_frontend.py | 1 + .../tests/unit/danswer/chat/test_chat_llm.py | 37 ++++++ .../direct_qa/test_qa_utils.py} | 0 9 files changed, 163 insertions(+), 30 deletions(-) create mode 100644 backend/alembic/versions/7ccea01261f6_store_chat_retrieval_docs.py create mode 100644 backend/tests/unit/danswer/chat/test_chat_llm.py rename backend/tests/unit/{qa_service/direct_qa/test_question_answer.py => danswer/direct_qa/test_qa_utils.py} (100%) diff --git a/backend/alembic/versions/7ccea01261f6_store_chat_retrieval_docs.py b/backend/alembic/versions/7ccea01261f6_store_chat_retrieval_docs.py new file mode 100644 index 000000000..24d8ff717 --- /dev/null +++ b/backend/alembic/versions/7ccea01261f6_store_chat_retrieval_docs.py @@ -0,0 +1,31 @@ +"""Store Chat Retrieval Docs + +Revision ID: 7ccea01261f6 +Revises: a570b80a5f20 +Create Date: 2023-10-15 10:39:23.317453 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "7ccea01261f6" +down_revision = "a570b80a5f20" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_message", + sa.Column( + "reference_docs", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("chat_message", "reference_docs") diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index 9b9c2eee5..ec36b4891 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -1,7 +1,6 @@ import re from collections.abc import Callable from collections.abc import Iterator -from typing import cast from uuid import UUID from langchain.schema.messages import AIMessage @@ -219,6 +218,61 @@ def llm_contextless_chat_answer( return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator +def extract_citations_from_stream( + tokens: Iterator[str], links: list[str | None] +) -> Iterator[str]: + if not links: + yield from tokens + return + + max_citation_num = len(links) + 1 # LLM is prompted to 1 index these + curr_segment = "" + prepend_bracket = False + for token in tokens: + # Special case of [1][ where ][ is a single token + if prepend_bracket: + curr_segment += "[" + curr_segment + prepend_bracket = False + + curr_segment += token + + possible_citation_pattern = r"(\[\d*$)" # [1, [, etc + possible_citation_found = re.search(possible_citation_pattern, curr_segment) + + citation_pattern = r"\[(\d+)\]" # [1], [2] etc + citation_found = re.search(citation_pattern, curr_segment) + + if citation_found: + numerical_value = int(citation_found.group(1)) + if 1 <= numerical_value <= max_citation_num: + link = links[numerical_value - 1] + if link: + curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) + curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) + + # In case there's another open bracket like [1][, don't want to match this + possible_citation_found = None + + # if we see "[", but haven't seen the right side, hold back - this may be a + # citation that needs to be replaced with a link + if possible_citation_found: + continue + + # Special case with back to back citations [1][2] + if curr_segment and curr_segment[-1] == "[": + curr_segment = curr_segment[:-1] + prepend_bracket = True + + yield curr_segment + curr_segment = "" + + if curr_segment: + if prepend_bracket: + yield "[" + curr_segment + else: + yield curr_segment + + def llm_contextual_chat_answer( messages: list[ChatMessage], persona: Persona, @@ -261,7 +315,6 @@ def llm_contextual_chat_answer( # Model will output "Yes Search" if search is useful # Be a little forgiving though, if we match yes, it's good enough - citation_max_num: int | None = None retrieved_chunks: list[InferenceChunk] = [] if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower(): retrieved_chunks = danswer_chat_retrieval( @@ -270,7 +323,6 @@ def llm_contextual_chat_answer( llm=llm, user_id=user_id, ) - citation_max_num = len(retrieved_chunks) + 1 yield retrieved_chunks tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks) @@ -299,23 +351,13 @@ def llm_contextual_chat_answer( final_msg_token_count=last_user_msg_tokens, ) - curr_segment = "" - for token in llm.stream(prompt): - curr_segment += token + tokens = llm.stream(prompt) + links = [ + chunk.source_links[0] if chunk.source_links else None + for chunk in retrieved_chunks + ] - pattern = r"\[(\d+)\]" # [1], [2] etc - found = re.search(pattern, curr_segment) - - if found: - numerical_value = int(found.group(1)) - if citation_max_num and 1 <= numerical_value <= citation_max_num: - reference_chunk = retrieved_chunks[numerical_value - 1] - if reference_chunk.source_links and reference_chunk.source_links[0]: - link = reference_chunk.source_links[0] - token = re.sub("]", f"]({link})", token) - curr_segment = "" - - yield token + yield from extract_citations_from_stream(tokens, links) except Exception as e: logger.error(f"LLM failed to produce valid chat message, error: {e}") @@ -327,7 +369,7 @@ def llm_tools_enabled_chat_answer( persona: Persona, user_id: UUID | None, tokenizer: Callable, -) -> Iterator[str]: +) -> Iterator[str | list[InferenceChunk]]: retrieval_enabled = persona.retrieval_enabled system_text = persona.system_text hint_text = persona.hint_text @@ -405,6 +447,7 @@ def llm_tools_enabled_chat_answer( llm=llm, user_id=user_id, ) + yield retrieved_chunks tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks) else: tool_result_str = call_tool(final_result, user_id=user_id) @@ -451,6 +494,15 @@ def llm_tools_enabled_chat_answer( yield LLM_CHAT_FAILURE_MSG +def wrap_chat_package_in_model( + package: str | list[InferenceChunk], +) -> DanswerAnswerPiece | RetrievalDocs: + if isinstance(package, str): + return DanswerAnswerPiece(answer_piece=package) + elif isinstance(package, list): + return RetrievalDocs(top_documents=chunks_to_search_docs(package)) + + def llm_chat_answer( messages: list[ChatMessage], persona: Persona | None, @@ -483,18 +535,11 @@ def llm_chat_answer( for package in llm_contextual_chat_answer( messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer ): - if isinstance(package, str): - yield DanswerAnswerPiece(answer_piece=package) - elif isinstance(package, list): - yield RetrievalDocs( - top_documents=chunks_to_search_docs( - cast(list[InferenceChunk], package) - ) - ) + yield wrap_chat_package_in_model(package) # Use most flexible/complex prompt format else: - for token in llm_tools_enabled_chat_answer( + for package in llm_tools_enabled_chat_answer( messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer ): - yield DanswerAnswerPiece(answer_piece=token) + yield wrap_chat_package_in_model(package) diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 21b058b69..5b68705c7 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -1,3 +1,4 @@ +from typing import Any from uuid import UUID from sqlalchemy import and_ @@ -191,6 +192,7 @@ def create_new_chat_message( parent_edit_number: int | None, message_type: MessageType, db_session: Session, + retrieval_docs: dict[str, Any] | None = None, ) -> ChatMessage: """Creates a new chat message and sets it to the latest message of its parent message""" # Get the count of existing edits at the provided message number @@ -213,6 +215,7 @@ def create_new_chat_message( parent_edit_number=parent_edit_number, edit_number=new_edit_number, message=message, + reference_docs=retrieval_docs, token_count=token_count, message_type=message_type, ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index f4d468bfc..40ea0ac6e 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -488,6 +488,9 @@ class ChatMessage(Base): message: Mapped[str] = mapped_column(Text) token_count: Mapped[int] = mapped_column(Integer) message_type: Mapped[MessageType] = mapped_column(Enum(MessageType)) + reference_docs: Mapped[dict[str, Any] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) persona_id: Mapped[int | None] = mapped_column( ForeignKey("persona.id"), nullable=True ) diff --git a/backend/danswer/server/chat_backend.py b/backend/danswer/server/chat_backend.py index b25b7b141..d682a824a 100644 --- a/backend/danswer/server/chat_backend.py +++ b/backend/danswer/server/chat_backend.py @@ -37,6 +37,7 @@ from danswer.server.models import CreateChatMessageRequest from danswer.server.models import CreateChatSessionID from danswer.server.models import RegenerateMessageRequest from danswer.server.models import RenameChatSessionResponse +from danswer.server.models import RetrievalDocs from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -110,6 +111,9 @@ def get_chat_session_messages( parent_edit_number=msg.parent_edit_number, latest=msg.latest, message=msg.message, + context_docs=RetrievalDocs(**msg.reference_docs) + if msg.reference_docs + else None, message_type=msg.message_type, time_sent=msg.time_sent, ) @@ -308,11 +312,14 @@ def handle_new_chat_message( tokenizer=llm_tokenizer, ) llm_output = "" + fetched_docs: RetrievalDocs | None = None for packet in response_packets: if isinstance(packet, DanswerAnswerPiece): token = packet.answer_piece if token: llm_output += token + elif isinstance(packet, RetrievalDocs): + fetched_docs = packet yield get_json_line(packet.dict()) create_new_chat_message( @@ -322,6 +329,7 @@ def handle_new_chat_message( message=llm_output, token_count=len(llm_tokenizer(llm_output)), message_type=MessageType.ASSISTANT, + retrieval_docs=fetched_docs.dict() if fetched_docs else None, db_session=db_session, ) @@ -393,11 +401,14 @@ def regenerate_message_given_parent( tokenizer=llm_tokenizer, ) llm_output = "" + fetched_docs: RetrievalDocs | None = None for packet in response_packets: if isinstance(packet, DanswerAnswerPiece): token = packet.answer_piece if token: llm_output += token + elif isinstance(packet, RetrievalDocs): + fetched_docs = packet yield get_json_line(packet.dict()) create_new_chat_message( @@ -407,6 +418,7 @@ def regenerate_message_given_parent( message=llm_output, token_count=len(llm_tokenizer(llm_output)), message_type=MessageType.ASSISTANT, + retrieval_docs=fetched_docs.dict() if fetched_docs else None, db_session=db_session, ) diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index cf829ab5a..7b6a923ce 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -236,6 +236,7 @@ class ChatMessageDetail(BaseModel): parent_edit_number: int | None latest: bool message: str + context_docs: RetrievalDocs | None message_type: MessageType time_sent: datetime diff --git a/backend/scripts/simulate_chat_frontend.py b/backend/scripts/simulate_chat_frontend.py index 583623282..e3a4064c4 100644 --- a/backend/scripts/simulate_chat_frontend.py +++ b/backend/scripts/simulate_chat_frontend.py @@ -59,6 +59,7 @@ def send_chat_message( def run_chat(contextual: bool) -> None: try: new_session_id = create_new_session() + print(f"Chat Session ID: {new_session_id}") except requests.exceptions.ConnectionError: print( "Looks like you haven't started the Danswer Backend server, please run the FastAPI server" diff --git a/backend/tests/unit/danswer/chat/test_chat_llm.py b/backend/tests/unit/danswer/chat/test_chat_llm.py new file mode 100644 index 000000000..e29fe3771 --- /dev/null +++ b/backend/tests/unit/danswer/chat/test_chat_llm.py @@ -0,0 +1,37 @@ +import unittest + +from danswer.chat.chat_llm import extract_citations_from_stream + + +class TestChatLlm(unittest.TestCase): + def test_citation_extraction(self) -> None: + links: list[str | None] = [f"link_{i}" for i in range(1, 21)] + + test_1 = "Something [1]" + res = "".join(list(extract_citations_from_stream(iter(test_1), links))) + self.assertEqual(res, "Something [[1]](link_1)") + + test_2 = "Something [14]" + res = "".join(list(extract_citations_from_stream(iter(test_2), links))) + self.assertEqual(res, "Something [[14]](link_14)") + + test_3 = "Something [14][15]" + res = "".join(list(extract_citations_from_stream(iter(test_3), links))) + self.assertEqual(res, "Something [[14]](link_14)[[15]](link_15)") + + test_4 = ["Something ", "[", "3", "][", "4", "]."] + res = "".join(list(extract_citations_from_stream(iter(test_4), links))) + self.assertEqual(res, "Something [[3]](link_3)[[4]](link_4).") + + test_5 = ["Something ", "[", "31", "][", "4", "]."] + res = "".join(list(extract_citations_from_stream(iter(test_5), links))) + self.assertEqual(res, "Something [31][[4]](link_4).") + + links[3] = None + test_1 = "Something [2][4][5]" + res = "".join(list(extract_citations_from_stream(iter(test_1), links))) + self.assertEqual(res, "Something [[2]](link_2)[4][[5]](link_5)") + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/tests/unit/qa_service/direct_qa/test_question_answer.py b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py similarity index 100% rename from backend/tests/unit/qa_service/direct_qa/test_question_answer.py rename to backend/tests/unit/danswer/direct_qa/test_qa_utils.py