mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-26 16:01:09 +02:00
Rework LLM answering flow
This commit is contained in:
parent
1ba74ee4df
commit
f135ba9c0c
@ -1,97 +1,29 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from functools import lru_cache
|
||||
from typing import cast
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy.orm import Session
|
||||
from tiktoken.core import Encoding
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_default_prompt
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
|
||||
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
|
||||
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
from danswer.prompts.prompt_utils import build_task_prompt_reminders
|
||||
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
||||
from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import (
|
||||
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
|
||||
)
|
||||
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def build_chat_system_message(
|
||||
prompt: Prompt,
|
||||
context_exists: bool,
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
citation_line: str = REQUIRE_CITATION_STATEMENT,
|
||||
no_citation_line: str = NO_CITATION_STATEMENT,
|
||||
) -> tuple[SystemMessage | None, int]:
|
||||
system_prompt = prompt.system_prompt.strip()
|
||||
if prompt.include_citations:
|
||||
if context_exists:
|
||||
system_prompt += citation_line
|
||||
else:
|
||||
system_prompt += no_citation_line
|
||||
if prompt.datetime_aware:
|
||||
if system_prompt:
|
||||
system_prompt += ADDITIONAL_INFO.format(
|
||||
datetime_info=get_current_llm_day_time()
|
||||
)
|
||||
else:
|
||||
system_prompt = get_current_llm_day_time()
|
||||
|
||||
if not system_prompt:
|
||||
return None, 0
|
||||
|
||||
token_count = len(llm_tokenizer_encode_func(system_prompt))
|
||||
system_msg = SystemMessage(content=system_prompt)
|
||||
|
||||
return system_msg, token_count
|
||||
|
||||
|
||||
def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inf_chunk.document_id,
|
||||
content=inf_chunk.content,
|
||||
blurb=inf_chunk.blurb,
|
||||
semantic_identifier=inf_chunk.semantic_identifier,
|
||||
source_type=inf_chunk.source_type,
|
||||
metadata=inf_chunk.metadata,
|
||||
updated_at=inf_chunk.updated_at,
|
||||
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
|
||||
source_links=inf_chunk.source_links,
|
||||
)
|
||||
|
||||
|
||||
@ -108,170 +40,6 @@ def map_document_id_order(
|
||||
return order_mapping
|
||||
|
||||
|
||||
def build_chat_user_message(
|
||||
chat_message: ChatMessage,
|
||||
prompt: Prompt,
|
||||
context_docs: list[LlmDoc],
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
all_doc_useful: bool,
|
||||
user_prompt_template: str = CHAT_USER_PROMPT,
|
||||
context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT,
|
||||
ignore_str: str = DEFAULT_IGNORE_STATEMENT,
|
||||
) -> tuple[HumanMessage, int]:
|
||||
user_query = chat_message.message
|
||||
|
||||
if not context_docs:
|
||||
# Simpler prompt for cases where there is no context
|
||||
user_prompt = (
|
||||
context_free_template.format(
|
||||
task_prompt=prompt.task_prompt, user_query=user_query
|
||||
)
|
||||
if prompt.task_prompt
|
||||
else user_query
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
token_count = len(llm_tokenizer_encode_func(user_prompt))
|
||||
user_msg = HumanMessage(content=user_prompt)
|
||||
return user_msg, token_count
|
||||
|
||||
context_docs_str = build_complete_context_str(
|
||||
cast(list[LlmDoc | InferenceChunk], context_docs)
|
||||
)
|
||||
optional_ignore = "" if all_doc_useful else ignore_str
|
||||
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
|
||||
|
||||
user_prompt = user_prompt_template.format(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=user_query,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
token_count = len(llm_tokenizer_encode_func(user_prompt))
|
||||
user_msg = HumanMessage(content=user_prompt)
|
||||
|
||||
return user_msg, token_count
|
||||
|
||||
|
||||
def _get_usable_chunks(
|
||||
chunks: list[InferenceChunk], token_limit: int
|
||||
) -> list[InferenceChunk]:
|
||||
total_token_count = 0
|
||||
usable_chunks = []
|
||||
for chunk in chunks:
|
||||
chunk_token_count = check_number_of_tokens(chunk.content)
|
||||
if total_token_count + chunk_token_count > token_limit:
|
||||
break
|
||||
|
||||
total_token_count += chunk_token_count
|
||||
usable_chunks.append(chunk)
|
||||
|
||||
# try and return at least one chunk if possible. This chunk will
|
||||
# get truncated later on in the pipeline. This would only occur if
|
||||
# the first chunk is larger than the token limit (usually due to character
|
||||
# count -> token count mismatches caused by special characters / non-ascii
|
||||
# languages)
|
||||
if not usable_chunks and chunks:
|
||||
usable_chunks = [chunks[0]]
|
||||
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def get_usable_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
token_limit: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
offset_into_chunks = 0
|
||||
usable_chunks: list[InferenceChunk] = []
|
||||
for _ in range(min(offset + 1, 1)): # go through this process at least once
|
||||
if offset_into_chunks >= len(chunks) and offset_into_chunks > 0:
|
||||
raise ValueError(
|
||||
"Chunks offset too large, should not retry this many times"
|
||||
)
|
||||
|
||||
usable_chunks = _get_usable_chunks(
|
||||
chunks=chunks[offset_into_chunks:], token_limit=token_limit
|
||||
)
|
||||
offset_into_chunks += len(usable_chunks)
|
||||
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def get_chunks_for_qa(
|
||||
chunks: list[InferenceChunk],
|
||||
llm_chunk_selection: list[bool],
|
||||
token_limit: int | None,
|
||||
llm_tokenizer: Encoding | None = None,
|
||||
batch_offset: int = 0,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Gives back indices of chunks to pass into the LLM for Q&A.
|
||||
|
||||
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
|
||||
by the LLM in a separate flow (this can be turned off)
|
||||
|
||||
Note, the batch_offset calculation has to count the batches from the beginning each time as
|
||||
there's no way to know which chunks were included in the prior batches without recounting atm,
|
||||
this is somewhat slow as it requires tokenizing all the chunks again
|
||||
"""
|
||||
token_leeway = 50
|
||||
batch_index = 0
|
||||
latest_batch_indices: list[int] = []
|
||||
token_count = 0
|
||||
|
||||
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
|
||||
for selection_target in [True, False]:
|
||||
for ind, chunk in enumerate(chunks):
|
||||
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
|
||||
IGNORE_FOR_QA
|
||||
):
|
||||
continue
|
||||
|
||||
# We calculate it live in case the user uses a different LLM + tokenizer
|
||||
chunk_token = check_number_of_tokens(chunk.content)
|
||||
if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway:
|
||||
logger.warning(
|
||||
"Found more tokens in chunk than expected, "
|
||||
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
|
||||
)
|
||||
chunk.content = tokenizer_trim_content(
|
||||
content=chunk.content,
|
||||
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
tokenizer=llm_tokenizer or get_default_llm_tokenizer(),
|
||||
)
|
||||
|
||||
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
|
||||
token_count += chunk_token + token_leeway
|
||||
|
||||
# Always use at least 1 chunk
|
||||
if (
|
||||
token_limit is None
|
||||
or token_count <= token_limit
|
||||
or not latest_batch_indices
|
||||
):
|
||||
latest_batch_indices.append(ind)
|
||||
current_chunk_unused = False
|
||||
else:
|
||||
current_chunk_unused = True
|
||||
|
||||
if token_limit is not None and token_count >= token_limit:
|
||||
if batch_index < batch_offset:
|
||||
batch_index += 1
|
||||
if current_chunk_unused:
|
||||
latest_batch_indices = [ind]
|
||||
token_count = chunk_token
|
||||
else:
|
||||
latest_batch_indices = []
|
||||
token_count = 0
|
||||
else:
|
||||
return latest_batch_indices
|
||||
|
||||
return latest_batch_indices
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
@ -341,157 +109,6 @@ def combine_message_chain(
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
_PER_MESSAGE_TOKEN_BUFFER = 7
|
||||
|
||||
|
||||
def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
|
||||
"""From the back, find the index of the last element to include
|
||||
before the list exceeds the maximum"""
|
||||
running_sum = 0
|
||||
|
||||
last_ind = 0
|
||||
for i in range(len(lst) - 1, -1, -1):
|
||||
running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER
|
||||
if running_sum > max_prompt_tokens:
|
||||
last_ind = i + 1
|
||||
break
|
||||
if last_ind >= len(lst):
|
||||
raise ValueError("Last message alone is too large!")
|
||||
return last_ind
|
||||
|
||||
|
||||
def drop_messages_history_overflow(
|
||||
system_msg: BaseMessage | None,
|
||||
system_token_count: int,
|
||||
history_msgs: list[BaseMessage],
|
||||
history_token_counts: list[int],
|
||||
final_msg: BaseMessage,
|
||||
final_msg_token_count: int,
|
||||
max_allowed_tokens: int,
|
||||
) -> list[BaseMessage]:
|
||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||
prompt template must be included"""
|
||||
if len(history_msgs) != len(history_token_counts):
|
||||
# This should never happen
|
||||
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
||||
|
||||
prompt: list[BaseMessage] = []
|
||||
|
||||
# Start dropping from the history if necessary
|
||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||
ind_prev_msg_start = find_last_index(
|
||||
all_tokens, max_prompt_tokens=max_allowed_tokens
|
||||
)
|
||||
|
||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||
prompt.append(system_msg)
|
||||
|
||||
prompt.extend(history_msgs[ind_prev_msg_start:])
|
||||
|
||||
prompt.append(final_msg)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def in_code_block(llm_text: str) -> bool:
|
||||
count = llm_text.count(TRIPLE_BACKTICK)
|
||||
return count % 2 != 0
|
||||
|
||||
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||
llm_out = ""
|
||||
max_citation_num = len(context_docs)
|
||||
curr_segment = ""
|
||||
prepend_bracket = False
|
||||
cited_inds = set()
|
||||
hold = ""
|
||||
for raw_token in tokens:
|
||||
if stop_stream:
|
||||
next_hold = hold + raw_token
|
||||
|
||||
if stop_stream in next_hold:
|
||||
break
|
||||
|
||||
if next_hold == stop_stream[: len(next_hold)]:
|
||||
hold = next_hold
|
||||
continue
|
||||
|
||||
token = next_hold
|
||||
hold = ""
|
||||
else:
|
||||
token = raw_token
|
||||
|
||||
# Special case of [1][ where ][ is a single token
|
||||
# This is where the model attempts to do consecutive citations like [1][2]
|
||||
if prepend_bracket:
|
||||
curr_segment += "[" + curr_segment
|
||||
prepend_bracket = False
|
||||
|
||||
curr_segment += token
|
||||
llm_out += token
|
||||
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
|
||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
citation_found = re.search(citation_pattern, curr_segment)
|
||||
|
||||
if citation_found and not in_code_block(llm_out):
|
||||
numerical_value = int(citation_found.group(1))
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[
|
||||
numerical_value - 1
|
||||
] # remove 1 index offset
|
||||
|
||||
link = context_llm_doc.link
|
||||
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
|
||||
|
||||
# Use the citation number for the document's rank in
|
||||
# the search (or selected docs) results
|
||||
curr_segment = re.sub(
|
||||
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
|
||||
)
|
||||
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
if link:
|
||||
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||
|
||||
# In case there's another open bracket like [1][, don't want to match this
|
||||
possible_citation_found = None
|
||||
|
||||
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||
# citation that needs to be replaced with a link
|
||||
if possible_citation_found:
|
||||
continue
|
||||
|
||||
# Special case with back to back citations [1][2]
|
||||
if curr_segment and curr_segment[-1] == "[":
|
||||
curr_segment = curr_segment[:-1]
|
||||
prepend_bracket = True
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
curr_segment = ""
|
||||
|
||||
if curr_segment:
|
||||
if prepend_bracket:
|
||||
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
||||
else:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
|
||||
|
||||
def reorganize_citations(
|
||||
answer: str, citations: list[CitationInfo]
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
@ -547,72 +164,3 @@ def reorganize_citations(
|
||||
new_citation_info[citation.citation_num] = citation
|
||||
|
||||
return new_answer, list(new_citation_info.values())
|
||||
|
||||
|
||||
def get_prompt_tokens(prompt: Prompt) -> int:
|
||||
# Note: currently custom prompts do not allow datetime aware, only default prompts
|
||||
return (
|
||||
check_number_of_tokens(prompt.system_prompt)
|
||||
+ check_number_of_tokens(prompt.task_prompt)
|
||||
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
|
||||
+ CITATION_STATEMENT_TOKEN_CNT
|
||||
+ CITATION_REMINDER_TOKEN_CNT
|
||||
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
|
||||
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0)
|
||||
)
|
||||
|
||||
|
||||
# buffer just to be safe so that we don't overflow the token limit due to
|
||||
# a small miscalculation
|
||||
_MISC_BUFFER = 40
|
||||
|
||||
|
||||
def compute_max_document_tokens(
|
||||
persona: Persona,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
"""Estimates the number of tokens available for context documents. Formula is roughly:
|
||||
|
||||
(
|
||||
model_context_window - reserved_output_tokens - prompt_tokens
|
||||
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
|
||||
)
|
||||
|
||||
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
|
||||
if we're trying to determine if the user should be able to select another document) then we just set an
|
||||
arbitrary "upper bound".
|
||||
"""
|
||||
llm_name = get_default_llm_version()[0]
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
# if we can't find a number of tokens, just assume some common default
|
||||
max_input_tokens = (
|
||||
max_llm_token_override
|
||||
if max_llm_token_override
|
||||
else get_max_input_tokens(model_name=llm_name)
|
||||
)
|
||||
if persona.prompts:
|
||||
# TODO this may not always be the first prompt
|
||||
prompt_tokens = get_prompt_tokens(persona.prompts[0])
|
||||
else:
|
||||
prompt_tokens = get_prompt_tokens(get_default_prompt())
|
||||
|
||||
user_input_tokens = (
|
||||
check_number_of_tokens(actual_user_input)
|
||||
if actual_user_input is not None
|
||||
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
)
|
||||
|
||||
return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
|
||||
|
||||
|
||||
def compute_max_llm_input_tokens(persona: Persona) -> int:
|
||||
"""Maximum tokens allows in the input to the LLM (of any type)."""
|
||||
llm_name = get_default_llm_version()[0]
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
input_tokens = get_max_input_tokens(model_name=llm_name)
|
||||
return input_tokens - _MISC_BUFFER
|
||||
|
@ -16,11 +16,13 @@ class LlmDoc(BaseModel):
|
||||
|
||||
document_id: str
|
||||
content: str
|
||||
blurb: str
|
||||
semantic_identifier: str
|
||||
source_type: DocumentSource
|
||||
metadata: dict[str, str | list[str]]
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
@ -100,9 +102,12 @@ class QAResponse(SearchResponse, DanswerAnswer):
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[
|
||||
DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError
|
||||
]
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece | DanswerQuotes | CitationInfo | DanswerContexts | StreamingError
|
||||
)
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
|
||||
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
|
@ -5,16 +5,8 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import build_chat_system_message
|
||||
from danswer.chat.chat_utils import build_chat_user_message
|
||||
from danswer.chat.chat_utils import compute_max_document_tokens
|
||||
from danswer.chat.chat_utils import compute_max_llm_input_tokens
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import drop_messages_history_overflow
|
||||
from danswer.chat.chat_utils import extract_citations_from_stream
|
||||
from danswer.chat.chat_utils import get_chunks_for_qa
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
|
||||
from danswer.chat.chat_utils import map_document_id_order
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
@ -23,9 +15,7 @@ from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
@ -37,21 +27,17 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.pipeline import SearchPipeline
|
||||
@ -68,72 +54,6 @@ from danswer.utils.timing import log_generator_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_ai_chat_response(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
persona: Persona,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
llm: LLM | None,
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
all_doc_useful: bool,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
|
||||
if llm is None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
# Not an error if it's a user configuration
|
||||
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
|
||||
return
|
||||
|
||||
if query_message.prompt is None:
|
||||
raise RuntimeError("No prompt received for generating Gen AI answer.")
|
||||
|
||||
try:
|
||||
context_exists = len(context_docs) > 0
|
||||
|
||||
system_message_or_none, system_tokens = build_chat_system_message(
|
||||
prompt=query_message.prompt,
|
||||
context_exists=context_exists,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
)
|
||||
|
||||
history_basemessages, history_token_counts = translate_history_to_basemessages(
|
||||
history
|
||||
)
|
||||
|
||||
# Be sure the context_docs passed to build_chat_user_message
|
||||
# Is the same as passed in later for extracting citations
|
||||
user_message, user_tokens = build_chat_user_message(
|
||||
chat_message=query_message,
|
||||
prompt=query_message.prompt,
|
||||
context_docs=context_docs,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
all_doc_useful=all_doc_useful,
|
||||
)
|
||||
|
||||
prompt = drop_messages_history_overflow(
|
||||
system_msg=system_message_or_none,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=history_basemessages,
|
||||
history_token_counts=history_token_counts,
|
||||
final_msg=user_message,
|
||||
final_msg_token_count=user_tokens,
|
||||
max_allowed_tokens=compute_max_llm_input_tokens(persona),
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
tokens = llm.stream(prompt)
|
||||
|
||||
yield from extract_citations_from_stream(
|
||||
tokens, context_docs, doc_id_to_rank_map
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
|
||||
def translate_citations(
|
||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||
) -> dict[int, int]:
|
||||
@ -154,24 +74,26 @@ def translate_citations(
|
||||
return citation_to_saved_doc_id_map
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
# For flow with search, don't include as many chunks as possible since we need to leave space
|
||||
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
) -> Iterator[
|
||||
ChatPacketStream = Iterator[
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| ChatMessageDetail
|
||||
| DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
]:
|
||||
]
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
# For flow with search, don't include as many chunks as possible since we need to leave space
|
||||
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
@ -277,10 +199,6 @@ def stream_chat_message_objects(
|
||||
query_message=final_msg, history=history_msgs, llm=llm
|
||||
)
|
||||
|
||||
max_document_tokens = compute_max_document_tokens(
|
||||
persona=persona, actual_user_input=message_text
|
||||
)
|
||||
|
||||
rephrased_query = None
|
||||
if reference_doc_ids:
|
||||
identifier_tuples = get_doc_query_identifiers_from_model(
|
||||
@ -296,64 +214,8 @@ def stream_chat_message_objects(
|
||||
doc_identifiers=identifier_tuples,
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
# truncate the last document if it exceeds the token limit
|
||||
tokens_per_doc = [
|
||||
len(
|
||||
llm_tokenizer_encode_func(
|
||||
build_doc_context_str(
|
||||
semantic_identifier=llm_doc.semantic_identifier,
|
||||
source_type=llm_doc.source_type,
|
||||
content=llm_doc.content,
|
||||
metadata_dict=llm_doc.metadata,
|
||||
updated_at=llm_doc.updated_at,
|
||||
ind=ind,
|
||||
)
|
||||
)
|
||||
)
|
||||
for ind, llm_doc in enumerate(llm_docs)
|
||||
]
|
||||
final_doc_ind = None
|
||||
total_tokens = 0
|
||||
for ind, tokens in enumerate(tokens_per_doc):
|
||||
total_tokens += tokens
|
||||
if total_tokens > max_document_tokens:
|
||||
final_doc_ind = ind
|
||||
break
|
||||
if final_doc_ind is not None:
|
||||
# only allow the final document to get truncated
|
||||
# if more than that, then the user message is too long
|
||||
if final_doc_ind != len(tokens_per_doc) - 1:
|
||||
yield StreamingError(
|
||||
error="LLM context window exceeded. Please de-select some documents or shorten your query."
|
||||
)
|
||||
return
|
||||
|
||||
final_doc_desired_length = tokens_per_doc[final_doc_ind] - (
|
||||
total_tokens - max_document_tokens
|
||||
)
|
||||
# 75 tokens is a reasonable over-estimate of the metadata and title
|
||||
final_doc_content_length = final_doc_desired_length - 75
|
||||
# this could occur if we only have space for the title / metadata
|
||||
# not ideal, but it's the most reasonable thing to do
|
||||
# NOTE: the frontend prevents documents from being selected if
|
||||
# less than 75 tokens are available to try and avoid this situation
|
||||
# from occuring in the first place
|
||||
if final_doc_content_length <= 0:
|
||||
logger.error(
|
||||
f"Final doc ({llm_docs[final_doc_ind].semantic_identifier}) content "
|
||||
"length is less than 0. Removing this doc from the final prompt."
|
||||
)
|
||||
llm_docs.pop()
|
||||
else:
|
||||
llm_docs[final_doc_ind].content = tokenizer_trim_content(
|
||||
content=llm_docs[final_doc_ind].content,
|
||||
desired_length=final_doc_content_length,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
doc_id_to_rank_map = map_document_id_order(
|
||||
cast(list[InferenceChunk | LlmDoc], llm_docs)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
is_manually_selected_docs=True
|
||||
)
|
||||
|
||||
# In case the search doc is deleted, just don't include it
|
||||
@ -393,9 +255,6 @@ def stream_chat_message_objects(
|
||||
top_chunks = search_pipeline.reranked_docs
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
|
||||
# Get ranking of the documents for citation purposes later
|
||||
doc_id_to_rank_map = map_document_id_order(top_chunks)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
|
||||
for top_doc in top_docs
|
||||
@ -423,41 +282,21 @@ def stream_chat_message_objects(
|
||||
)
|
||||
yield llm_relevance_filtering_response
|
||||
|
||||
# Prep chunks to pass to LLM
|
||||
num_llm_chunks = (
|
||||
persona.num_chunks
|
||||
if persona.num_chunks is not None
|
||||
else default_num_chunks
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
persona.num_chunks
|
||||
if persona.num_chunks is not None
|
||||
else default_num_chunks
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
)
|
||||
|
||||
llm_name = get_default_llm_version()[0]
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
llm_max_input_tokens = get_max_input_tokens(model_name=llm_name)
|
||||
|
||||
llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens
|
||||
|
||||
chunk_token_limit = int(
|
||||
min(
|
||||
num_llm_chunks * default_chunk_size,
|
||||
max_document_tokens,
|
||||
llm_token_based_chunk_lim,
|
||||
)
|
||||
)
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=search_pipeline.chunk_relevance_list,
|
||||
token_limit=chunk_token_limit,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]
|
||||
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in top_chunks]
|
||||
|
||||
else:
|
||||
llm_docs = []
|
||||
doc_id_to_rank_map = {}
|
||||
reference_db_search_docs = None
|
||||
document_pruning_config = DocumentPruningConfig()
|
||||
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
partial_response = partial(
|
||||
@ -495,33 +334,24 @@ def stream_chat_message_objects(
|
||||
return
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
response_packets = generate_ai_chat_response(
|
||||
query_message=final_msg,
|
||||
history=history_msgs,
|
||||
answer = Answer(
|
||||
question=final_msg.message,
|
||||
docs=llm_docs,
|
||||
answer_style_config=AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=reference_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
),
|
||||
prompt=final_msg.prompt,
|
||||
persona=persona,
|
||||
context_docs=llm_docs,
|
||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||
llm=llm,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
all_doc_useful=reference_doc_ids is not None,
|
||||
message_history=[
|
||||
PreviousMessage.from_chat_message(msg) for msg in history_msgs
|
||||
],
|
||||
)
|
||||
# generator will not include quotes, so we can cast
|
||||
yield from cast(ChatPacketStream, answer.processed_streamed_output)
|
||||
|
||||
# Capture outputs and errors
|
||||
llm_output = ""
|
||||
error: str | None = None
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in response_packets:
|
||||
if isinstance(packet, DanswerAnswerPiece):
|
||||
token = packet.answer_piece
|
||||
if token:
|
||||
llm_output += token
|
||||
elif isinstance(packet, StreamingError):
|
||||
error = packet.error
|
||||
elif isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
continue
|
||||
|
||||
yield packet
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
@ -535,16 +365,16 @@ def stream_chat_message_objects(
|
||||
db_citations = None
|
||||
if reference_db_search_docs:
|
||||
db_citations = translate_citations(
|
||||
citations_list=citations,
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
gen_ai_response_message = partial_response(
|
||||
message=llm_output,
|
||||
token_count=len(llm_tokenizer_encode_func(llm_output)),
|
||||
message=answer.llm_answer,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
error=error,
|
||||
error=None,
|
||||
)
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
|
@ -12,7 +12,6 @@ from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import compute_max_document_tokens
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
@ -39,6 +38,7 @@ from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
|
176
backend/danswer/llm/answering/answer.py
Normal file
176
backend/danswer/llm/answering/answer.py
Normal file
@ -0,0 +1,176 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.llm.answering.doc_pruning import prune_documents
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
from danswer.llm.answering.prompts.citations_prompt import build_citations_prompt
|
||||
from danswer.llm.answering.prompts.quotes_prompt import (
|
||||
build_quotes_prompt,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
build_citation_processor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
build_quotes_processor,
|
||||
)
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
|
||||
|
||||
def _get_stream_processor(
|
||||
docs: list[LlmDoc], answer_style_configs: AnswerStyleConfig
|
||||
) -> StreamProcessor:
|
||||
if answer_style_configs.citation_config:
|
||||
return build_citation_processor(
|
||||
context_docs=docs,
|
||||
)
|
||||
if answer_style_configs.quotes_config:
|
||||
return build_quotes_processor(
|
||||
context_docs=docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
|
||||
)
|
||||
|
||||
raise RuntimeError("Not implemented yet")
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
question: str,
|
||||
docs: list[LlmDoc],
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt: Prompt,
|
||||
persona: Persona,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
doc_relevance_list: list[bool] | None = None,
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
"Cannot provide both `message_history` and `single_message_history`"
|
||||
)
|
||||
|
||||
self.question = question
|
||||
self.docs = docs
|
||||
self.doc_relevance_list = doc_relevance_list
|
||||
self.message_history = message_history or []
|
||||
# used for QA flow where we only want to send a single message
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.answer_style_config = answer_style_config
|
||||
|
||||
self.llm = get_default_llm(
|
||||
gen_ai_model_version_override=persona.llm_model_version_override,
|
||||
timeout=timeout,
|
||||
)
|
||||
self.llm_tokenizer = get_default_llm_tokenizer()
|
||||
|
||||
self.prompt = prompt
|
||||
self.persona = persona
|
||||
|
||||
self.process_stream_fn = _get_stream_processor(docs, answer_style_config)
|
||||
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
|
||||
self._pruned_docs: list[LlmDoc] | None = None
|
||||
|
||||
self._streamed_output: list[str] | None = None
|
||||
self._processed_stream: list[AnswerQuestionPossibleReturn] | None = None
|
||||
|
||||
@property
|
||||
def pruned_docs(self) -> list[LlmDoc]:
|
||||
if self._pruned_docs is not None:
|
||||
return self._pruned_docs
|
||||
|
||||
self._pruned_docs = prune_documents(
|
||||
docs=self.docs,
|
||||
doc_relevance_list=self.doc_relevance_list,
|
||||
persona=self.persona,
|
||||
question=self.question,
|
||||
document_pruning_config=self.answer_style_config.document_pruning_config,
|
||||
)
|
||||
return self._pruned_docs
|
||||
|
||||
@property
|
||||
def final_prompt(self) -> list[BaseMessage]:
|
||||
if self._final_prompt is not None:
|
||||
return self._final_prompt
|
||||
|
||||
if self.answer_style_config.citation_config:
|
||||
self._final_prompt = build_citations_prompt(
|
||||
question=self.question,
|
||||
message_history=self.message_history,
|
||||
persona=self.persona,
|
||||
prompt=self.prompt,
|
||||
context_docs=self.pruned_docs,
|
||||
all_doc_useful=self.answer_style_config.citation_config.all_docs_useful,
|
||||
llm_tokenizer_encode_func=self.llm_tokenizer.encode,
|
||||
history_message=self.single_message_history or "",
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
self._final_prompt = build_quotes_prompt(
|
||||
question=self.question,
|
||||
context_docs=self.pruned_docs,
|
||||
history_str=self.single_message_history or "",
|
||||
prompt=self.prompt,
|
||||
)
|
||||
|
||||
return cast(list[BaseMessage], self._final_prompt)
|
||||
|
||||
@property
|
||||
def raw_streamed_output(self) -> Iterator[str]:
|
||||
if self._streamed_output is not None:
|
||||
yield from self._streamed_output
|
||||
return
|
||||
|
||||
streamed_output = []
|
||||
for message in self.llm.stream(self.final_prompt):
|
||||
streamed_output.append(message)
|
||||
yield message
|
||||
|
||||
self._streamed_output = streamed_output
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerQuestionStreamReturn:
|
||||
if self._processed_stream is not None:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self.process_stream_fn(self.raw_streamed_output):
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
|
||||
return citations
|
205
backend/danswer/llm/answering/doc_pruning.py
Normal file
205
backend/danswer/llm/answering/doc_pruning.py
Normal file
@ -0,0 +1,205 @@
|
||||
from copy import deepcopy
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.chat.models import (
|
||||
LlmDoc,
|
||||
)
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import Persona
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
T = TypeVar("T", bound=LlmDoc | InferenceChunk)
|
||||
|
||||
_METADATA_TOKEN_ESTIMATE = 75
|
||||
|
||||
|
||||
class PruningError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _compute_limit(
|
||||
persona: Persona,
|
||||
question: str,
|
||||
max_chunks: int | None,
|
||||
max_window_percentage: float | None,
|
||||
max_tokens: int | None,
|
||||
) -> int:
|
||||
llm_max_document_tokens = compute_max_document_tokens(
|
||||
persona=persona, actual_user_input=question
|
||||
)
|
||||
|
||||
window_percentage_based_limit = (
|
||||
max_window_percentage * llm_max_document_tokens
|
||||
if max_window_percentage
|
||||
else None
|
||||
)
|
||||
chunk_count_based_limit = (
|
||||
max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None
|
||||
)
|
||||
|
||||
limit_options = [
|
||||
lim
|
||||
for lim in [
|
||||
window_percentage_based_limit,
|
||||
chunk_count_based_limit,
|
||||
max_tokens,
|
||||
llm_max_document_tokens,
|
||||
]
|
||||
if lim
|
||||
]
|
||||
return int(min(limit_options))
|
||||
|
||||
|
||||
def reorder_docs(
|
||||
docs: list[T],
|
||||
doc_relevance_list: list[bool] | None,
|
||||
) -> list[T]:
|
||||
if doc_relevance_list is None:
|
||||
return docs
|
||||
|
||||
reordered_docs: list[T] = []
|
||||
if doc_relevance_list is not None:
|
||||
for selection_target in [True, False]:
|
||||
for doc, is_relevant in zip(docs, doc_relevance_list):
|
||||
if is_relevant == selection_target:
|
||||
reordered_docs.append(doc)
|
||||
return reordered_docs
|
||||
|
||||
|
||||
def _remove_docs_to_ignore(docs: list[LlmDoc]) -> list[LlmDoc]:
|
||||
return [doc for doc in docs if not doc.metadata.get(IGNORE_FOR_QA)]
|
||||
|
||||
|
||||
def _apply_pruning(
|
||||
docs: list[LlmDoc],
|
||||
doc_relevance_list: list[bool] | None,
|
||||
token_limit: int,
|
||||
is_manually_selected_docs: bool,
|
||||
) -> list[LlmDoc]:
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
docs = deepcopy(docs) # don't modify in place
|
||||
|
||||
# re-order docs with all the "relevant" docs at the front
|
||||
docs = reorder_docs(docs=docs, doc_relevance_list=doc_relevance_list)
|
||||
# remove docs that are explicitly marked as not for QA
|
||||
docs = _remove_docs_to_ignore(docs=docs)
|
||||
|
||||
tokens_per_doc: list[int] = []
|
||||
final_doc_ind = None
|
||||
total_tokens = 0
|
||||
for ind, llm_doc in enumerate(docs):
|
||||
doc_tokens = len(
|
||||
llm_tokenizer.encode(
|
||||
build_doc_context_str(
|
||||
semantic_identifier=llm_doc.semantic_identifier,
|
||||
source_type=llm_doc.source_type,
|
||||
content=llm_doc.content,
|
||||
metadata_dict=llm_doc.metadata,
|
||||
updated_at=llm_doc.updated_at,
|
||||
ind=ind,
|
||||
)
|
||||
)
|
||||
)
|
||||
# if chunks, truncate chunks that are way too long
|
||||
# this can happen if the embedding model tokenizer is different
|
||||
# than the LLM tokenizer
|
||||
if (
|
||||
not is_manually_selected_docs
|
||||
and doc_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
|
||||
):
|
||||
logger.warning(
|
||||
"Found more tokens in chunk than expected, "
|
||||
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
|
||||
)
|
||||
llm_doc.content = tokenizer_trim_content(
|
||||
content=llm_doc.content,
|
||||
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
doc_tokens = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
tokens_per_doc.append(doc_tokens)
|
||||
total_tokens += doc_tokens
|
||||
if total_tokens > token_limit:
|
||||
final_doc_ind = ind
|
||||
break
|
||||
|
||||
if final_doc_ind is not None:
|
||||
if is_manually_selected_docs:
|
||||
# for document selection, only allow the final document to get truncated
|
||||
# if more than that, then the user message is too long
|
||||
if final_doc_ind != len(docs) - 1:
|
||||
raise PruningError(
|
||||
"LLM context window exceeded. Please de-select some documents or shorten your query."
|
||||
)
|
||||
|
||||
final_doc_desired_length = tokens_per_doc[final_doc_ind] - (
|
||||
total_tokens - token_limit
|
||||
)
|
||||
final_doc_content_length = (
|
||||
final_doc_desired_length - _METADATA_TOKEN_ESTIMATE
|
||||
)
|
||||
# this could occur if we only have space for the title / metadata
|
||||
# not ideal, but it's the most reasonable thing to do
|
||||
# NOTE: the frontend prevents documents from being selected if
|
||||
# less than 75 tokens are available to try and avoid this situation
|
||||
# from occuring in the first place
|
||||
if final_doc_content_length <= 0:
|
||||
logger.error(
|
||||
f"Final doc ({docs[final_doc_ind].semantic_identifier}) content "
|
||||
"length is less than 0. Removing this doc from the final prompt."
|
||||
)
|
||||
docs.pop()
|
||||
else:
|
||||
docs[final_doc_ind].content = tokenizer_trim_content(
|
||||
content=docs[final_doc_ind].content,
|
||||
desired_length=final_doc_content_length,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
else:
|
||||
# for regular search, don't truncate the final document unless it's the only one
|
||||
if final_doc_ind != 0:
|
||||
docs = docs[:final_doc_ind]
|
||||
else:
|
||||
docs[0].content = tokenizer_trim_content(
|
||||
content=docs[0].content,
|
||||
desired_length=token_limit - _METADATA_TOKEN_ESTIMATE,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
docs = [docs[0]]
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def prune_documents(
|
||||
docs: list[LlmDoc],
|
||||
doc_relevance_list: list[bool] | None,
|
||||
persona: Persona,
|
||||
question: str,
|
||||
document_pruning_config: DocumentPruningConfig,
|
||||
) -> list[LlmDoc]:
|
||||
if doc_relevance_list is not None:
|
||||
assert len(docs) == len(doc_relevance_list)
|
||||
|
||||
doc_token_limit = _compute_limit(
|
||||
persona=persona,
|
||||
question=question,
|
||||
max_chunks=document_pruning_config.max_chunks,
|
||||
max_window_percentage=document_pruning_config.max_window_percentage,
|
||||
max_tokens=document_pruning_config.max_tokens,
|
||||
)
|
||||
return _apply_pruning(
|
||||
docs=docs,
|
||||
doc_relevance_list=doc_relevance_list,
|
||||
token_limit=doc_token_limit,
|
||||
is_manually_selected_docs=document_pruning_config.is_manually_selected_docs,
|
||||
)
|
77
backend/danswer/llm/answering/models.py
Normal file
77
backend/danswer/llm/answering/models.py
Normal file
@ -0,0 +1,77 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import root_validator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.configs.constants import MessageType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
class PreviousMessage(BaseModel):
|
||||
"""Simplified version of `ChatMessage`"""
|
||||
|
||||
message: str
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(cls, chat_message: "ChatMessage") -> "PreviousMessage":
|
||||
return cls(
|
||||
message=chat_message.message,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
)
|
||||
|
||||
|
||||
class DocumentPruningConfig(BaseModel):
|
||||
max_chunks: int | None = None
|
||||
max_window_percentage: float | None = None
|
||||
max_tokens: int | None = None
|
||||
# different pruning behavior is expected when the
|
||||
# user manually selects documents they want to chat with
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
|
||||
@root_validator
|
||||
def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
citation_config = values.get("citation_config")
|
||||
quotes_config = values.get("quotes_config")
|
||||
|
||||
if citation_config is None and quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if citation_config is not None and quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return values
|
281
backend/danswer/llm/answering/prompts/citations_prompt.py
Normal file
281
backend/danswer/llm/answering/prompts/citations_prompt.py
Normal file
@ -0,0 +1,281 @@
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import cast
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.db.chat import get_default_prompt
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
|
||||
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from danswer.prompts.direct_qa_prompts import (
|
||||
CITATIONS_PROMPT,
|
||||
)
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
from danswer.prompts.prompt_utils import build_task_prompt_reminders
|
||||
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
||||
from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import (
|
||||
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
|
||||
)
|
||||
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||
|
||||
|
||||
_PER_MESSAGE_TOKEN_BUFFER = 7
|
||||
|
||||
|
||||
def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
|
||||
"""From the back, find the index of the last element to include
|
||||
before the list exceeds the maximum"""
|
||||
running_sum = 0
|
||||
|
||||
last_ind = 0
|
||||
for i in range(len(lst) - 1, -1, -1):
|
||||
running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER
|
||||
if running_sum > max_prompt_tokens:
|
||||
last_ind = i + 1
|
||||
break
|
||||
if last_ind >= len(lst):
|
||||
raise ValueError("Last message alone is too large!")
|
||||
return last_ind
|
||||
|
||||
|
||||
def drop_messages_history_overflow(
|
||||
system_msg: BaseMessage | None,
|
||||
system_token_count: int,
|
||||
history_msgs: list[BaseMessage],
|
||||
history_token_counts: list[int],
|
||||
final_msg: BaseMessage,
|
||||
final_msg_token_count: int,
|
||||
max_allowed_tokens: int,
|
||||
) -> list[BaseMessage]:
|
||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||
prompt template must be included"""
|
||||
if len(history_msgs) != len(history_token_counts):
|
||||
# This should never happen
|
||||
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
||||
|
||||
prompt: list[BaseMessage] = []
|
||||
|
||||
# Start dropping from the history if necessary
|
||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||
ind_prev_msg_start = find_last_index(
|
||||
all_tokens, max_prompt_tokens=max_allowed_tokens
|
||||
)
|
||||
|
||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||
prompt.append(system_msg)
|
||||
|
||||
prompt.extend(history_msgs[ind_prev_msg_start:])
|
||||
|
||||
prompt.append(final_msg)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_prompt_tokens(prompt: Prompt) -> int:
|
||||
# Note: currently custom prompts do not allow datetime aware, only default prompts
|
||||
return (
|
||||
check_number_of_tokens(prompt.system_prompt)
|
||||
+ check_number_of_tokens(prompt.task_prompt)
|
||||
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
|
||||
+ CITATION_STATEMENT_TOKEN_CNT
|
||||
+ CITATION_REMINDER_TOKEN_CNT
|
||||
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
|
||||
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0)
|
||||
)
|
||||
|
||||
|
||||
# buffer just to be safe so that we don't overflow the token limit due to
|
||||
# a small miscalculation
|
||||
_MISC_BUFFER = 40
|
||||
|
||||
|
||||
def compute_max_document_tokens(
|
||||
persona: Persona,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
"""Estimates the number of tokens available for context documents. Formula is roughly:
|
||||
|
||||
(
|
||||
model_context_window - reserved_output_tokens - prompt_tokens
|
||||
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
|
||||
)
|
||||
|
||||
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
|
||||
if we're trying to determine if the user should be able to select another document) then we just set an
|
||||
arbitrary "upper bound".
|
||||
"""
|
||||
llm_name = get_default_llm_version()[0]
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
# if we can't find a number of tokens, just assume some common default
|
||||
max_input_tokens = (
|
||||
max_llm_token_override
|
||||
if max_llm_token_override
|
||||
else get_max_input_tokens(model_name=llm_name)
|
||||
)
|
||||
if persona.prompts:
|
||||
# TODO this may not always be the first prompt
|
||||
prompt_tokens = get_prompt_tokens(persona.prompts[0])
|
||||
else:
|
||||
prompt_tokens = get_prompt_tokens(get_default_prompt())
|
||||
|
||||
user_input_tokens = (
|
||||
check_number_of_tokens(actual_user_input)
|
||||
if actual_user_input is not None
|
||||
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
)
|
||||
|
||||
return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
|
||||
|
||||
|
||||
def compute_max_llm_input_tokens(persona: Persona) -> int:
|
||||
"""Maximum tokens allows in the input to the LLM (of any type)."""
|
||||
llm_name = get_default_llm_version()[0]
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
input_tokens = get_max_input_tokens(model_name=llm_name)
|
||||
return input_tokens - _MISC_BUFFER
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def build_system_message(
|
||||
prompt: Prompt,
|
||||
context_exists: bool,
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
citation_line: str = REQUIRE_CITATION_STATEMENT,
|
||||
no_citation_line: str = NO_CITATION_STATEMENT,
|
||||
) -> tuple[SystemMessage | None, int]:
|
||||
system_prompt = prompt.system_prompt.strip()
|
||||
if prompt.include_citations:
|
||||
if context_exists:
|
||||
system_prompt += citation_line
|
||||
else:
|
||||
system_prompt += no_citation_line
|
||||
if prompt.datetime_aware:
|
||||
if system_prompt:
|
||||
system_prompt += ADDITIONAL_INFO.format(
|
||||
datetime_info=get_current_llm_day_time()
|
||||
)
|
||||
else:
|
||||
system_prompt = get_current_llm_day_time()
|
||||
|
||||
if not system_prompt:
|
||||
return None, 0
|
||||
|
||||
token_count = len(llm_tokenizer_encode_func(system_prompt))
|
||||
system_msg = SystemMessage(content=system_prompt)
|
||||
|
||||
return system_msg, token_count
|
||||
|
||||
|
||||
def build_user_message(
|
||||
question: str,
|
||||
prompt: Prompt,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
all_doc_useful: bool,
|
||||
history_message: str,
|
||||
) -> tuple[HumanMessage, int]:
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_tokenizer_encode_func = cast(Callable[[str], list[int]], llm_tokenizer.encode)
|
||||
|
||||
if not context_docs:
|
||||
# Simpler prompt for cases where there is no context
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
task_prompt=prompt.task_prompt, user_query=question
|
||||
)
|
||||
if prompt.task_prompt
|
||||
else question
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
token_count = len(llm_tokenizer_encode_func(user_prompt))
|
||||
user_msg = HumanMessage(content=user_prompt)
|
||||
return user_msg, token_count
|
||||
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
|
||||
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
|
||||
|
||||
user_prompt = CITATIONS_PROMPT.format(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=question,
|
||||
history_block=history_message,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
token_count = len(llm_tokenizer_encode_func(user_prompt))
|
||||
user_msg = HumanMessage(content=user_prompt)
|
||||
|
||||
return user_msg, token_count
|
||||
|
||||
|
||||
def build_citations_prompt(
|
||||
question: str,
|
||||
message_history: list[PreviousMessage],
|
||||
persona: Persona,
|
||||
prompt: Prompt,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
all_doc_useful: bool,
|
||||
history_message: str,
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
) -> list[BaseMessage]:
|
||||
context_exists = len(context_docs) > 0
|
||||
|
||||
system_message_or_none, system_tokens = build_system_message(
|
||||
prompt=prompt,
|
||||
context_exists=context_exists,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
)
|
||||
|
||||
history_basemessages, history_token_counts = translate_history_to_basemessages(
|
||||
message_history
|
||||
)
|
||||
|
||||
# Be sure the context_docs passed to build_chat_user_message
|
||||
# Is the same as passed in later for extracting citations
|
||||
user_message, user_tokens = build_user_message(
|
||||
question=question,
|
||||
prompt=prompt,
|
||||
context_docs=context_docs,
|
||||
all_doc_useful=all_doc_useful,
|
||||
history_message=history_message,
|
||||
)
|
||||
|
||||
final_prompt_msgs = drop_messages_history_overflow(
|
||||
system_msg=system_message_or_none,
|
||||
system_token_count=system_tokens,
|
||||
history_msgs=history_basemessages,
|
||||
history_token_counts=history_token_counts,
|
||||
final_msg=user_message,
|
||||
final_msg_token_count=user_tokens,
|
||||
max_allowed_tokens=compute_max_llm_input_tokens(persona),
|
||||
)
|
||||
|
||||
return final_prompt_msgs
|
88
backend/danswer/llm/answering/prompts/quotes_prompt.py
Normal file
88
backend/danswer/llm/answering/prompts/quotes_prompt.py
Normal file
@ -0,0 +1,88 @@
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
|
||||
|
||||
def _build_weak_llm_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: Prompt,
|
||||
use_language_hint: bool,
|
||||
) -> list[BaseMessage]:
|
||||
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
|
||||
as an option to use with weaker LLMs such as small version, low float precision, quantized,
|
||||
or distilled models. It only uses one context document and has very weak requirements of
|
||||
output format.
|
||||
"""
|
||||
context_block = ""
|
||||
if context_docs:
|
||||
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content)
|
||||
|
||||
prompt_str = WEAK_LLM_PROMPT.format(
|
||||
system_prompt=prompt.system_prompt,
|
||||
context_block=context_block,
|
||||
task_prompt=prompt.task_prompt,
|
||||
user_query=question,
|
||||
)
|
||||
return [HumanMessage(content=prompt_str)]
|
||||
|
||||
|
||||
def _build_strong_llm_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: Prompt,
|
||||
use_language_hint: bool,
|
||||
) -> list[BaseMessage]:
|
||||
context_block = ""
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
|
||||
|
||||
history_block = ""
|
||||
if history_str:
|
||||
history_block = HISTORY_BLOCK.format(history_str=history_str)
|
||||
|
||||
full_prompt = JSON_PROMPT.format(
|
||||
system_prompt=prompt.system_prompt,
|
||||
context_block=context_block,
|
||||
history_block=history_block,
|
||||
task_prompt=prompt.task_prompt,
|
||||
user_query=question,
|
||||
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
|
||||
).strip()
|
||||
return [HumanMessage(content=full_prompt)]
|
||||
|
||||
|
||||
def build_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: Prompt,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> list[BaseMessage]:
|
||||
prompt_builder = (
|
||||
_build_weak_llm_quotes_prompt
|
||||
if QA_PROMPT_OVERRIDE == "weak"
|
||||
else _build_strong_llm_quotes_prompt
|
||||
)
|
||||
|
||||
return prompt_builder(
|
||||
question=question,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
use_language_hint=use_language_hint,
|
||||
)
|
20
backend/danswer/llm/answering/prompts/utils.py
Normal file
20
backend/danswer/llm/answering/prompts/utils.py
Normal file
@ -0,0 +1,20 @@
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
@ -0,0 +1,126 @@
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def in_code_block(llm_text: str) -> bool:
|
||||
count = llm_text.count(TRIPLE_BACKTICK)
|
||||
return count % 2 != 0
|
||||
|
||||
|
||||
def extract_citations_from_stream(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||
llm_out = ""
|
||||
max_citation_num = len(context_docs)
|
||||
curr_segment = ""
|
||||
prepend_bracket = False
|
||||
cited_inds = set()
|
||||
hold = ""
|
||||
for raw_token in tokens:
|
||||
if stop_stream:
|
||||
next_hold = hold + raw_token
|
||||
|
||||
if stop_stream in next_hold:
|
||||
break
|
||||
|
||||
if next_hold == stop_stream[: len(next_hold)]:
|
||||
hold = next_hold
|
||||
continue
|
||||
|
||||
token = next_hold
|
||||
hold = ""
|
||||
else:
|
||||
token = raw_token
|
||||
|
||||
# Special case of [1][ where ][ is a single token
|
||||
# This is where the model attempts to do consecutive citations like [1][2]
|
||||
if prepend_bracket:
|
||||
curr_segment += "[" + curr_segment
|
||||
prepend_bracket = False
|
||||
|
||||
curr_segment += token
|
||||
llm_out += token
|
||||
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
|
||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
citation_found = re.search(citation_pattern, curr_segment)
|
||||
|
||||
if citation_found and not in_code_block(llm_out):
|
||||
numerical_value = int(citation_found.group(1))
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[
|
||||
numerical_value - 1
|
||||
] # remove 1 index offset
|
||||
|
||||
link = context_llm_doc.link
|
||||
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
|
||||
|
||||
# Use the citation number for the document's rank in
|
||||
# the search (or selected docs) results
|
||||
curr_segment = re.sub(
|
||||
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
|
||||
)
|
||||
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
if link:
|
||||
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||
|
||||
# In case there's another open bracket like [1][, don't want to match this
|
||||
possible_citation_found = None
|
||||
|
||||
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||
# citation that needs to be replaced with a link
|
||||
if possible_citation_found:
|
||||
continue
|
||||
|
||||
# Special case with back to back citations [1][2]
|
||||
if curr_segment and curr_segment[-1] == "[":
|
||||
curr_segment = curr_segment[:-1]
|
||||
prepend_bracket = True
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
curr_segment = ""
|
||||
|
||||
if curr_segment:
|
||||
if prepend_bracket:
|
||||
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
||||
else:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
|
||||
|
||||
def build_citation_processor(
|
||||
context_docs: list[LlmDoc],
|
||||
) -> StreamProcessor:
|
||||
def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn:
|
||||
yield from extract_citations_from_stream(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
doc_id_to_rank_map=map_document_id_order(context_docs),
|
||||
)
|
||||
|
||||
return stream_processor
|
@ -0,0 +1,282 @@
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.prompts.constants import ANSWER_PAT
|
||||
from danswer.prompts.constants import QUOTE_PAT
|
||||
from danswer.prompts.constants import UNCERTAINTY_PAT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_model_quote
|
||||
from danswer.utils.text_processing import clean_up_code_blocks
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
"""Splits the model output into an Answer and 0 or more Quote sections.
|
||||
Splits by the Quote pattern, if not exist then assume it's all answer and no quotes
|
||||
"""
|
||||
# If no answer section, don't care about the quote
|
||||
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
|
||||
return None, None
|
||||
|
||||
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
|
||||
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
|
||||
answer_raw = answer_raw[len(ANSWER_PAT) :]
|
||||
|
||||
# Accept quote sections starting with the lower case version
|
||||
answer_raw = answer_raw.replace(
|
||||
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
|
||||
) # Just in case model unreliable
|
||||
|
||||
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
|
||||
sections_clean = [
|
||||
str(section).strip() for section in sections if str(section).strip()
|
||||
]
|
||||
if not sections_clean:
|
||||
return None, None
|
||||
|
||||
answer = str(sections_clean[0])
|
||||
if len(sections) == 1:
|
||||
return answer, None
|
||||
return answer, sections_clean[1:]
|
||||
|
||||
|
||||
def _extract_answer_quotes_json(
|
||||
answer_dict: dict[str, str | list[str]]
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
|
||||
answer = str(answer_dict.get("answer"))
|
||||
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
|
||||
if isinstance(quotes, str):
|
||||
quotes = [quotes]
|
||||
return answer, quotes
|
||||
|
||||
|
||||
def _extract_answer_json(raw_model_output: str) -> dict:
|
||||
try:
|
||||
answer_json = extract_embedded_json(raw_model_output)
|
||||
except (ValueError, JSONDecodeError):
|
||||
# LLMs get confused when handling the list in the json. Sometimes it doesn't attend
|
||||
# enough to the previous { token so it just ends the list of quotes and stops there
|
||||
# here, we add logic to try to fix this LLM error.
|
||||
answer_json = extract_embedded_json(raw_model_output + "}")
|
||||
|
||||
if "answer" not in answer_json:
|
||||
raise ValueError("Model did not output an answer as expected.")
|
||||
|
||||
return answer_json
|
||||
|
||||
|
||||
def match_quotes_to_docs(
|
||||
quotes: list[str],
|
||||
docs: list[LlmDoc] | list[InferenceChunk],
|
||||
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
|
||||
fuzzy_search: bool = False,
|
||||
prefix_only_length: int = 100,
|
||||
) -> DanswerQuotes:
|
||||
danswer_quotes: list[DanswerQuote] = []
|
||||
for quote in quotes:
|
||||
max_edits = math.ceil(float(len(quote)) * max_error_percent)
|
||||
|
||||
for doc in docs:
|
||||
if not doc.source_links:
|
||||
continue
|
||||
|
||||
quote_clean = shared_precompare_cleanup(
|
||||
clean_model_quote(quote, trim_length=prefix_only_length)
|
||||
)
|
||||
chunk_clean = shared_precompare_cleanup(doc.content)
|
||||
|
||||
# Finding the offset of the quote in the plain text
|
||||
if fuzzy_search:
|
||||
re_search_str = (
|
||||
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
|
||||
)
|
||||
found = regex.search(re_search_str, chunk_clean)
|
||||
if not found:
|
||||
continue
|
||||
offset = found.span()[0]
|
||||
else:
|
||||
if quote_clean not in chunk_clean:
|
||||
continue
|
||||
offset = chunk_clean.index(quote_clean)
|
||||
|
||||
# Extracting the link from the offset
|
||||
curr_link = None
|
||||
for link_offset, link in doc.source_links.items():
|
||||
# Should always find one because offset is at least 0 and there
|
||||
# must be a 0 link_offset
|
||||
if int(link_offset) <= offset:
|
||||
curr_link = link
|
||||
else:
|
||||
break
|
||||
|
||||
danswer_quotes.append(
|
||||
DanswerQuote(
|
||||
quote=quote,
|
||||
document_id=doc.document_id,
|
||||
link=curr_link,
|
||||
source_type=doc.source_type,
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
blurb=doc.blurb,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
return DanswerQuotes(quotes=danswer_quotes)
|
||||
|
||||
|
||||
def separate_answer_quotes(
|
||||
answer_raw: str, is_json_prompt: bool = False
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
"""Takes in a raw model output and pulls out the answer and the quotes sections."""
|
||||
if is_json_prompt:
|
||||
model_raw_json = _extract_answer_json(answer_raw)
|
||||
return _extract_answer_quotes_json(model_raw_json)
|
||||
|
||||
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
|
||||
|
||||
|
||||
def process_answer(
|
||||
answer_raw: str,
|
||||
docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
||||
"""Used (1) in the non-streaming case to process the model output
|
||||
into an Answer and Quotes AND (2) after the complete streaming response
|
||||
has been received to process the model output into an Answer and Quotes."""
|
||||
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
|
||||
if answer == UNCERTAINTY_PAT or not answer:
|
||||
if answer == UNCERTAINTY_PAT:
|
||||
logger.debug("Answer matched UNCERTAINTY_PAT")
|
||||
else:
|
||||
logger.debug("No answer extracted from raw output")
|
||||
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
|
||||
|
||||
logger.info(f"Answer: {answer}")
|
||||
if not quote_strings:
|
||||
logger.debug("No quotes extracted from raw output")
|
||||
return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[])
|
||||
logger.info(f"All quotes (including unmatched): {quote_strings}")
|
||||
quotes = match_quotes_to_docs(quote_strings, docs)
|
||||
logger.debug(f"Final quotes: {quotes}")
|
||||
|
||||
return DanswerAnswer(answer=answer), quotes
|
||||
|
||||
|
||||
def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
next_token = next_token.replace('\\"', "")
|
||||
# If the previous character is an escape token, don't consider the first character of next_token
|
||||
# This does not work if it's an escaped escape sign before the " but this is rare, not worth handling
|
||||
if answer_so_far and answer_so_far[-1] == "\\":
|
||||
next_token = next_token[1:]
|
||||
if '"' in next_token:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _extract_quotes_from_completed_token_stream(
|
||||
model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True
|
||||
) -> DanswerQuotes:
|
||||
answer, quotes = process_answer(model_output, context_docs, is_json_prompt)
|
||||
if answer:
|
||||
logger.info(answer)
|
||||
elif model_output:
|
||||
logger.warning("Answer extraction from model output failed.")
|
||||
|
||||
return quotes
|
||||
|
||||
|
||||
def process_model_tokens(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||
"""Used in the streaming case to process the model output
|
||||
into an Answer and Quotes
|
||||
|
||||
Yields Answer tokens back out in a dict for streaming to frontend
|
||||
When Answer section ends, yields dict with answer_finished key
|
||||
Collects all the tokens at the end to form the complete model output"""
|
||||
quote_pat = f"\n{QUOTE_PAT}"
|
||||
# Sometimes worse model outputs new line instead of :
|
||||
quote_loose = f"\n{quote_pat[:-1]}\n"
|
||||
# Sometime model outputs two newlines before quote section
|
||||
quote_pat_full = f"\n{quote_pat}"
|
||||
model_output: str = ""
|
||||
found_answer_start = False if is_json_prompt else True
|
||||
found_answer_end = False
|
||||
hold_quote = ""
|
||||
for token in tokens:
|
||||
model_previous = model_output
|
||||
model_output += token
|
||||
|
||||
if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output):
|
||||
# Note, if the token that completes the pattern has additional text, for example if the token is "?
|
||||
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
|
||||
# event that the model outputs the UNCERTAINTY_PAT
|
||||
found_answer_start = True
|
||||
|
||||
# Prevent heavy cases of hallucinations where model is not even providing a json until later
|
||||
if is_json_prompt and len(model_output) > 40:
|
||||
logger.warning("LLM did not produce json as prompted")
|
||||
found_answer_end = True
|
||||
|
||||
continue
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if is_json_prompt and _stream_json_answer_end(model_previous, token):
|
||||
found_answer_end = True
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
continue
|
||||
elif not is_json_prompt:
|
||||
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
|
||||
found_answer_end = True
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
continue
|
||||
if hold_quote + token in quote_pat_full:
|
||||
hold_quote += token
|
||||
continue
|
||||
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
|
||||
hold_quote = ""
|
||||
|
||||
logger.debug(f"Raw Model QnA Output: {model_output}")
|
||||
|
||||
yield _extract_quotes_from_completed_token_stream(
|
||||
model_output=model_output,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
|
||||
def build_quotes_processor(
|
||||
context_docs: list[LlmDoc], is_json_prompt: bool
|
||||
) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]:
|
||||
def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn:
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
return stream_processor
|
17
backend/danswer/llm/answering/stream_processing/utils.py
Normal file
17
backend/danswer/llm/answering/stream_processing/utils.py
Normal file
@ -0,0 +1,17 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> dict[str, int]:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in order_mapping:
|
||||
order_mapping[chunk.document_id] = current
|
||||
current += 1
|
||||
|
||||
return order_mapping
|
@ -33,6 +33,7 @@ from danswer.db.models import ChatMessage
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@ -114,7 +115,9 @@ def tokenizer_trim_chunks(
|
||||
return new_chunks
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage:
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
) -> BaseMessage:
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
@ -126,7 +129,7 @@ def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage:
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage],
|
||||
history: list[ChatMessage] | list[PreviousMessage],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
|
@ -1,54 +1,37 @@
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import build_chat_system_message
|
||||
from danswer.chat.chat_utils import compute_max_document_tokens
|
||||
from danswer.chat.chat_utils import extract_citations_from_stream
|
||||
from danswer.chat.chat_utils import get_chunks_for_qa
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
|
||||
from danswer.chat.chat_utils import map_document_id_order
|
||||
from danswer.chat.chat_utils import reorganize_citations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContext
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LLMMetricsContainer
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import get_persona_by_id
|
||||
from danswer.db.chat import get_prompt_by_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import User
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.one_shot_answer.factory import get_question_answer_model
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.one_shot_answer.qa_block import no_gen_ai_response
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
from danswer.prompts.prompt_utils import build_task_prompt_reminders
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
@ -77,106 +60,6 @@ AnswerObjectIterator = Iterator[
|
||||
]
|
||||
|
||||
|
||||
def quote_based_qa(
|
||||
prompt: Prompt,
|
||||
query_message: ThreadMessage,
|
||||
history_str: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
llm_override: str | None,
|
||||
timeout: int,
|
||||
use_chain_of_thought: bool,
|
||||
return_contexts: bool,
|
||||
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerObjectIterator:
|
||||
qa_model = get_question_answer_model(
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
chain_of_thought=use_chain_of_thought,
|
||||
llm_version=llm_override,
|
||||
)
|
||||
|
||||
full_prompt_str = (
|
||||
qa_model.build_prompt(
|
||||
query=query_message.message,
|
||||
history_str=history_str,
|
||||
context_chunks=context_chunks,
|
||||
)
|
||||
if qa_model is not None
|
||||
else "Gen AI Disabled"
|
||||
)
|
||||
|
||||
response_packets = (
|
||||
qa_model.answer_question_stream(
|
||||
prompt=full_prompt_str,
|
||||
llm_context_docs=context_chunks,
|
||||
metrics_callback=llm_metrics_callback,
|
||||
)
|
||||
if qa_model is not None
|
||||
else no_gen_ai_response()
|
||||
)
|
||||
|
||||
if qa_model is not None and return_contexts:
|
||||
contexts = DanswerContexts(
|
||||
contexts=[
|
||||
DanswerContext(
|
||||
content=context_chunk.content,
|
||||
document_id=context_chunk.document_id,
|
||||
semantic_identifier=context_chunk.semantic_identifier,
|
||||
blurb=context_chunk.semantic_identifier,
|
||||
)
|
||||
for context_chunk in context_chunks
|
||||
]
|
||||
)
|
||||
|
||||
response_packets = itertools.chain(response_packets, [contexts])
|
||||
|
||||
yield from response_packets
|
||||
|
||||
|
||||
def citation_based_qa(
|
||||
prompt: Prompt,
|
||||
query_message: ThreadMessage,
|
||||
history_str: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
llm_override: str | None,
|
||||
timeout: int,
|
||||
) -> AnswerObjectIterator:
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
|
||||
system_prompt_or_none, _ = build_chat_system_message(
|
||||
prompt=prompt,
|
||||
context_exists=True,
|
||||
llm_tokenizer_encode_func=llm_tokenizer.encode,
|
||||
)
|
||||
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
|
||||
|
||||
context_docs_str = build_complete_context_str(context_chunks)
|
||||
user_message = HumanMessage(
|
||||
content=CITATIONS_PROMPT.format(
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query_message.message,
|
||||
history_block=history_str,
|
||||
context_docs_str=context_docs_str,
|
||||
)
|
||||
)
|
||||
|
||||
llm = get_default_llm(
|
||||
timeout=timeout,
|
||||
gen_ai_model_version_override=llm_override,
|
||||
)
|
||||
|
||||
llm_prompt: list[BaseMessage] = [user_message]
|
||||
if system_prompt_or_none is not None:
|
||||
llm_prompt = [system_prompt_or_none] + llm_prompt
|
||||
|
||||
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in context_chunks]
|
||||
doc_id_to_rank_map = map_document_id_order(llm_docs)
|
||||
|
||||
tokens = llm.stream(llm_prompt)
|
||||
yield from extract_citations_from_stream(tokens, llm_docs, doc_id_to_rank_map)
|
||||
|
||||
|
||||
def stream_answer_objects(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
@ -188,14 +71,12 @@ def stream_answer_objects(
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerObjectIterator:
|
||||
"""Streams in order:
|
||||
1. [always] Retrieved documents, stops flow if nothing is found
|
||||
@ -273,43 +154,11 @@ def stream_answer_objects(
|
||||
)
|
||||
yield llm_relevance_filtering_response
|
||||
|
||||
# Prep chunks to pass to LLM
|
||||
num_llm_chunks = (
|
||||
chat_session.persona.num_chunks
|
||||
if chat_session.persona.num_chunks is not None
|
||||
else default_num_chunks
|
||||
)
|
||||
|
||||
chunk_token_limit = int(num_llm_chunks * default_chunk_size)
|
||||
if max_document_tokens:
|
||||
chunk_token_limit = min(chunk_token_limit, max_document_tokens)
|
||||
else:
|
||||
max_document_tokens = compute_max_document_tokens(
|
||||
persona=chat_session.persona, actual_user_input=query_msg.message
|
||||
)
|
||||
chunk_token_limit = min(chunk_token_limit, max_document_tokens)
|
||||
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=search_pipeline.chunk_relevance_list,
|
||||
token_limit=chunk_token_limit,
|
||||
)
|
||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||
|
||||
logger.debug(
|
||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
|
||||
)
|
||||
|
||||
prompt = None
|
||||
llm_override = None
|
||||
if query_req.prompt_id is not None:
|
||||
prompt = get_prompt_by_id(
|
||||
prompt_id=query_req.prompt_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
persona = get_persona_by_id(
|
||||
persona_id=query_req.persona_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
llm_override = persona.llm_model_version_override
|
||||
if prompt is None:
|
||||
if not chat_session.persona.prompts:
|
||||
raise RuntimeError(
|
||||
@ -329,52 +178,39 @@ def stream_answer_objects(
|
||||
commit=True,
|
||||
)
|
||||
|
||||
if use_citations:
|
||||
qa_stream = citation_based_qa(
|
||||
prompt=prompt,
|
||||
query_message=query_msg,
|
||||
history_str=history_str,
|
||||
context_chunks=llm_chunks,
|
||||
llm_override=llm_override,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
qa_stream = quote_based_qa(
|
||||
prompt=prompt,
|
||||
query_message=query_msg,
|
||||
history_str=history_str,
|
||||
context_chunks=llm_chunks,
|
||||
llm_override=llm_override,
|
||||
timeout=timeout,
|
||||
use_chain_of_thought=False,
|
||||
return_contexts=False,
|
||||
llm_metrics_callback=llm_metrics_callback,
|
||||
)
|
||||
|
||||
# Capture outputs and errors
|
||||
llm_output = ""
|
||||
error: str | None = None
|
||||
for packet in qa_stream:
|
||||
logger.debug(packet)
|
||||
|
||||
if isinstance(packet, DanswerAnswerPiece):
|
||||
token = packet.answer_piece
|
||||
if token:
|
||||
llm_output += token
|
||||
elif isinstance(packet, StreamingError):
|
||||
error = packet.error
|
||||
|
||||
yield packet
|
||||
answer_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig() if use_citations else None,
|
||||
quotes_config=QuotesConfig() if not use_citations else None,
|
||||
document_pruning_config=DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
chat_session.persona.num_chunks
|
||||
if chat_session.persona.num_chunks is not None
|
||||
else default_num_chunks
|
||||
),
|
||||
max_tokens=max_document_tokens,
|
||||
),
|
||||
)
|
||||
answer = Answer(
|
||||
question=query_msg.message,
|
||||
docs=[llm_doc_from_inference_chunk(chunk) for chunk in top_chunks],
|
||||
answer_style_config=answer_config,
|
||||
prompt=prompt,
|
||||
persona=chat_session.persona,
|
||||
doc_relevance_list=search_pipeline.chunk_relevance_list,
|
||||
single_message_history=history_str,
|
||||
timeout=timeout,
|
||||
)
|
||||
yield from answer.processed_streamed_output
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
gen_ai_response_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=llm_output,
|
||||
token_count=len(llm_tokenizer(llm_output)),
|
||||
message=answer.llm_answer,
|
||||
token_count=len(llm_tokenizer(answer.llm_answer)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=error,
|
||||
error=None,
|
||||
reference_docs=None, # Don't need to save reference docs for one shot flow
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
@ -419,7 +255,6 @@ def get_search_answer(
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> OneShotQAResponse:
|
||||
"""Collects the streamed one shot answer responses into a single object"""
|
||||
qa_response = OneShotQAResponse()
|
||||
@ -435,7 +270,6 @@ def get_search_answer(
|
||||
timeout=answer_generation_timeout,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
llm_metrics_callback=llm_metrics_callback,
|
||||
)
|
||||
|
||||
answer = ""
|
||||
|
@ -1,48 +0,0 @@
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.one_shot_answer.interfaces import QAModel
|
||||
from danswer.one_shot_answer.qa_block import QABlock
|
||||
from danswer.one_shot_answer.qa_block import QAHandler
|
||||
from danswer.one_shot_answer.qa_block import SingleMessageQAHandler
|
||||
from danswer.one_shot_answer.qa_block import WeakLLMQAHandler
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_question_answer_model(
|
||||
prompt: Prompt | None,
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
chain_of_thought: bool = False,
|
||||
llm_version: str | None = None,
|
||||
qa_model_version: str | None = QA_PROMPT_OVERRIDE,
|
||||
) -> QAModel | None:
|
||||
if chain_of_thought:
|
||||
raise NotImplementedError("COT has been disabled")
|
||||
|
||||
system_prompt = prompt.system_prompt if prompt is not None else None
|
||||
task_prompt = prompt.task_prompt if prompt is not None else None
|
||||
|
||||
try:
|
||||
llm = get_default_llm(
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
gen_ai_model_version_override=llm_version,
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
return None
|
||||
|
||||
if qa_model_version == "weak":
|
||||
qa_handler: QAHandler = WeakLLMQAHandler(
|
||||
system_prompt=system_prompt, task_prompt=task_prompt
|
||||
)
|
||||
else:
|
||||
qa_handler = SingleMessageQAHandler(
|
||||
system_prompt=system_prompt, task_prompt=task_prompt
|
||||
)
|
||||
|
||||
return QABlock(llm=llm, qa_handler=qa_handler)
|
@ -1,26 +0,0 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import LLMMetricsContainer
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
|
||||
|
||||
class QAModel:
|
||||
@abc.abstractmethod
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
history_str: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def answer_question_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
llm_context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
raise NotImplementedError
|
@ -1,313 +0,0 @@
|
||||
import abc
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import LLMMetricsContainer
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.interfaces import QAModel
|
||||
from danswer.one_shot_answer.qa_utils import process_answer
|
||||
from danswer.one_shot_answer.qa_utils import process_model_tokens
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import COT_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
from danswer.prompts.direct_qa_prompts import ONE_SHOT_SYSTEM_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import ONE_SHOT_TASK_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_MODEL_SYSTEM_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_MODEL_TASK_PROMPT
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_up_code_blocks
|
||||
from danswer.utils.text_processing import escape_newlines
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class QAHandler(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_json_output(self) -> bool:
|
||||
"""Does the model output a valid json with answer and quotes keys? Most flows with a
|
||||
capable model should output a json. This hints to the model that the output is used
|
||||
with a downstream system rather than freeform creative output. Most models should be
|
||||
finetuned to recognize this."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
history_str: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def process_llm_token_stream(
|
||||
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_chunks,
|
||||
is_json_prompt=self.is_json_output,
|
||||
)
|
||||
|
||||
|
||||
class WeakLLMQAHandler(QAHandler):
|
||||
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
|
||||
as an option to use with weaker LLMs such as small version, low float precision, quantized,
|
||||
or distilled models. It only uses one context document and has very weak requirements of
|
||||
output format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str | None,
|
||||
task_prompt: str | None,
|
||||
) -> None:
|
||||
if not system_prompt and not task_prompt:
|
||||
self.system_prompt = WEAK_MODEL_SYSTEM_PROMPT
|
||||
self.task_prompt = WEAK_MODEL_TASK_PROMPT
|
||||
else:
|
||||
self.system_prompt = system_prompt or ""
|
||||
self.task_prompt = task_prompt or ""
|
||||
|
||||
self.task_prompt = self.task_prompt.rstrip()
|
||||
if self.task_prompt and self.task_prompt[0] != "\n":
|
||||
self.task_prompt = "\n" + self.task_prompt
|
||||
|
||||
@property
|
||||
def is_json_output(self) -> bool:
|
||||
return False
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
history_str: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> str:
|
||||
context_block = ""
|
||||
if context_chunks:
|
||||
context_block = CONTEXT_BLOCK.format(
|
||||
context_docs_str=context_chunks[0].content
|
||||
)
|
||||
|
||||
prompt_str = WEAK_LLM_PROMPT.format(
|
||||
system_prompt=self.system_prompt,
|
||||
context_block=context_block,
|
||||
task_prompt=self.task_prompt,
|
||||
user_query=query,
|
||||
)
|
||||
return prompt_str
|
||||
|
||||
|
||||
class SingleMessageQAHandler(QAHandler):
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str | None,
|
||||
task_prompt: str | None,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> None:
|
||||
self.use_language_hint = use_language_hint
|
||||
if not system_prompt and not task_prompt:
|
||||
self.system_prompt = ONE_SHOT_SYSTEM_PROMPT
|
||||
self.task_prompt = ONE_SHOT_TASK_PROMPT
|
||||
else:
|
||||
self.system_prompt = system_prompt or ""
|
||||
self.task_prompt = task_prompt or ""
|
||||
|
||||
self.task_prompt = self.task_prompt.rstrip()
|
||||
if self.task_prompt and self.task_prompt[0] != "\n":
|
||||
self.task_prompt = "\n" + self.task_prompt
|
||||
|
||||
@property
|
||||
def is_json_output(self) -> bool:
|
||||
return True
|
||||
|
||||
def build_prompt(
|
||||
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
|
||||
) -> str:
|
||||
context_block = ""
|
||||
if context_chunks:
|
||||
context_docs_str = build_complete_context_str(
|
||||
cast(list[LlmDoc | InferenceChunk], context_chunks)
|
||||
)
|
||||
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
|
||||
|
||||
history_block = ""
|
||||
if history_str:
|
||||
history_block = HISTORY_BLOCK.format(history_str=history_str)
|
||||
|
||||
full_prompt = JSON_PROMPT.format(
|
||||
system_prompt=self.system_prompt,
|
||||
context_block=context_block,
|
||||
history_block=history_block,
|
||||
task_prompt=self.task_prompt,
|
||||
user_query=query,
|
||||
language_hint_or_none=LANGUAGE_HINT.strip()
|
||||
if self.use_language_hint
|
||||
else "",
|
||||
).strip()
|
||||
return full_prompt
|
||||
|
||||
|
||||
# This one isn't used, currently only streaming prompts are used
|
||||
class SingleMessageScratchpadHandler(QAHandler):
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str | None,
|
||||
task_prompt: str | None,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> None:
|
||||
self.use_language_hint = use_language_hint
|
||||
if not system_prompt and not task_prompt:
|
||||
self.system_prompt = ONE_SHOT_SYSTEM_PROMPT
|
||||
self.task_prompt = ONE_SHOT_TASK_PROMPT
|
||||
else:
|
||||
self.system_prompt = system_prompt or ""
|
||||
self.task_prompt = task_prompt or ""
|
||||
|
||||
self.task_prompt = self.task_prompt.rstrip()
|
||||
if self.task_prompt and self.task_prompt[0] != "\n":
|
||||
self.task_prompt = "\n" + self.task_prompt
|
||||
|
||||
@property
|
||||
def is_json_output(self) -> bool:
|
||||
return True
|
||||
|
||||
def build_prompt(
|
||||
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
|
||||
) -> str:
|
||||
context_docs_str = build_complete_context_str(
|
||||
cast(list[LlmDoc | InferenceChunk], context_chunks)
|
||||
)
|
||||
|
||||
# Outdated
|
||||
prompt = COT_PROMPT.format(
|
||||
context_docs_str=context_docs_str,
|
||||
user_query=query,
|
||||
language_hint_or_none=LANGUAGE_HINT.strip()
|
||||
if self.use_language_hint
|
||||
else "",
|
||||
).strip()
|
||||
|
||||
return prompt
|
||||
|
||||
def process_llm_output(
|
||||
self, model_output: str, context_chunks: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
||||
logger.debug(model_output)
|
||||
|
||||
model_clean = clean_up_code_blocks(model_output)
|
||||
|
||||
match = re.search(r'{\s*"answer":', model_clean)
|
||||
if not match:
|
||||
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
|
||||
|
||||
final_json = escape_newlines(model_clean[match.start() :])
|
||||
|
||||
return process_answer(
|
||||
final_json, context_chunks, is_json_prompt=self.is_json_output
|
||||
)
|
||||
|
||||
def process_llm_token_stream(
|
||||
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
# Can be supported but the parsing is more involved, not handling until needed
|
||||
raise ValueError(
|
||||
"This Scratchpad approach is not suitable for real time uses like streaming"
|
||||
)
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]:
|
||||
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
|
||||
|
||||
|
||||
class QABlock(QAModel):
|
||||
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
|
||||
self._llm = llm
|
||||
self._qa_handler = qa_handler
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
history_str: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> str:
|
||||
prompt = self._qa_handler.build_prompt(
|
||||
query=query, history_str=history_str, context_chunks=context_chunks
|
||||
)
|
||||
return prompt
|
||||
|
||||
def answer_question_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
llm_context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
tokens_stream = self._llm.stream(prompt)
|
||||
|
||||
captured_tokens = []
|
||||
|
||||
try:
|
||||
for answer_piece in self._qa_handler.process_llm_token_stream(
|
||||
iter(tokens_stream), llm_context_docs
|
||||
):
|
||||
if (
|
||||
isinstance(answer_piece, DanswerAnswerPiece)
|
||||
and answer_piece.answer_piece
|
||||
):
|
||||
captured_tokens.append(answer_piece.answer_piece)
|
||||
yield answer_piece
|
||||
|
||||
except Exception as e:
|
||||
yield StreamingError(error=str(e))
|
||||
|
||||
if metrics_callback is not None:
|
||||
prompt_tokens = check_number_of_tokens(
|
||||
text=str(prompt), encode_fn=get_default_llm_token_encode()
|
||||
)
|
||||
|
||||
response_tokens = check_number_of_tokens(
|
||||
text="".join(captured_tokens), encode_fn=get_default_llm_token_encode()
|
||||
)
|
||||
|
||||
metrics_callback(
|
||||
LLMMetricsContainer(
|
||||
prompt_tokens=prompt_tokens, response_tokens=response_tokens
|
||||
)
|
||||
)
|
@ -1,275 +1,14 @@
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import regex
|
||||
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.prompts.constants import ANSWER_PAT
|
||||
from danswer.prompts.constants import QUOTE_PAT
|
||||
from danswer.prompts.constants import UNCERTAINTY_PAT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_model_quote
|
||||
from danswer.utils.text_processing import clean_up_code_blocks
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
"""Splits the model output into an Answer and 0 or more Quote sections.
|
||||
Splits by the Quote pattern, if not exist then assume it's all answer and no quotes
|
||||
"""
|
||||
# If no answer section, don't care about the quote
|
||||
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
|
||||
return None, None
|
||||
|
||||
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
|
||||
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
|
||||
answer_raw = answer_raw[len(ANSWER_PAT) :]
|
||||
|
||||
# Accept quote sections starting with the lower case version
|
||||
answer_raw = answer_raw.replace(
|
||||
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
|
||||
) # Just in case model unreliable
|
||||
|
||||
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
|
||||
sections_clean = [
|
||||
str(section).strip() for section in sections if str(section).strip()
|
||||
]
|
||||
if not sections_clean:
|
||||
return None, None
|
||||
|
||||
answer = str(sections_clean[0])
|
||||
if len(sections) == 1:
|
||||
return answer, None
|
||||
return answer, sections_clean[1:]
|
||||
|
||||
|
||||
def _extract_answer_quotes_json(
|
||||
answer_dict: dict[str, str | list[str]]
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
|
||||
answer = str(answer_dict.get("answer"))
|
||||
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
|
||||
if isinstance(quotes, str):
|
||||
quotes = [quotes]
|
||||
return answer, quotes
|
||||
|
||||
|
||||
def _extract_answer_json(raw_model_output: str) -> dict:
|
||||
try:
|
||||
answer_json = extract_embedded_json(raw_model_output)
|
||||
except (ValueError, JSONDecodeError):
|
||||
# LLMs get confused when handling the list in the json. Sometimes it doesn't attend
|
||||
# enough to the previous { token so it just ends the list of quotes and stops there
|
||||
# here, we add logic to try to fix this LLM error.
|
||||
answer_json = extract_embedded_json(raw_model_output + "}")
|
||||
|
||||
if "answer" not in answer_json:
|
||||
raise ValueError("Model did not output an answer as expected.")
|
||||
|
||||
return answer_json
|
||||
|
||||
|
||||
def separate_answer_quotes(
|
||||
answer_raw: str, is_json_prompt: bool = False
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
"""Takes in a raw model output and pulls out the answer and the quotes sections."""
|
||||
if is_json_prompt:
|
||||
model_raw_json = _extract_answer_json(answer_raw)
|
||||
return _extract_answer_quotes_json(model_raw_json)
|
||||
|
||||
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
|
||||
|
||||
|
||||
def match_quotes_to_docs(
|
||||
quotes: list[str],
|
||||
chunks: list[InferenceChunk],
|
||||
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
|
||||
fuzzy_search: bool = False,
|
||||
prefix_only_length: int = 100,
|
||||
) -> DanswerQuotes:
|
||||
danswer_quotes: list[DanswerQuote] = []
|
||||
for quote in quotes:
|
||||
max_edits = math.ceil(float(len(quote)) * max_error_percent)
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.source_links:
|
||||
continue
|
||||
|
||||
quote_clean = shared_precompare_cleanup(
|
||||
clean_model_quote(quote, trim_length=prefix_only_length)
|
||||
)
|
||||
chunk_clean = shared_precompare_cleanup(chunk.content)
|
||||
|
||||
# Finding the offset of the quote in the plain text
|
||||
if fuzzy_search:
|
||||
re_search_str = (
|
||||
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
|
||||
)
|
||||
found = regex.search(re_search_str, chunk_clean)
|
||||
if not found:
|
||||
continue
|
||||
offset = found.span()[0]
|
||||
else:
|
||||
if quote_clean not in chunk_clean:
|
||||
continue
|
||||
offset = chunk_clean.index(quote_clean)
|
||||
|
||||
# Extracting the link from the offset
|
||||
curr_link = None
|
||||
for link_offset, link in chunk.source_links.items():
|
||||
# Should always find one because offset is at least 0 and there
|
||||
# must be a 0 link_offset
|
||||
if int(link_offset) <= offset:
|
||||
curr_link = link
|
||||
else:
|
||||
break
|
||||
|
||||
danswer_quotes.append(
|
||||
DanswerQuote(
|
||||
quote=quote,
|
||||
document_id=chunk.document_id,
|
||||
link=curr_link,
|
||||
source_type=chunk.source_type,
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
blurb=chunk.blurb,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
return DanswerQuotes(quotes=danswer_quotes)
|
||||
|
||||
|
||||
def process_answer(
|
||||
answer_raw: str,
|
||||
chunks: list[InferenceChunk],
|
||||
is_json_prompt: bool = True,
|
||||
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
||||
"""Used (1) in the non-streaming case to process the model output
|
||||
into an Answer and Quotes AND (2) after the complete streaming response
|
||||
has been received to process the model output into an Answer and Quotes."""
|
||||
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
|
||||
if answer == UNCERTAINTY_PAT or not answer:
|
||||
if answer == UNCERTAINTY_PAT:
|
||||
logger.debug("Answer matched UNCERTAINTY_PAT")
|
||||
else:
|
||||
logger.debug("No answer extracted from raw output")
|
||||
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
|
||||
|
||||
logger.info(f"Answer: {answer}")
|
||||
if not quote_strings:
|
||||
logger.debug("No quotes extracted from raw output")
|
||||
return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[])
|
||||
logger.info(f"All quotes (including unmatched): {quote_strings}")
|
||||
quotes = match_quotes_to_docs(quote_strings, chunks)
|
||||
logger.debug(f"Final quotes: {quotes}")
|
||||
|
||||
return DanswerAnswer(answer=answer), quotes
|
||||
|
||||
|
||||
def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
next_token = next_token.replace('\\"', "")
|
||||
# If the previous character is an escape token, don't consider the first character of next_token
|
||||
# This does not work if it's an escaped escape sign before the " but this is rare, not worth handling
|
||||
if answer_so_far and answer_so_far[-1] == "\\":
|
||||
next_token = next_token[1:]
|
||||
if '"' in next_token:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _extract_quotes_from_completed_token_stream(
|
||||
model_output: str, context_chunks: list[InferenceChunk], is_json_prompt: bool = True
|
||||
) -> DanswerQuotes:
|
||||
answer, quotes = process_answer(model_output, context_chunks, is_json_prompt)
|
||||
if answer:
|
||||
logger.info(answer)
|
||||
elif model_output:
|
||||
logger.warning("Answer extraction from model output failed.")
|
||||
|
||||
return quotes
|
||||
|
||||
|
||||
def process_model_tokens(
|
||||
tokens: Iterator[str],
|
||||
context_docs: list[InferenceChunk],
|
||||
is_json_prompt: bool = True,
|
||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||
"""Used in the streaming case to process the model output
|
||||
into an Answer and Quotes
|
||||
|
||||
Yields Answer tokens back out in a dict for streaming to frontend
|
||||
When Answer section ends, yields dict with answer_finished key
|
||||
Collects all the tokens at the end to form the complete model output"""
|
||||
quote_pat = f"\n{QUOTE_PAT}"
|
||||
# Sometimes worse model outputs new line instead of :
|
||||
quote_loose = f"\n{quote_pat[:-1]}\n"
|
||||
# Sometime model outputs two newlines before quote section
|
||||
quote_pat_full = f"\n{quote_pat}"
|
||||
model_output: str = ""
|
||||
found_answer_start = False if is_json_prompt else True
|
||||
found_answer_end = False
|
||||
hold_quote = ""
|
||||
for token in tokens:
|
||||
model_previous = model_output
|
||||
model_output += token
|
||||
|
||||
if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output):
|
||||
# Note, if the token that completes the pattern has additional text, for example if the token is "?
|
||||
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
|
||||
# event that the model outputs the UNCERTAINTY_PAT
|
||||
found_answer_start = True
|
||||
|
||||
# Prevent heavy cases of hallucinations where model is not even providing a json until later
|
||||
if is_json_prompt and len(model_output) > 40:
|
||||
logger.warning("LLM did not produce json as prompted")
|
||||
found_answer_end = True
|
||||
|
||||
continue
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if is_json_prompt and _stream_json_answer_end(model_previous, token):
|
||||
found_answer_end = True
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
continue
|
||||
elif not is_json_prompt:
|
||||
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
|
||||
found_answer_end = True
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
continue
|
||||
if hold_quote + token in quote_pat_full:
|
||||
hold_quote += token
|
||||
continue
|
||||
yield DanswerAnswerPiece(answer_piece=hold_quote + token)
|
||||
hold_quote = ""
|
||||
|
||||
logger.debug(f"Raw Model QnA Output: {model_output}")
|
||||
|
||||
yield _extract_quotes_from_completed_token_stream(
|
||||
model_output=model_output,
|
||||
context_chunks=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
|
||||
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
|
||||
"""Mock streaming by generating the passed in model output, character by character"""
|
||||
for token in model_out:
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@ -51,6 +52,11 @@ class SearchPipeline:
|
||||
self._reranked_docs: list[InferenceChunk] | None = None
|
||||
self._relevant_chunk_indicies: list[int] | None = None
|
||||
|
||||
# generator state
|
||||
self._postprocessing_generator: Generator[
|
||||
list[InferenceChunk] | list[str], None, None
|
||||
] | None = None
|
||||
|
||||
"""Pre-processing"""
|
||||
|
||||
def _run_preprocessing(self) -> None:
|
||||
@ -113,36 +119,38 @@ class SearchPipeline:
|
||||
|
||||
"""Post-Processing"""
|
||||
|
||||
def _run_postprocessing(self) -> None:
|
||||
postprocessing_generator = search_postprocessing(
|
||||
search_query=self.search_query,
|
||||
retrieved_chunks=self.retrieved_docs,
|
||||
rerank_metrics_callback=self.rerank_metrics_callback,
|
||||
)
|
||||
self._reranked_docs = cast(list[InferenceChunk], next(postprocessing_generator))
|
||||
|
||||
relevant_chunk_ids = cast(list[str], next(postprocessing_generator))
|
||||
self._relevant_chunk_indicies = [
|
||||
ind
|
||||
for ind, chunk in enumerate(self._reranked_docs)
|
||||
if chunk.unique_id in relevant_chunk_ids
|
||||
]
|
||||
|
||||
@property
|
||||
def reranked_docs(self) -> list[InferenceChunk]:
|
||||
if self._reranked_docs is not None:
|
||||
return self._reranked_docs
|
||||
|
||||
self._run_postprocessing()
|
||||
return cast(list[InferenceChunk], self._reranked_docs)
|
||||
self._postprocessing_generator = search_postprocessing(
|
||||
search_query=self.search_query,
|
||||
retrieved_chunks=self.retrieved_docs,
|
||||
rerank_metrics_callback=self.rerank_metrics_callback,
|
||||
)
|
||||
self._reranked_docs = cast(
|
||||
list[InferenceChunk], next(self._postprocessing_generator)
|
||||
)
|
||||
return self._reranked_docs
|
||||
|
||||
@property
|
||||
def relevant_chunk_indicies(self) -> list[int]:
|
||||
if self._relevant_chunk_indicies is not None:
|
||||
return self._relevant_chunk_indicies
|
||||
|
||||
self._run_postprocessing()
|
||||
return cast(list[int], self._relevant_chunk_indicies)
|
||||
# run first step of postprocessing generator if not already done
|
||||
reranked_docs = self.reranked_docs
|
||||
|
||||
relevant_chunk_ids = next(
|
||||
cast(Generator[list[str], None, None], self._postprocessing_generator)
|
||||
)
|
||||
self._relevant_chunk_indicies = [
|
||||
ind
|
||||
for ind, chunk in enumerate(reranked_docs)
|
||||
if chunk.unique_id in relevant_chunk_ids
|
||||
]
|
||||
return self._relevant_chunk_indicies
|
||||
|
||||
@property
|
||||
def chunk_relevance_list(self) -> list[bool]:
|
||||
|
@ -223,11 +223,13 @@ def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=first_chunk.document_id,
|
||||
content="\n".join(chunk_texts),
|
||||
blurb=first_chunk.blurb,
|
||||
semantic_identifier=first_chunk.semantic_identifier,
|
||||
source_type=first_chunk.source_type,
|
||||
metadata=first_chunk.metadata,
|
||||
updated_at=first_chunk.updated_at,
|
||||
link=first_chunk.source_links[0] if first_chunk.source_links else None,
|
||||
source_links=first_chunk.source_links,
|
||||
)
|
||||
|
||||
|
||||
|
@ -14,8 +14,8 @@ from danswer.db.chat import update_persona_visibility
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import create_update_persona
|
||||
from danswer.llm.answering.prompts.utils import build_dummy_prompt
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.one_shot_answer.qa_block import build_dummy_prompt
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.features.persona.models import PromptTemplateResponse
|
||||
|
@ -6,7 +6,6 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import compute_max_document_tokens
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.process_message import stream_chat_message
|
||||
from danswer.db.chat import create_chat_session
|
||||
@ -25,6 +24,7 @@ from danswer.db.feedback import create_doc_retrieval_feedback
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
)
|
||||
|
@ -77,7 +77,6 @@ def get_answer_for_question(
|
||||
str | None,
|
||||
RetrievalMetricsContainer | None,
|
||||
RerankMetricsContainer | None,
|
||||
LLMMetricsContainer | None,
|
||||
]:
|
||||
filters = IndexFilters(
|
||||
source_type=None,
|
||||
@ -103,7 +102,6 @@ def get_answer_for_question(
|
||||
|
||||
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
|
||||
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
||||
llm_metrics = MetricsHander[LLMMetricsContainer]()
|
||||
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
@ -116,14 +114,12 @@ def get_answer_for_question(
|
||||
bypass_acl=True,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||
llm_metrics_callback=llm_metrics.record_metric,
|
||||
)
|
||||
|
||||
return (
|
||||
answer.answer,
|
||||
retrieval_metrics.metrics,
|
||||
rerank_metrics.metrics,
|
||||
llm_metrics.metrics,
|
||||
)
|
||||
|
||||
|
||||
@ -221,7 +217,6 @@ if __name__ == "__main__":
|
||||
answer,
|
||||
retrieval_metrics,
|
||||
rerank_metrics,
|
||||
llm_metrics,
|
||||
) = get_answer_for_question(sample["question"], db_session)
|
||||
end_time = datetime.now()
|
||||
|
||||
@ -237,12 +232,6 @@ if __name__ == "__main__":
|
||||
else "\tFailed, either crashed or refused to answer."
|
||||
)
|
||||
if not args.discard_metrics:
|
||||
print("\nLLM Tokens Usage:")
|
||||
if llm_metrics is None:
|
||||
print("No LLM Metrics Available")
|
||||
else:
|
||||
_print_llm_metrics(llm_metrics)
|
||||
|
||||
print("\nRetrieval Metrics:")
|
||||
if retrieval_metrics is None:
|
||||
print("No Retrieval Metrics Available")
|
||||
|
@ -7,9 +7,9 @@ from typing import TextIO
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import get_chunks_for_qa
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.answering.doc_pruning import reorder_docs
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchRequest
|
||||
@ -95,16 +95,8 @@ def get_search_results(
|
||||
top_chunks = search_pipeline.reranked_docs
|
||||
llm_chunk_selection = search_pipeline.chunk_relevance_list
|
||||
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
token_limit=None,
|
||||
)
|
||||
|
||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||
|
||||
return (
|
||||
llm_chunks,
|
||||
reorder_docs(top_chunks, llm_chunk_selection),
|
||||
retrieval_metrics.metrics,
|
||||
rerank_metrics.metrics,
|
||||
)
|
||||
|
@ -3,8 +3,12 @@ import unittest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.one_shot_answer.qa_utils import match_quotes_to_docs
|
||||
from danswer.one_shot_answer.qa_utils import separate_answer_quotes
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
match_quotes_to_docs,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
separate_answer_quotes,
|
||||
)
|
||||
|
||||
|
||||
class TestQAPostprocessing(unittest.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user