mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-12 22:09:36 +02:00
Add Retrieval to Chat History (#577)
This commit is contained in:
parent
d2f7dff464
commit
595f61ea3a
@ -0,0 +1,31 @@
|
||||
"""Store Chat Retrieval Docs
|
||||
|
||||
Revision ID: 7ccea01261f6
|
||||
Revises: a570b80a5f20
|
||||
Create Date: 2023-10-15 10:39:23.317453
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7ccea01261f6"
|
||||
down_revision = "a570b80a5f20"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"reference_docs",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "reference_docs")
|
@ -1,7 +1,6 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
@ -219,6 +218,61 @@ def llm_contextless_chat_answer(
|
||||
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
|
||||
|
||||
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str], links: list[str | None]
|
||||
) -> Iterator[str]:
|
||||
if not links:
|
||||
yield from tokens
|
||||
return
|
||||
|
||||
max_citation_num = len(links) + 1 # LLM is prompted to 1 index these
|
||||
curr_segment = ""
|
||||
prepend_bracket = False
|
||||
for token in tokens:
|
||||
# Special case of [1][ where ][ is a single token
|
||||
if prepend_bracket:
|
||||
curr_segment += "[" + curr_segment
|
||||
prepend_bracket = False
|
||||
|
||||
curr_segment += token
|
||||
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
|
||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
citation_found = re.search(citation_pattern, curr_segment)
|
||||
|
||||
if citation_found:
|
||||
numerical_value = int(citation_found.group(1))
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
link = links[numerical_value - 1]
|
||||
if link:
|
||||
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||
|
||||
# In case there's another open bracket like [1][, don't want to match this
|
||||
possible_citation_found = None
|
||||
|
||||
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||
# citation that needs to be replaced with a link
|
||||
if possible_citation_found:
|
||||
continue
|
||||
|
||||
# Special case with back to back citations [1][2]
|
||||
if curr_segment and curr_segment[-1] == "[":
|
||||
curr_segment = curr_segment[:-1]
|
||||
prepend_bracket = True
|
||||
|
||||
yield curr_segment
|
||||
curr_segment = ""
|
||||
|
||||
if curr_segment:
|
||||
if prepend_bracket:
|
||||
yield "[" + curr_segment
|
||||
else:
|
||||
yield curr_segment
|
||||
|
||||
|
||||
def llm_contextual_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona,
|
||||
@ -261,7 +315,6 @@ def llm_contextual_chat_answer(
|
||||
|
||||
# Model will output "Yes Search" if search is useful
|
||||
# Be a little forgiving though, if we match yes, it's good enough
|
||||
citation_max_num: int | None = None
|
||||
retrieved_chunks: list[InferenceChunk] = []
|
||||
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
|
||||
retrieved_chunks = danswer_chat_retrieval(
|
||||
@ -270,7 +323,6 @@ def llm_contextual_chat_answer(
|
||||
llm=llm,
|
||||
user_id=user_id,
|
||||
)
|
||||
citation_max_num = len(retrieved_chunks) + 1
|
||||
yield retrieved_chunks
|
||||
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
||||
|
||||
@ -299,23 +351,13 @@ def llm_contextual_chat_answer(
|
||||
final_msg_token_count=last_user_msg_tokens,
|
||||
)
|
||||
|
||||
curr_segment = ""
|
||||
for token in llm.stream(prompt):
|
||||
curr_segment += token
|
||||
tokens = llm.stream(prompt)
|
||||
links = [
|
||||
chunk.source_links[0] if chunk.source_links else None
|
||||
for chunk in retrieved_chunks
|
||||
]
|
||||
|
||||
pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
found = re.search(pattern, curr_segment)
|
||||
|
||||
if found:
|
||||
numerical_value = int(found.group(1))
|
||||
if citation_max_num and 1 <= numerical_value <= citation_max_num:
|
||||
reference_chunk = retrieved_chunks[numerical_value - 1]
|
||||
if reference_chunk.source_links and reference_chunk.source_links[0]:
|
||||
link = reference_chunk.source_links[0]
|
||||
token = re.sub("]", f"]({link})", token)
|
||||
curr_segment = ""
|
||||
|
||||
yield token
|
||||
yield from extract_citations_from_stream(tokens, links)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||
@ -327,7 +369,7 @@ def llm_tools_enabled_chat_answer(
|
||||
persona: Persona,
|
||||
user_id: UUID | None,
|
||||
tokenizer: Callable,
|
||||
) -> Iterator[str]:
|
||||
) -> Iterator[str | list[InferenceChunk]]:
|
||||
retrieval_enabled = persona.retrieval_enabled
|
||||
system_text = persona.system_text
|
||||
hint_text = persona.hint_text
|
||||
@ -405,6 +447,7 @@ def llm_tools_enabled_chat_answer(
|
||||
llm=llm,
|
||||
user_id=user_id,
|
||||
)
|
||||
yield retrieved_chunks
|
||||
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
||||
else:
|
||||
tool_result_str = call_tool(final_result, user_id=user_id)
|
||||
@ -451,6 +494,15 @@ def llm_tools_enabled_chat_answer(
|
||||
yield LLM_CHAT_FAILURE_MSG
|
||||
|
||||
|
||||
def wrap_chat_package_in_model(
|
||||
package: str | list[InferenceChunk],
|
||||
) -> DanswerAnswerPiece | RetrievalDocs:
|
||||
if isinstance(package, str):
|
||||
return DanswerAnswerPiece(answer_piece=package)
|
||||
elif isinstance(package, list):
|
||||
return RetrievalDocs(top_documents=chunks_to_search_docs(package))
|
||||
|
||||
|
||||
def llm_chat_answer(
|
||||
messages: list[ChatMessage],
|
||||
persona: Persona | None,
|
||||
@ -483,18 +535,11 @@ def llm_chat_answer(
|
||||
for package in llm_contextual_chat_answer(
|
||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||
):
|
||||
if isinstance(package, str):
|
||||
yield DanswerAnswerPiece(answer_piece=package)
|
||||
elif isinstance(package, list):
|
||||
yield RetrievalDocs(
|
||||
top_documents=chunks_to_search_docs(
|
||||
cast(list[InferenceChunk], package)
|
||||
)
|
||||
)
|
||||
yield wrap_chat_package_in_model(package)
|
||||
|
||||
# Use most flexible/complex prompt format
|
||||
else:
|
||||
for token in llm_tools_enabled_chat_answer(
|
||||
for package in llm_tools_enabled_chat_answer(
|
||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||
):
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
yield wrap_chat_package_in_model(package)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
@ -191,6 +192,7 @@ def create_new_chat_message(
|
||||
parent_edit_number: int | None,
|
||||
message_type: MessageType,
|
||||
db_session: Session,
|
||||
retrieval_docs: dict[str, Any] | None = None,
|
||||
) -> ChatMessage:
|
||||
"""Creates a new chat message and sets it to the latest message of its parent message"""
|
||||
# Get the count of existing edits at the provided message number
|
||||
@ -213,6 +215,7 @@ def create_new_chat_message(
|
||||
parent_edit_number=parent_edit_number,
|
||||
edit_number=new_edit_number,
|
||||
message=message,
|
||||
reference_docs=retrieval_docs,
|
||||
token_count=token_count,
|
||||
message_type=message_type,
|
||||
)
|
||||
|
@ -488,6 +488,9 @@ class ChatMessage(Base):
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
token_count: Mapped[int] = mapped_column(Integer)
|
||||
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
|
||||
reference_docs: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
|
@ -37,6 +37,7 @@ from danswer.server.models import CreateChatMessageRequest
|
||||
from danswer.server.models import CreateChatSessionID
|
||||
from danswer.server.models import RegenerateMessageRequest
|
||||
from danswer.server.models import RenameChatSessionResponse
|
||||
from danswer.server.models import RetrievalDocs
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
@ -110,6 +111,9 @@ def get_chat_session_messages(
|
||||
parent_edit_number=msg.parent_edit_number,
|
||||
latest=msg.latest,
|
||||
message=msg.message,
|
||||
context_docs=RetrievalDocs(**msg.reference_docs)
|
||||
if msg.reference_docs
|
||||
else None,
|
||||
message_type=msg.message_type,
|
||||
time_sent=msg.time_sent,
|
||||
)
|
||||
@ -308,11 +312,14 @@ def handle_new_chat_message(
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
llm_output = ""
|
||||
fetched_docs: RetrievalDocs | None = None
|
||||
for packet in response_packets:
|
||||
if isinstance(packet, DanswerAnswerPiece):
|
||||
token = packet.answer_piece
|
||||
if token:
|
||||
llm_output += token
|
||||
elif isinstance(packet, RetrievalDocs):
|
||||
fetched_docs = packet
|
||||
yield get_json_line(packet.dict())
|
||||
|
||||
create_new_chat_message(
|
||||
@ -322,6 +329,7 @@ def handle_new_chat_message(
|
||||
message=llm_output,
|
||||
token_count=len(llm_tokenizer(llm_output)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
retrieval_docs=fetched_docs.dict() if fetched_docs else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@ -393,11 +401,14 @@ def regenerate_message_given_parent(
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
llm_output = ""
|
||||
fetched_docs: RetrievalDocs | None = None
|
||||
for packet in response_packets:
|
||||
if isinstance(packet, DanswerAnswerPiece):
|
||||
token = packet.answer_piece
|
||||
if token:
|
||||
llm_output += token
|
||||
elif isinstance(packet, RetrievalDocs):
|
||||
fetched_docs = packet
|
||||
yield get_json_line(packet.dict())
|
||||
|
||||
create_new_chat_message(
|
||||
@ -407,6 +418,7 @@ def regenerate_message_given_parent(
|
||||
message=llm_output,
|
||||
token_count=len(llm_tokenizer(llm_output)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
retrieval_docs=fetched_docs.dict() if fetched_docs else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
@ -236,6 +236,7 @@ class ChatMessageDetail(BaseModel):
|
||||
parent_edit_number: int | None
|
||||
latest: bool
|
||||
message: str
|
||||
context_docs: RetrievalDocs | None
|
||||
message_type: MessageType
|
||||
time_sent: datetime
|
||||
|
||||
|
@ -59,6 +59,7 @@ def send_chat_message(
|
||||
def run_chat(contextual: bool) -> None:
|
||||
try:
|
||||
new_session_id = create_new_session()
|
||||
print(f"Chat Session ID: {new_session_id}")
|
||||
except requests.exceptions.ConnectionError:
|
||||
print(
|
||||
"Looks like you haven't started the Danswer Backend server, please run the FastAPI server"
|
||||
|
37
backend/tests/unit/danswer/chat/test_chat_llm.py
Normal file
37
backend/tests/unit/danswer/chat/test_chat_llm.py
Normal file
@ -0,0 +1,37 @@
|
||||
import unittest
|
||||
|
||||
from danswer.chat.chat_llm import extract_citations_from_stream
|
||||
|
||||
|
||||
class TestChatLlm(unittest.TestCase):
|
||||
def test_citation_extraction(self) -> None:
|
||||
links: list[str | None] = [f"link_{i}" for i in range(1, 21)]
|
||||
|
||||
test_1 = "Something [1]"
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
|
||||
self.assertEqual(res, "Something [[1]](link_1)")
|
||||
|
||||
test_2 = "Something [14]"
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_2), links)))
|
||||
self.assertEqual(res, "Something [[14]](link_14)")
|
||||
|
||||
test_3 = "Something [14][15]"
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_3), links)))
|
||||
self.assertEqual(res, "Something [[14]](link_14)[[15]](link_15)")
|
||||
|
||||
test_4 = ["Something ", "[", "3", "][", "4", "]."]
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_4), links)))
|
||||
self.assertEqual(res, "Something [[3]](link_3)[[4]](link_4).")
|
||||
|
||||
test_5 = ["Something ", "[", "31", "][", "4", "]."]
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_5), links)))
|
||||
self.assertEqual(res, "Something [31][[4]](link_4).")
|
||||
|
||||
links[3] = None
|
||||
test_1 = "Something [2][4][5]"
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
|
||||
self.assertEqual(res, "Something [[2]](link_2)[4][[5]](link_5)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user