API support for Chat to have citations (#569)

This commit is contained in:
Yuhong Sun 2023-10-13 17:38:25 -07:00 committed by GitHub
parent f0337d2eba
commit af510cc965
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 127 additions and 63 deletions

View File

@ -215,6 +215,7 @@ def handle_message(
client=client,
channel=channel,
receiver_ids=send_to,
text="Something has gone wrong! The Slack blocks failed to load...",
blocks=restate_question_block + answer_blocks + document_blocks,
thread_ts=message_ts_to_respond_to,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long

View File

@ -1,5 +1,7 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from uuid import UUID
from langchain.schema.messages import AIMessage
@ -18,6 +20,7 @@ from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
from danswer.chat.chat_prompts import YES_SEARCH
from danswer.chat.tools import call_tool
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
from danswer.configs.chat_configs import FORCE_TOOL_PROMPT
from danswer.configs.constants import IGNORE_FOR_QA
@ -32,7 +35,9 @@ from danswer.llm.build import get_default_llm
from danswer.llm.llm import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_danswer_msg_to_langchain
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import RetrievalDocs
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import extract_embedded_json
from danswer.utils.text_processing import has_unescaped_quote
@ -114,7 +119,7 @@ def danswer_chat_retrieval(
history: list[ChatMessage],
llm: LLM,
user_id: UUID | None,
) -> str:
) -> list[InferenceChunk]:
if history:
query_combination_msgs = build_combined_query(query_message, history)
reworded_query = llm.invoke(query_combination_msgs)
@ -129,7 +134,7 @@ def danswer_chat_retrieval(
datastore=get_default_document_index(),
)
if not ranked_chunks:
return "No results found"
return []
if unranked_chunks:
ranked_chunks.extend(unranked_chunks)
@ -144,7 +149,7 @@ def danswer_chat_retrieval(
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
)
return format_danswer_chunks_for_chat(usable_chunks)
return usable_chunks
def _drop_messages_history_overflow(
@ -220,7 +225,7 @@ def llm_contextual_chat_answer(
user_id: UUID | None,
tokenizer: Callable,
run_search_system_text: str = REQUIRE_DANSWER_SYSTEM_MSG,
) -> Iterator[str]:
) -> Iterator[str | list[InferenceChunk]]:
last_message = messages[-1]
final_query_text = last_message.message
previous_messages = messages[:-1]
@ -256,13 +261,19 @@ def llm_contextual_chat_answer(
# Model will output "Yes Search" if search is useful
# Be a little forgiving though, if we match yes, it's good enough
citation_max_num: int | None = None
retrieved_chunks: list[InferenceChunk] = []
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
tool_result_str = danswer_chat_retrieval(
retrieved_chunks = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
user_id=user_id,
)
citation_max_num = len(retrieved_chunks) + 1
yield retrieved_chunks
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
last_user_msg_text = form_tool_less_followup_text(
tool_output=tool_result_str,
query=last_message.message,
@ -288,11 +299,27 @@ def llm_contextual_chat_answer(
final_msg_token_count=last_user_msg_tokens,
)
return llm.stream(prompt)
curr_segment = ""
for token in llm.stream(prompt):
curr_segment += token
pattern = r"\[(\d+)\]" # [1], [2] etc
found = re.search(pattern, curr_segment)
if found:
numerical_value = int(found.group(1))
if citation_max_num and 1 <= numerical_value <= citation_max_num:
reference_chunk = retrieved_chunks[numerical_value - 1]
if reference_chunk.source_links and reference_chunk.source_links[0]:
link = reference_chunk.source_links[0]
token = re.sub("]", f"]({link})", token)
curr_segment = ""
yield token
except Exception as e:
logger.error(f"LLM failed to produce valid chat message, error: {e}")
return (msg for msg in [LLM_CHAT_FAILURE_MSG]) # needs to be an Iterator
yield LLM_CHAT_FAILURE_MSG # needs to be an Iterator
def llm_tools_enabled_chat_answer(
@ -372,12 +399,13 @@ def llm_tools_enabled_chat_answer(
retrieval_enabled
and final_result.action.lower() == DANSWER_TOOL_NAME.lower()
):
tool_result_str = danswer_chat_retrieval(
retrieved_chunks = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
user_id=user_id,
)
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
else:
tool_result_str = call_tool(final_result, user_id=user_id)
@ -428,7 +456,7 @@ def llm_chat_answer(
persona: Persona | None,
user_id: UUID | None,
tokenizer: Callable,
) -> Iterator[str]:
) -> Iterator[DanswerAnswerPiece | RetrievalDocs]:
# Common error cases to keep in mind:
# - User asks question about something long ago, due to context limit, the message is dropped
# - Tool use gives wrong/irrelevant results, model gets confused by the noise
@ -438,24 +466,35 @@ def llm_chat_answer(
# No setting/persona available therefore no retrieval and no additional tools
if persona is None:
return llm_contextless_chat_answer(messages)
for token in llm_contextless_chat_answer(messages):
yield DanswerAnswerPiece(answer_piece=token)
# Persona is configured but with retrieval off and no tools
# therefore cannot retrieve any context so contextless
elif persona.retrieval_enabled is False and not persona.tools:
return llm_contextless_chat_answer(
for token in llm_contextless_chat_answer(
messages, system_text=persona.system_text, tokenizer=tokenizer
)
):
yield DanswerAnswerPiece(answer_piece=token)
# No additional tools outside of Danswer retrieval, can use a more basic prompt
# Doesn't require tool calling output format (all LLM outputs are therefore valid)
elif persona.retrieval_enabled and not persona.tools and not FORCE_TOOL_PROMPT:
return llm_contextual_chat_answer(
for package in llm_contextual_chat_answer(
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
)
):
if isinstance(package, str):
yield DanswerAnswerPiece(answer_piece=package)
elif isinstance(package, list):
yield RetrievalDocs(
top_documents=chunks_to_search_docs(
cast(list[InferenceChunk], package)
)
)
# Use most flexible/complex prompt format
else:
return llm_tools_enabled_chat_answer(
for token in llm_tools_enabled_chat_answer(
messages=messages, persona=persona, user_id=user_id, tokenizer=tokenizer
)
):
yield DanswerAnswerPiece(answer_piece=token)

View File

@ -172,6 +172,9 @@ def form_tool_section_text(
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
if not chunks:
return "No Results Found"
return "\n".join(
f"DOCUMENT {ind}:{CODE_BLOCK_PAT.format(chunk.content)}"
for ind, chunk in enumerate(chunks, start=1)

View File

@ -5,6 +5,7 @@ personas:
You are a question answering system that is constantly learning and improving.
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
Your responses are as INFORMATIVE and DETAILED as possible.
Cite relevant statements using the format [1], [2], etc to reference the document number, do not provide any links following the citation.
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
retrieval_enabled: true
# Example of adding tools, it must follow this structure:

View File

@ -15,6 +15,7 @@ from requests import Response
from danswer.chunking.models import DocMetadataAwareIndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import EDIT_KEYWORD_QUERY
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP
from danswer.configs.app_configs import VESPA_HOST
@ -43,6 +44,7 @@ from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import UpdateRequest
from danswer.datastores.vespa.utils import remove_invalid_unicode_chars
from danswer.search.keyword_search import remove_stop_words
from danswer.search.semantic_search import embed_query
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
@ -536,9 +538,13 @@ class VespaIndex(DocumentIndex):
query_embedding = embed_query(query)
query_keywords = (
" ".join(remove_stop_words(query)) if EDIT_KEYWORD_QUERY else query
)
params = {
"yql": yql,
"query": query,
"query": query_keywords,
"input.query(query_embedding)": str(query_embedding),
"ranking.profile": "semantic_search",
}

View File

@ -173,6 +173,7 @@ class ConnectorCredentialPair(Base):
"DocumentSet",
secondary=DocumentSet__ConnectorCredentialPair.__table__,
back_populates="connector_credential_pairs",
overlaps="document_set",
)
@ -410,6 +411,7 @@ class DocumentSet(Base):
"ConnectorCredentialPair",
secondary=DocumentSet__ConnectorCredentialPair.__table__,
back_populates="document_sets",
overlaps="document_set",
)
personas: Mapped[list["Persona"]] = relationship(
"Persona",

View File

@ -1,33 +1,30 @@
import abc
from collections.abc import Callable
from collections.abc import Generator
from dataclasses import dataclass
from collections.abc import Iterator
from pydantic import BaseModel
from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.models import LLMMetricsContainer
@dataclass
class DanswerAnswer:
class DanswerAnswer(BaseModel):
answer: str | None
@dataclass
class DanswerChatModelOut:
class DanswerChatModelOut(BaseModel):
model_raw: str
action: str
action_input: str
@dataclass
class DanswerAnswerPiece:
class DanswerAnswerPiece(BaseModel):
"""A small piece of a complete answer. Used for streaming back answers."""
answer_piece: str | None # if None, specifies the end of an Answer
@dataclass
class DanswerQuote:
class DanswerQuote(BaseModel):
# This is during inference so everything is a string by this point
quote: str
document_id: str
@ -37,20 +34,13 @@ class DanswerQuote:
blurb: str
@dataclass
class DanswerQuotes:
"""A little clunky, but making this into a separate class so that the result from
`answer_question_stream` is always a subclass of `dataclass` and can thus use `asdict()`
"""
class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
# Final int is for number of output tokens
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
AnswerQuestionStreamReturn = Generator[
DanswerAnswerPiece | DanswerQuotes | None, None, None
]
AnswerQuestionStreamReturn = Iterator[DanswerAnswerPiece | DanswerQuotes]
class QAModel:

View File

@ -1,6 +1,5 @@
import re
from collections.abc import Iterator
from dataclasses import asdict
from danswer.configs.constants import CODE_BLOCK_PAT
from danswer.configs.constants import GENERAL_SEP_PAT
@ -115,7 +114,7 @@ def stream_query_answerability(user_query: str) -> Iterator[str]:
remaining = model_output[reason_ind + len(REASONING_PAT) :]
if remaining:
yield get_json_line(
asdict(DanswerAnswerPiece(answer_piece=remaining))
DanswerAnswerPiece(answer_piece=remaining).dict()
)
continue
@ -124,7 +123,7 @@ def stream_query_answerability(user_query: str) -> Iterator[str]:
if hold_answerable == ANSWERABLE_PAT[: len(hold_answerable)]:
continue
yield get_json_line(
asdict(DanswerAnswerPiece(answer_piece=hold_answerable))
DanswerAnswerPiece(answer_piece=hold_answerable).dict()
)
hold_answerable = ""

View File

@ -1,5 +1,4 @@
from collections.abc import Iterator
from dataclasses import asdict
from fastapi import APIRouter
from fastapi import Depends
@ -302,16 +301,19 @@ def handle_new_chat_message(
@log_generator_function_time()
def stream_chat_tokens() -> Iterator[str]:
tokens = llm_chat_answer(
response_packets = llm_chat_answer(
messages=mainline_messages,
persona=persona,
user_id=user_id,
tokenizer=llm_tokenizer,
)
llm_output = ""
for token in tokens:
llm_output += token
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
for packet in response_packets:
if isinstance(packet, DanswerAnswerPiece):
token = packet.answer_piece
if token:
llm_output += token
yield get_json_line(packet.dict())
create_new_chat_message(
chat_session_id=chat_session_id,
@ -384,16 +386,19 @@ def regenerate_message_given_parent(
@log_generator_function_time()
def stream_regenerate_tokens() -> Iterator[str]:
tokens = llm_chat_answer(
response_packets = llm_chat_answer(
messages=mainline_messages,
persona=persona,
user_id=user_id,
tokenizer=llm_tokenizer,
)
llm_output = ""
for token in tokens:
llm_output += token
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
for packet in response_packets:
if isinstance(packet, DanswerAnswerPiece):
token = packet.answer_piece
if token:
llm_output += token
yield get_json_line(packet.dict())
create_new_chat_message(
chat_session_id=chat_session_id,

View File

@ -149,6 +149,16 @@ class SearchDoc(BaseModel):
match_highlights: list[str]
class RetrievalDocs(BaseModel):
top_documents: list[SearchDoc]
class RerankedRetrievalDocs(RetrievalDocs):
unranked_top_documents: list[SearchDoc]
predicted_flow: QueryFlow
predicted_search: SearchType
class CreateChatSessionID(BaseModel):
chat_session_id: int

View File

@ -1,5 +1,4 @@
from collections.abc import Generator
from dataclasses import asdict
from fastapi import APIRouter
from fastapi import Depends
@ -38,6 +37,7 @@ from danswer.server.models import QAFeedbackRequest
from danswer.server.models import QAResponse
from danswer.server.models import QueryValidationResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import RerankedRetrievalDocs
from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse
from danswer.server.utils import get_json_line
@ -224,18 +224,19 @@ def stream_direct_qa(
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
initial_response_dict = {
top_documents_key: [top_doc.json() for top_doc in top_docs],
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
initial_response = RerankedRetrievalDocs(
top_documents=top_docs,
unranked_top_documents=unranked_top_docs,
# if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents
predicted_flow_key: QueryFlow.SEARCH
predicted_flow=QueryFlow.SEARCH
if disable_generative_answer
else predicted_flow,
predicted_search_key: predicted_search,
}
logger.debug(send_packet_debug_msg.format(initial_response_dict))
yield get_json_line(initial_response_dict)
predicted_search=predicted_search,
).dict()
logger.debug(send_packet_debug_msg.format(initial_response))
yield get_json_line(initial_response)
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
@ -277,7 +278,7 @@ def stream_direct_qa(
):
answer_so_far = answer_so_far + response_packet.answer_piece
logger.debug(f"Sending packet: {response_packet}")
yield get_json_line(asdict(response_packet))
yield get_json_line(response_packet.dict())
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
yield get_json_line({"error": str(e)})

View File

@ -37,15 +37,24 @@ def send_chat_message(
"persona_id": persona_id,
}
docs: list[dict] | None = None
with requests.post(
LOCAL_CHAT_ENDPOINT + "send-message", json=data, stream=True
) as r:
for json_response in r.iter_lines():
response_text = json.loads(json_response.decode())
new_token = response_text.get("answer_piece")
print(new_token, end="", flush=True)
if docs is None:
docs = response_text.get("top_documents")
if new_token:
print(new_token, end="", flush=True)
print()
if docs:
print("\nReference Docs:")
for ind, doc in enumerate(docs, start=1):
print(f"\t - Doc {ind}: {doc.get('semantic_identifier')}")
def run_chat(contextual: bool) -> None:
try:

View File

@ -149,11 +149,9 @@ export const searchRequestStreamed = async ({
// These all come together
if (Object.hasOwn(chunk, "top_documents")) {
const topDocuments = chunk.top_documents as any[] | null;
const topDocuments = chunk.top_documents as DanswerDocument[] | null;
if (topDocuments) {
relevantDocuments = topDocuments.map(
(doc) => JSON.parse(doc) as DanswerDocument
);
relevantDocuments = topDocuments;
updateDocs(relevantDocuments);
}