mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-04 03:59:25 +02:00
Add Retrieval to Chat History (#577)
This commit is contained in:
parent
d2f7dff464
commit
595f61ea3a
@ -0,0 +1,31 @@
|
|||||||
|
"""Store Chat Retrieval Docs
|
||||||
|
|
||||||
|
Revision ID: 7ccea01261f6
|
||||||
|
Revises: a570b80a5f20
|
||||||
|
Create Date: 2023-10-15 10:39:23.317453
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "7ccea01261f6"
|
||||||
|
down_revision = "a570b80a5f20"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"chat_message",
|
||||||
|
sa.Column(
|
||||||
|
"reference_docs",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("chat_message", "reference_docs")
|
@ -1,7 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from typing import cast
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.schema.messages import AIMessage
|
from langchain.schema.messages import AIMessage
|
||||||
@ -219,6 +218,61 @@ def llm_contextless_chat_answer(
|
|||||||
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
|
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
|
||||||
|
|
||||||
|
|
||||||
|
def extract_citations_from_stream(
|
||||||
|
tokens: Iterator[str], links: list[str | None]
|
||||||
|
) -> Iterator[str]:
|
||||||
|
if not links:
|
||||||
|
yield from tokens
|
||||||
|
return
|
||||||
|
|
||||||
|
max_citation_num = len(links) + 1 # LLM is prompted to 1 index these
|
||||||
|
curr_segment = ""
|
||||||
|
prepend_bracket = False
|
||||||
|
for token in tokens:
|
||||||
|
# Special case of [1][ where ][ is a single token
|
||||||
|
if prepend_bracket:
|
||||||
|
curr_segment += "[" + curr_segment
|
||||||
|
prepend_bracket = False
|
||||||
|
|
||||||
|
curr_segment += token
|
||||||
|
|
||||||
|
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||||
|
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||||
|
|
||||||
|
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||||
|
citation_found = re.search(citation_pattern, curr_segment)
|
||||||
|
|
||||||
|
if citation_found:
|
||||||
|
numerical_value = int(citation_found.group(1))
|
||||||
|
if 1 <= numerical_value <= max_citation_num:
|
||||||
|
link = links[numerical_value - 1]
|
||||||
|
if link:
|
||||||
|
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||||
|
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||||
|
|
||||||
|
# In case there's another open bracket like [1][, don't want to match this
|
||||||
|
possible_citation_found = None
|
||||||
|
|
||||||
|
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||||
|
# citation that needs to be replaced with a link
|
||||||
|
if possible_citation_found:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Special case with back to back citations [1][2]
|
||||||
|
if curr_segment and curr_segment[-1] == "[":
|
||||||
|
curr_segment = curr_segment[:-1]
|
||||||
|
prepend_bracket = True
|
||||||
|
|
||||||
|
yield curr_segment
|
||||||
|
curr_segment = ""
|
||||||
|
|
||||||
|
if curr_segment:
|
||||||
|
if prepend_bracket:
|
||||||
|
yield "[" + curr_segment
|
||||||
|
else:
|
||||||
|
yield curr_segment
|
||||||
|
|
||||||
|
|
||||||
def llm_contextual_chat_answer(
|
def llm_contextual_chat_answer(
|
||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
persona: Persona,
|
persona: Persona,
|
||||||
@ -261,7 +315,6 @@ def llm_contextual_chat_answer(
|
|||||||
|
|
||||||
# Model will output "Yes Search" if search is useful
|
# Model will output "Yes Search" if search is useful
|
||||||
# Be a little forgiving though, if we match yes, it's good enough
|
# Be a little forgiving though, if we match yes, it's good enough
|
||||||
citation_max_num: int | None = None
|
|
||||||
retrieved_chunks: list[InferenceChunk] = []
|
retrieved_chunks: list[InferenceChunk] = []
|
||||||
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
|
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
|
||||||
retrieved_chunks = danswer_chat_retrieval(
|
retrieved_chunks = danswer_chat_retrieval(
|
||||||
@ -270,7 +323,6 @@ def llm_contextual_chat_answer(
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
citation_max_num = len(retrieved_chunks) + 1
|
|
||||||
yield retrieved_chunks
|
yield retrieved_chunks
|
||||||
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
||||||
|
|
||||||
@ -299,23 +351,13 @@ def llm_contextual_chat_answer(
|
|||||||
final_msg_token_count=last_user_msg_tokens,
|
final_msg_token_count=last_user_msg_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_segment = ""
|
tokens = llm.stream(prompt)
|
||||||
for token in llm.stream(prompt):
|
links = [
|
||||||
curr_segment += token
|
chunk.source_links[0] if chunk.source_links else None
|
||||||
|
for chunk in retrieved_chunks
|
||||||
|
]
|
||||||
|
|
||||||
pattern = r"\[(\d+)\]" # [1], [2] etc
|
yield from extract_citations_from_stream(tokens, links)
|
||||||
found = re.search(pattern, curr_segment)
|
|
||||||
|
|
||||||
if found:
|
|
||||||
numerical_value = int(found.group(1))
|
|
||||||
if citation_max_num and 1 <= numerical_value <= citation_max_num:
|
|
||||||
reference_chunk = retrieved_chunks[numerical_value - 1]
|
|
||||||
if reference_chunk.source_links and reference_chunk.source_links[0]:
|
|
||||||
link = reference_chunk.source_links[0]
|
|
||||||
token = re.sub("]", f"]({link})", token)
|
|
||||||
curr_segment = ""
|
|
||||||
|
|
||||||
yield token
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||||
@ -327,7 +369,7 @@ def llm_tools_enabled_chat_answer(
|
|||||||
persona: Persona,
|
persona: Persona,
|
||||||
user_id: UUID | None,
|
user_id: UUID | None,
|
||||||
tokenizer: Callable,
|
tokenizer: Callable,
|
||||||
) -> Iterator[str]:
|
) -> Iterator[str | list[InferenceChunk]]:
|
||||||
retrieval_enabled = persona.retrieval_enabled
|
retrieval_enabled = persona.retrieval_enabled
|
||||||
system_text = persona.system_text
|
system_text = persona.system_text
|
||||||
hint_text = persona.hint_text
|
hint_text = persona.hint_text
|
||||||
@ -405,6 +447,7 @@ def llm_tools_enabled_chat_answer(
|
|||||||
llm=llm,
|
llm=llm,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
yield retrieved_chunks
|
||||||
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
||||||
else:
|
else:
|
||||||
tool_result_str = call_tool(final_result, user_id=user_id)
|
tool_result_str = call_tool(final_result, user_id=user_id)
|
||||||
@ -451,6 +494,15 @@ def llm_tools_enabled_chat_answer(
|
|||||||
yield LLM_CHAT_FAILURE_MSG
|
yield LLM_CHAT_FAILURE_MSG
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_chat_package_in_model(
|
||||||
|
package: str | list[InferenceChunk],
|
||||||
|
) -> DanswerAnswerPiece | RetrievalDocs:
|
||||||
|
if isinstance(package, str):
|
||||||
|
return DanswerAnswerPiece(answer_piece=package)
|
||||||
|
elif isinstance(package, list):
|
||||||
|
return RetrievalDocs(top_documents=chunks_to_search_docs(package))
|
||||||
|
|
||||||
|
|
||||||
def llm_chat_answer(
|
def llm_chat_answer(
|
||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
persona: Persona | None,
|
persona: Persona | None,
|
||||||
@ -483,18 +535,11 @@ def llm_chat_answer(
|
|||||||
for package in llm_contextual_chat_answer(
|
for package in llm_contextual_chat_answer(
|
||||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||||
):
|
):
|
||||||
if isinstance(package, str):
|
yield wrap_chat_package_in_model(package)
|
||||||
yield DanswerAnswerPiece(answer_piece=package)
|
|
||||||
elif isinstance(package, list):
|
|
||||||
yield RetrievalDocs(
|
|
||||||
top_documents=chunks_to_search_docs(
|
|
||||||
cast(list[InferenceChunk], package)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use most flexible/complex prompt format
|
# Use most flexible/complex prompt format
|
||||||
else:
|
else:
|
||||||
for token in llm_tools_enabled_chat_answer(
|
for package in llm_tools_enabled_chat_answer(
|
||||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||||
):
|
):
|
||||||
yield DanswerAnswerPiece(answer_piece=token)
|
yield wrap_chat_package_in_model(package)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
@ -191,6 +192,7 @@ def create_new_chat_message(
|
|||||||
parent_edit_number: int | None,
|
parent_edit_number: int | None,
|
||||||
message_type: MessageType,
|
message_type: MessageType,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
|
retrieval_docs: dict[str, Any] | None = None,
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
"""Creates a new chat message and sets it to the latest message of its parent message"""
|
"""Creates a new chat message and sets it to the latest message of its parent message"""
|
||||||
# Get the count of existing edits at the provided message number
|
# Get the count of existing edits at the provided message number
|
||||||
@ -213,6 +215,7 @@ def create_new_chat_message(
|
|||||||
parent_edit_number=parent_edit_number,
|
parent_edit_number=parent_edit_number,
|
||||||
edit_number=new_edit_number,
|
edit_number=new_edit_number,
|
||||||
message=message,
|
message=message,
|
||||||
|
reference_docs=retrieval_docs,
|
||||||
token_count=token_count,
|
token_count=token_count,
|
||||||
message_type=message_type,
|
message_type=message_type,
|
||||||
)
|
)
|
||||||
|
@ -488,6 +488,9 @@ class ChatMessage(Base):
|
|||||||
message: Mapped[str] = mapped_column(Text)
|
message: Mapped[str] = mapped_column(Text)
|
||||||
token_count: Mapped[int] = mapped_column(Integer)
|
token_count: Mapped[int] = mapped_column(Integer)
|
||||||
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
|
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
|
||||||
|
reference_docs: Mapped[dict[str, Any] | None] = mapped_column(
|
||||||
|
postgresql.JSONB(), nullable=True
|
||||||
|
)
|
||||||
persona_id: Mapped[int | None] = mapped_column(
|
persona_id: Mapped[int | None] = mapped_column(
|
||||||
ForeignKey("persona.id"), nullable=True
|
ForeignKey("persona.id"), nullable=True
|
||||||
)
|
)
|
||||||
|
@ -37,6 +37,7 @@ from danswer.server.models import CreateChatMessageRequest
|
|||||||
from danswer.server.models import CreateChatSessionID
|
from danswer.server.models import CreateChatSessionID
|
||||||
from danswer.server.models import RegenerateMessageRequest
|
from danswer.server.models import RegenerateMessageRequest
|
||||||
from danswer.server.models import RenameChatSessionResponse
|
from danswer.server.models import RenameChatSessionResponse
|
||||||
|
from danswer.server.models import RetrievalDocs
|
||||||
from danswer.server.utils import get_json_line
|
from danswer.server.utils import get_json_line
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.timing import log_generator_function_time
|
from danswer.utils.timing import log_generator_function_time
|
||||||
@ -110,6 +111,9 @@ def get_chat_session_messages(
|
|||||||
parent_edit_number=msg.parent_edit_number,
|
parent_edit_number=msg.parent_edit_number,
|
||||||
latest=msg.latest,
|
latest=msg.latest,
|
||||||
message=msg.message,
|
message=msg.message,
|
||||||
|
context_docs=RetrievalDocs(**msg.reference_docs)
|
||||||
|
if msg.reference_docs
|
||||||
|
else None,
|
||||||
message_type=msg.message_type,
|
message_type=msg.message_type,
|
||||||
time_sent=msg.time_sent,
|
time_sent=msg.time_sent,
|
||||||
)
|
)
|
||||||
@ -308,11 +312,14 @@ def handle_new_chat_message(
|
|||||||
tokenizer=llm_tokenizer,
|
tokenizer=llm_tokenizer,
|
||||||
)
|
)
|
||||||
llm_output = ""
|
llm_output = ""
|
||||||
|
fetched_docs: RetrievalDocs | None = None
|
||||||
for packet in response_packets:
|
for packet in response_packets:
|
||||||
if isinstance(packet, DanswerAnswerPiece):
|
if isinstance(packet, DanswerAnswerPiece):
|
||||||
token = packet.answer_piece
|
token = packet.answer_piece
|
||||||
if token:
|
if token:
|
||||||
llm_output += token
|
llm_output += token
|
||||||
|
elif isinstance(packet, RetrievalDocs):
|
||||||
|
fetched_docs = packet
|
||||||
yield get_json_line(packet.dict())
|
yield get_json_line(packet.dict())
|
||||||
|
|
||||||
create_new_chat_message(
|
create_new_chat_message(
|
||||||
@ -322,6 +329,7 @@ def handle_new_chat_message(
|
|||||||
message=llm_output,
|
message=llm_output,
|
||||||
token_count=len(llm_tokenizer(llm_output)),
|
token_count=len(llm_tokenizer(llm_output)),
|
||||||
message_type=MessageType.ASSISTANT,
|
message_type=MessageType.ASSISTANT,
|
||||||
|
retrieval_docs=fetched_docs.dict() if fetched_docs else None,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -393,11 +401,14 @@ def regenerate_message_given_parent(
|
|||||||
tokenizer=llm_tokenizer,
|
tokenizer=llm_tokenizer,
|
||||||
)
|
)
|
||||||
llm_output = ""
|
llm_output = ""
|
||||||
|
fetched_docs: RetrievalDocs | None = None
|
||||||
for packet in response_packets:
|
for packet in response_packets:
|
||||||
if isinstance(packet, DanswerAnswerPiece):
|
if isinstance(packet, DanswerAnswerPiece):
|
||||||
token = packet.answer_piece
|
token = packet.answer_piece
|
||||||
if token:
|
if token:
|
||||||
llm_output += token
|
llm_output += token
|
||||||
|
elif isinstance(packet, RetrievalDocs):
|
||||||
|
fetched_docs = packet
|
||||||
yield get_json_line(packet.dict())
|
yield get_json_line(packet.dict())
|
||||||
|
|
||||||
create_new_chat_message(
|
create_new_chat_message(
|
||||||
@ -407,6 +418,7 @@ def regenerate_message_given_parent(
|
|||||||
message=llm_output,
|
message=llm_output,
|
||||||
token_count=len(llm_tokenizer(llm_output)),
|
token_count=len(llm_tokenizer(llm_output)),
|
||||||
message_type=MessageType.ASSISTANT,
|
message_type=MessageType.ASSISTANT,
|
||||||
|
retrieval_docs=fetched_docs.dict() if fetched_docs else None,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -236,6 +236,7 @@ class ChatMessageDetail(BaseModel):
|
|||||||
parent_edit_number: int | None
|
parent_edit_number: int | None
|
||||||
latest: bool
|
latest: bool
|
||||||
message: str
|
message: str
|
||||||
|
context_docs: RetrievalDocs | None
|
||||||
message_type: MessageType
|
message_type: MessageType
|
||||||
time_sent: datetime
|
time_sent: datetime
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ def send_chat_message(
|
|||||||
def run_chat(contextual: bool) -> None:
|
def run_chat(contextual: bool) -> None:
|
||||||
try:
|
try:
|
||||||
new_session_id = create_new_session()
|
new_session_id = create_new_session()
|
||||||
|
print(f"Chat Session ID: {new_session_id}")
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
print(
|
print(
|
||||||
"Looks like you haven't started the Danswer Backend server, please run the FastAPI server"
|
"Looks like you haven't started the Danswer Backend server, please run the FastAPI server"
|
||||||
|
37
backend/tests/unit/danswer/chat/test_chat_llm.py
Normal file
37
backend/tests/unit/danswer/chat/test_chat_llm.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from danswer.chat.chat_llm import extract_citations_from_stream
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatLlm(unittest.TestCase):
|
||||||
|
def test_citation_extraction(self) -> None:
|
||||||
|
links: list[str | None] = [f"link_{i}" for i in range(1, 21)]
|
||||||
|
|
||||||
|
test_1 = "Something [1]"
|
||||||
|
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
|
||||||
|
self.assertEqual(res, "Something [[1]](link_1)")
|
||||||
|
|
||||||
|
test_2 = "Something [14]"
|
||||||
|
res = "".join(list(extract_citations_from_stream(iter(test_2), links)))
|
||||||
|
self.assertEqual(res, "Something [[14]](link_14)")
|
||||||
|
|
||||||
|
test_3 = "Something [14][15]"
|
||||||
|
res = "".join(list(extract_citations_from_stream(iter(test_3), links)))
|
||||||
|
self.assertEqual(res, "Something [[14]](link_14)[[15]](link_15)")
|
||||||
|
|
||||||
|
test_4 = ["Something ", "[", "3", "][", "4", "]."]
|
||||||
|
res = "".join(list(extract_citations_from_stream(iter(test_4), links)))
|
||||||
|
self.assertEqual(res, "Something [[3]](link_3)[[4]](link_4).")
|
||||||
|
|
||||||
|
test_5 = ["Something ", "[", "31", "][", "4", "]."]
|
||||||
|
res = "".join(list(extract_citations_from_stream(iter(test_5), links)))
|
||||||
|
self.assertEqual(res, "Something [31][[4]](link_4).")
|
||||||
|
|
||||||
|
links[3] = None
|
||||||
|
test_1 = "Something [2][4][5]"
|
||||||
|
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
|
||||||
|
self.assertEqual(res, "Something [[2]](link_2)[4][[5]](link_5)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user