Add Retrieval to Chat History (#577)

This commit is contained in:
Yuhong Sun 2023-10-15 13:40:07 -07:00 committed by GitHub
parent d2f7dff464
commit 595f61ea3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 163 additions and 30 deletions

View File

@ -0,0 +1,31 @@
"""Store Chat Retrieval Docs
Revision ID: 7ccea01261f6
Revises: a570b80a5f20
Create Date: 2023-10-15 10:39:23.317453
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "7ccea01261f6"
down_revision = "a570b80a5f20"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"reference_docs",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("chat_message", "reference_docs")

View File

@ -1,7 +1,6 @@
import re import re
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from typing import cast
from uuid import UUID from uuid import UUID
from langchain.schema.messages import AIMessage from langchain.schema.messages import AIMessage
@ -219,6 +218,61 @@ def llm_contextless_chat_answer(
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
def extract_citations_from_stream(
tokens: Iterator[str], links: list[str | None]
) -> Iterator[str]:
if not links:
yield from tokens
return
max_citation_num = len(links) + 1 # LLM is prompted to 1 index these
curr_segment = ""
prepend_bracket = False
for token in tokens:
# Special case of [1][ where ][ is a single token
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
curr_segment += token
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
if citation_found:
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
link = links[numerical_value - 1]
if link:
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
# In case there's another open bracket like [1][, don't want to match this
possible_citation_found = None
# if we see "[", but haven't seen the right side, hold back - this may be a
# citation that needs to be replaced with a link
if possible_citation_found:
continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield curr_segment
curr_segment = ""
if curr_segment:
if prepend_bracket:
yield "[" + curr_segment
else:
yield curr_segment
def llm_contextual_chat_answer( def llm_contextual_chat_answer(
messages: list[ChatMessage], messages: list[ChatMessage],
persona: Persona, persona: Persona,
@ -261,7 +315,6 @@ def llm_contextual_chat_answer(
# Model will output "Yes Search" if search is useful # Model will output "Yes Search" if search is useful
# Be a little forgiving though, if we match yes, it's good enough # Be a little forgiving though, if we match yes, it's good enough
citation_max_num: int | None = None
retrieved_chunks: list[InferenceChunk] = [] retrieved_chunks: list[InferenceChunk] = []
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower(): if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
retrieved_chunks = danswer_chat_retrieval( retrieved_chunks = danswer_chat_retrieval(
@ -270,7 +323,6 @@ def llm_contextual_chat_answer(
llm=llm, llm=llm,
user_id=user_id, user_id=user_id,
) )
citation_max_num = len(retrieved_chunks) + 1
yield retrieved_chunks yield retrieved_chunks
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks) tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
@ -299,23 +351,13 @@ def llm_contextual_chat_answer(
final_msg_token_count=last_user_msg_tokens, final_msg_token_count=last_user_msg_tokens,
) )
curr_segment = "" tokens = llm.stream(prompt)
for token in llm.stream(prompt): links = [
curr_segment += token chunk.source_links[0] if chunk.source_links else None
for chunk in retrieved_chunks
]
pattern = r"\[(\d+)\]" # [1], [2] etc yield from extract_citations_from_stream(tokens, links)
found = re.search(pattern, curr_segment)
if found:
numerical_value = int(found.group(1))
if citation_max_num and 1 <= numerical_value <= citation_max_num:
reference_chunk = retrieved_chunks[numerical_value - 1]
if reference_chunk.source_links and reference_chunk.source_links[0]:
link = reference_chunk.source_links[0]
token = re.sub("]", f"]({link})", token)
curr_segment = ""
yield token
except Exception as e: except Exception as e:
logger.error(f"LLM failed to produce valid chat message, error: {e}") logger.error(f"LLM failed to produce valid chat message, error: {e}")
@ -327,7 +369,7 @@ def llm_tools_enabled_chat_answer(
persona: Persona, persona: Persona,
user_id: UUID | None, user_id: UUID | None,
tokenizer: Callable, tokenizer: Callable,
) -> Iterator[str]: ) -> Iterator[str | list[InferenceChunk]]:
retrieval_enabled = persona.retrieval_enabled retrieval_enabled = persona.retrieval_enabled
system_text = persona.system_text system_text = persona.system_text
hint_text = persona.hint_text hint_text = persona.hint_text
@ -405,6 +447,7 @@ def llm_tools_enabled_chat_answer(
llm=llm, llm=llm,
user_id=user_id, user_id=user_id,
) )
yield retrieved_chunks
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks) tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
else: else:
tool_result_str = call_tool(final_result, user_id=user_id) tool_result_str = call_tool(final_result, user_id=user_id)
@ -451,6 +494,15 @@ def llm_tools_enabled_chat_answer(
yield LLM_CHAT_FAILURE_MSG yield LLM_CHAT_FAILURE_MSG
def wrap_chat_package_in_model(
package: str | list[InferenceChunk],
) -> DanswerAnswerPiece | RetrievalDocs:
if isinstance(package, str):
return DanswerAnswerPiece(answer_piece=package)
elif isinstance(package, list):
return RetrievalDocs(top_documents=chunks_to_search_docs(package))
def llm_chat_answer( def llm_chat_answer(
messages: list[ChatMessage], messages: list[ChatMessage],
persona: Persona | None, persona: Persona | None,
@ -483,18 +535,11 @@ def llm_chat_answer(
for package in llm_contextual_chat_answer( for package in llm_contextual_chat_answer(
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
): ):
if isinstance(package, str): yield wrap_chat_package_in_model(package)
yield DanswerAnswerPiece(answer_piece=package)
elif isinstance(package, list):
yield RetrievalDocs(
top_documents=chunks_to_search_docs(
cast(list[InferenceChunk], package)
)
)
# Use most flexible/complex prompt format # Use most flexible/complex prompt format
else: else:
for token in llm_tools_enabled_chat_answer( for package in llm_tools_enabled_chat_answer(
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
): ):
yield DanswerAnswerPiece(answer_piece=token) yield wrap_chat_package_in_model(package)

View File

@ -1,3 +1,4 @@
from typing import Any
from uuid import UUID from uuid import UUID
from sqlalchemy import and_ from sqlalchemy import and_
@ -191,6 +192,7 @@ def create_new_chat_message(
parent_edit_number: int | None, parent_edit_number: int | None,
message_type: MessageType, message_type: MessageType,
db_session: Session, db_session: Session,
retrieval_docs: dict[str, Any] | None = None,
) -> ChatMessage: ) -> ChatMessage:
"""Creates a new chat message and sets it to the latest message of its parent message""" """Creates a new chat message and sets it to the latest message of its parent message"""
# Get the count of existing edits at the provided message number # Get the count of existing edits at the provided message number
@ -213,6 +215,7 @@ def create_new_chat_message(
parent_edit_number=parent_edit_number, parent_edit_number=parent_edit_number,
edit_number=new_edit_number, edit_number=new_edit_number,
message=message, message=message,
reference_docs=retrieval_docs,
token_count=token_count, token_count=token_count,
message_type=message_type, message_type=message_type,
) )

View File

@ -488,6 +488,9 @@ class ChatMessage(Base):
message: Mapped[str] = mapped_column(Text) message: Mapped[str] = mapped_column(Text)
token_count: Mapped[int] = mapped_column(Integer) token_count: Mapped[int] = mapped_column(Integer)
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType)) message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
reference_docs: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
persona_id: Mapped[int | None] = mapped_column( persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True ForeignKey("persona.id"), nullable=True
) )

View File

@ -37,6 +37,7 @@ from danswer.server.models import CreateChatMessageRequest
from danswer.server.models import CreateChatSessionID from danswer.server.models import CreateChatSessionID
from danswer.server.models import RegenerateMessageRequest from danswer.server.models import RegenerateMessageRequest
from danswer.server.models import RenameChatSessionResponse from danswer.server.models import RenameChatSessionResponse
from danswer.server.models import RetrievalDocs
from danswer.server.utils import get_json_line from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time from danswer.utils.timing import log_generator_function_time
@ -110,6 +111,9 @@ def get_chat_session_messages(
parent_edit_number=msg.parent_edit_number, parent_edit_number=msg.parent_edit_number,
latest=msg.latest, latest=msg.latest,
message=msg.message, message=msg.message,
context_docs=RetrievalDocs(**msg.reference_docs)
if msg.reference_docs
else None,
message_type=msg.message_type, message_type=msg.message_type,
time_sent=msg.time_sent, time_sent=msg.time_sent,
) )
@ -308,11 +312,14 @@ def handle_new_chat_message(
tokenizer=llm_tokenizer, tokenizer=llm_tokenizer,
) )
llm_output = "" llm_output = ""
fetched_docs: RetrievalDocs | None = None
for packet in response_packets: for packet in response_packets:
if isinstance(packet, DanswerAnswerPiece): if isinstance(packet, DanswerAnswerPiece):
token = packet.answer_piece token = packet.answer_piece
if token: if token:
llm_output += token llm_output += token
elif isinstance(packet, RetrievalDocs):
fetched_docs = packet
yield get_json_line(packet.dict()) yield get_json_line(packet.dict())
create_new_chat_message( create_new_chat_message(
@ -322,6 +329,7 @@ def handle_new_chat_message(
message=llm_output, message=llm_output,
token_count=len(llm_tokenizer(llm_output)), token_count=len(llm_tokenizer(llm_output)),
message_type=MessageType.ASSISTANT, message_type=MessageType.ASSISTANT,
retrieval_docs=fetched_docs.dict() if fetched_docs else None,
db_session=db_session, db_session=db_session,
) )
@ -393,11 +401,14 @@ def regenerate_message_given_parent(
tokenizer=llm_tokenizer, tokenizer=llm_tokenizer,
) )
llm_output = "" llm_output = ""
fetched_docs: RetrievalDocs | None = None
for packet in response_packets: for packet in response_packets:
if isinstance(packet, DanswerAnswerPiece): if isinstance(packet, DanswerAnswerPiece):
token = packet.answer_piece token = packet.answer_piece
if token: if token:
llm_output += token llm_output += token
elif isinstance(packet, RetrievalDocs):
fetched_docs = packet
yield get_json_line(packet.dict()) yield get_json_line(packet.dict())
create_new_chat_message( create_new_chat_message(
@ -407,6 +418,7 @@ def regenerate_message_given_parent(
message=llm_output, message=llm_output,
token_count=len(llm_tokenizer(llm_output)), token_count=len(llm_tokenizer(llm_output)),
message_type=MessageType.ASSISTANT, message_type=MessageType.ASSISTANT,
retrieval_docs=fetched_docs.dict() if fetched_docs else None,
db_session=db_session, db_session=db_session,
) )

View File

@ -236,6 +236,7 @@ class ChatMessageDetail(BaseModel):
parent_edit_number: int | None parent_edit_number: int | None
latest: bool latest: bool
message: str message: str
context_docs: RetrievalDocs | None
message_type: MessageType message_type: MessageType
time_sent: datetime time_sent: datetime

View File

@ -59,6 +59,7 @@ def send_chat_message(
def run_chat(contextual: bool) -> None: def run_chat(contextual: bool) -> None:
try: try:
new_session_id = create_new_session() new_session_id = create_new_session()
print(f"Chat Session ID: {new_session_id}")
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
print( print(
"Looks like you haven't started the Danswer Backend server, please run the FastAPI server" "Looks like you haven't started the Danswer Backend server, please run the FastAPI server"

View File

@ -0,0 +1,37 @@
import unittest
from danswer.chat.chat_llm import extract_citations_from_stream
class TestChatLlm(unittest.TestCase):
def test_citation_extraction(self) -> None:
links: list[str | None] = [f"link_{i}" for i in range(1, 21)]
test_1 = "Something [1]"
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
self.assertEqual(res, "Something [[1]](link_1)")
test_2 = "Something [14]"
res = "".join(list(extract_citations_from_stream(iter(test_2), links)))
self.assertEqual(res, "Something [[14]](link_14)")
test_3 = "Something [14][15]"
res = "".join(list(extract_citations_from_stream(iter(test_3), links)))
self.assertEqual(res, "Something [[14]](link_14)[[15]](link_15)")
test_4 = ["Something ", "[", "3", "][", "4", "]."]
res = "".join(list(extract_citations_from_stream(iter(test_4), links)))
self.assertEqual(res, "Something [[3]](link_3)[[4]](link_4).")
test_5 = ["Something ", "[", "31", "][", "4", "]."]
res = "".join(list(extract_citations_from_stream(iter(test_5), links)))
self.assertEqual(res, "Something [31][[4]](link_4).")
links[3] = None
test_1 = "Something [2][4][5]"
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
self.assertEqual(res, "Something [[2]](link_2)[4][[5]](link_5)")
if __name__ == "__main__":
unittest.main()