Add Retrieval to Chat History (#577)

This commit is contained in:
Yuhong Sun 2023-10-15 13:40:07 -07:00 committed by GitHub
parent d2f7dff464
commit 595f61ea3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 163 additions and 30 deletions

View File

@ -0,0 +1,31 @@
"""Store Chat Retrieval Docs
Revision ID: 7ccea01261f6
Revises: a570b80a5f20
Create Date: 2023-10-15 10:39:23.317453
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "7ccea01261f6"
down_revision = "a570b80a5f20"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column(
"reference_docs",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("chat_message", "reference_docs")

View File

@ -1,7 +1,6 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from uuid import UUID
from langchain.schema.messages import AIMessage
@ -219,6 +218,61 @@ def llm_contextless_chat_answer(
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
def extract_citations_from_stream(
tokens: Iterator[str], links: list[str | None]
) -> Iterator[str]:
if not links:
yield from tokens
return
max_citation_num = len(links) + 1 # LLM is prompted to 1 index these
curr_segment = ""
prepend_bracket = False
for token in tokens:
# Special case of [1][ where ][ is a single token
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
curr_segment += token
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
if citation_found:
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
link = links[numerical_value - 1]
if link:
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
# In case there's another open bracket like [1][, don't want to match this
possible_citation_found = None
# if we see "[", but haven't seen the right side, hold back - this may be a
# citation that needs to be replaced with a link
if possible_citation_found:
continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield curr_segment
curr_segment = ""
if curr_segment:
if prepend_bracket:
yield "[" + curr_segment
else:
yield curr_segment
def llm_contextual_chat_answer(
messages: list[ChatMessage],
persona: Persona,
@ -261,7 +315,6 @@ def llm_contextual_chat_answer(
# Model will output "Yes Search" if search is useful
# Be a little forgiving though, if we match yes, it's good enough
citation_max_num: int | None = None
retrieved_chunks: list[InferenceChunk] = []
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
retrieved_chunks = danswer_chat_retrieval(
@ -270,7 +323,6 @@ def llm_contextual_chat_answer(
llm=llm,
user_id=user_id,
)
citation_max_num = len(retrieved_chunks) + 1
yield retrieved_chunks
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
@ -299,23 +351,13 @@ def llm_contextual_chat_answer(
final_msg_token_count=last_user_msg_tokens,
)
curr_segment = ""
for token in llm.stream(prompt):
curr_segment += token
tokens = llm.stream(prompt)
links = [
chunk.source_links[0] if chunk.source_links else None
for chunk in retrieved_chunks
]
pattern = r"\[(\d+)\]" # [1], [2] etc
found = re.search(pattern, curr_segment)
if found:
numerical_value = int(found.group(1))
if citation_max_num and 1 <= numerical_value <= citation_max_num:
reference_chunk = retrieved_chunks[numerical_value - 1]
if reference_chunk.source_links and reference_chunk.source_links[0]:
link = reference_chunk.source_links[0]
token = re.sub("]", f"]({link})", token)
curr_segment = ""
yield token
yield from extract_citations_from_stream(tokens, links)
except Exception as e:
logger.error(f"LLM failed to produce valid chat message, error: {e}")
@ -327,7 +369,7 @@ def llm_tools_enabled_chat_answer(
persona: Persona,
user_id: UUID | None,
tokenizer: Callable,
) -> Iterator[str]:
) -> Iterator[str | list[InferenceChunk]]:
retrieval_enabled = persona.retrieval_enabled
system_text = persona.system_text
hint_text = persona.hint_text
@ -405,6 +447,7 @@ def llm_tools_enabled_chat_answer(
llm=llm,
user_id=user_id,
)
yield retrieved_chunks
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
else:
tool_result_str = call_tool(final_result, user_id=user_id)
@ -451,6 +494,15 @@ def llm_tools_enabled_chat_answer(
yield LLM_CHAT_FAILURE_MSG
def wrap_chat_package_in_model(
package: str | list[InferenceChunk],
) -> DanswerAnswerPiece | RetrievalDocs:
if isinstance(package, str):
return DanswerAnswerPiece(answer_piece=package)
elif isinstance(package, list):
return RetrievalDocs(top_documents=chunks_to_search_docs(package))
def llm_chat_answer(
messages: list[ChatMessage],
persona: Persona | None,
@ -483,18 +535,11 @@ def llm_chat_answer(
for package in llm_contextual_chat_answer(
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
):
if isinstance(package, str):
yield DanswerAnswerPiece(answer_piece=package)
elif isinstance(package, list):
yield RetrievalDocs(
top_documents=chunks_to_search_docs(
cast(list[InferenceChunk], package)
)
)
yield wrap_chat_package_in_model(package)
# Use most flexible/complex prompt format
else:
for token in llm_tools_enabled_chat_answer(
for package in llm_tools_enabled_chat_answer(
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
):
yield DanswerAnswerPiece(answer_piece=token)
yield wrap_chat_package_in_model(package)

View File

@ -1,3 +1,4 @@
from typing import Any
from uuid import UUID
from sqlalchemy import and_
@ -191,6 +192,7 @@ def create_new_chat_message(
parent_edit_number: int | None,
message_type: MessageType,
db_session: Session,
retrieval_docs: dict[str, Any] | None = None,
) -> ChatMessage:
"""Creates a new chat message and sets it to the latest message of its parent message"""
# Get the count of existing edits at the provided message number
@ -213,6 +215,7 @@ def create_new_chat_message(
parent_edit_number=parent_edit_number,
edit_number=new_edit_number,
message=message,
reference_docs=retrieval_docs,
token_count=token_count,
message_type=message_type,
)

View File

@ -488,6 +488,9 @@ class ChatMessage(Base):
message: Mapped[str] = mapped_column(Text)
token_count: Mapped[int] = mapped_column(Integer)
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
reference_docs: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)

View File

@ -37,6 +37,7 @@ from danswer.server.models import CreateChatMessageRequest
from danswer.server.models import CreateChatSessionID
from danswer.server.models import RegenerateMessageRequest
from danswer.server.models import RenameChatSessionResponse
from danswer.server.models import RetrievalDocs
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@ -110,6 +111,9 @@ def get_chat_session_messages(
parent_edit_number=msg.parent_edit_number,
latest=msg.latest,
message=msg.message,
context_docs=RetrievalDocs(**msg.reference_docs)
if msg.reference_docs
else None,
message_type=msg.message_type,
time_sent=msg.time_sent,
)
@ -308,11 +312,14 @@ def handle_new_chat_message(
tokenizer=llm_tokenizer,
)
llm_output = ""
fetched_docs: RetrievalDocs | None = None
for packet in response_packets:
if isinstance(packet, DanswerAnswerPiece):
token = packet.answer_piece
if token:
llm_output += token
elif isinstance(packet, RetrievalDocs):
fetched_docs = packet
yield get_json_line(packet.dict())
create_new_chat_message(
@ -322,6 +329,7 @@ def handle_new_chat_message(
message=llm_output,
token_count=len(llm_tokenizer(llm_output)),
message_type=MessageType.ASSISTANT,
retrieval_docs=fetched_docs.dict() if fetched_docs else None,
db_session=db_session,
)
@ -393,11 +401,14 @@ def regenerate_message_given_parent(
tokenizer=llm_tokenizer,
)
llm_output = ""
fetched_docs: RetrievalDocs | None = None
for packet in response_packets:
if isinstance(packet, DanswerAnswerPiece):
token = packet.answer_piece
if token:
llm_output += token
elif isinstance(packet, RetrievalDocs):
fetched_docs = packet
yield get_json_line(packet.dict())
create_new_chat_message(
@ -407,6 +418,7 @@ def regenerate_message_given_parent(
message=llm_output,
token_count=len(llm_tokenizer(llm_output)),
message_type=MessageType.ASSISTANT,
retrieval_docs=fetched_docs.dict() if fetched_docs else None,
db_session=db_session,
)

View File

@ -236,6 +236,7 @@ class ChatMessageDetail(BaseModel):
parent_edit_number: int | None
latest: bool
message: str
context_docs: RetrievalDocs | None
message_type: MessageType
time_sent: datetime

View File

@ -59,6 +59,7 @@ def send_chat_message(
def run_chat(contextual: bool) -> None:
try:
new_session_id = create_new_session()
print(f"Chat Session ID: {new_session_id}")
except requests.exceptions.ConnectionError:
print(
"Looks like you haven't started the Danswer Backend server, please run the FastAPI server"

View File

@ -0,0 +1,37 @@
import unittest
from danswer.chat.chat_llm import extract_citations_from_stream
class TestChatLlm(unittest.TestCase):
def test_citation_extraction(self) -> None:
links: list[str | None] = [f"link_{i}" for i in range(1, 21)]
test_1 = "Something [1]"
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
self.assertEqual(res, "Something [[1]](link_1)")
test_2 = "Something [14]"
res = "".join(list(extract_citations_from_stream(iter(test_2), links)))
self.assertEqual(res, "Something [[14]](link_14)")
test_3 = "Something [14][15]"
res = "".join(list(extract_citations_from_stream(iter(test_3), links)))
self.assertEqual(res, "Something [[14]](link_14)[[15]](link_15)")
test_4 = ["Something ", "[", "3", "][", "4", "]."]
res = "".join(list(extract_citations_from_stream(iter(test_4), links)))
self.assertEqual(res, "Something [[3]](link_3)[[4]](link_4).")
test_5 = ["Something ", "[", "31", "][", "4", "]."]
res = "".join(list(extract_citations_from_stream(iter(test_5), links)))
self.assertEqual(res, "Something [31][[4]](link_4).")
links[3] = None
test_1 = "Something [2][4][5]"
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
self.assertEqual(res, "Something [[2]](link_2)[4][[5]](link_5)")
if __name__ == "__main__":
unittest.main()