Rework LLM answering flow

This commit is contained in:
Weves 2024-03-25 12:09:27 -07:00 committed by Chris Weaver
parent 1ba74ee4df
commit f135ba9c0c
26 changed files with 1407 additions and 1568 deletions

View File

@ -1,97 +1,29 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from functools import lru_cache
from typing import cast
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session
from tiktoken.core import Encoding
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_default_prompt
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import tokenizer_trim_content
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.prompts.prompt_utils import build_task_prompt_reminders
from danswer.prompts.prompt_utils import get_current_llm_day_time
from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
from danswer.prompts.token_counts import (
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
)
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.utils.logger import setup_logger
logger = setup_logger()
@lru_cache()
def build_chat_system_message(
prompt: Prompt,
context_exists: bool,
llm_tokenizer_encode_func: Callable,
citation_line: str = REQUIRE_CITATION_STATEMENT,
no_citation_line: str = NO_CITATION_STATEMENT,
) -> tuple[SystemMessage | None, int]:
system_prompt = prompt.system_prompt.strip()
if prompt.include_citations:
if context_exists:
system_prompt += citation_line
else:
system_prompt += no_citation_line
if prompt.datetime_aware:
if system_prompt:
system_prompt += ADDITIONAL_INFO.format(
datetime_info=get_current_llm_day_time()
)
else:
system_prompt = get_current_llm_day_time()
if not system_prompt:
return None, 0
token_count = len(llm_tokenizer_encode_func(system_prompt))
system_msg = SystemMessage(content=system_prompt)
return system_msg, token_count
def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc:
return LlmDoc(
document_id=inf_chunk.document_id,
content=inf_chunk.content,
blurb=inf_chunk.blurb,
semantic_identifier=inf_chunk.semantic_identifier,
source_type=inf_chunk.source_type,
metadata=inf_chunk.metadata,
updated_at=inf_chunk.updated_at,
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
source_links=inf_chunk.source_links,
)
@ -108,170 +40,6 @@ def map_document_id_order(
return order_mapping
def build_chat_user_message(
chat_message: ChatMessage,
prompt: Prompt,
context_docs: list[LlmDoc],
llm_tokenizer_encode_func: Callable,
all_doc_useful: bool,
user_prompt_template: str = CHAT_USER_PROMPT,
context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT,
ignore_str: str = DEFAULT_IGNORE_STATEMENT,
) -> tuple[HumanMessage, int]:
user_query = chat_message.message
if not context_docs:
# Simpler prompt for cases where there is no context
user_prompt = (
context_free_template.format(
task_prompt=prompt.task_prompt, user_query=user_query
)
if prompt.task_prompt
else user_query
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt)
return user_msg, token_count
context_docs_str = build_complete_context_str(
cast(list[LlmDoc | InferenceChunk], context_docs)
)
optional_ignore = "" if all_doc_useful else ignore_str
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
user_prompt = user_prompt_template.format(
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
user_query=user_query,
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt)
return user_msg, token_count
def _get_usable_chunks(
chunks: list[InferenceChunk], token_limit: int
) -> list[InferenceChunk]:
total_token_count = 0
usable_chunks = []
for chunk in chunks:
chunk_token_count = check_number_of_tokens(chunk.content)
if total_token_count + chunk_token_count > token_limit:
break
total_token_count += chunk_token_count
usable_chunks.append(chunk)
# try and return at least one chunk if possible. This chunk will
# get truncated later on in the pipeline. This would only occur if
# the first chunk is larger than the token limit (usually due to character
# count -> token count mismatches caused by special characters / non-ascii
# languages)
if not usable_chunks and chunks:
usable_chunks = [chunks[0]]
return usable_chunks
def get_usable_chunks(
chunks: list[InferenceChunk],
token_limit: int,
offset: int = 0,
) -> list[InferenceChunk]:
offset_into_chunks = 0
usable_chunks: list[InferenceChunk] = []
for _ in range(min(offset + 1, 1)): # go through this process at least once
if offset_into_chunks >= len(chunks) and offset_into_chunks > 0:
raise ValueError(
"Chunks offset too large, should not retry this many times"
)
usable_chunks = _get_usable_chunks(
chunks=chunks[offset_into_chunks:], token_limit=token_limit
)
offset_into_chunks += len(usable_chunks)
return usable_chunks
def get_chunks_for_qa(
chunks: list[InferenceChunk],
llm_chunk_selection: list[bool],
token_limit: int | None,
llm_tokenizer: Encoding | None = None,
batch_offset: int = 0,
) -> list[int]:
"""
Gives back indices of chunks to pass into the LLM for Q&A.
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
by the LLM in a separate flow (this can be turned off)
Note, the batch_offset calculation has to count the batches from the beginning each time as
there's no way to know which chunks were included in the prior batches without recounting atm,
this is somewhat slow as it requires tokenizing all the chunks again
"""
token_leeway = 50
batch_index = 0
latest_batch_indices: list[int] = []
token_count = 0
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
for selection_target in [True, False]:
for ind, chunk in enumerate(chunks):
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
IGNORE_FOR_QA
):
continue
# We calculate it live in case the user uses a different LLM + tokenizer
chunk_token = check_number_of_tokens(chunk.content)
if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway:
logger.warning(
"Found more tokens in chunk than expected, "
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
)
chunk.content = tokenizer_trim_content(
content=chunk.content,
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
tokenizer=llm_tokenizer or get_default_llm_tokenizer(),
)
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
token_count += chunk_token + token_leeway
# Always use at least 1 chunk
if (
token_limit is None
or token_count <= token_limit
or not latest_batch_indices
):
latest_batch_indices.append(ind)
current_chunk_unused = False
else:
current_chunk_unused = True
if token_limit is not None and token_count >= token_limit:
if batch_index < batch_offset:
batch_index += 1
if current_chunk_unused:
latest_batch_indices = [ind]
token_count = chunk_token
else:
latest_batch_indices = []
token_count = 0
else:
return latest_batch_indices
return latest_batch_indices
def create_chat_chain(
chat_session_id: int,
db_session: Session,
@ -341,157 +109,6 @@ def combine_message_chain(
return "\n\n".join(message_strs)
_PER_MESSAGE_TOKEN_BUFFER = 7
def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
"""From the back, find the index of the last element to include
before the list exceeds the maximum"""
running_sum = 0
last_ind = 0
for i in range(len(lst) - 1, -1, -1):
running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER
if running_sum > max_prompt_tokens:
last_ind = i + 1
break
if last_ind >= len(lst):
raise ValueError("Last message alone is too large!")
return last_ind
def drop_messages_history_overflow(
system_msg: BaseMessage | None,
system_token_count: int,
history_msgs: list[BaseMessage],
history_token_counts: list[int],
final_msg: BaseMessage,
final_msg_token_count: int,
max_allowed_tokens: int,
) -> list[BaseMessage]:
"""As message history grows, messages need to be dropped starting from the furthest in the past.
The System message should be kept if at all possible and the latest user input which is inserted in the
prompt template must be included"""
if len(history_msgs) != len(history_token_counts):
# This should never happen
raise ValueError("Need exactly 1 token count per message for tracking overflow")
prompt: list[BaseMessage] = []
# Start dropping from the history if necessary
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
ind_prev_msg_start = find_last_index(
all_tokens, max_prompt_tokens=max_allowed_tokens
)
if system_msg and ind_prev_msg_start <= len(history_msgs):
prompt.append(system_msg)
prompt.extend(history_msgs[ind_prev_msg_start:])
prompt.append(final_msg)
return prompt
def in_code_block(llm_text: str) -> bool:
count = llm_text.count(TRIPLE_BACKTICK)
return count % 2 != 0
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
stop_stream: str | None = STOP_STREAM_PAT,
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
llm_out = ""
max_citation_num = len(context_docs)
curr_segment = ""
prepend_bracket = False
cited_inds = set()
hold = ""
for raw_token in tokens:
if stop_stream:
next_hold = hold + raw_token
if stop_stream in next_hold:
break
if next_hold == stop_stream[: len(next_hold)]:
hold = next_hold
continue
token = next_hold
hold = ""
else:
token = raw_token
# Special case of [1][ where ][ is a single token
# This is where the model attempts to do consecutive citations like [1][2]
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
curr_segment += token
llm_out += token
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
if citation_found and not in_code_block(llm_out):
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[
numerical_value - 1
] # remove 1 index offset
link = context_llm_doc.link
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
# Use the citation number for the document's rank in
# the search (or selected docs) results
curr_segment = re.sub(
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
)
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
if link:
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
# In case there's another open bracket like [1][, don't want to match this
possible_citation_found = None
# if we see "[", but haven't seen the right side, hold back - this may be a
# citation that needs to be replaced with a link
if possible_citation_found:
continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = ""
if curr_segment:
if prepend_bracket:
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
else:
yield DanswerAnswerPiece(answer_piece=curr_segment)
def reorganize_citations(
answer: str, citations: list[CitationInfo]
) -> tuple[str, list[CitationInfo]]:
@ -547,72 +164,3 @@ def reorganize_citations(
new_citation_info[citation.citation_num] = citation
return new_answer, list(new_citation_info.values())
def get_prompt_tokens(prompt: Prompt) -> int:
# Note: currently custom prompts do not allow datetime aware, only default prompts
return (
check_number_of_tokens(prompt.system_prompt)
+ check_number_of_tokens(prompt.task_prompt)
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
+ CITATION_STATEMENT_TOKEN_CNT
+ CITATION_REMINDER_TOKEN_CNT
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0)
)
# buffer just to be safe so that we don't overflow the token limit due to
# a small miscalculation
_MISC_BUFFER = 40
def compute_max_document_tokens(
persona: Persona,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
"""Estimates the number of tokens available for context documents. Formula is roughly:
(
model_context_window - reserved_output_tokens - prompt_tokens
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
)
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
if we're trying to determine if the user should be able to select another document) then we just set an
arbitrary "upper bound".
"""
llm_name = get_default_llm_version()[0]
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
# if we can't find a number of tokens, just assume some common default
max_input_tokens = (
max_llm_token_override
if max_llm_token_override
else get_max_input_tokens(model_name=llm_name)
)
if persona.prompts:
# TODO this may not always be the first prompt
prompt_tokens = get_prompt_tokens(persona.prompts[0])
else:
prompt_tokens = get_prompt_tokens(get_default_prompt())
user_input_tokens = (
check_number_of_tokens(actual_user_input)
if actual_user_input is not None
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
)
return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
def compute_max_llm_input_tokens(persona: Persona) -> int:
"""Maximum tokens allows in the input to the LLM (of any type)."""
llm_name = get_default_llm_version()[0]
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
input_tokens = get_max_input_tokens(model_name=llm_name)
return input_tokens - _MISC_BUFFER

View File

@ -16,11 +16,13 @@ class LlmDoc(BaseModel):
document_id: str
content: str
blurb: str
semantic_identifier: str
source_type: DocumentSource
metadata: dict[str, str | list[str]]
updated_at: datetime | None
link: str | None
source_links: dict[int, str] | None
# First chunk of info for streaming QA
@ -100,9 +102,12 @@ class QAResponse(SearchResponse, DanswerAnswer):
error_msg: str | None = None
AnswerQuestionStreamReturn = Iterator[
DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError
]
AnswerQuestionPossibleReturn = (
DanswerAnswerPiece | DanswerQuotes | CitationInfo | DanswerContexts | StreamingError
)
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
class LLMMetricsContainer(BaseModel):

View File

@ -5,16 +5,8 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import build_chat_system_message
from danswer.chat.chat_utils import build_chat_user_message
from danswer.chat.chat_utils import compute_max_document_tokens
from danswer.chat.chat_utils import compute_max_llm_input_tokens
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import drop_messages_history_overflow
from danswer.chat.chat_utils import extract_citations_from_stream
from danswer.chat.chat_utils import get_chunks_for_qa
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
from danswer.chat.chat_utils import map_document_id_order
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
@ -23,9 +15,7 @@ from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
@ -37,21 +27,17 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import tokenizer_trim_content
from danswer.llm.utils import translate_history_to_basemessages
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
@ -68,72 +54,6 @@ from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
def generate_ai_chat_response(
query_message: ChatMessage,
history: list[ChatMessage],
persona: Persona,
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
llm: LLM | None,
llm_tokenizer_encode_func: Callable,
all_doc_useful: bool,
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# Not an error if it's a user configuration
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
return
if query_message.prompt is None:
raise RuntimeError("No prompt received for generating Gen AI answer.")
try:
context_exists = len(context_docs) > 0
system_message_or_none, system_tokens = build_chat_system_message(
prompt=query_message.prompt,
context_exists=context_exists,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
)
history_basemessages, history_token_counts = translate_history_to_basemessages(
history
)
# Be sure the context_docs passed to build_chat_user_message
# Is the same as passed in later for extracting citations
user_message, user_tokens = build_chat_user_message(
chat_message=query_message,
prompt=query_message.prompt,
context_docs=context_docs,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
all_doc_useful=all_doc_useful,
)
prompt = drop_messages_history_overflow(
system_msg=system_message_or_none,
system_token_count=system_tokens,
history_msgs=history_basemessages,
history_token_counts=history_token_counts,
final_msg=user_message,
final_msg_token_count=user_tokens,
max_allowed_tokens=compute_max_llm_input_tokens(persona),
)
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
yield from extract_citations_from_stream(
tokens, context_docs, doc_id_to_rank_map
)
except Exception as e:
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
yield StreamingError(error=str(e))
def translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> dict[int, int]:
@ -154,24 +74,26 @@ def translate_citations(
return citation_to_saved_doc_id_map
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
# For flow with search, don't include as many chunks as possible since we need to leave space
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
) -> Iterator[
ChatPacketStream = Iterator[
StreamingError
| QADocsResponse
| LLMRelevanceFilterResponse
| ChatMessageDetail
| DanswerAnswerPiece
| CitationInfo
]:
]
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
# For flow with search, don't include as many chunks as possible since we need to leave space
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
@ -277,10 +199,6 @@ def stream_chat_message_objects(
query_message=final_msg, history=history_msgs, llm=llm
)
max_document_tokens = compute_max_document_tokens(
persona=persona, actual_user_input=message_text
)
rephrased_query = None
if reference_doc_ids:
identifier_tuples = get_doc_query_identifiers_from_model(
@ -296,64 +214,8 @@ def stream_chat_message_objects(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
# truncate the last document if it exceeds the token limit
tokens_per_doc = [
len(
llm_tokenizer_encode_func(
build_doc_context_str(
semantic_identifier=llm_doc.semantic_identifier,
source_type=llm_doc.source_type,
content=llm_doc.content,
metadata_dict=llm_doc.metadata,
updated_at=llm_doc.updated_at,
ind=ind,
)
)
)
for ind, llm_doc in enumerate(llm_docs)
]
final_doc_ind = None
total_tokens = 0
for ind, tokens in enumerate(tokens_per_doc):
total_tokens += tokens
if total_tokens > max_document_tokens:
final_doc_ind = ind
break
if final_doc_ind is not None:
# only allow the final document to get truncated
# if more than that, then the user message is too long
if final_doc_ind != len(tokens_per_doc) - 1:
yield StreamingError(
error="LLM context window exceeded. Please de-select some documents or shorten your query."
)
return
final_doc_desired_length = tokens_per_doc[final_doc_ind] - (
total_tokens - max_document_tokens
)
# 75 tokens is a reasonable over-estimate of the metadata and title
final_doc_content_length = final_doc_desired_length - 75
# this could occur if we only have space for the title / metadata
# not ideal, but it's the most reasonable thing to do
# NOTE: the frontend prevents documents from being selected if
# less than 75 tokens are available to try and avoid this situation
# from occuring in the first place
if final_doc_content_length <= 0:
logger.error(
f"Final doc ({llm_docs[final_doc_ind].semantic_identifier}) content "
"length is less than 0. Removing this doc from the final prompt."
)
llm_docs.pop()
else:
llm_docs[final_doc_ind].content = tokenizer_trim_content(
content=llm_docs[final_doc_ind].content,
desired_length=final_doc_content_length,
tokenizer=llm_tokenizer,
)
doc_id_to_rank_map = map_document_id_order(
cast(list[InferenceChunk | LlmDoc], llm_docs)
document_pruning_config = DocumentPruningConfig(
is_manually_selected_docs=True
)
# In case the search doc is deleted, just don't include it
@ -393,9 +255,6 @@ def stream_chat_message_objects(
top_chunks = search_pipeline.reranked_docs
top_docs = chunks_to_search_docs(top_chunks)
# Get ranking of the documents for citation purposes later
doc_id_to_rank_map = map_document_id_order(top_chunks)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
for top_doc in top_docs
@ -423,41 +282,21 @@ def stream_chat_message_objects(
)
yield llm_relevance_filtering_response
# Prep chunks to pass to LLM
num_llm_chunks = (
persona.num_chunks
if persona.num_chunks is not None
else default_num_chunks
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
persona.num_chunks
if persona.num_chunks is not None
else default_num_chunks
),
max_window_percentage=max_document_percentage,
)
llm_name = get_default_llm_version()[0]
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
llm_max_input_tokens = get_max_input_tokens(model_name=llm_name)
llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens
chunk_token_limit = int(
min(
num_llm_chunks * default_chunk_size,
max_document_tokens,
llm_token_based_chunk_lim,
)
)
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=search_pipeline.chunk_relevance_list,
token_limit=chunk_token_limit,
llm_tokenizer=llm_tokenizer,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in top_chunks]
else:
llm_docs = []
doc_id_to_rank_map = {}
reference_db_search_docs = None
document_pruning_config = DocumentPruningConfig()
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
@ -495,33 +334,24 @@ def stream_chat_message_objects(
return
# LLM prompt building, response capturing, etc.
response_packets = generate_ai_chat_response(
query_message=final_msg,
history=history_msgs,
answer = Answer(
question=final_msg.message,
docs=llm_docs,
answer_style_config=AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=reference_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
),
prompt=final_msg.prompt,
persona=persona,
context_docs=llm_docs,
doc_id_to_rank_map=doc_id_to_rank_map,
llm=llm,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
all_doc_useful=reference_doc_ids is not None,
message_history=[
PreviousMessage.from_chat_message(msg) for msg in history_msgs
],
)
# generator will not include quotes, so we can cast
yield from cast(ChatPacketStream, answer.processed_streamed_output)
# Capture outputs and errors
llm_output = ""
error: str | None = None
citations: list[CitationInfo] = []
for packet in response_packets:
if isinstance(packet, DanswerAnswerPiece):
token = packet.answer_piece
if token:
llm_output += token
elif isinstance(packet, StreamingError):
error = packet.error
elif isinstance(packet, CitationInfo):
citations.append(packet)
continue
yield packet
except Exception as e:
logger.exception(e)
@ -535,16 +365,16 @@ def stream_chat_message_objects(
db_citations = None
if reference_db_search_docs:
db_citations = translate_citations(
citations_list=citations,
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
# Saving Gen AI answer and responding with message info
gen_ai_response_message = partial_response(
message=llm_output,
token_count=len(llm_tokenizer_encode_func(llm_output)),
message=answer.llm_answer,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=error,
error=None,
)
msg_detail_response = translate_db_message_to_chat_message_detail(

View File

@ -12,7 +12,6 @@ from slack_sdk.errors import SlackApiError
from slack_sdk.models.blocks import DividerBlock
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import compute_max_document_tokens
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
@ -39,6 +38,7 @@ from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens

View File

@ -0,0 +1,176 @@
from collections.abc import Iterator
from typing import cast
from langchain.schema.messages import BaseMessage
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.llm.answering.doc_pruning import prune_documents
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.prompts.citations_prompt import build_citations_prompt
from danswer.llm.answering.prompts.quotes_prompt import (
build_quotes_prompt,
)
from danswer.llm.answering.stream_processing.citation_processing import (
build_citation_processor,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
build_quotes_processor,
)
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_default_llm_tokenizer
def _get_stream_processor(
docs: list[LlmDoc], answer_style_configs: AnswerStyleConfig
) -> StreamProcessor:
if answer_style_configs.citation_config:
return build_citation_processor(
context_docs=docs,
)
if answer_style_configs.quotes_config:
return build_quotes_processor(
context_docs=docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
)
raise RuntimeError("Not implemented yet")
class Answer:
def __init__(
self,
question: str,
docs: list[LlmDoc],
answer_style_config: AnswerStyleConfig,
prompt: Prompt,
persona: Persona,
# must be the same length as `docs`. If None, all docs are considered "relevant"
doc_relevance_list: list[bool] | None = None,
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
timeout: int = QA_TIMEOUT,
) -> None:
if single_message_history and message_history:
raise ValueError(
"Cannot provide both `message_history` and `single_message_history`"
)
self.question = question
self.docs = docs
self.doc_relevance_list = doc_relevance_list
self.message_history = message_history or []
# used for QA flow where we only want to send a single message
self.single_message_history = single_message_history
self.answer_style_config = answer_style_config
self.llm = get_default_llm(
gen_ai_model_version_override=persona.llm_model_version_override,
timeout=timeout,
)
self.llm_tokenizer = get_default_llm_tokenizer()
self.prompt = prompt
self.persona = persona
self.process_stream_fn = _get_stream_processor(docs, answer_style_config)
self._final_prompt: list[BaseMessage] | None = None
self._pruned_docs: list[LlmDoc] | None = None
self._streamed_output: list[str] | None = None
self._processed_stream: list[AnswerQuestionPossibleReturn] | None = None
@property
def pruned_docs(self) -> list[LlmDoc]:
if self._pruned_docs is not None:
return self._pruned_docs
self._pruned_docs = prune_documents(
docs=self.docs,
doc_relevance_list=self.doc_relevance_list,
persona=self.persona,
question=self.question,
document_pruning_config=self.answer_style_config.document_pruning_config,
)
return self._pruned_docs
@property
def final_prompt(self) -> list[BaseMessage]:
if self._final_prompt is not None:
return self._final_prompt
if self.answer_style_config.citation_config:
self._final_prompt = build_citations_prompt(
question=self.question,
message_history=self.message_history,
persona=self.persona,
prompt=self.prompt,
context_docs=self.pruned_docs,
all_doc_useful=self.answer_style_config.citation_config.all_docs_useful,
llm_tokenizer_encode_func=self.llm_tokenizer.encode,
history_message=self.single_message_history or "",
)
elif self.answer_style_config.quotes_config:
self._final_prompt = build_quotes_prompt(
question=self.question,
context_docs=self.pruned_docs,
history_str=self.single_message_history or "",
prompt=self.prompt,
)
return cast(list[BaseMessage], self._final_prompt)
@property
def raw_streamed_output(self) -> Iterator[str]:
if self._streamed_output is not None:
yield from self._streamed_output
return
streamed_output = []
for message in self.llm.stream(self.final_prompt):
streamed_output.append(message)
yield message
self._streamed_output = streamed_output
@property
def processed_streamed_output(self) -> AnswerQuestionStreamReturn:
if self._processed_stream is not None:
yield from self._processed_stream
return
processed_stream = []
for processed_packet in self.process_stream_fn(self.raw_streamed_output):
processed_stream.append(processed_packet)
yield processed_packet
self._processed_stream = processed_stream
@property
def llm_answer(self) -> str:
answer = ""
for packet in self.processed_streamed_output:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
return answer
@property
def citations(self) -> list[CitationInfo]:
citations: list[CitationInfo] = []
for packet in self.processed_streamed_output:
if isinstance(packet, CitationInfo):
citations.append(packet)
return citations

View File

@ -0,0 +1,205 @@
from copy import deepcopy
from typing import TypeVar
from danswer.chat.models import (
LlmDoc,
)
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import Persona
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.utils.logger import setup_logger
logger = setup_logger()
T = TypeVar("T", bound=LlmDoc | InferenceChunk)
_METADATA_TOKEN_ESTIMATE = 75
class PruningError(Exception):
pass
def _compute_limit(
persona: Persona,
question: str,
max_chunks: int | None,
max_window_percentage: float | None,
max_tokens: int | None,
) -> int:
llm_max_document_tokens = compute_max_document_tokens(
persona=persona, actual_user_input=question
)
window_percentage_based_limit = (
max_window_percentage * llm_max_document_tokens
if max_window_percentage
else None
)
chunk_count_based_limit = (
max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None
)
limit_options = [
lim
for lim in [
window_percentage_based_limit,
chunk_count_based_limit,
max_tokens,
llm_max_document_tokens,
]
if lim
]
return int(min(limit_options))
def reorder_docs(
docs: list[T],
doc_relevance_list: list[bool] | None,
) -> list[T]:
if doc_relevance_list is None:
return docs
reordered_docs: list[T] = []
if doc_relevance_list is not None:
for selection_target in [True, False]:
for doc, is_relevant in zip(docs, doc_relevance_list):
if is_relevant == selection_target:
reordered_docs.append(doc)
return reordered_docs
def _remove_docs_to_ignore(docs: list[LlmDoc]) -> list[LlmDoc]:
return [doc for doc in docs if not doc.metadata.get(IGNORE_FOR_QA)]
def _apply_pruning(
docs: list[LlmDoc],
doc_relevance_list: list[bool] | None,
token_limit: int,
is_manually_selected_docs: bool,
) -> list[LlmDoc]:
llm_tokenizer = get_default_llm_tokenizer()
docs = deepcopy(docs) # don't modify in place
# re-order docs with all the "relevant" docs at the front
docs = reorder_docs(docs=docs, doc_relevance_list=doc_relevance_list)
# remove docs that are explicitly marked as not for QA
docs = _remove_docs_to_ignore(docs=docs)
tokens_per_doc: list[int] = []
final_doc_ind = None
total_tokens = 0
for ind, llm_doc in enumerate(docs):
doc_tokens = len(
llm_tokenizer.encode(
build_doc_context_str(
semantic_identifier=llm_doc.semantic_identifier,
source_type=llm_doc.source_type,
content=llm_doc.content,
metadata_dict=llm_doc.metadata,
updated_at=llm_doc.updated_at,
ind=ind,
)
)
)
# if chunks, truncate chunks that are way too long
# this can happen if the embedding model tokenizer is different
# than the LLM tokenizer
if (
not is_manually_selected_docs
and doc_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
):
logger.warning(
"Found more tokens in chunk than expected, "
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
)
llm_doc.content = tokenizer_trim_content(
content=llm_doc.content,
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
tokenizer=llm_tokenizer,
)
doc_tokens = DOC_EMBEDDING_CONTEXT_SIZE
tokens_per_doc.append(doc_tokens)
total_tokens += doc_tokens
if total_tokens > token_limit:
final_doc_ind = ind
break
if final_doc_ind is not None:
if is_manually_selected_docs:
# for document selection, only allow the final document to get truncated
# if more than that, then the user message is too long
if final_doc_ind != len(docs) - 1:
raise PruningError(
"LLM context window exceeded. Please de-select some documents or shorten your query."
)
final_doc_desired_length = tokens_per_doc[final_doc_ind] - (
total_tokens - token_limit
)
final_doc_content_length = (
final_doc_desired_length - _METADATA_TOKEN_ESTIMATE
)
# this could occur if we only have space for the title / metadata
# not ideal, but it's the most reasonable thing to do
# NOTE: the frontend prevents documents from being selected if
# less than 75 tokens are available to try and avoid this situation
# from occuring in the first place
if final_doc_content_length <= 0:
logger.error(
f"Final doc ({docs[final_doc_ind].semantic_identifier}) content "
"length is less than 0. Removing this doc from the final prompt."
)
docs.pop()
else:
docs[final_doc_ind].content = tokenizer_trim_content(
content=docs[final_doc_ind].content,
desired_length=final_doc_content_length,
tokenizer=llm_tokenizer,
)
else:
# for regular search, don't truncate the final document unless it's the only one
if final_doc_ind != 0:
docs = docs[:final_doc_ind]
else:
docs[0].content = tokenizer_trim_content(
content=docs[0].content,
desired_length=token_limit - _METADATA_TOKEN_ESTIMATE,
tokenizer=llm_tokenizer,
)
docs = [docs[0]]
return docs
def prune_documents(
docs: list[LlmDoc],
doc_relevance_list: list[bool] | None,
persona: Persona,
question: str,
document_pruning_config: DocumentPruningConfig,
) -> list[LlmDoc]:
if doc_relevance_list is not None:
assert len(docs) == len(doc_relevance_list)
doc_token_limit = _compute_limit(
persona=persona,
question=question,
max_chunks=document_pruning_config.max_chunks,
max_window_percentage=document_pruning_config.max_window_percentage,
max_tokens=document_pruning_config.max_tokens,
)
return _apply_pruning(
docs=docs,
doc_relevance_list=doc_relevance_list,
token_limit=doc_token_limit,
is_manually_selected_docs=document_pruning_config.is_manually_selected_docs,
)

View File

@ -0,0 +1,77 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import Field
from pydantic import root_validator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.configs.constants import MessageType
if TYPE_CHECKING:
from danswer.db.models import ChatMessage
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
class PreviousMessage(BaseModel):
"""Simplified version of `ChatMessage`"""
message: str
token_count: int
message_type: MessageType
@classmethod
def from_chat_message(cls, chat_message: "ChatMessage") -> "PreviousMessage":
return cls(
message=chat_message.message,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
)
class DocumentPruningConfig(BaseModel):
max_chunks: int | None = None
max_window_percentage: float | None = None
max_tokens: int | None = None
# different pruning behavior is expected when the
# user manually selects documents they want to chat with
# e.g. we don't want to truncate each document to be no more
# than one chunk long
is_manually_selected_docs: bool = False
class CitationConfig(BaseModel):
all_docs_useful: bool = False
class QuotesConfig(BaseModel):
pass
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig | None = None
quotes_config: QuotesConfig | None = None
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
@root_validator
def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]:
citation_config = values.get("citation_config")
quotes_config = values.get("quotes_config")
if citation_config is None and quotes_config is None:
raise ValueError(
"One of `citation_config` or `quotes_config` must be provided"
)
if citation_config is not None and quotes_config is not None:
raise ValueError(
"Only one of `citation_config` or `quotes_config` must be provided"
)
return values

View File

@ -0,0 +1,281 @@
from collections.abc import Callable
from functools import lru_cache
from typing import cast
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_default_prompt
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import translate_history_to_basemessages
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.direct_qa_prompts import (
CITATIONS_PROMPT,
)
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.prompts.prompt_utils import build_task_prompt_reminders
from danswer.prompts.prompt_utils import get_current_llm_day_time
from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
from danswer.prompts.token_counts import (
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
)
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
_PER_MESSAGE_TOKEN_BUFFER = 7
def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
"""From the back, find the index of the last element to include
before the list exceeds the maximum"""
running_sum = 0
last_ind = 0
for i in range(len(lst) - 1, -1, -1):
running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER
if running_sum > max_prompt_tokens:
last_ind = i + 1
break
if last_ind >= len(lst):
raise ValueError("Last message alone is too large!")
return last_ind
def drop_messages_history_overflow(
system_msg: BaseMessage | None,
system_token_count: int,
history_msgs: list[BaseMessage],
history_token_counts: list[int],
final_msg: BaseMessage,
final_msg_token_count: int,
max_allowed_tokens: int,
) -> list[BaseMessage]:
"""As message history grows, messages need to be dropped starting from the furthest in the past.
The System message should be kept if at all possible and the latest user input which is inserted in the
prompt template must be included"""
if len(history_msgs) != len(history_token_counts):
# This should never happen
raise ValueError("Need exactly 1 token count per message for tracking overflow")
prompt: list[BaseMessage] = []
# Start dropping from the history if necessary
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
ind_prev_msg_start = find_last_index(
all_tokens, max_prompt_tokens=max_allowed_tokens
)
if system_msg and ind_prev_msg_start <= len(history_msgs):
prompt.append(system_msg)
prompt.extend(history_msgs[ind_prev_msg_start:])
prompt.append(final_msg)
return prompt
def get_prompt_tokens(prompt: Prompt) -> int:
# Note: currently custom prompts do not allow datetime aware, only default prompts
return (
check_number_of_tokens(prompt.system_prompt)
+ check_number_of_tokens(prompt.task_prompt)
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
+ CITATION_STATEMENT_TOKEN_CNT
+ CITATION_REMINDER_TOKEN_CNT
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0)
)
# buffer just to be safe so that we don't overflow the token limit due to
# a small miscalculation
_MISC_BUFFER = 40
def compute_max_document_tokens(
persona: Persona,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
"""Estimates the number of tokens available for context documents. Formula is roughly:
(
model_context_window - reserved_output_tokens - prompt_tokens
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
)
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
if we're trying to determine if the user should be able to select another document) then we just set an
arbitrary "upper bound".
"""
llm_name = get_default_llm_version()[0]
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
# if we can't find a number of tokens, just assume some common default
max_input_tokens = (
max_llm_token_override
if max_llm_token_override
else get_max_input_tokens(model_name=llm_name)
)
if persona.prompts:
# TODO this may not always be the first prompt
prompt_tokens = get_prompt_tokens(persona.prompts[0])
else:
prompt_tokens = get_prompt_tokens(get_default_prompt())
user_input_tokens = (
check_number_of_tokens(actual_user_input)
if actual_user_input is not None
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
)
return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
def compute_max_llm_input_tokens(persona: Persona) -> int:
"""Maximum tokens allows in the input to the LLM (of any type)."""
llm_name = get_default_llm_version()[0]
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
input_tokens = get_max_input_tokens(model_name=llm_name)
return input_tokens - _MISC_BUFFER
@lru_cache()
def build_system_message(
prompt: Prompt,
context_exists: bool,
llm_tokenizer_encode_func: Callable,
citation_line: str = REQUIRE_CITATION_STATEMENT,
no_citation_line: str = NO_CITATION_STATEMENT,
) -> tuple[SystemMessage | None, int]:
system_prompt = prompt.system_prompt.strip()
if prompt.include_citations:
if context_exists:
system_prompt += citation_line
else:
system_prompt += no_citation_line
if prompt.datetime_aware:
if system_prompt:
system_prompt += ADDITIONAL_INFO.format(
datetime_info=get_current_llm_day_time()
)
else:
system_prompt = get_current_llm_day_time()
if not system_prompt:
return None, 0
token_count = len(llm_tokenizer_encode_func(system_prompt))
system_msg = SystemMessage(content=system_prompt)
return system_msg, token_count
def build_user_message(
question: str,
prompt: Prompt,
context_docs: list[LlmDoc] | list[InferenceChunk],
all_doc_useful: bool,
history_message: str,
) -> tuple[HumanMessage, int]:
llm_tokenizer = get_default_llm_tokenizer()
llm_tokenizer_encode_func = cast(Callable[[str], list[int]], llm_tokenizer.encode)
if not context_docs:
# Simpler prompt for cases where there is no context
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
task_prompt=prompt.task_prompt, user_query=question
)
if prompt.task_prompt
else question
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt)
return user_msg, token_count
context_docs_str = build_complete_context_str(context_docs)
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
user_prompt = CITATIONS_PROMPT.format(
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
user_query=question,
history_block=history_message,
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt)
return user_msg, token_count
def build_citations_prompt(
question: str,
message_history: list[PreviousMessage],
persona: Persona,
prompt: Prompt,
context_docs: list[LlmDoc] | list[InferenceChunk],
all_doc_useful: bool,
history_message: str,
llm_tokenizer_encode_func: Callable,
) -> list[BaseMessage]:
context_exists = len(context_docs) > 0
system_message_or_none, system_tokens = build_system_message(
prompt=prompt,
context_exists=context_exists,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
)
history_basemessages, history_token_counts = translate_history_to_basemessages(
message_history
)
# Be sure the context_docs passed to build_chat_user_message
# Is the same as passed in later for extracting citations
user_message, user_tokens = build_user_message(
question=question,
prompt=prompt,
context_docs=context_docs,
all_doc_useful=all_doc_useful,
history_message=history_message,
)
final_prompt_msgs = drop_messages_history_overflow(
system_msg=system_message_or_none,
system_token_count=system_tokens,
history_msgs=history_basemessages,
history_token_counts=history_token_counts,
final_msg=user_message,
final_msg_token_count=user_tokens,
max_allowed_tokens=compute_max_llm_input_tokens(persona),
)
return final_prompt_msgs

View File

@ -0,0 +1,88 @@
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.prompt_utils import build_complete_context_str
def _build_weak_llm_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: Prompt,
use_language_hint: bool,
) -> list[BaseMessage]:
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
as an option to use with weaker LLMs such as small version, low float precision, quantized,
or distilled models. It only uses one context document and has very weak requirements of
output format.
"""
context_block = ""
if context_docs:
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content)
prompt_str = WEAK_LLM_PROMPT.format(
system_prompt=prompt.system_prompt,
context_block=context_block,
task_prompt=prompt.task_prompt,
user_query=question,
)
return [HumanMessage(content=prompt_str)]
def _build_strong_llm_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: Prompt,
use_language_hint: bool,
) -> list[BaseMessage]:
context_block = ""
if context_docs:
context_docs_str = build_complete_context_str(context_docs)
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
history_block = ""
if history_str:
history_block = HISTORY_BLOCK.format(history_str=history_str)
full_prompt = JSON_PROMPT.format(
system_prompt=prompt.system_prompt,
context_block=context_block,
history_block=history_block,
task_prompt=prompt.task_prompt,
user_query=question,
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
).strip()
return [HumanMessage(content=full_prompt)]
def build_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: Prompt,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> list[BaseMessage]:
prompt_builder = (
_build_weak_llm_quotes_prompt
if QA_PROMPT_OVERRIDE == "weak"
else _build_strong_llm_quotes_prompt
)
return prompt_builder(
question=question,
context_docs=context_docs,
history_str=history_str,
prompt=prompt,
use_language_hint=use_language_hint,
)

View File

@ -0,0 +1,20 @@
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()

View File

@ -0,0 +1,126 @@
import re
from collections.abc import Iterator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.utils.logger import setup_logger
logger = setup_logger()
def in_code_block(llm_text: str) -> bool:
count = llm_text.count(TRIPLE_BACKTICK)
return count % 2 != 0
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
stop_stream: str | None = STOP_STREAM_PAT,
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
llm_out = ""
max_citation_num = len(context_docs)
curr_segment = ""
prepend_bracket = False
cited_inds = set()
hold = ""
for raw_token in tokens:
if stop_stream:
next_hold = hold + raw_token
if stop_stream in next_hold:
break
if next_hold == stop_stream[: len(next_hold)]:
hold = next_hold
continue
token = next_hold
hold = ""
else:
token = raw_token
# Special case of [1][ where ][ is a single token
# This is where the model attempts to do consecutive citations like [1][2]
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
curr_segment += token
llm_out += token
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
if citation_found and not in_code_block(llm_out):
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[
numerical_value - 1
] # remove 1 index offset
link = context_llm_doc.link
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
# Use the citation number for the document's rank in
# the search (or selected docs) results
curr_segment = re.sub(
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
)
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
if link:
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
# In case there's another open bracket like [1][, don't want to match this
possible_citation_found = None
# if we see "[", but haven't seen the right side, hold back - this may be a
# citation that needs to be replaced with a link
if possible_citation_found:
continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = ""
if curr_segment:
if prepend_bracket:
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
else:
yield DanswerAnswerPiece(answer_piece=curr_segment)
def build_citation_processor(
context_docs: list[LlmDoc],
) -> StreamProcessor:
def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn:
yield from extract_citations_from_stream(
tokens=tokens,
context_docs=context_docs,
doc_id_to_rank_map=map_document_id_order(context_docs),
)
return stream_processor

View File

@ -0,0 +1,282 @@
import math
import re
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Optional
import regex
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.indexing.models import InferenceChunk
from danswer.prompts.constants import ANSWER_PAT
from danswer.prompts.constants import QUOTE_PAT
from danswer.prompts.constants import UNCERTAINTY_PAT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import clean_up_code_blocks
from danswer.utils.text_processing import extract_embedded_json
from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger()
def _extract_answer_quotes_freeform(
answer_raw: str,
) -> tuple[Optional[str], Optional[list[str]]]:
"""Splits the model output into an Answer and 0 or more Quote sections.
Splits by the Quote pattern, if not exist then assume it's all answer and no quotes
"""
# If no answer section, don't care about the quote
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
return None, None
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
answer_raw = answer_raw[len(ANSWER_PAT) :]
# Accept quote sections starting with the lower case version
answer_raw = answer_raw.replace(
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
) # Just in case model unreliable
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
sections_clean = [
str(section).strip() for section in sections if str(section).strip()
]
if not sections_clean:
return None, None
answer = str(sections_clean[0])
if len(sections) == 1:
return answer, None
return answer, sections_clean[1:]
def _extract_answer_quotes_json(
answer_dict: dict[str, str | list[str]]
) -> tuple[Optional[str], Optional[list[str]]]:
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
answer = str(answer_dict.get("answer"))
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
if isinstance(quotes, str):
quotes = [quotes]
return answer, quotes
def _extract_answer_json(raw_model_output: str) -> dict:
try:
answer_json = extract_embedded_json(raw_model_output)
except (ValueError, JSONDecodeError):
# LLMs get confused when handling the list in the json. Sometimes it doesn't attend
# enough to the previous { token so it just ends the list of quotes and stops there
# here, we add logic to try to fix this LLM error.
answer_json = extract_embedded_json(raw_model_output + "}")
if "answer" not in answer_json:
raise ValueError("Model did not output an answer as expected.")
return answer_json
def match_quotes_to_docs(
quotes: list[str],
docs: list[LlmDoc] | list[InferenceChunk],
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
fuzzy_search: bool = False,
prefix_only_length: int = 100,
) -> DanswerQuotes:
danswer_quotes: list[DanswerQuote] = []
for quote in quotes:
max_edits = math.ceil(float(len(quote)) * max_error_percent)
for doc in docs:
if not doc.source_links:
continue
quote_clean = shared_precompare_cleanup(
clean_model_quote(quote, trim_length=prefix_only_length)
)
chunk_clean = shared_precompare_cleanup(doc.content)
# Finding the offset of the quote in the plain text
if fuzzy_search:
re_search_str = (
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
)
found = regex.search(re_search_str, chunk_clean)
if not found:
continue
offset = found.span()[0]
else:
if quote_clean not in chunk_clean:
continue
offset = chunk_clean.index(quote_clean)
# Extracting the link from the offset
curr_link = None
for link_offset, link in doc.source_links.items():
# Should always find one because offset is at least 0 and there
# must be a 0 link_offset
if int(link_offset) <= offset:
curr_link = link
else:
break
danswer_quotes.append(
DanswerQuote(
quote=quote,
document_id=doc.document_id,
link=curr_link,
source_type=doc.source_type,
semantic_identifier=doc.semantic_identifier,
blurb=doc.blurb,
)
)
break
return DanswerQuotes(quotes=danswer_quotes)
def separate_answer_quotes(
answer_raw: str, is_json_prompt: bool = False
) -> tuple[Optional[str], Optional[list[str]]]:
"""Takes in a raw model output and pulls out the answer and the quotes sections."""
if is_json_prompt:
model_raw_json = _extract_answer_json(answer_raw)
return _extract_answer_quotes_json(model_raw_json)
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
def process_answer(
answer_raw: str,
docs: list[LlmDoc],
is_json_prompt: bool = True,
) -> tuple[DanswerAnswer, DanswerQuotes]:
"""Used (1) in the non-streaming case to process the model output
into an Answer and Quotes AND (2) after the complete streaming response
has been received to process the model output into an Answer and Quotes."""
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
if answer == UNCERTAINTY_PAT or not answer:
if answer == UNCERTAINTY_PAT:
logger.debug("Answer matched UNCERTAINTY_PAT")
else:
logger.debug("No answer extracted from raw output")
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
logger.info(f"Answer: {answer}")
if not quote_strings:
logger.debug("No quotes extracted from raw output")
return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[])
logger.info(f"All quotes (including unmatched): {quote_strings}")
quotes = match_quotes_to_docs(quote_strings, docs)
logger.debug(f"Final quotes: {quotes}")
return DanswerAnswer(answer=answer), quotes
def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
next_token = next_token.replace('\\"', "")
# If the previous character is an escape token, don't consider the first character of next_token
# This does not work if it's an escaped escape sign before the " but this is rare, not worth handling
if answer_so_far and answer_so_far[-1] == "\\":
next_token = next_token[1:]
if '"' in next_token:
return True
return False
def _extract_quotes_from_completed_token_stream(
model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True
) -> DanswerQuotes:
answer, quotes = process_answer(model_output, context_docs, is_json_prompt)
if answer:
logger.info(answer)
elif model_output:
logger.warning("Answer extraction from model output failed.")
return quotes
def process_model_tokens(
tokens: Iterator[str],
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
"""Used in the streaming case to process the model output
into an Answer and Quotes
Yields Answer tokens back out in a dict for streaming to frontend
When Answer section ends, yields dict with answer_finished key
Collects all the tokens at the end to form the complete model output"""
quote_pat = f"\n{QUOTE_PAT}"
# Sometimes worse model outputs new line instead of :
quote_loose = f"\n{quote_pat[:-1]}\n"
# Sometime model outputs two newlines before quote section
quote_pat_full = f"\n{quote_pat}"
model_output: str = ""
found_answer_start = False if is_json_prompt else True
found_answer_end = False
hold_quote = ""
for token in tokens:
model_previous = model_output
model_output += token
if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output):
# Note, if the token that completes the pattern has additional text, for example if the token is "?
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
# event that the model outputs the UNCERTAINTY_PAT
found_answer_start = True
# Prevent heavy cases of hallucinations where model is not even providing a json until later
if is_json_prompt and len(model_output) > 40:
logger.warning("LLM did not produce json as prompted")
found_answer_end = True
continue
if found_answer_start and not found_answer_end:
if is_json_prompt and _stream_json_answer_end(model_previous, token):
found_answer_end = True
yield DanswerAnswerPiece(answer_piece=None)
continue
elif not is_json_prompt:
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
found_answer_end = True
yield DanswerAnswerPiece(answer_piece=None)
continue
if hold_quote + token in quote_pat_full:
hold_quote += token
continue
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
hold_quote = ""
logger.debug(f"Raw Model QnA Output: {model_output}")
yield _extract_quotes_from_completed_token_stream(
model_output=model_output,
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
def build_quotes_processor(
context_docs: list[LlmDoc], is_json_prompt: bool
) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]:
def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
return stream_processor

View File

@ -0,0 +1,17 @@
from collections.abc import Sequence
from danswer.chat.models import LlmDoc
from danswer.indexing.models import InferenceChunk
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> dict[str, int]:
order_mapping = {}
current = 1 if one_indexed else 0
for chunk in chunks:
if chunk.document_id not in order_mapping:
order_mapping[chunk.document_id] = current
current += 1
return order_mapping

View File

@ -33,6 +33,7 @@ from danswer.db.models import ChatMessage
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.utils.logger import setup_logger
@ -114,7 +115,9 @@ def tokenizer_trim_chunks(
return new_chunks
def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage:
def translate_danswer_msg_to_langchain(
msg: ChatMessage | PreviousMessage,
) -> BaseMessage:
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
if msg.message_type == MessageType.ASSISTANT:
@ -126,7 +129,7 @@ def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage:
def translate_history_to_basemessages(
history: list[ChatMessage],
history: list[ChatMessage] | list[PreviousMessage],
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_danswer_msg_to_langchain(msg)

View File

@ -1,54 +1,37 @@
import itertools
from collections.abc import Callable
from collections.abc import Iterator
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import build_chat_system_message
from danswer.chat.chat_utils import compute_max_document_tokens
from danswer.chat.chat_utils import extract_citations_from_stream
from danswer.chat.chat_utils import get_chunks_for_qa
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
from danswer.chat.chat_utils import map_document_id_order
from danswer.chat.chat_utils import reorganize_citations
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContext
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LLMMetricsContainer
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import get_persona_by_id
from danswer.db.chat import get_prompt_by_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.engine import get_session_context_manager
from danswer.db.models import Prompt
from danswer.db.models import User
from danswer.indexing.models import InferenceChunk
from danswer.llm.factory import get_default_llm
from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import QuotesConfig
from danswer.llm.utils import get_default_llm_token_encode
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.one_shot_answer.factory import get_question_answer_model
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.models import ThreadMessage
from danswer.one_shot_answer.qa_block import no_gen_ai_response
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.prompts.prompt_utils import build_task_prompt_reminders
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SavedSearchDoc
@ -77,106 +60,6 @@ AnswerObjectIterator = Iterator[
]
def quote_based_qa(
prompt: Prompt,
query_message: ThreadMessage,
history_str: str,
context_chunks: list[InferenceChunk],
llm_override: str | None,
timeout: int,
use_chain_of_thought: bool,
return_contexts: bool,
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerObjectIterator:
qa_model = get_question_answer_model(
prompt=prompt,
timeout=timeout,
chain_of_thought=use_chain_of_thought,
llm_version=llm_override,
)
full_prompt_str = (
qa_model.build_prompt(
query=query_message.message,
history_str=history_str,
context_chunks=context_chunks,
)
if qa_model is not None
else "Gen AI Disabled"
)
response_packets = (
qa_model.answer_question_stream(
prompt=full_prompt_str,
llm_context_docs=context_chunks,
metrics_callback=llm_metrics_callback,
)
if qa_model is not None
else no_gen_ai_response()
)
if qa_model is not None and return_contexts:
contexts = DanswerContexts(
contexts=[
DanswerContext(
content=context_chunk.content,
document_id=context_chunk.document_id,
semantic_identifier=context_chunk.semantic_identifier,
blurb=context_chunk.semantic_identifier,
)
for context_chunk in context_chunks
]
)
response_packets = itertools.chain(response_packets, [contexts])
yield from response_packets
def citation_based_qa(
prompt: Prompt,
query_message: ThreadMessage,
history_str: str,
context_chunks: list[InferenceChunk],
llm_override: str | None,
timeout: int,
) -> AnswerObjectIterator:
llm_tokenizer = get_default_llm_tokenizer()
system_prompt_or_none, _ = build_chat_system_message(
prompt=prompt,
context_exists=True,
llm_tokenizer_encode_func=llm_tokenizer.encode,
)
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
context_docs_str = build_complete_context_str(context_chunks)
user_message = HumanMessage(
content=CITATIONS_PROMPT.format(
task_prompt=task_prompt_with_reminder,
user_query=query_message.message,
history_block=history_str,
context_docs_str=context_docs_str,
)
)
llm = get_default_llm(
timeout=timeout,
gen_ai_model_version_override=llm_override,
)
llm_prompt: list[BaseMessage] = [user_message]
if system_prompt_or_none is not None:
llm_prompt = [system_prompt_or_none] + llm_prompt
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in context_chunks]
doc_id_to_rank_map = map_document_id_order(llm_docs)
tokens = llm.stream(llm_prompt)
yield from extract_citations_from_stream(tokens, llm_docs, doc_id_to_rank_map)
def stream_answer_objects(
query_req: DirectQARequest,
user: User | None,
@ -188,14 +71,12 @@ def stream_answer_objects(
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
timeout: int = QA_TIMEOUT,
bypass_acl: bool = False,
use_citations: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerObjectIterator:
"""Streams in order:
1. [always] Retrieved documents, stops flow if nothing is found
@ -273,43 +154,11 @@ def stream_answer_objects(
)
yield llm_relevance_filtering_response
# Prep chunks to pass to LLM
num_llm_chunks = (
chat_session.persona.num_chunks
if chat_session.persona.num_chunks is not None
else default_num_chunks
)
chunk_token_limit = int(num_llm_chunks * default_chunk_size)
if max_document_tokens:
chunk_token_limit = min(chunk_token_limit, max_document_tokens)
else:
max_document_tokens = compute_max_document_tokens(
persona=chat_session.persona, actual_user_input=query_msg.message
)
chunk_token_limit = min(chunk_token_limit, max_document_tokens)
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=search_pipeline.chunk_relevance_list,
token_limit=chunk_token_limit,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
)
prompt = None
llm_override = None
if query_req.prompt_id is not None:
prompt = get_prompt_by_id(
prompt_id=query_req.prompt_id, user_id=user_id, db_session=db_session
)
persona = get_persona_by_id(
persona_id=query_req.persona_id, user_id=user_id, db_session=db_session
)
llm_override = persona.llm_model_version_override
if prompt is None:
if not chat_session.persona.prompts:
raise RuntimeError(
@ -329,52 +178,39 @@ def stream_answer_objects(
commit=True,
)
if use_citations:
qa_stream = citation_based_qa(
prompt=prompt,
query_message=query_msg,
history_str=history_str,
context_chunks=llm_chunks,
llm_override=llm_override,
timeout=timeout,
)
else:
qa_stream = quote_based_qa(
prompt=prompt,
query_message=query_msg,
history_str=history_str,
context_chunks=llm_chunks,
llm_override=llm_override,
timeout=timeout,
use_chain_of_thought=False,
return_contexts=False,
llm_metrics_callback=llm_metrics_callback,
)
# Capture outputs and errors
llm_output = ""
error: str | None = None
for packet in qa_stream:
logger.debug(packet)
if isinstance(packet, DanswerAnswerPiece):
token = packet.answer_piece
if token:
llm_output += token
elif isinstance(packet, StreamingError):
error = packet.error
yield packet
answer_config = AnswerStyleConfig(
citation_config=CitationConfig() if use_citations else None,
quotes_config=QuotesConfig() if not use_citations else None,
document_pruning_config=DocumentPruningConfig(
max_chunks=int(
chat_session.persona.num_chunks
if chat_session.persona.num_chunks is not None
else default_num_chunks
),
max_tokens=max_document_tokens,
),
)
answer = Answer(
question=query_msg.message,
docs=[llm_doc_from_inference_chunk(chunk) for chunk in top_chunks],
answer_style_config=answer_config,
prompt=prompt,
persona=chat_session.persona,
doc_relevance_list=search_pipeline.chunk_relevance_list,
single_message_history=history_str,
timeout=timeout,
)
yield from answer.processed_streamed_output
# Saving Gen AI answer and responding with message info
gen_ai_response_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=query_req.prompt_id,
message=llm_output,
token_count=len(llm_tokenizer(llm_output)),
message=answer.llm_answer,
token_count=len(llm_tokenizer(answer.llm_answer)),
message_type=MessageType.ASSISTANT,
error=error,
error=None,
reference_docs=None, # Don't need to save reference docs for one shot flow
db_session=db_session,
commit=True,
@ -419,7 +255,6 @@ def get_search_answer(
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> OneShotQAResponse:
"""Collects the streamed one shot answer responses into a single object"""
qa_response = OneShotQAResponse()
@ -435,7 +270,6 @@ def get_search_answer(
timeout=answer_generation_timeout,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
llm_metrics_callback=llm_metrics_callback,
)
answer = ""

View File

@ -1,48 +0,0 @@
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.db.models import Prompt
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.one_shot_answer.interfaces import QAModel
from danswer.one_shot_answer.qa_block import QABlock
from danswer.one_shot_answer.qa_block import QAHandler
from danswer.one_shot_answer.qa_block import SingleMessageQAHandler
from danswer.one_shot_answer.qa_block import WeakLLMQAHandler
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_question_answer_model(
prompt: Prompt | None,
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
chain_of_thought: bool = False,
llm_version: str | None = None,
qa_model_version: str | None = QA_PROMPT_OVERRIDE,
) -> QAModel | None:
if chain_of_thought:
raise NotImplementedError("COT has been disabled")
system_prompt = prompt.system_prompt if prompt is not None else None
task_prompt = prompt.task_prompt if prompt is not None else None
try:
llm = get_default_llm(
api_key=api_key,
timeout=timeout,
gen_ai_model_version_override=llm_version,
)
except GenAIDisabledException:
return None
if qa_model_version == "weak":
qa_handler: QAHandler = WeakLLMQAHandler(
system_prompt=system_prompt, task_prompt=task_prompt
)
else:
qa_handler = SingleMessageQAHandler(
system_prompt=system_prompt, task_prompt=task_prompt
)
return QABlock(llm=llm, qa_handler=qa_handler)

View File

@ -1,26 +0,0 @@
import abc
from collections.abc import Callable
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import LLMMetricsContainer
from danswer.indexing.models import InferenceChunk
class QAModel:
@abc.abstractmethod
def build_prompt(
self,
query: str,
history_str: str,
context_chunks: list[InferenceChunk],
) -> str:
raise NotImplementedError
@abc.abstractmethod
def answer_question_stream(
self,
prompt: str,
llm_context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionStreamReturn:
raise NotImplementedError

View File

@ -1,313 +0,0 @@
import abc
import re
from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMMetricsContainer
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.indexing.models import InferenceChunk
from danswer.llm.interfaces import LLM
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.interfaces import QAModel
from danswer.one_shot_answer.qa_utils import process_answer
from danswer.one_shot_answer.qa_utils import process_model_tokens
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import COT_PROMPT
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import ONE_SHOT_SYSTEM_PROMPT
from danswer.prompts.direct_qa_prompts import ONE_SHOT_TASK_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_MODEL_SYSTEM_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_MODEL_TASK_PROMPT
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_up_code_blocks
from danswer.utils.text_processing import escape_newlines
logger = setup_logger()
class QAHandler(abc.ABC):
@property
@abc.abstractmethod
def is_json_output(self) -> bool:
"""Does the model output a valid json with answer and quotes keys? Most flows with a
capable model should output a json. This hints to the model that the output is used
with a downstream system rather than freeform creative output. Most models should be
finetuned to recognize this."""
raise NotImplementedError
@abc.abstractmethod
def build_prompt(
self,
query: str,
history_str: str,
context_chunks: list[InferenceChunk],
) -> str:
raise NotImplementedError
def process_llm_token_stream(
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_chunks,
is_json_prompt=self.is_json_output,
)
class WeakLLMQAHandler(QAHandler):
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
as an option to use with weaker LLMs such as small version, low float precision, quantized,
or distilled models. It only uses one context document and has very weak requirements of
output format.
"""
def __init__(
self,
system_prompt: str | None,
task_prompt: str | None,
) -> None:
if not system_prompt and not task_prompt:
self.system_prompt = WEAK_MODEL_SYSTEM_PROMPT
self.task_prompt = WEAK_MODEL_TASK_PROMPT
else:
self.system_prompt = system_prompt or ""
self.task_prompt = task_prompt or ""
self.task_prompt = self.task_prompt.rstrip()
if self.task_prompt and self.task_prompt[0] != "\n":
self.task_prompt = "\n" + self.task_prompt
@property
def is_json_output(self) -> bool:
return False
def build_prompt(
self,
query: str,
history_str: str,
context_chunks: list[InferenceChunk],
) -> str:
context_block = ""
if context_chunks:
context_block = CONTEXT_BLOCK.format(
context_docs_str=context_chunks[0].content
)
prompt_str = WEAK_LLM_PROMPT.format(
system_prompt=self.system_prompt,
context_block=context_block,
task_prompt=self.task_prompt,
user_query=query,
)
return prompt_str
class SingleMessageQAHandler(QAHandler):
def __init__(
self,
system_prompt: str | None,
task_prompt: str | None,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> None:
self.use_language_hint = use_language_hint
if not system_prompt and not task_prompt:
self.system_prompt = ONE_SHOT_SYSTEM_PROMPT
self.task_prompt = ONE_SHOT_TASK_PROMPT
else:
self.system_prompt = system_prompt or ""
self.task_prompt = task_prompt or ""
self.task_prompt = self.task_prompt.rstrip()
if self.task_prompt and self.task_prompt[0] != "\n":
self.task_prompt = "\n" + self.task_prompt
@property
def is_json_output(self) -> bool:
return True
def build_prompt(
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
) -> str:
context_block = ""
if context_chunks:
context_docs_str = build_complete_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks)
)
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
history_block = ""
if history_str:
history_block = HISTORY_BLOCK.format(history_str=history_str)
full_prompt = JSON_PROMPT.format(
system_prompt=self.system_prompt,
context_block=context_block,
history_block=history_block,
task_prompt=self.task_prompt,
user_query=query,
language_hint_or_none=LANGUAGE_HINT.strip()
if self.use_language_hint
else "",
).strip()
return full_prompt
# This one isn't used, currently only streaming prompts are used
class SingleMessageScratchpadHandler(QAHandler):
def __init__(
self,
system_prompt: str | None,
task_prompt: str | None,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> None:
self.use_language_hint = use_language_hint
if not system_prompt and not task_prompt:
self.system_prompt = ONE_SHOT_SYSTEM_PROMPT
self.task_prompt = ONE_SHOT_TASK_PROMPT
else:
self.system_prompt = system_prompt or ""
self.task_prompt = task_prompt or ""
self.task_prompt = self.task_prompt.rstrip()
if self.task_prompt and self.task_prompt[0] != "\n":
self.task_prompt = "\n" + self.task_prompt
@property
def is_json_output(self) -> bool:
return True
def build_prompt(
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
) -> str:
context_docs_str = build_complete_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks)
)
# Outdated
prompt = COT_PROMPT.format(
context_docs_str=context_docs_str,
user_query=query,
language_hint_or_none=LANGUAGE_HINT.strip()
if self.use_language_hint
else "",
).strip()
return prompt
def process_llm_output(
self, model_output: str, context_chunks: list[InferenceChunk]
) -> tuple[DanswerAnswer, DanswerQuotes]:
logger.debug(model_output)
model_clean = clean_up_code_blocks(model_output)
match = re.search(r'{\s*"answer":', model_clean)
if not match:
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
final_json = escape_newlines(model_clean[match.start() :])
return process_answer(
final_json, context_chunks, is_json_prompt=self.is_json_output
)
def process_llm_token_stream(
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
# Can be supported but the parsing is more involved, not handling until needed
raise ValueError(
"This Scratchpad approach is not suitable for real time uses like streaming"
)
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]:
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
class QABlock(QAModel):
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
self._llm = llm
self._qa_handler = qa_handler
def build_prompt(
self,
query: str,
history_str: str,
context_chunks: list[InferenceChunk],
) -> str:
prompt = self._qa_handler.build_prompt(
query=query, history_str=history_str, context_chunks=context_chunks
)
return prompt
def answer_question_stream(
self,
prompt: str,
llm_context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionStreamReturn:
tokens_stream = self._llm.stream(prompt)
captured_tokens = []
try:
for answer_piece in self._qa_handler.process_llm_token_stream(
iter(tokens_stream), llm_context_docs
):
if (
isinstance(answer_piece, DanswerAnswerPiece)
and answer_piece.answer_piece
):
captured_tokens.append(answer_piece.answer_piece)
yield answer_piece
except Exception as e:
yield StreamingError(error=str(e))
if metrics_callback is not None:
prompt_tokens = check_number_of_tokens(
text=str(prompt), encode_fn=get_default_llm_token_encode()
)
response_tokens = check_number_of_tokens(
text="".join(captured_tokens), encode_fn=get_default_llm_token_encode()
)
metrics_callback(
LLMMetricsContainer(
prompt_tokens=prompt_tokens, response_tokens=response_tokens
)
)

View File

@ -1,275 +1,14 @@
import math
import re
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from json.decoder import JSONDecodeError
from typing import Optional
from typing import Tuple
import regex
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import MessageType
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import ThreadMessage
from danswer.prompts.constants import ANSWER_PAT
from danswer.prompts.constants import QUOTE_PAT
from danswer.prompts.constants import UNCERTAINTY_PAT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import clean_up_code_blocks
from danswer.utils.text_processing import extract_embedded_json
from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger()
def _extract_answer_quotes_freeform(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
"""Splits the model output into an Answer and 0 or more Quote sections.
Splits by the Quote pattern, if not exist then assume it's all answer and no quotes
"""
# If no answer section, don't care about the quote
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
return None, None
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
answer_raw = answer_raw[len(ANSWER_PAT) :]
# Accept quote sections starting with the lower case version
answer_raw = answer_raw.replace(
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
) # Just in case model unreliable
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
sections_clean = [
str(section).strip() for section in sections if str(section).strip()
]
if not sections_clean:
return None, None
answer = str(sections_clean[0])
if len(sections) == 1:
return answer, None
return answer, sections_clean[1:]
def _extract_answer_quotes_json(
answer_dict: dict[str, str | list[str]]
) -> Tuple[Optional[str], Optional[list[str]]]:
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
answer = str(answer_dict.get("answer"))
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
if isinstance(quotes, str):
quotes = [quotes]
return answer, quotes
def _extract_answer_json(raw_model_output: str) -> dict:
try:
answer_json = extract_embedded_json(raw_model_output)
except (ValueError, JSONDecodeError):
# LLMs get confused when handling the list in the json. Sometimes it doesn't attend
# enough to the previous { token so it just ends the list of quotes and stops there
# here, we add logic to try to fix this LLM error.
answer_json = extract_embedded_json(raw_model_output + "}")
if "answer" not in answer_json:
raise ValueError("Model did not output an answer as expected.")
return answer_json
def separate_answer_quotes(
answer_raw: str, is_json_prompt: bool = False
) -> Tuple[Optional[str], Optional[list[str]]]:
"""Takes in a raw model output and pulls out the answer and the quotes sections."""
if is_json_prompt:
model_raw_json = _extract_answer_json(answer_raw)
return _extract_answer_quotes_json(model_raw_json)
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
def match_quotes_to_docs(
quotes: list[str],
chunks: list[InferenceChunk],
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
fuzzy_search: bool = False,
prefix_only_length: int = 100,
) -> DanswerQuotes:
danswer_quotes: list[DanswerQuote] = []
for quote in quotes:
max_edits = math.ceil(float(len(quote)) * max_error_percent)
for chunk in chunks:
if not chunk.source_links:
continue
quote_clean = shared_precompare_cleanup(
clean_model_quote(quote, trim_length=prefix_only_length)
)
chunk_clean = shared_precompare_cleanup(chunk.content)
# Finding the offset of the quote in the plain text
if fuzzy_search:
re_search_str = (
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
)
found = regex.search(re_search_str, chunk_clean)
if not found:
continue
offset = found.span()[0]
else:
if quote_clean not in chunk_clean:
continue
offset = chunk_clean.index(quote_clean)
# Extracting the link from the offset
curr_link = None
for link_offset, link in chunk.source_links.items():
# Should always find one because offset is at least 0 and there
# must be a 0 link_offset
if int(link_offset) <= offset:
curr_link = link
else:
break
danswer_quotes.append(
DanswerQuote(
quote=quote,
document_id=chunk.document_id,
link=curr_link,
source_type=chunk.source_type,
semantic_identifier=chunk.semantic_identifier,
blurb=chunk.blurb,
)
)
break
return DanswerQuotes(quotes=danswer_quotes)
def process_answer(
answer_raw: str,
chunks: list[InferenceChunk],
is_json_prompt: bool = True,
) -> tuple[DanswerAnswer, DanswerQuotes]:
"""Used (1) in the non-streaming case to process the model output
into an Answer and Quotes AND (2) after the complete streaming response
has been received to process the model output into an Answer and Quotes."""
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
if answer == UNCERTAINTY_PAT or not answer:
if answer == UNCERTAINTY_PAT:
logger.debug("Answer matched UNCERTAINTY_PAT")
else:
logger.debug("No answer extracted from raw output")
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
logger.info(f"Answer: {answer}")
if not quote_strings:
logger.debug("No quotes extracted from raw output")
return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[])
logger.info(f"All quotes (including unmatched): {quote_strings}")
quotes = match_quotes_to_docs(quote_strings, chunks)
logger.debug(f"Final quotes: {quotes}")
return DanswerAnswer(answer=answer), quotes
def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
next_token = next_token.replace('\\"', "")
# If the previous character is an escape token, don't consider the first character of next_token
# This does not work if it's an escaped escape sign before the " but this is rare, not worth handling
if answer_so_far and answer_so_far[-1] == "\\":
next_token = next_token[1:]
if '"' in next_token:
return True
return False
def _extract_quotes_from_completed_token_stream(
model_output: str, context_chunks: list[InferenceChunk], is_json_prompt: bool = True
) -> DanswerQuotes:
answer, quotes = process_answer(model_output, context_chunks, is_json_prompt)
if answer:
logger.info(answer)
elif model_output:
logger.warning("Answer extraction from model output failed.")
return quotes
def process_model_tokens(
tokens: Iterator[str],
context_docs: list[InferenceChunk],
is_json_prompt: bool = True,
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
"""Used in the streaming case to process the model output
into an Answer and Quotes
Yields Answer tokens back out in a dict for streaming to frontend
When Answer section ends, yields dict with answer_finished key
Collects all the tokens at the end to form the complete model output"""
quote_pat = f"\n{QUOTE_PAT}"
# Sometimes worse model outputs new line instead of :
quote_loose = f"\n{quote_pat[:-1]}\n"
# Sometime model outputs two newlines before quote section
quote_pat_full = f"\n{quote_pat}"
model_output: str = ""
found_answer_start = False if is_json_prompt else True
found_answer_end = False
hold_quote = ""
for token in tokens:
model_previous = model_output
model_output += token
if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output):
# Note, if the token that completes the pattern has additional text, for example if the token is "?
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
# event that the model outputs the UNCERTAINTY_PAT
found_answer_start = True
# Prevent heavy cases of hallucinations where model is not even providing a json until later
if is_json_prompt and len(model_output) > 40:
logger.warning("LLM did not produce json as prompted")
found_answer_end = True
continue
if found_answer_start and not found_answer_end:
if is_json_prompt and _stream_json_answer_end(model_previous, token):
found_answer_end = True
yield DanswerAnswerPiece(answer_piece=None)
continue
elif not is_json_prompt:
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
found_answer_end = True
yield DanswerAnswerPiece(answer_piece=None)
continue
if hold_quote + token in quote_pat_full:
hold_quote += token
continue
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
hold_quote = ""
logger.debug(f"Raw Model QnA Output: {model_output}")
yield _extract_quotes_from_completed_token_stream(
model_output=model_output,
context_chunks=context_docs,
is_json_prompt=is_json_prompt,
)
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
"""Mock streaming by generating the passed in model output, character by character"""
for token in model_out:

View File

@ -1,4 +1,5 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import cast
from sqlalchemy.orm import Session
@ -51,6 +52,11 @@ class SearchPipeline:
self._reranked_docs: list[InferenceChunk] | None = None
self._relevant_chunk_indicies: list[int] | None = None
# generator state
self._postprocessing_generator: Generator[
list[InferenceChunk] | list[str], None, None
] | None = None
"""Pre-processing"""
def _run_preprocessing(self) -> None:
@ -113,36 +119,38 @@ class SearchPipeline:
"""Post-Processing"""
def _run_postprocessing(self) -> None:
postprocessing_generator = search_postprocessing(
search_query=self.search_query,
retrieved_chunks=self.retrieved_docs,
rerank_metrics_callback=self.rerank_metrics_callback,
)
self._reranked_docs = cast(list[InferenceChunk], next(postprocessing_generator))
relevant_chunk_ids = cast(list[str], next(postprocessing_generator))
self._relevant_chunk_indicies = [
ind
for ind, chunk in enumerate(self._reranked_docs)
if chunk.unique_id in relevant_chunk_ids
]
@property
def reranked_docs(self) -> list[InferenceChunk]:
if self._reranked_docs is not None:
return self._reranked_docs
self._run_postprocessing()
return cast(list[InferenceChunk], self._reranked_docs)
self._postprocessing_generator = search_postprocessing(
search_query=self.search_query,
retrieved_chunks=self.retrieved_docs,
rerank_metrics_callback=self.rerank_metrics_callback,
)
self._reranked_docs = cast(
list[InferenceChunk], next(self._postprocessing_generator)
)
return self._reranked_docs
@property
def relevant_chunk_indicies(self) -> list[int]:
if self._relevant_chunk_indicies is not None:
return self._relevant_chunk_indicies
self._run_postprocessing()
return cast(list[int], self._relevant_chunk_indicies)
# run first step of postprocessing generator if not already done
reranked_docs = self.reranked_docs
relevant_chunk_ids = next(
cast(Generator[list[str], None, None], self._postprocessing_generator)
)
self._relevant_chunk_indicies = [
ind
for ind, chunk in enumerate(reranked_docs)
if chunk.unique_id in relevant_chunk_ids
]
return self._relevant_chunk_indicies
@property
def chunk_relevance_list(self) -> list[bool]:

View File

@ -223,11 +223,13 @@ def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc:
return LlmDoc(
document_id=first_chunk.document_id,
content="\n".join(chunk_texts),
blurb=first_chunk.blurb,
semantic_identifier=first_chunk.semantic_identifier,
source_type=first_chunk.source_type,
metadata=first_chunk.metadata,
updated_at=first_chunk.updated_at,
link=first_chunk.source_links[0] if first_chunk.source_links else None,
source_links=first_chunk.source_links,
)

View File

@ -14,8 +14,8 @@ from danswer.db.chat import update_persona_visibility
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.persona import create_update_persona
from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.llm.utils import get_default_llm_version
from danswer.one_shot_answer.qa_block import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse

View File

@ -6,7 +6,6 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.chat_utils import compute_max_document_tokens
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.process_message import stream_chat_message
from danswer.db.chat import create_chat_session
@ -25,6 +24,7 @@ from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.models import User
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
)

View File

@ -77,7 +77,6 @@ def get_answer_for_question(
str | None,
RetrievalMetricsContainer | None,
RerankMetricsContainer | None,
LLMMetricsContainer | None,
]:
filters = IndexFilters(
source_type=None,
@ -103,7 +102,6 @@ def get_answer_for_question(
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
rerank_metrics = MetricsHander[RerankMetricsContainer]()
llm_metrics = MetricsHander[LLMMetricsContainer]()
answer = get_search_answer(
query_req=new_message_request,
@ -116,14 +114,12 @@ def get_answer_for_question(
bypass_acl=True,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,
llm_metrics_callback=llm_metrics.record_metric,
)
return (
answer.answer,
retrieval_metrics.metrics,
rerank_metrics.metrics,
llm_metrics.metrics,
)
@ -221,7 +217,6 @@ if __name__ == "__main__":
answer,
retrieval_metrics,
rerank_metrics,
llm_metrics,
) = get_answer_for_question(sample["question"], db_session)
end_time = datetime.now()
@ -237,12 +232,6 @@ if __name__ == "__main__":
else "\tFailed, either crashed or refused to answer."
)
if not args.discard_metrics:
print("\nLLM Tokens Usage:")
if llm_metrics is None:
print("No LLM Metrics Available")
else:
_print_llm_metrics(llm_metrics)
print("\nRetrieval Metrics:")
if retrieval_metrics is None:
print("No Retrieval Metrics Available")

View File

@ -7,9 +7,9 @@ from typing import TextIO
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import get_chunks_for_qa
from danswer.db.engine import get_sqlalchemy_engine
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.doc_pruning import reorder_docs
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchRequest
@ -95,16 +95,8 @@ def get_search_results(
top_chunks = search_pipeline.reranked_docs
llm_chunk_selection = search_pipeline.chunk_relevance_list
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
token_limit=None,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
return (
llm_chunks,
reorder_docs(top_chunks, llm_chunk_selection),
retrieval_metrics.metrics,
rerank_metrics.metrics,
)

View File

@ -3,8 +3,12 @@ import unittest
from danswer.configs.constants import DocumentSource
from danswer.indexing.models import InferenceChunk
from danswer.one_shot_answer.qa_utils import match_quotes_to_docs
from danswer.one_shot_answer.qa_utils import separate_answer_quotes
from danswer.llm.answering.stream_processing.quotes_processing import (
match_quotes_to_docs,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
separate_answer_quotes,
)
class TestQAPostprocessing(unittest.TestCase):