mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
API support for Chat to have citations (#569)
This commit is contained in:
parent
f0337d2eba
commit
af510cc965
@ -215,6 +215,7 @@ def handle_message(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=send_to,
|
||||
text="Something has gone wrong! The Slack blocks failed to load...",
|
||||
blocks=restate_question_block + answer_blocks + document_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
|
@ -1,5 +1,7 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
@ -18,6 +20,7 @@ from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
|
||||
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
|
||||
from danswer.chat.chat_prompts import YES_SEARCH
|
||||
from danswer.chat.tools import call_tool
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import FORCE_TOOL_PROMPT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
@ -32,7 +35,9 @@ from danswer.llm.build import get_default_llm
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.search.semantic_search import chunks_to_search_docs
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
from danswer.server.models import RetrievalDocs
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
from danswer.utils.text_processing import has_unescaped_quote
|
||||
@ -114,7 +119,7 @@ def danswer_chat_retrieval(
|
||||
history: list[ChatMessage],
|
||||
llm: LLM,
|
||||
user_id: UUID | None,
|
||||
) -> str:
|
||||
) -> list[InferenceChunk]:
|
||||
if history:
|
||||
query_combination_msgs = build_combined_query(query_message, history)
|
||||
reworded_query = llm.invoke(query_combination_msgs)
|
||||
@ -129,7 +134,7 @@ def danswer_chat_retrieval(
|
||||
datastore=get_default_document_index(),
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return "No results found"
|
||||
return []
|
||||
|
||||
if unranked_chunks:
|
||||
ranked_chunks.extend(unranked_chunks)
|
||||
@ -144,7 +149,7 @@ def danswer_chat_retrieval(
|
||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
|
||||
)
|
||||
|
||||
return format_danswer_chunks_for_chat(usable_chunks)
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def _drop_messages_history_overflow(
|
||||
@ -220,7 +225,7 @@ def llm_contextual_chat_answer(
|
||||
user_id: UUID | None,
|
||||
tokenizer: Callable,
|
||||
run_search_system_text: str = REQUIRE_DANSWER_SYSTEM_MSG,
|
||||
) -> Iterator[str]:
|
||||
) -> Iterator[str | list[InferenceChunk]]:
|
||||
last_message = messages[-1]
|
||||
final_query_text = last_message.message
|
||||
previous_messages = messages[:-1]
|
||||
@ -256,13 +261,19 @@ def llm_contextual_chat_answer(
|
||||
|
||||
# Model will output "Yes Search" if search is useful
|
||||
# Be a little forgiving though, if we match yes, it's good enough
|
||||
citation_max_num: int | None = None
|
||||
retrieved_chunks: list[InferenceChunk] = []
|
||||
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
|
||||
tool_result_str = danswer_chat_retrieval(
|
||||
retrieved_chunks = danswer_chat_retrieval(
|
||||
query_message=last_message,
|
||||
history=previous_messages,
|
||||
llm=llm,
|
||||
user_id=user_id,
|
||||
)
|
||||
citation_max_num = len(retrieved_chunks) + 1
|
||||
yield retrieved_chunks
|
||||
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
||||
|
||||
last_user_msg_text = form_tool_less_followup_text(
|
||||
tool_output=tool_result_str,
|
||||
query=last_message.message,
|
||||
@ -288,11 +299,27 @@ def llm_contextual_chat_answer(
|
||||
final_msg_token_count=last_user_msg_tokens,
|
||||
)
|
||||
|
||||
return llm.stream(prompt)
|
||||
curr_segment = ""
|
||||
for token in llm.stream(prompt):
|
||||
curr_segment += token
|
||||
|
||||
pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
found = re.search(pattern, curr_segment)
|
||||
|
||||
if found:
|
||||
numerical_value = int(found.group(1))
|
||||
if citation_max_num and 1 <= numerical_value <= citation_max_num:
|
||||
reference_chunk = retrieved_chunks[numerical_value - 1]
|
||||
if reference_chunk.source_links and reference_chunk.source_links[0]:
|
||||
link = reference_chunk.source_links[0]
|
||||
token = re.sub("]", f"]({link})", token)
|
||||
curr_segment = ""
|
||||
|
||||
yield token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM failed to produce valid chat message, error: {e}")
|
||||
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
|
||||
yield LLM_CHAT_FAILURE_MSG # needs to be an Iterator
|
||||
|
||||
|
||||
def llm_tools_enabled_chat_answer(
|
||||
@ -372,12 +399,13 @@ def llm_tools_enabled_chat_answer(
|
||||
retrieval_enabled
|
||||
and final_result.action.lower() == DANSWER_TOOL_NAME.lower()
|
||||
):
|
||||
tool_result_str = danswer_chat_retrieval(
|
||||
retrieved_chunks = danswer_chat_retrieval(
|
||||
query_message=last_message,
|
||||
history=previous_messages,
|
||||
llm=llm,
|
||||
user_id=user_id,
|
||||
)
|
||||
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
|
||||
else:
|
||||
tool_result_str = call_tool(final_result, user_id=user_id)
|
||||
|
||||
@ -428,7 +456,7 @@ def llm_chat_answer(
|
||||
persona: Persona | None,
|
||||
user_id: UUID | None,
|
||||
tokenizer: Callable,
|
||||
) -> Iterator[str]:
|
||||
) -> Iterator[DanswerAnswerPiece | RetrievalDocs]:
|
||||
# Common error cases to keep in mind:
|
||||
# - User asks question about something long ago, due to context limit, the message is dropped
|
||||
# - Tool use gives wrong/irrelevant results, model gets confused by the noise
|
||||
@ -438,24 +466,35 @@ def llm_chat_answer(
|
||||
|
||||
# No setting/persona available therefore no retrieval and no additional tools
|
||||
if persona is None:
|
||||
return llm_contextless_chat_answer(messages)
|
||||
for token in llm_contextless_chat_answer(messages):
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
|
||||
# Persona is configured but with retrieval off and no tools
|
||||
# therefore cannot retrieve any context so contextless
|
||||
elif persona.retrieval_enabled is False and not persona.tools:
|
||||
return llm_contextless_chat_answer(
|
||||
for token in llm_contextless_chat_answer(
|
||||
messages, system_text=persona.system_text, tokenizer=tokenizer
|
||||
)
|
||||
):
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
|
||||
# No additional tools outside of Danswer retrieval, can use a more basic prompt
|
||||
# Doesn't require tool calling output format (all LLM outputs are therefore valid)
|
||||
elif persona.retrieval_enabled and not persona.tools and not FORCE_TOOL_PROMPT:
|
||||
return llm_contextual_chat_answer(
|
||||
for package in llm_contextual_chat_answer(
|
||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||
)
|
||||
):
|
||||
if isinstance(package, str):
|
||||
yield DanswerAnswerPiece(answer_piece=package)
|
||||
elif isinstance(package, list):
|
||||
yield RetrievalDocs(
|
||||
top_documents=chunks_to_search_docs(
|
||||
cast(list[InferenceChunk], package)
|
||||
)
|
||||
)
|
||||
|
||||
# Use most flexible/complex prompt format
|
||||
else:
|
||||
return llm_tools_enabled_chat_answer(
|
||||
for token in llm_tools_enabled_chat_answer(
|
||||
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
|
||||
)
|
||||
):
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
|
@ -172,6 +172,9 @@ def form_tool_section_text(
|
||||
|
||||
|
||||
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
|
||||
if not chunks:
|
||||
return "No Results Found"
|
||||
|
||||
return "\n".join(
|
||||
f"DOCUMENT {ind}:{CODE_BLOCK_PAT.format(chunk.content)}"
|
||||
for ind, chunk in enumerate(chunks, start=1)
|
||||
|
@ -5,6 +5,7 @@ personas:
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
|
||||
Your responses are as INFORMATIVE and DETAILED as possible.
|
||||
Cite relevant statements using the format [1], [2], etc to reference the document number, do not provide any links following the citation.
|
||||
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
|
||||
retrieval_enabled: true
|
||||
# Example of adding tools, it must follow this structure:
|
||||
|
@ -15,6 +15,7 @@ from requests import Response
|
||||
from danswer.chunking.models import DocMetadataAwareIndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import EDIT_KEYWORD_QUERY
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP
|
||||
from danswer.configs.app_configs import VESPA_HOST
|
||||
@ -43,6 +44,7 @@ from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.datastores.vespa.utils import remove_invalid_unicode_chars
|
||||
from danswer.search.keyword_search import remove_stop_words
|
||||
from danswer.search.semantic_search import embed_query
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -536,9 +538,13 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
query_embedding = embed_query(query)
|
||||
|
||||
query_keywords = (
|
||||
" ".join(remove_stop_words(query)) if EDIT_KEYWORD_QUERY else query
|
||||
)
|
||||
|
||||
params = {
|
||||
"yql": yql,
|
||||
"query": query,
|
||||
"query": query_keywords,
|
||||
"input.query(query_embedding)": str(query_embedding),
|
||||
"ranking.profile": "semantic_search",
|
||||
}
|
||||
|
@ -173,6 +173,7 @@ class ConnectorCredentialPair(Base):
|
||||
"DocumentSet",
|
||||
secondary=DocumentSet__ConnectorCredentialPair.__table__,
|
||||
back_populates="connector_credential_pairs",
|
||||
overlaps="document_set",
|
||||
)
|
||||
|
||||
|
||||
@ -410,6 +411,7 @@ class DocumentSet(Base):
|
||||
"ConnectorCredentialPair",
|
||||
secondary=DocumentSet__ConnectorCredentialPair.__table__,
|
||||
back_populates="document_sets",
|
||||
overlaps="document_set",
|
||||
)
|
||||
personas: Mapped[list["Persona"]] = relationship(
|
||||
"Persona",
|
||||
|
@ -1,33 +1,30 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Iterator
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerAnswer:
|
||||
class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerChatModelOut:
|
||||
class DanswerChatModelOut(BaseModel):
|
||||
model_raw: str
|
||||
action: str
|
||||
action_input: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerAnswerPiece:
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
"""A small piece of a complete answer. Used for streaming back answers."""
|
||||
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerQuote:
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
@ -37,20 +34,13 @@ class DanswerQuote:
|
||||
blurb: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerQuotes:
|
||||
"""A little clunky, but making this into a separate class so that the result from
|
||||
`answer_question_stream` is always a subclass of `dataclass` and can thus use `asdict()`
|
||||
"""
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
# Final int is for number of output tokens
|
||||
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
|
||||
AnswerQuestionStreamReturn = Generator[
|
||||
DanswerAnswerPiece | DanswerQuotes | None, None, None
|
||||
]
|
||||
AnswerQuestionStreamReturn = Iterator[DanswerAnswerPiece | DanswerQuotes]
|
||||
|
||||
|
||||
class QAModel:
|
||||
|
@ -1,6 +1,5 @@
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import asdict
|
||||
|
||||
from danswer.configs.constants import CODE_BLOCK_PAT
|
||||
from danswer.configs.constants import GENERAL_SEP_PAT
|
||||
@ -115,7 +114,7 @@ def stream_query_answerability(user_query: str) -> Iterator[str]:
|
||||
remaining = model_output[reason_ind + len(REASONING_PAT) :]
|
||||
if remaining:
|
||||
yield get_json_line(
|
||||
asdict(DanswerAnswerPiece(answer_piece=remaining))
|
||||
DanswerAnswerPiece(answer_piece=remaining).dict()
|
||||
)
|
||||
continue
|
||||
|
||||
@ -124,7 +123,7 @@ def stream_query_answerability(user_query: str) -> Iterator[str]:
|
||||
if hold_answerable == ANSWERABLE_PAT[: len(hold_answerable)]:
|
||||
continue
|
||||
yield get_json_line(
|
||||
asdict(DanswerAnswerPiece(answer_piece=hold_answerable))
|
||||
DanswerAnswerPiece(answer_piece=hold_answerable).dict()
|
||||
)
|
||||
hold_answerable = ""
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import asdict
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@ -302,16 +301,19 @@ def handle_new_chat_message(
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_chat_tokens() -> Iterator[str]:
|
||||
tokens = llm_chat_answer(
|
||||
response_packets = llm_chat_answer(
|
||||
messages=mainline_messages,
|
||||
persona=persona,
|
||||
user_id=user_id,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
llm_output = ""
|
||||
for token in tokens:
|
||||
llm_output += token
|
||||
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
|
||||
for packet in response_packets:
|
||||
if isinstance(packet, DanswerAnswerPiece):
|
||||
token = packet.answer_piece
|
||||
if token:
|
||||
llm_output += token
|
||||
yield get_json_line(packet.dict())
|
||||
|
||||
create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
@ -384,16 +386,19 @@ def regenerate_message_given_parent(
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_regenerate_tokens() -> Iterator[str]:
|
||||
tokens = llm_chat_answer(
|
||||
response_packets = llm_chat_answer(
|
||||
messages=mainline_messages,
|
||||
persona=persona,
|
||||
user_id=user_id,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
llm_output = ""
|
||||
for token in tokens:
|
||||
llm_output += token
|
||||
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
|
||||
for packet in response_packets:
|
||||
if isinstance(packet, DanswerAnswerPiece):
|
||||
token = packet.answer_piece
|
||||
if token:
|
||||
llm_output += token
|
||||
yield get_json_line(packet.dict())
|
||||
|
||||
create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
|
@ -149,6 +149,16 @@ class SearchDoc(BaseModel):
|
||||
match_highlights: list[str]
|
||||
|
||||
|
||||
class RetrievalDocs(BaseModel):
|
||||
top_documents: list[SearchDoc]
|
||||
|
||||
|
||||
class RerankedRetrievalDocs(RetrievalDocs):
|
||||
unranked_top_documents: list[SearchDoc]
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: int
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@ -38,6 +37,7 @@ from danswer.server.models import QAFeedbackRequest
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QueryValidationResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import RerankedRetrievalDocs
|
||||
from danswer.server.models import SearchFeedbackRequest
|
||||
from danswer.server.models import SearchResponse
|
||||
from danswer.server.utils import get_json_line
|
||||
@ -224,18 +224,19 @@ def stream_direct_qa(
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
initial_response_dict = {
|
||||
top_documents_key: [top_doc.json() for top_doc in top_docs],
|
||||
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
|
||||
initial_response = RerankedRetrievalDocs(
|
||||
top_documents=top_docs,
|
||||
unranked_top_documents=unranked_top_docs,
|
||||
# if generative AI is disabled, set flow as search so frontend
|
||||
# doesn't ask the user if they want to run QA over more documents
|
||||
predicted_flow_key: QueryFlow.SEARCH
|
||||
predicted_flow=QueryFlow.SEARCH
|
||||
if disable_generative_answer
|
||||
else predicted_flow,
|
||||
predicted_search_key: predicted_search,
|
||||
}
|
||||
logger.debug(send_packet_debug_msg.format(initial_response_dict))
|
||||
yield get_json_line(initial_response_dict)
|
||||
predicted_search=predicted_search,
|
||||
).dict()
|
||||
|
||||
logger.debug(send_packet_debug_msg.format(initial_response))
|
||||
yield get_json_line(initial_response)
|
||||
|
||||
if disable_generative_answer:
|
||||
logger.debug("Skipping QA because generative AI is disabled")
|
||||
@ -277,7 +278,7 @@ def stream_direct_qa(
|
||||
):
|
||||
answer_so_far = answer_so_far + response_packet.answer_piece
|
||||
logger.debug(f"Sending packet: {response_packet}")
|
||||
yield get_json_line(asdict(response_packet))
|
||||
yield get_json_line(response_packet.dict())
|
||||
except Exception as e:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
yield get_json_line({"error": str(e)})
|
||||
|
@ -37,15 +37,24 @@ def send_chat_message(
|
||||
"persona_id": persona_id,
|
||||
}
|
||||
|
||||
docs: list[dict] | None = None
|
||||
with requests.post(
|
||||
LOCAL_CHAT_ENDPOINT + "send-message", json=data, stream=True
|
||||
) as r:
|
||||
for json_response in r.iter_lines():
|
||||
response_text = json.loads(json_response.decode())
|
||||
new_token = response_text.get("answer_piece")
|
||||
print(new_token, end="", flush=True)
|
||||
if docs is None:
|
||||
docs = response_text.get("top_documents")
|
||||
if new_token:
|
||||
print(new_token, end="", flush=True)
|
||||
print()
|
||||
|
||||
if docs:
|
||||
print("\nReference Docs:")
|
||||
for ind, doc in enumerate(docs, start=1):
|
||||
print(f"\t - Doc {ind}: {doc.get('semantic_identifier')}")
|
||||
|
||||
|
||||
def run_chat(contextual: bool) -> None:
|
||||
try:
|
||||
|
@ -149,11 +149,9 @@ export const searchRequestStreamed = async ({
|
||||
|
||||
// These all come together
|
||||
if (Object.hasOwn(chunk, "top_documents")) {
|
||||
const topDocuments = chunk.top_documents as any[] | null;
|
||||
const topDocuments = chunk.top_documents as DanswerDocument[] | null;
|
||||
if (topDocuments) {
|
||||
relevantDocuments = topDocuments.map(
|
||||
(doc) => JSON.parse(doc) as DanswerDocument
|
||||
);
|
||||
relevantDocuments = topDocuments;
|
||||
updateDocs(relevantDocuments);
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user