diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index fe97b0b392..ee2f582c95 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -1,97 +1,29 @@ import re -from collections.abc import Callable -from collections.abc import Iterator from collections.abc import Sequence -from functools import lru_cache -from typing import cast -from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage -from langchain.schema.messages import SystemMessage from sqlalchemy.orm import Session -from tiktoken.core import Encoding from danswer.chat.models import CitationInfo -from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc -from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.chat_configs import STOP_STREAM_PAT -from danswer.configs.constants import IGNORE_FOR_QA -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE -from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.chat import get_chat_messages_by_session -from danswer.db.chat import get_default_prompt from danswer.db.models import ChatMessage -from danswer.db.models import Persona -from danswer.db.models import Prompt from danswer.indexing.models import InferenceChunk -from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import get_default_llm_version -from danswer.llm.utils import get_max_input_tokens -from danswer.llm.utils import tokenizer_trim_content -from danswer.prompts.chat_prompts import ADDITIONAL_INFO -from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT -from danswer.prompts.chat_prompts import CHAT_USER_PROMPT -from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT -from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT -from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT -from danswer.prompts.constants import TRIPLE_BACKTICK -from danswer.prompts.prompt_utils import build_complete_context_str -from danswer.prompts.prompt_utils import build_task_prompt_reminders -from danswer.prompts.prompt_utils import get_current_llm_day_time -from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT -from danswer.prompts.token_counts import ( - CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, -) -from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT -from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT -from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT from danswer.utils.logger import setup_logger logger = setup_logger() -@lru_cache() -def build_chat_system_message( - prompt: Prompt, - context_exists: bool, - llm_tokenizer_encode_func: Callable, - citation_line: str = REQUIRE_CITATION_STATEMENT, - no_citation_line: str = NO_CITATION_STATEMENT, -) -> tuple[SystemMessage | None, int]: - system_prompt = prompt.system_prompt.strip() - if prompt.include_citations: - if context_exists: - system_prompt += citation_line - else: - system_prompt += no_citation_line - if prompt.datetime_aware: - if system_prompt: - system_prompt += ADDITIONAL_INFO.format( - datetime_info=get_current_llm_day_time() - ) - else: - system_prompt = get_current_llm_day_time() - - if not system_prompt: - return None, 0 - - token_count = len(llm_tokenizer_encode_func(system_prompt)) - system_msg = SystemMessage(content=system_prompt) - - return system_msg, token_count - - def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc: return LlmDoc( document_id=inf_chunk.document_id, content=inf_chunk.content, + blurb=inf_chunk.blurb, semantic_identifier=inf_chunk.semantic_identifier, source_type=inf_chunk.source_type, metadata=inf_chunk.metadata, updated_at=inf_chunk.updated_at, link=inf_chunk.source_links[0] if inf_chunk.source_links else None, + source_links=inf_chunk.source_links, ) @@ -108,170 +40,6 @@ def map_document_id_order( return order_mapping -def build_chat_user_message( - chat_message: ChatMessage, - prompt: Prompt, - context_docs: list[LlmDoc], - llm_tokenizer_encode_func: Callable, - all_doc_useful: bool, - user_prompt_template: str = CHAT_USER_PROMPT, - context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT, - ignore_str: str = DEFAULT_IGNORE_STATEMENT, -) -> tuple[HumanMessage, int]: - user_query = chat_message.message - - if not context_docs: - # Simpler prompt for cases where there is no context - user_prompt = ( - context_free_template.format( - task_prompt=prompt.task_prompt, user_query=user_query - ) - if prompt.task_prompt - else user_query - ) - user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer_encode_func(user_prompt)) - user_msg = HumanMessage(content=user_prompt) - return user_msg, token_count - - context_docs_str = build_complete_context_str( - cast(list[LlmDoc | InferenceChunk], context_docs) - ) - optional_ignore = "" if all_doc_useful else ignore_str - - task_prompt_with_reminder = build_task_prompt_reminders(prompt) - - user_prompt = user_prompt_template.format( - optional_ignore_statement=optional_ignore, - context_docs_str=context_docs_str, - task_prompt=task_prompt_with_reminder, - user_query=user_query, - ) - - user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer_encode_func(user_prompt)) - user_msg = HumanMessage(content=user_prompt) - - return user_msg, token_count - - -def _get_usable_chunks( - chunks: list[InferenceChunk], token_limit: int -) -> list[InferenceChunk]: - total_token_count = 0 - usable_chunks = [] - for chunk in chunks: - chunk_token_count = check_number_of_tokens(chunk.content) - if total_token_count + chunk_token_count > token_limit: - break - - total_token_count += chunk_token_count - usable_chunks.append(chunk) - - # try and return at least one chunk if possible. This chunk will - # get truncated later on in the pipeline. This would only occur if - # the first chunk is larger than the token limit (usually due to character - # count -> token count mismatches caused by special characters / non-ascii - # languages) - if not usable_chunks and chunks: - usable_chunks = [chunks[0]] - - return usable_chunks - - -def get_usable_chunks( - chunks: list[InferenceChunk], - token_limit: int, - offset: int = 0, -) -> list[InferenceChunk]: - offset_into_chunks = 0 - usable_chunks: list[InferenceChunk] = [] - for _ in range(min(offset + 1, 1)): # go through this process at least once - if offset_into_chunks >= len(chunks) and offset_into_chunks > 0: - raise ValueError( - "Chunks offset too large, should not retry this many times" - ) - - usable_chunks = _get_usable_chunks( - chunks=chunks[offset_into_chunks:], token_limit=token_limit - ) - offset_into_chunks += len(usable_chunks) - - return usable_chunks - - -def get_chunks_for_qa( - chunks: list[InferenceChunk], - llm_chunk_selection: list[bool], - token_limit: int | None, - llm_tokenizer: Encoding | None = None, - batch_offset: int = 0, -) -> list[int]: - """ - Gives back indices of chunks to pass into the LLM for Q&A. - - Only selects chunks viable for Q&A, within the token limit, and prioritize those selected - by the LLM in a separate flow (this can be turned off) - - Note, the batch_offset calculation has to count the batches from the beginning each time as - there's no way to know which chunks were included in the prior batches without recounting atm, - this is somewhat slow as it requires tokenizing all the chunks again - """ - token_leeway = 50 - batch_index = 0 - latest_batch_indices: list[int] = [] - token_count = 0 - - # First iterate the LLM selected chunks, then iterate the rest if tokens remaining - for selection_target in [True, False]: - for ind, chunk in enumerate(chunks): - if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get( - IGNORE_FOR_QA - ): - continue - - # We calculate it live in case the user uses a different LLM + tokenizer - chunk_token = check_number_of_tokens(chunk.content) - if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway: - logger.warning( - "Found more tokens in chunk than expected, " - "likely mismatch between embedding and LLM tokenizers. Trimming content..." - ) - chunk.content = tokenizer_trim_content( - content=chunk.content, - desired_length=DOC_EMBEDDING_CONTEXT_SIZE, - tokenizer=llm_tokenizer or get_default_llm_tokenizer(), - ) - - # 50 for an approximate/slight overestimate for # tokens for metadata for the chunk - token_count += chunk_token + token_leeway - - # Always use at least 1 chunk - if ( - token_limit is None - or token_count <= token_limit - or not latest_batch_indices - ): - latest_batch_indices.append(ind) - current_chunk_unused = False - else: - current_chunk_unused = True - - if token_limit is not None and token_count >= token_limit: - if batch_index < batch_offset: - batch_index += 1 - if current_chunk_unused: - latest_batch_indices = [ind] - token_count = chunk_token - else: - latest_batch_indices = [] - token_count = 0 - else: - return latest_batch_indices - - return latest_batch_indices - - def create_chat_chain( chat_session_id: int, db_session: Session, @@ -341,157 +109,6 @@ def combine_message_chain( return "\n\n".join(message_strs) -_PER_MESSAGE_TOKEN_BUFFER = 7 - - -def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: - """From the back, find the index of the last element to include - before the list exceeds the maximum""" - running_sum = 0 - - last_ind = 0 - for i in range(len(lst) - 1, -1, -1): - running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER - if running_sum > max_prompt_tokens: - last_ind = i + 1 - break - if last_ind >= len(lst): - raise ValueError("Last message alone is too large!") - return last_ind - - -def drop_messages_history_overflow( - system_msg: BaseMessage | None, - system_token_count: int, - history_msgs: list[BaseMessage], - history_token_counts: list[int], - final_msg: BaseMessage, - final_msg_token_count: int, - max_allowed_tokens: int, -) -> list[BaseMessage]: - """As message history grows, messages need to be dropped starting from the furthest in the past. - The System message should be kept if at all possible and the latest user input which is inserted in the - prompt template must be included""" - if len(history_msgs) != len(history_token_counts): - # This should never happen - raise ValueError("Need exactly 1 token count per message for tracking overflow") - - prompt: list[BaseMessage] = [] - - # Start dropping from the history if necessary - all_tokens = history_token_counts + [system_token_count, final_msg_token_count] - ind_prev_msg_start = find_last_index( - all_tokens, max_prompt_tokens=max_allowed_tokens - ) - - if system_msg and ind_prev_msg_start <= len(history_msgs): - prompt.append(system_msg) - - prompt.extend(history_msgs[ind_prev_msg_start:]) - - prompt.append(final_msg) - - return prompt - - -def in_code_block(llm_text: str) -> bool: - count = llm_text.count(TRIPLE_BACKTICK) - return count % 2 != 0 - - -def extract_citations_from_stream( - tokens: Iterator[str], - context_docs: list[LlmDoc], - doc_id_to_rank_map: dict[str, int], - stop_stream: str | None = STOP_STREAM_PAT, -) -> Iterator[DanswerAnswerPiece | CitationInfo]: - llm_out = "" - max_citation_num = len(context_docs) - curr_segment = "" - prepend_bracket = False - cited_inds = set() - hold = "" - for raw_token in tokens: - if stop_stream: - next_hold = hold + raw_token - - if stop_stream in next_hold: - break - - if next_hold == stop_stream[: len(next_hold)]: - hold = next_hold - continue - - token = next_hold - hold = "" - else: - token = raw_token - - # Special case of [1][ where ][ is a single token - # This is where the model attempts to do consecutive citations like [1][2] - if prepend_bracket: - curr_segment += "[" + curr_segment - prepend_bracket = False - - curr_segment += token - llm_out += token - - possible_citation_pattern = r"(\[\d*$)" # [1, [, etc - possible_citation_found = re.search(possible_citation_pattern, curr_segment) - - citation_pattern = r"\[(\d+)\]" # [1], [2] etc - citation_found = re.search(citation_pattern, curr_segment) - - if citation_found and not in_code_block(llm_out): - numerical_value = int(citation_found.group(1)) - if 1 <= numerical_value <= max_citation_num: - context_llm_doc = context_docs[ - numerical_value - 1 - ] # remove 1 index offset - - link = context_llm_doc.link - target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id] - - # Use the citation number for the document's rank in - # the search (or selected docs) results - curr_segment = re.sub( - rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment - ) - - if target_citation_num not in cited_inds: - cited_inds.add(target_citation_num) - yield CitationInfo( - citation_num=target_citation_num, - document_id=context_llm_doc.document_id, - ) - - if link: - curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) - curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) - - # In case there's another open bracket like [1][, don't want to match this - possible_citation_found = None - - # if we see "[", but haven't seen the right side, hold back - this may be a - # citation that needs to be replaced with a link - if possible_citation_found: - continue - - # Special case with back to back citations [1][2] - if curr_segment and curr_segment[-1] == "[": - curr_segment = curr_segment[:-1] - prepend_bracket = True - - yield DanswerAnswerPiece(answer_piece=curr_segment) - curr_segment = "" - - if curr_segment: - if prepend_bracket: - yield DanswerAnswerPiece(answer_piece="[" + curr_segment) - else: - yield DanswerAnswerPiece(answer_piece=curr_segment) - - def reorganize_citations( answer: str, citations: list[CitationInfo] ) -> tuple[str, list[CitationInfo]]: @@ -547,72 +164,3 @@ def reorganize_citations( new_citation_info[citation.citation_num] = citation return new_answer, list(new_citation_info.values()) - - -def get_prompt_tokens(prompt: Prompt) -> int: - # Note: currently custom prompts do not allow datetime aware, only default prompts - return ( - check_number_of_tokens(prompt.system_prompt) - + check_number_of_tokens(prompt.task_prompt) - + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT - + CITATION_STATEMENT_TOKEN_CNT - + CITATION_REMINDER_TOKEN_CNT - + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) - + (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0) - ) - - -# buffer just to be safe so that we don't overflow the token limit due to -# a small miscalculation -_MISC_BUFFER = 40 - - -def compute_max_document_tokens( - persona: Persona, - actual_user_input: str | None = None, - max_llm_token_override: int | None = None, -) -> int: - """Estimates the number of tokens available for context documents. Formula is roughly: - - ( - model_context_window - reserved_output_tokens - prompt_tokens - - (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe) - ) - - The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g. - if we're trying to determine if the user should be able to select another document) then we just set an - arbitrary "upper bound". - """ - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - # if we can't find a number of tokens, just assume some common default - max_input_tokens = ( - max_llm_token_override - if max_llm_token_override - else get_max_input_tokens(model_name=llm_name) - ) - if persona.prompts: - # TODO this may not always be the first prompt - prompt_tokens = get_prompt_tokens(persona.prompts[0]) - else: - prompt_tokens = get_prompt_tokens(get_default_prompt()) - - user_input_tokens = ( - check_number_of_tokens(actual_user_input) - if actual_user_input is not None - else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS - ) - - return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER - - -def compute_max_llm_input_tokens(persona: Persona) -> int: - """Maximum tokens allows in the input to the LLM (of any type).""" - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - input_tokens = get_max_input_tokens(model_name=llm_name) - return input_tokens - _MISC_BUFFER diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 47d554de77..d2dd9f31fa 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -16,11 +16,13 @@ class LlmDoc(BaseModel): document_id: str content: str + blurb: str semantic_identifier: str source_type: DocumentSource metadata: dict[str, str | list[str]] updated_at: datetime | None link: str | None + source_links: dict[int, str] | None # First chunk of info for streaming QA @@ -100,9 +102,12 @@ class QAResponse(SearchResponse, DanswerAnswer): error_msg: str | None = None -AnswerQuestionStreamReturn = Iterator[ - DanswerAnswerPiece | DanswerQuotes | DanswerContexts | StreamingError -] +AnswerQuestionPossibleReturn = ( + DanswerAnswerPiece | DanswerQuotes | CitationInfo | DanswerContexts | StreamingError +) + + +AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn] class LLMMetricsContainer(BaseModel): diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 9cd78c963b..270afc67e2 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -5,16 +5,8 @@ from typing import cast from sqlalchemy.orm import Session -from danswer.chat.chat_utils import build_chat_system_message -from danswer.chat.chat_utils import build_chat_user_message -from danswer.chat.chat_utils import compute_max_document_tokens -from danswer.chat.chat_utils import compute_max_llm_input_tokens from danswer.chat.chat_utils import create_chat_chain -from danswer.chat.chat_utils import drop_messages_history_overflow -from danswer.chat.chat_utils import extract_citations_from_stream -from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.chat_utils import llm_doc_from_inference_chunk -from danswer.chat.chat_utils import map_document_id_order from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc @@ -23,9 +15,7 @@ from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT -from danswer.configs.constants import DISABLED_GEN_AI_MSG from danswer.configs.constants import MessageType -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_chat_message @@ -37,21 +27,17 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session_context_manager -from danswer.db.models import ChatMessage -from danswer.db.models import Persona from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import User from danswer.document_index.factory import get_default_document_index -from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import PreviousMessage from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm -from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer -from danswer.llm.utils import get_default_llm_version -from danswer.llm.utils import get_max_input_tokens -from danswer.llm.utils import tokenizer_trim_content -from danswer.llm.utils import translate_history_to_basemessages -from danswer.prompts.prompt_utils import build_doc_context_str from danswer.search.models import OptionalSearchSetting from danswer.search.models import SearchRequest from danswer.search.pipeline import SearchPipeline @@ -68,72 +54,6 @@ from danswer.utils.timing import log_generator_function_time logger = setup_logger() -def generate_ai_chat_response( - query_message: ChatMessage, - history: list[ChatMessage], - persona: Persona, - context_docs: list[LlmDoc], - doc_id_to_rank_map: dict[str, int], - llm: LLM | None, - llm_tokenizer_encode_func: Callable, - all_doc_useful: bool, -) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]: - if llm is None: - try: - llm = get_default_llm() - except GenAIDisabledException: - # Not an error if it's a user configuration - yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG) - return - - if query_message.prompt is None: - raise RuntimeError("No prompt received for generating Gen AI answer.") - - try: - context_exists = len(context_docs) > 0 - - system_message_or_none, system_tokens = build_chat_system_message( - prompt=query_message.prompt, - context_exists=context_exists, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - ) - - history_basemessages, history_token_counts = translate_history_to_basemessages( - history - ) - - # Be sure the context_docs passed to build_chat_user_message - # Is the same as passed in later for extracting citations - user_message, user_tokens = build_chat_user_message( - chat_message=query_message, - prompt=query_message.prompt, - context_docs=context_docs, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - all_doc_useful=all_doc_useful, - ) - - prompt = drop_messages_history_overflow( - system_msg=system_message_or_none, - system_token_count=system_tokens, - history_msgs=history_basemessages, - history_token_counts=history_token_counts, - final_msg=user_message, - final_msg_token_count=user_tokens, - max_allowed_tokens=compute_max_llm_input_tokens(persona), - ) - - # Good Debug/Breakpoint - tokens = llm.stream(prompt) - - yield from extract_citations_from_stream( - tokens, context_docs, doc_id_to_rank_map - ) - - except Exception as e: - logger.exception(f"LLM failed to produce valid chat message, error: {e}") - yield StreamingError(error=str(e)) - - def translate_citations( citations_list: list[CitationInfo], db_docs: list[DbSearchDoc] ) -> dict[int, int]: @@ -154,24 +74,26 @@ def translate_citations( return citation_to_saved_doc_id_map -def stream_chat_message_objects( - new_msg_req: CreateChatMessageRequest, - user: User | None, - db_session: Session, - # Needed to translate persona num_chunks to tokens to the LLM - default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, - default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE, - # For flow with search, don't include as many chunks as possible since we need to leave space - # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks - max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, -) -> Iterator[ +ChatPacketStream = Iterator[ StreamingError | QADocsResponse | LLMRelevanceFilterResponse | ChatMessageDetail | DanswerAnswerPiece | CitationInfo -]: +] + + +def stream_chat_message_objects( + new_msg_req: CreateChatMessageRequest, + user: User | None, + db_session: Session, + # Needed to translate persona num_chunks to tokens to the LLM + default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, + # For flow with search, don't include as many chunks as possible since we need to leave space + # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks + max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, +) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run 2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on @@ -277,10 +199,6 @@ def stream_chat_message_objects( query_message=final_msg, history=history_msgs, llm=llm ) - max_document_tokens = compute_max_document_tokens( - persona=persona, actual_user_input=message_text - ) - rephrased_query = None if reference_doc_ids: identifier_tuples = get_doc_query_identifiers_from_model( @@ -296,64 +214,8 @@ def stream_chat_message_objects( doc_identifiers=identifier_tuples, document_index=document_index, ) - - # truncate the last document if it exceeds the token limit - tokens_per_doc = [ - len( - llm_tokenizer_encode_func( - build_doc_context_str( - semantic_identifier=llm_doc.semantic_identifier, - source_type=llm_doc.source_type, - content=llm_doc.content, - metadata_dict=llm_doc.metadata, - updated_at=llm_doc.updated_at, - ind=ind, - ) - ) - ) - for ind, llm_doc in enumerate(llm_docs) - ] - final_doc_ind = None - total_tokens = 0 - for ind, tokens in enumerate(tokens_per_doc): - total_tokens += tokens - if total_tokens > max_document_tokens: - final_doc_ind = ind - break - if final_doc_ind is not None: - # only allow the final document to get truncated - # if more than that, then the user message is too long - if final_doc_ind != len(tokens_per_doc) - 1: - yield StreamingError( - error="LLM context window exceeded. Please de-select some documents or shorten your query." - ) - return - - final_doc_desired_length = tokens_per_doc[final_doc_ind] - ( - total_tokens - max_document_tokens - ) - # 75 tokens is a reasonable over-estimate of the metadata and title - final_doc_content_length = final_doc_desired_length - 75 - # this could occur if we only have space for the title / metadata - # not ideal, but it's the most reasonable thing to do - # NOTE: the frontend prevents documents from being selected if - # less than 75 tokens are available to try and avoid this situation - # from occuring in the first place - if final_doc_content_length <= 0: - logger.error( - f"Final doc ({llm_docs[final_doc_ind].semantic_identifier}) content " - "length is less than 0. Removing this doc from the final prompt." - ) - llm_docs.pop() - else: - llm_docs[final_doc_ind].content = tokenizer_trim_content( - content=llm_docs[final_doc_ind].content, - desired_length=final_doc_content_length, - tokenizer=llm_tokenizer, - ) - - doc_id_to_rank_map = map_document_id_order( - cast(list[InferenceChunk | LlmDoc], llm_docs) + document_pruning_config = DocumentPruningConfig( + is_manually_selected_docs=True ) # In case the search doc is deleted, just don't include it @@ -393,9 +255,6 @@ def stream_chat_message_objects( top_chunks = search_pipeline.reranked_docs top_docs = chunks_to_search_docs(top_chunks) - # Get ranking of the documents for citation purposes later - doc_id_to_rank_map = map_document_id_order(top_chunks) - reference_db_search_docs = [ create_db_search_doc(server_search_doc=top_doc, db_session=db_session) for top_doc in top_docs @@ -423,41 +282,21 @@ def stream_chat_message_objects( ) yield llm_relevance_filtering_response - # Prep chunks to pass to LLM - num_llm_chunks = ( - persona.num_chunks - if persona.num_chunks is not None - else default_num_chunks + document_pruning_config = DocumentPruningConfig( + max_chunks=int( + persona.num_chunks + if persona.num_chunks is not None + else default_num_chunks + ), + max_window_percentage=max_document_percentage, ) - llm_name = get_default_llm_version()[0] - if persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - llm_max_input_tokens = get_max_input_tokens(model_name=llm_name) - - llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens - - chunk_token_limit = int( - min( - num_llm_chunks * default_chunk_size, - max_document_tokens, - llm_token_based_chunk_lim, - ) - ) - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=search_pipeline.chunk_relevance_list, - token_limit=chunk_token_limit, - llm_tokenizer=llm_tokenizer, - ) - llm_chunks = [top_chunks[i] for i in llm_chunks_indices] - llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks] + llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in top_chunks] else: llm_docs = [] - doc_id_to_rank_map = {} reference_db_search_docs = None + document_pruning_config = DocumentPruningConfig() # Cannot determine these without the LLM step or breaking out early partial_response = partial( @@ -495,33 +334,24 @@ def stream_chat_message_objects( return # LLM prompt building, response capturing, etc. - response_packets = generate_ai_chat_response( - query_message=final_msg, - history=history_msgs, + answer = Answer( + question=final_msg.message, + docs=llm_docs, + answer_style_config=AnswerStyleConfig( + citation_config=CitationConfig( + all_docs_useful=reference_db_search_docs is not None + ), + document_pruning_config=document_pruning_config, + ), + prompt=final_msg.prompt, persona=persona, - context_docs=llm_docs, - doc_id_to_rank_map=doc_id_to_rank_map, - llm=llm, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - all_doc_useful=reference_doc_ids is not None, + message_history=[ + PreviousMessage.from_chat_message(msg) for msg in history_msgs + ], ) + # generator will not include quotes, so we can cast + yield from cast(ChatPacketStream, answer.processed_streamed_output) - # Capture outputs and errors - llm_output = "" - error: str | None = None - citations: list[CitationInfo] = [] - for packet in response_packets: - if isinstance(packet, DanswerAnswerPiece): - token = packet.answer_piece - if token: - llm_output += token - elif isinstance(packet, StreamingError): - error = packet.error - elif isinstance(packet, CitationInfo): - citations.append(packet) - continue - - yield packet except Exception as e: logger.exception(e) @@ -535,16 +365,16 @@ def stream_chat_message_objects( db_citations = None if reference_db_search_docs: db_citations = translate_citations( - citations_list=citations, + citations_list=answer.citations, db_docs=reference_db_search_docs, ) # Saving Gen AI answer and responding with message info gen_ai_response_message = partial_response( - message=llm_output, - token_count=len(llm_tokenizer_encode_func(llm_output)), + message=answer.llm_answer, + token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), citations=db_citations, - error=error, + error=None, ) msg_detail_response = translate_db_message_to_chat_message_detail( diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 1e065dd1da..b3fdb79c88 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -12,7 +12,6 @@ from slack_sdk.errors import SlackApiError from slack_sdk.models.blocks import DividerBlock from sqlalchemy.orm import Session -from danswer.chat.chat_utils import compute_max_document_tokens from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER @@ -39,6 +38,7 @@ from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_max_input_tokens diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py new file mode 100644 index 0000000000..76d399d8bd --- /dev/null +++ b/backend/danswer/llm/answering/answer.py @@ -0,0 +1,176 @@ +from collections.abc import Iterator +from typing import cast + +from langchain.schema.messages import BaseMessage + +from danswer.chat.models import AnswerQuestionPossibleReturn +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE +from danswer.configs.chat_configs import QA_TIMEOUT +from danswer.db.models import Persona +from danswer.db.models import Prompt +from danswer.llm.answering.doc_pruning import prune_documents +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import StreamProcessor +from danswer.llm.answering.prompts.citations_prompt import build_citations_prompt +from danswer.llm.answering.prompts.quotes_prompt import ( + build_quotes_prompt, +) +from danswer.llm.answering.stream_processing.citation_processing import ( + build_citation_processor, +) +from danswer.llm.answering.stream_processing.quotes_processing import ( + build_quotes_processor, +) +from danswer.llm.factory import get_default_llm +from danswer.llm.utils import get_default_llm_tokenizer + + +def _get_stream_processor( + docs: list[LlmDoc], answer_style_configs: AnswerStyleConfig +) -> StreamProcessor: + if answer_style_configs.citation_config: + return build_citation_processor( + context_docs=docs, + ) + if answer_style_configs.quotes_config: + return build_quotes_processor( + context_docs=docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak") + ) + + raise RuntimeError("Not implemented yet") + + +class Answer: + def __init__( + self, + question: str, + docs: list[LlmDoc], + answer_style_config: AnswerStyleConfig, + prompt: Prompt, + persona: Persona, + # must be the same length as `docs`. If None, all docs are considered "relevant" + doc_relevance_list: list[bool] | None = None, + message_history: list[PreviousMessage] | None = None, + single_message_history: str | None = None, + timeout: int = QA_TIMEOUT, + ) -> None: + if single_message_history and message_history: + raise ValueError( + "Cannot provide both `message_history` and `single_message_history`" + ) + + self.question = question + self.docs = docs + self.doc_relevance_list = doc_relevance_list + self.message_history = message_history or [] + # used for QA flow where we only want to send a single message + self.single_message_history = single_message_history + + self.answer_style_config = answer_style_config + + self.llm = get_default_llm( + gen_ai_model_version_override=persona.llm_model_version_override, + timeout=timeout, + ) + self.llm_tokenizer = get_default_llm_tokenizer() + + self.prompt = prompt + self.persona = persona + + self.process_stream_fn = _get_stream_processor(docs, answer_style_config) + + self._final_prompt: list[BaseMessage] | None = None + + self._pruned_docs: list[LlmDoc] | None = None + + self._streamed_output: list[str] | None = None + self._processed_stream: list[AnswerQuestionPossibleReturn] | None = None + + @property + def pruned_docs(self) -> list[LlmDoc]: + if self._pruned_docs is not None: + return self._pruned_docs + + self._pruned_docs = prune_documents( + docs=self.docs, + doc_relevance_list=self.doc_relevance_list, + persona=self.persona, + question=self.question, + document_pruning_config=self.answer_style_config.document_pruning_config, + ) + return self._pruned_docs + + @property + def final_prompt(self) -> list[BaseMessage]: + if self._final_prompt is not None: + return self._final_prompt + + if self.answer_style_config.citation_config: + self._final_prompt = build_citations_prompt( + question=self.question, + message_history=self.message_history, + persona=self.persona, + prompt=self.prompt, + context_docs=self.pruned_docs, + all_doc_useful=self.answer_style_config.citation_config.all_docs_useful, + llm_tokenizer_encode_func=self.llm_tokenizer.encode, + history_message=self.single_message_history or "", + ) + elif self.answer_style_config.quotes_config: + self._final_prompt = build_quotes_prompt( + question=self.question, + context_docs=self.pruned_docs, + history_str=self.single_message_history or "", + prompt=self.prompt, + ) + + return cast(list[BaseMessage], self._final_prompt) + + @property + def raw_streamed_output(self) -> Iterator[str]: + if self._streamed_output is not None: + yield from self._streamed_output + return + + streamed_output = [] + for message in self.llm.stream(self.final_prompt): + streamed_output.append(message) + yield message + + self._streamed_output = streamed_output + + @property + def processed_streamed_output(self) -> AnswerQuestionStreamReturn: + if self._processed_stream is not None: + yield from self._processed_stream + return + + processed_stream = [] + for processed_packet in self.process_stream_fn(self.raw_streamed_output): + processed_stream.append(processed_packet) + yield processed_packet + + self._processed_stream = processed_stream + + @property + def llm_answer(self) -> str: + answer = "" + for packet in self.processed_streamed_output: + if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece: + answer += packet.answer_piece + + return answer + + @property + def citations(self) -> list[CitationInfo]: + citations: list[CitationInfo] = [] + for packet in self.processed_streamed_output: + if isinstance(packet, CitationInfo): + citations.append(packet) + + return citations diff --git a/backend/danswer/llm/answering/doc_pruning.py b/backend/danswer/llm/answering/doc_pruning.py new file mode 100644 index 0000000000..29c913673d --- /dev/null +++ b/backend/danswer/llm/answering/doc_pruning.py @@ -0,0 +1,205 @@ +from copy import deepcopy +from typing import TypeVar + +from danswer.chat.models import ( + LlmDoc, +) +from danswer.configs.constants import IGNORE_FOR_QA +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE +from danswer.db.models import Persona +from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import tokenizer_trim_content +from danswer.prompts.prompt_utils import build_doc_context_str +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +T = TypeVar("T", bound=LlmDoc | InferenceChunk) + +_METADATA_TOKEN_ESTIMATE = 75 + + +class PruningError(Exception): + pass + + +def _compute_limit( + persona: Persona, + question: str, + max_chunks: int | None, + max_window_percentage: float | None, + max_tokens: int | None, +) -> int: + llm_max_document_tokens = compute_max_document_tokens( + persona=persona, actual_user_input=question + ) + + window_percentage_based_limit = ( + max_window_percentage * llm_max_document_tokens + if max_window_percentage + else None + ) + chunk_count_based_limit = ( + max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None + ) + + limit_options = [ + lim + for lim in [ + window_percentage_based_limit, + chunk_count_based_limit, + max_tokens, + llm_max_document_tokens, + ] + if lim + ] + return int(min(limit_options)) + + +def reorder_docs( + docs: list[T], + doc_relevance_list: list[bool] | None, +) -> list[T]: + if doc_relevance_list is None: + return docs + + reordered_docs: list[T] = [] + if doc_relevance_list is not None: + for selection_target in [True, False]: + for doc, is_relevant in zip(docs, doc_relevance_list): + if is_relevant == selection_target: + reordered_docs.append(doc) + return reordered_docs + + +def _remove_docs_to_ignore(docs: list[LlmDoc]) -> list[LlmDoc]: + return [doc for doc in docs if not doc.metadata.get(IGNORE_FOR_QA)] + + +def _apply_pruning( + docs: list[LlmDoc], + doc_relevance_list: list[bool] | None, + token_limit: int, + is_manually_selected_docs: bool, +) -> list[LlmDoc]: + llm_tokenizer = get_default_llm_tokenizer() + docs = deepcopy(docs) # don't modify in place + + # re-order docs with all the "relevant" docs at the front + docs = reorder_docs(docs=docs, doc_relevance_list=doc_relevance_list) + # remove docs that are explicitly marked as not for QA + docs = _remove_docs_to_ignore(docs=docs) + + tokens_per_doc: list[int] = [] + final_doc_ind = None + total_tokens = 0 + for ind, llm_doc in enumerate(docs): + doc_tokens = len( + llm_tokenizer.encode( + build_doc_context_str( + semantic_identifier=llm_doc.semantic_identifier, + source_type=llm_doc.source_type, + content=llm_doc.content, + metadata_dict=llm_doc.metadata, + updated_at=llm_doc.updated_at, + ind=ind, + ) + ) + ) + # if chunks, truncate chunks that are way too long + # this can happen if the embedding model tokenizer is different + # than the LLM tokenizer + if ( + not is_manually_selected_docs + and doc_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE + ): + logger.warning( + "Found more tokens in chunk than expected, " + "likely mismatch between embedding and LLM tokenizers. Trimming content..." + ) + llm_doc.content = tokenizer_trim_content( + content=llm_doc.content, + desired_length=DOC_EMBEDDING_CONTEXT_SIZE, + tokenizer=llm_tokenizer, + ) + doc_tokens = DOC_EMBEDDING_CONTEXT_SIZE + tokens_per_doc.append(doc_tokens) + total_tokens += doc_tokens + if total_tokens > token_limit: + final_doc_ind = ind + break + + if final_doc_ind is not None: + if is_manually_selected_docs: + # for document selection, only allow the final document to get truncated + # if more than that, then the user message is too long + if final_doc_ind != len(docs) - 1: + raise PruningError( + "LLM context window exceeded. Please de-select some documents or shorten your query." + ) + + final_doc_desired_length = tokens_per_doc[final_doc_ind] - ( + total_tokens - token_limit + ) + final_doc_content_length = ( + final_doc_desired_length - _METADATA_TOKEN_ESTIMATE + ) + # this could occur if we only have space for the title / metadata + # not ideal, but it's the most reasonable thing to do + # NOTE: the frontend prevents documents from being selected if + # less than 75 tokens are available to try and avoid this situation + # from occuring in the first place + if final_doc_content_length <= 0: + logger.error( + f"Final doc ({docs[final_doc_ind].semantic_identifier}) content " + "length is less than 0. Removing this doc from the final prompt." + ) + docs.pop() + else: + docs[final_doc_ind].content = tokenizer_trim_content( + content=docs[final_doc_ind].content, + desired_length=final_doc_content_length, + tokenizer=llm_tokenizer, + ) + else: + # for regular search, don't truncate the final document unless it's the only one + if final_doc_ind != 0: + docs = docs[:final_doc_ind] + else: + docs[0].content = tokenizer_trim_content( + content=docs[0].content, + desired_length=token_limit - _METADATA_TOKEN_ESTIMATE, + tokenizer=llm_tokenizer, + ) + docs = [docs[0]] + + return docs + + +def prune_documents( + docs: list[LlmDoc], + doc_relevance_list: list[bool] | None, + persona: Persona, + question: str, + document_pruning_config: DocumentPruningConfig, +) -> list[LlmDoc]: + if doc_relevance_list is not None: + assert len(docs) == len(doc_relevance_list) + + doc_token_limit = _compute_limit( + persona=persona, + question=question, + max_chunks=document_pruning_config.max_chunks, + max_window_percentage=document_pruning_config.max_window_percentage, + max_tokens=document_pruning_config.max_tokens, + ) + return _apply_pruning( + docs=docs, + doc_relevance_list=doc_relevance_list, + token_limit=doc_token_limit, + is_manually_selected_docs=document_pruning_config.is_manually_selected_docs, + ) diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py new file mode 100644 index 0000000000..360535ac80 --- /dev/null +++ b/backend/danswer/llm/answering/models.py @@ -0,0 +1,77 @@ +from collections.abc import Callable +from collections.abc import Iterator +from typing import Any +from typing import TYPE_CHECKING + +from pydantic import BaseModel +from pydantic import Field +from pydantic import root_validator + +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.configs.constants import MessageType + +if TYPE_CHECKING: + from danswer.db.models import ChatMessage + + +StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] + + +class PreviousMessage(BaseModel): + """Simplified version of `ChatMessage`""" + + message: str + token_count: int + message_type: MessageType + + @classmethod + def from_chat_message(cls, chat_message: "ChatMessage") -> "PreviousMessage": + return cls( + message=chat_message.message, + token_count=chat_message.token_count, + message_type=chat_message.message_type, + ) + + +class DocumentPruningConfig(BaseModel): + max_chunks: int | None = None + max_window_percentage: float | None = None + max_tokens: int | None = None + # different pruning behavior is expected when the + # user manually selects documents they want to chat with + # e.g. we don't want to truncate each document to be no more + # than one chunk long + is_manually_selected_docs: bool = False + + +class CitationConfig(BaseModel): + all_docs_useful: bool = False + + +class QuotesConfig(BaseModel): + pass + + +class AnswerStyleConfig(BaseModel): + citation_config: CitationConfig | None = None + quotes_config: QuotesConfig | None = None + document_pruning_config: DocumentPruningConfig = Field( + default_factory=DocumentPruningConfig + ) + + @root_validator + def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]: + citation_config = values.get("citation_config") + quotes_config = values.get("quotes_config") + + if citation_config is None and quotes_config is None: + raise ValueError( + "One of `citation_config` or `quotes_config` must be provided" + ) + + if citation_config is not None and quotes_config is not None: + raise ValueError( + "Only one of `citation_config` or `quotes_config` must be provided" + ) + + return values diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py new file mode 100644 index 0000000000..61c42c19c7 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -0,0 +1,281 @@ +from collections.abc import Callable +from functools import lru_cache +from typing import cast + +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage + +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS +from danswer.db.chat import get_default_prompt +from danswer.db.models import Persona +from danswer.db.models import Prompt +from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.utils import check_number_of_tokens +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import get_default_llm_version +from danswer.llm.utils import get_max_input_tokens +from danswer.llm.utils import translate_history_to_basemessages +from danswer.prompts.chat_prompts import ADDITIONAL_INFO +from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT +from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT +from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT +from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT +from danswer.prompts.direct_qa_prompts import ( + CITATIONS_PROMPT, +) +from danswer.prompts.prompt_utils import build_complete_context_str +from danswer.prompts.prompt_utils import build_task_prompt_reminders +from danswer.prompts.prompt_utils import get_current_llm_day_time +from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT +from danswer.prompts.token_counts import ( + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, +) +from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT +from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT +from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT + + +_PER_MESSAGE_TOKEN_BUFFER = 7 + + +def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: + """From the back, find the index of the last element to include + before the list exceeds the maximum""" + running_sum = 0 + + last_ind = 0 + for i in range(len(lst) - 1, -1, -1): + running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER + if running_sum > max_prompt_tokens: + last_ind = i + 1 + break + if last_ind >= len(lst): + raise ValueError("Last message alone is too large!") + return last_ind + + +def drop_messages_history_overflow( + system_msg: BaseMessage | None, + system_token_count: int, + history_msgs: list[BaseMessage], + history_token_counts: list[int], + final_msg: BaseMessage, + final_msg_token_count: int, + max_allowed_tokens: int, +) -> list[BaseMessage]: + """As message history grows, messages need to be dropped starting from the furthest in the past. + The System message should be kept if at all possible and the latest user input which is inserted in the + prompt template must be included""" + if len(history_msgs) != len(history_token_counts): + # This should never happen + raise ValueError("Need exactly 1 token count per message for tracking overflow") + + prompt: list[BaseMessage] = [] + + # Start dropping from the history if necessary + all_tokens = history_token_counts + [system_token_count, final_msg_token_count] + ind_prev_msg_start = find_last_index( + all_tokens, max_prompt_tokens=max_allowed_tokens + ) + + if system_msg and ind_prev_msg_start <= len(history_msgs): + prompt.append(system_msg) + + prompt.extend(history_msgs[ind_prev_msg_start:]) + + prompt.append(final_msg) + + return prompt + + +def get_prompt_tokens(prompt: Prompt) -> int: + # Note: currently custom prompts do not allow datetime aware, only default prompts + return ( + check_number_of_tokens(prompt.system_prompt) + + check_number_of_tokens(prompt.task_prompt) + + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT + + CITATION_STATEMENT_TOKEN_CNT + + CITATION_REMINDER_TOKEN_CNT + + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) + + (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0) + ) + + +# buffer just to be safe so that we don't overflow the token limit due to +# a small miscalculation +_MISC_BUFFER = 40 + + +def compute_max_document_tokens( + persona: Persona, + actual_user_input: str | None = None, + max_llm_token_override: int | None = None, +) -> int: + """Estimates the number of tokens available for context documents. Formula is roughly: + + ( + model_context_window - reserved_output_tokens - prompt_tokens + - (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe) + ) + + The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g. + if we're trying to determine if the user should be able to select another document) then we just set an + arbitrary "upper bound". + """ + llm_name = get_default_llm_version()[0] + if persona.llm_model_version_override: + llm_name = persona.llm_model_version_override + + # if we can't find a number of tokens, just assume some common default + max_input_tokens = ( + max_llm_token_override + if max_llm_token_override + else get_max_input_tokens(model_name=llm_name) + ) + if persona.prompts: + # TODO this may not always be the first prompt + prompt_tokens = get_prompt_tokens(persona.prompts[0]) + else: + prompt_tokens = get_prompt_tokens(get_default_prompt()) + + user_input_tokens = ( + check_number_of_tokens(actual_user_input) + if actual_user_input is not None + else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS + ) + + return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER + + +def compute_max_llm_input_tokens(persona: Persona) -> int: + """Maximum tokens allows in the input to the LLM (of any type).""" + llm_name = get_default_llm_version()[0] + if persona.llm_model_version_override: + llm_name = persona.llm_model_version_override + + input_tokens = get_max_input_tokens(model_name=llm_name) + return input_tokens - _MISC_BUFFER + + +@lru_cache() +def build_system_message( + prompt: Prompt, + context_exists: bool, + llm_tokenizer_encode_func: Callable, + citation_line: str = REQUIRE_CITATION_STATEMENT, + no_citation_line: str = NO_CITATION_STATEMENT, +) -> tuple[SystemMessage | None, int]: + system_prompt = prompt.system_prompt.strip() + if prompt.include_citations: + if context_exists: + system_prompt += citation_line + else: + system_prompt += no_citation_line + if prompt.datetime_aware: + if system_prompt: + system_prompt += ADDITIONAL_INFO.format( + datetime_info=get_current_llm_day_time() + ) + else: + system_prompt = get_current_llm_day_time() + + if not system_prompt: + return None, 0 + + token_count = len(llm_tokenizer_encode_func(system_prompt)) + system_msg = SystemMessage(content=system_prompt) + + return system_msg, token_count + + +def build_user_message( + question: str, + prompt: Prompt, + context_docs: list[LlmDoc] | list[InferenceChunk], + all_doc_useful: bool, + history_message: str, +) -> tuple[HumanMessage, int]: + llm_tokenizer = get_default_llm_tokenizer() + llm_tokenizer_encode_func = cast(Callable[[str], list[int]], llm_tokenizer.encode) + + if not context_docs: + # Simpler prompt for cases where there is no context + user_prompt = ( + CHAT_USER_CONTEXT_FREE_PROMPT.format( + task_prompt=prompt.task_prompt, user_query=question + ) + if prompt.task_prompt + else question + ) + user_prompt = user_prompt.strip() + token_count = len(llm_tokenizer_encode_func(user_prompt)) + user_msg = HumanMessage(content=user_prompt) + return user_msg, token_count + + context_docs_str = build_complete_context_str(context_docs) + optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT + + task_prompt_with_reminder = build_task_prompt_reminders(prompt) + + user_prompt = CITATIONS_PROMPT.format( + optional_ignore_statement=optional_ignore, + context_docs_str=context_docs_str, + task_prompt=task_prompt_with_reminder, + user_query=question, + history_block=history_message, + ) + + user_prompt = user_prompt.strip() + token_count = len(llm_tokenizer_encode_func(user_prompt)) + user_msg = HumanMessage(content=user_prompt) + + return user_msg, token_count + + +def build_citations_prompt( + question: str, + message_history: list[PreviousMessage], + persona: Persona, + prompt: Prompt, + context_docs: list[LlmDoc] | list[InferenceChunk], + all_doc_useful: bool, + history_message: str, + llm_tokenizer_encode_func: Callable, +) -> list[BaseMessage]: + context_exists = len(context_docs) > 0 + + system_message_or_none, system_tokens = build_system_message( + prompt=prompt, + context_exists=context_exists, + llm_tokenizer_encode_func=llm_tokenizer_encode_func, + ) + + history_basemessages, history_token_counts = translate_history_to_basemessages( + message_history + ) + + # Be sure the context_docs passed to build_chat_user_message + # Is the same as passed in later for extracting citations + user_message, user_tokens = build_user_message( + question=question, + prompt=prompt, + context_docs=context_docs, + all_doc_useful=all_doc_useful, + history_message=history_message, + ) + + final_prompt_msgs = drop_messages_history_overflow( + system_msg=system_message_or_none, + system_token_count=system_tokens, + history_msgs=history_basemessages, + history_token_counts=history_token_counts, + final_msg=user_message, + final_msg_token_count=user_tokens, + max_allowed_tokens=compute_max_llm_input_tokens(persona), + ) + + return final_prompt_msgs diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/llm/answering/prompts/quotes_prompt.py new file mode 100644 index 0000000000..c9e145e810 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/quotes_prompt.py @@ -0,0 +1,88 @@ +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage + +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE +from danswer.db.models import Prompt +from danswer.indexing.models import InferenceChunk +from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK +from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK +from danswer.prompts.direct_qa_prompts import JSON_PROMPT +from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT +from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT +from danswer.prompts.prompt_utils import build_complete_context_str + + +def _build_weak_llm_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: Prompt, + use_language_hint: bool, +) -> list[BaseMessage]: + """Since Danswer supports a variety of LLMs, this less demanding prompt is provided + as an option to use with weaker LLMs such as small version, low float precision, quantized, + or distilled models. It only uses one context document and has very weak requirements of + output format. + """ + context_block = "" + if context_docs: + context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content) + + prompt_str = WEAK_LLM_PROMPT.format( + system_prompt=prompt.system_prompt, + context_block=context_block, + task_prompt=prompt.task_prompt, + user_query=question, + ) + return [HumanMessage(content=prompt_str)] + + +def _build_strong_llm_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: Prompt, + use_language_hint: bool, +) -> list[BaseMessage]: + context_block = "" + if context_docs: + context_docs_str = build_complete_context_str(context_docs) + context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str) + + history_block = "" + if history_str: + history_block = HISTORY_BLOCK.format(history_str=history_str) + + full_prompt = JSON_PROMPT.format( + system_prompt=prompt.system_prompt, + context_block=context_block, + history_block=history_block, + task_prompt=prompt.task_prompt, + user_query=question, + language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "", + ).strip() + return [HumanMessage(content=full_prompt)] + + +def build_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: Prompt, + use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), +) -> list[BaseMessage]: + prompt_builder = ( + _build_weak_llm_quotes_prompt + if QA_PROMPT_OVERRIDE == "weak" + else _build_strong_llm_quotes_prompt + ) + + return prompt_builder( + question=question, + context_docs=context_docs, + history_str=history_str, + prompt=prompt, + use_language_hint=use_language_hint, + ) diff --git a/backend/danswer/llm/answering/prompts/utils.py b/backend/danswer/llm/answering/prompts/utils.py new file mode 100644 index 0000000000..bcc8b89181 --- /dev/null +++ b/backend/danswer/llm/answering/prompts/utils.py @@ -0,0 +1,20 @@ +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT +from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT + + +def build_dummy_prompt( + system_prompt: str, task_prompt: str, retrieval_disabled: bool +) -> str: + if retrieval_disabled: + return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() + + return PARAMATERIZED_PROMPT.format( + context_docs_str="", + user_query="", + system_prompt=system_prompt, + task_prompt=task_prompt, + ).strip() diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py new file mode 100644 index 0000000000..a26021835c --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -0,0 +1,126 @@ +import re +from collections.abc import Iterator + +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import STOP_STREAM_PAT +from danswer.llm.answering.models import StreamProcessor +from danswer.llm.answering.stream_processing.utils import map_document_id_order +from danswer.prompts.constants import TRIPLE_BACKTICK +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +def in_code_block(llm_text: str) -> bool: + count = llm_text.count(TRIPLE_BACKTICK) + return count % 2 != 0 + + +def extract_citations_from_stream( + tokens: Iterator[str], + context_docs: list[LlmDoc], + doc_id_to_rank_map: dict[str, int], + stop_stream: str | None = STOP_STREAM_PAT, +) -> Iterator[DanswerAnswerPiece | CitationInfo]: + llm_out = "" + max_citation_num = len(context_docs) + curr_segment = "" + prepend_bracket = False + cited_inds = set() + hold = "" + for raw_token in tokens: + if stop_stream: + next_hold = hold + raw_token + + if stop_stream in next_hold: + break + + if next_hold == stop_stream[: len(next_hold)]: + hold = next_hold + continue + + token = next_hold + hold = "" + else: + token = raw_token + + # Special case of [1][ where ][ is a single token + # This is where the model attempts to do consecutive citations like [1][2] + if prepend_bracket: + curr_segment += "[" + curr_segment + prepend_bracket = False + + curr_segment += token + llm_out += token + + possible_citation_pattern = r"(\[\d*$)" # [1, [, etc + possible_citation_found = re.search(possible_citation_pattern, curr_segment) + + citation_pattern = r"\[(\d+)\]" # [1], [2] etc + citation_found = re.search(citation_pattern, curr_segment) + + if citation_found and not in_code_block(llm_out): + numerical_value = int(citation_found.group(1)) + if 1 <= numerical_value <= max_citation_num: + context_llm_doc = context_docs[ + numerical_value - 1 + ] # remove 1 index offset + + link = context_llm_doc.link + target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id] + + # Use the citation number for the document's rank in + # the search (or selected docs) results + curr_segment = re.sub( + rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment + ) + + if target_citation_num not in cited_inds: + cited_inds.add(target_citation_num) + yield CitationInfo( + citation_num=target_citation_num, + document_id=context_llm_doc.document_id, + ) + + if link: + curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) + curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) + + # In case there's another open bracket like [1][, don't want to match this + possible_citation_found = None + + # if we see "[", but haven't seen the right side, hold back - this may be a + # citation that needs to be replaced with a link + if possible_citation_found: + continue + + # Special case with back to back citations [1][2] + if curr_segment and curr_segment[-1] == "[": + curr_segment = curr_segment[:-1] + prepend_bracket = True + + yield DanswerAnswerPiece(answer_piece=curr_segment) + curr_segment = "" + + if curr_segment: + if prepend_bracket: + yield DanswerAnswerPiece(answer_piece="[" + curr_segment) + else: + yield DanswerAnswerPiece(answer_piece=curr_segment) + + +def build_citation_processor( + context_docs: list[LlmDoc], +) -> StreamProcessor: + def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + yield from extract_citations_from_stream( + tokens=tokens, + context_docs=context_docs, + doc_id_to_rank_map=map_document_id_order(context_docs), + ) + + return stream_processor diff --git a/backend/danswer/llm/answering/stream_processing/quotes_processing.py b/backend/danswer/llm/answering/stream_processing/quotes_processing.py new file mode 100644 index 0000000000..daa966e694 --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/quotes_processing.py @@ -0,0 +1,282 @@ +import math +import re +from collections.abc import Callable +from collections.abc import Generator +from collections.abc import Iterator +from json import JSONDecodeError +from typing import Optional + +import regex + +from danswer.chat.models import AnswerQuestionStreamReturn +from danswer.chat.models import DanswerAnswer +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import DanswerQuote +from danswer.chat.models import DanswerQuotes +from danswer.chat.models import LlmDoc +from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT +from danswer.indexing.models import InferenceChunk +from danswer.prompts.constants import ANSWER_PAT +from danswer.prompts.constants import QUOTE_PAT +from danswer.prompts.constants import UNCERTAINTY_PAT +from danswer.utils.logger import setup_logger +from danswer.utils.text_processing import clean_model_quote +from danswer.utils.text_processing import clean_up_code_blocks +from danswer.utils.text_processing import extract_embedded_json +from danswer.utils.text_processing import shared_precompare_cleanup + + +logger = setup_logger() + + +def _extract_answer_quotes_freeform( + answer_raw: str, +) -> tuple[Optional[str], Optional[list[str]]]: + """Splits the model output into an Answer and 0 or more Quote sections. + Splits by the Quote pattern, if not exist then assume it's all answer and no quotes + """ + # If no answer section, don't care about the quote + if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()): + return None, None + + # Sometimes model regenerates the Answer: pattern despite it being provided in the prompt + if answer_raw.lower().startswith(ANSWER_PAT.lower()): + answer_raw = answer_raw[len(ANSWER_PAT) :] + + # Accept quote sections starting with the lower case version + answer_raw = answer_raw.replace( + f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}" + ) # Just in case model unreliable + + sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw) + sections_clean = [ + str(section).strip() for section in sections if str(section).strip() + ] + if not sections_clean: + return None, None + + answer = str(sections_clean[0]) + if len(sections) == 1: + return answer, None + return answer, sections_clean[1:] + + +def _extract_answer_quotes_json( + answer_dict: dict[str, str | list[str]] +) -> tuple[Optional[str], Optional[list[str]]]: + answer_dict = {k.lower(): v for k, v in answer_dict.items()} + answer = str(answer_dict.get("answer")) + quotes = answer_dict.get("quotes") or answer_dict.get("quote") + if isinstance(quotes, str): + quotes = [quotes] + return answer, quotes + + +def _extract_answer_json(raw_model_output: str) -> dict: + try: + answer_json = extract_embedded_json(raw_model_output) + except (ValueError, JSONDecodeError): + # LLMs get confused when handling the list in the json. Sometimes it doesn't attend + # enough to the previous { token so it just ends the list of quotes and stops there + # here, we add logic to try to fix this LLM error. + answer_json = extract_embedded_json(raw_model_output + "}") + + if "answer" not in answer_json: + raise ValueError("Model did not output an answer as expected.") + + return answer_json + + +def match_quotes_to_docs( + quotes: list[str], + docs: list[LlmDoc] | list[InferenceChunk], + max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT, + fuzzy_search: bool = False, + prefix_only_length: int = 100, +) -> DanswerQuotes: + danswer_quotes: list[DanswerQuote] = [] + for quote in quotes: + max_edits = math.ceil(float(len(quote)) * max_error_percent) + + for doc in docs: + if not doc.source_links: + continue + + quote_clean = shared_precompare_cleanup( + clean_model_quote(quote, trim_length=prefix_only_length) + ) + chunk_clean = shared_precompare_cleanup(doc.content) + + # Finding the offset of the quote in the plain text + if fuzzy_search: + re_search_str = ( + r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}" + ) + found = regex.search(re_search_str, chunk_clean) + if not found: + continue + offset = found.span()[0] + else: + if quote_clean not in chunk_clean: + continue + offset = chunk_clean.index(quote_clean) + + # Extracting the link from the offset + curr_link = None + for link_offset, link in doc.source_links.items(): + # Should always find one because offset is at least 0 and there + # must be a 0 link_offset + if int(link_offset) <= offset: + curr_link = link + else: + break + + danswer_quotes.append( + DanswerQuote( + quote=quote, + document_id=doc.document_id, + link=curr_link, + source_type=doc.source_type, + semantic_identifier=doc.semantic_identifier, + blurb=doc.blurb, + ) + ) + break + + return DanswerQuotes(quotes=danswer_quotes) + + +def separate_answer_quotes( + answer_raw: str, is_json_prompt: bool = False +) -> tuple[Optional[str], Optional[list[str]]]: + """Takes in a raw model output and pulls out the answer and the quotes sections.""" + if is_json_prompt: + model_raw_json = _extract_answer_json(answer_raw) + return _extract_answer_quotes_json(model_raw_json) + + return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw)) + + +def process_answer( + answer_raw: str, + docs: list[LlmDoc], + is_json_prompt: bool = True, +) -> tuple[DanswerAnswer, DanswerQuotes]: + """Used (1) in the non-streaming case to process the model output + into an Answer and Quotes AND (2) after the complete streaming response + has been received to process the model output into an Answer and Quotes.""" + answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt) + if answer == UNCERTAINTY_PAT or not answer: + if answer == UNCERTAINTY_PAT: + logger.debug("Answer matched UNCERTAINTY_PAT") + else: + logger.debug("No answer extracted from raw output") + return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) + + logger.info(f"Answer: {answer}") + if not quote_strings: + logger.debug("No quotes extracted from raw output") + return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[]) + logger.info(f"All quotes (including unmatched): {quote_strings}") + quotes = match_quotes_to_docs(quote_strings, docs) + logger.debug(f"Final quotes: {quotes}") + + return DanswerAnswer(answer=answer), quotes + + +def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: + next_token = next_token.replace('\\"', "") + # If the previous character is an escape token, don't consider the first character of next_token + # This does not work if it's an escaped escape sign before the " but this is rare, not worth handling + if answer_so_far and answer_so_far[-1] == "\\": + next_token = next_token[1:] + if '"' in next_token: + return True + return False + + +def _extract_quotes_from_completed_token_stream( + model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True +) -> DanswerQuotes: + answer, quotes = process_answer(model_output, context_docs, is_json_prompt) + if answer: + logger.info(answer) + elif model_output: + logger.warning("Answer extraction from model output failed.") + + return quotes + + +def process_model_tokens( + tokens: Iterator[str], + context_docs: list[LlmDoc], + is_json_prompt: bool = True, +) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: + """Used in the streaming case to process the model output + into an Answer and Quotes + + Yields Answer tokens back out in a dict for streaming to frontend + When Answer section ends, yields dict with answer_finished key + Collects all the tokens at the end to form the complete model output""" + quote_pat = f"\n{QUOTE_PAT}" + # Sometimes worse model outputs new line instead of : + quote_loose = f"\n{quote_pat[:-1]}\n" + # Sometime model outputs two newlines before quote section + quote_pat_full = f"\n{quote_pat}" + model_output: str = "" + found_answer_start = False if is_json_prompt else True + found_answer_end = False + hold_quote = "" + for token in tokens: + model_previous = model_output + model_output += token + + if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output): + # Note, if the token that completes the pattern has additional text, for example if the token is "? + # Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the + # event that the model outputs the UNCERTAINTY_PAT + found_answer_start = True + + # Prevent heavy cases of hallucinations where model is not even providing a json until later + if is_json_prompt and len(model_output) > 40: + logger.warning("LLM did not produce json as prompted") + found_answer_end = True + + continue + + if found_answer_start and not found_answer_end: + if is_json_prompt and _stream_json_answer_end(model_previous, token): + found_answer_end = True + yield DanswerAnswerPiece(answer_piece=None) + continue + elif not is_json_prompt: + if quote_pat in hold_quote + token or quote_loose in hold_quote + token: + found_answer_end = True + yield DanswerAnswerPiece(answer_piece=None) + continue + if hold_quote + token in quote_pat_full: + hold_quote += token + continue + yield DanswerAnswerPiece(answer_piece=hold_quote + token) + hold_quote = "" + + logger.debug(f"Raw Model QnA Output: {model_output}") + + yield _extract_quotes_from_completed_token_stream( + model_output=model_output, + context_docs=context_docs, + is_json_prompt=is_json_prompt, + ) + + +def build_quotes_processor( + context_docs: list[LlmDoc], is_json_prompt: bool +) -> Callable[[Iterator[str]], AnswerQuestionStreamReturn]: + def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn: + yield from process_model_tokens( + tokens=tokens, + context_docs=context_docs, + is_json_prompt=is_json_prompt, + ) + + return stream_processor diff --git a/backend/danswer/llm/answering/stream_processing/utils.py b/backend/danswer/llm/answering/stream_processing/utils.py new file mode 100644 index 0000000000..1ddcdf605e --- /dev/null +++ b/backend/danswer/llm/answering/stream_processing/utils.py @@ -0,0 +1,17 @@ +from collections.abc import Sequence + +from danswer.chat.models import LlmDoc +from danswer.indexing.models import InferenceChunk + + +def map_document_id_order( + chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True +) -> dict[str, int]: + order_mapping = {} + current = 1 if one_indexed else 0 + for chunk in chunks: + if chunk.document_id not in order_mapping: + order_mapping[chunk.document_id] = current + current += 1 + + return order_mapping diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index f36f285461..c07b708bb5 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -33,6 +33,7 @@ from danswer.db.models import ChatMessage from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM from danswer.utils.logger import setup_logger @@ -114,7 +115,9 @@ def tokenizer_trim_chunks( return new_chunks -def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage: +def translate_danswer_msg_to_langchain( + msg: ChatMessage | PreviousMessage, +) -> BaseMessage: if msg.message_type == MessageType.SYSTEM: raise ValueError("System messages are not currently part of history") if msg.message_type == MessageType.ASSISTANT: @@ -126,7 +129,7 @@ def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage: def translate_history_to_basemessages( - history: list[ChatMessage], + history: list[ChatMessage] | list[PreviousMessage], ) -> tuple[list[BaseMessage], list[int]]: history_basemessages = [ translate_danswer_msg_to_langchain(msg) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index db5ef6f0f9..e863f4ac09 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -1,54 +1,37 @@ -import itertools from collections.abc import Callable from collections.abc import Iterator -from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage from sqlalchemy.orm import Session -from danswer.chat.chat_utils import build_chat_system_message -from danswer.chat.chat_utils import compute_max_document_tokens -from danswer.chat.chat_utils import extract_citations_from_stream -from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.chat_utils import llm_doc_from_inference_chunk -from danswer.chat.chat_utils import map_document_id_order from danswer.chat.chat_utils import reorganize_citations from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerContext from danswer.chat.models import DanswerContexts from danswer.chat.models import DanswerQuotes -from danswer.chat.models import LLMMetricsContainer from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.constants import MessageType -from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.chat import create_chat_session from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_or_create_root_message -from danswer.db.chat import get_persona_by_id from danswer.db.chat import get_prompt_by_id from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.engine import get_session_context_manager -from danswer.db.models import Prompt from danswer.db.models import User -from danswer.indexing.models import InferenceChunk -from danswer.llm.factory import get_default_llm +from danswer.llm.answering.answer import Answer +from danswer.llm.answering.models import AnswerStyleConfig +from danswer.llm.answering.models import CitationConfig +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import QuotesConfig from danswer.llm.utils import get_default_llm_token_encode -from danswer.llm.utils import get_default_llm_tokenizer -from danswer.one_shot_answer.factory import get_question_answer_model from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import OneShotQAResponse from danswer.one_shot_answer.models import QueryRephrase -from danswer.one_shot_answer.models import ThreadMessage -from danswer.one_shot_answer.qa_block import no_gen_ai_response from danswer.one_shot_answer.qa_utils import combine_message_thread -from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT -from danswer.prompts.prompt_utils import build_complete_context_str -from danswer.prompts.prompt_utils import build_task_prompt_reminders from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SavedSearchDoc @@ -77,106 +60,6 @@ AnswerObjectIterator = Iterator[ ] -def quote_based_qa( - prompt: Prompt, - query_message: ThreadMessage, - history_str: str, - context_chunks: list[InferenceChunk], - llm_override: str | None, - timeout: int, - use_chain_of_thought: bool, - return_contexts: bool, - llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, -) -> AnswerObjectIterator: - qa_model = get_question_answer_model( - prompt=prompt, - timeout=timeout, - chain_of_thought=use_chain_of_thought, - llm_version=llm_override, - ) - - full_prompt_str = ( - qa_model.build_prompt( - query=query_message.message, - history_str=history_str, - context_chunks=context_chunks, - ) - if qa_model is not None - else "Gen AI Disabled" - ) - - response_packets = ( - qa_model.answer_question_stream( - prompt=full_prompt_str, - llm_context_docs=context_chunks, - metrics_callback=llm_metrics_callback, - ) - if qa_model is not None - else no_gen_ai_response() - ) - - if qa_model is not None and return_contexts: - contexts = DanswerContexts( - contexts=[ - DanswerContext( - content=context_chunk.content, - document_id=context_chunk.document_id, - semantic_identifier=context_chunk.semantic_identifier, - blurb=context_chunk.semantic_identifier, - ) - for context_chunk in context_chunks - ] - ) - - response_packets = itertools.chain(response_packets, [contexts]) - - yield from response_packets - - -def citation_based_qa( - prompt: Prompt, - query_message: ThreadMessage, - history_str: str, - context_chunks: list[InferenceChunk], - llm_override: str | None, - timeout: int, -) -> AnswerObjectIterator: - llm_tokenizer = get_default_llm_tokenizer() - - system_prompt_or_none, _ = build_chat_system_message( - prompt=prompt, - context_exists=True, - llm_tokenizer_encode_func=llm_tokenizer.encode, - ) - - task_prompt_with_reminder = build_task_prompt_reminders(prompt) - - context_docs_str = build_complete_context_str(context_chunks) - user_message = HumanMessage( - content=CITATIONS_PROMPT.format( - task_prompt=task_prompt_with_reminder, - user_query=query_message.message, - history_block=history_str, - context_docs_str=context_docs_str, - ) - ) - - llm = get_default_llm( - timeout=timeout, - gen_ai_model_version_override=llm_override, - ) - - llm_prompt: list[BaseMessage] = [user_message] - if system_prompt_or_none is not None: - llm_prompt = [system_prompt_or_none] + llm_prompt - - llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in context_chunks] - doc_id_to_rank_map = map_document_id_order(llm_docs) - - tokens = llm.stream(llm_prompt) - yield from extract_citations_from_stream(tokens, llm_docs, doc_id_to_rank_map) - - def stream_answer_objects( query_req: DirectQARequest, user: User | None, @@ -188,14 +71,12 @@ def stream_answer_objects( db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, - default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE, timeout: int = QA_TIMEOUT, bypass_acl: bool = False, use_citations: bool = False, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, - llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, ) -> AnswerObjectIterator: """Streams in order: 1. [always] Retrieved documents, stops flow if nothing is found @@ -273,43 +154,11 @@ def stream_answer_objects( ) yield llm_relevance_filtering_response - # Prep chunks to pass to LLM - num_llm_chunks = ( - chat_session.persona.num_chunks - if chat_session.persona.num_chunks is not None - else default_num_chunks - ) - - chunk_token_limit = int(num_llm_chunks * default_chunk_size) - if max_document_tokens: - chunk_token_limit = min(chunk_token_limit, max_document_tokens) - else: - max_document_tokens = compute_max_document_tokens( - persona=chat_session.persona, actual_user_input=query_msg.message - ) - chunk_token_limit = min(chunk_token_limit, max_document_tokens) - - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=search_pipeline.chunk_relevance_list, - token_limit=chunk_token_limit, - ) - llm_chunks = [top_chunks[i] for i in llm_chunks_indices] - - logger.debug( - f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}" - ) - prompt = None - llm_override = None if query_req.prompt_id is not None: prompt = get_prompt_by_id( prompt_id=query_req.prompt_id, user_id=user_id, db_session=db_session ) - persona = get_persona_by_id( - persona_id=query_req.persona_id, user_id=user_id, db_session=db_session - ) - llm_override = persona.llm_model_version_override if prompt is None: if not chat_session.persona.prompts: raise RuntimeError( @@ -329,52 +178,39 @@ def stream_answer_objects( commit=True, ) - if use_citations: - qa_stream = citation_based_qa( - prompt=prompt, - query_message=query_msg, - history_str=history_str, - context_chunks=llm_chunks, - llm_override=llm_override, - timeout=timeout, - ) - else: - qa_stream = quote_based_qa( - prompt=prompt, - query_message=query_msg, - history_str=history_str, - context_chunks=llm_chunks, - llm_override=llm_override, - timeout=timeout, - use_chain_of_thought=False, - return_contexts=False, - llm_metrics_callback=llm_metrics_callback, - ) - - # Capture outputs and errors - llm_output = "" - error: str | None = None - for packet in qa_stream: - logger.debug(packet) - - if isinstance(packet, DanswerAnswerPiece): - token = packet.answer_piece - if token: - llm_output += token - elif isinstance(packet, StreamingError): - error = packet.error - - yield packet + answer_config = AnswerStyleConfig( + citation_config=CitationConfig() if use_citations else None, + quotes_config=QuotesConfig() if not use_citations else None, + document_pruning_config=DocumentPruningConfig( + max_chunks=int( + chat_session.persona.num_chunks + if chat_session.persona.num_chunks is not None + else default_num_chunks + ), + max_tokens=max_document_tokens, + ), + ) + answer = Answer( + question=query_msg.message, + docs=[llm_doc_from_inference_chunk(chunk) for chunk in top_chunks], + answer_style_config=answer_config, + prompt=prompt, + persona=chat_session.persona, + doc_relevance_list=search_pipeline.chunk_relevance_list, + single_message_history=history_str, + timeout=timeout, + ) + yield from answer.processed_streamed_output # Saving Gen AI answer and responding with message info gen_ai_response_message = create_new_chat_message( chat_session_id=chat_session.id, parent_message=new_user_message, prompt_id=query_req.prompt_id, - message=llm_output, - token_count=len(llm_tokenizer(llm_output)), + message=answer.llm_answer, + token_count=len(llm_tokenizer(answer.llm_answer)), message_type=MessageType.ASSISTANT, - error=error, + error=None, reference_docs=None, # Don't need to save reference docs for one shot flow db_session=db_session, commit=True, @@ -419,7 +255,6 @@ def get_search_answer( retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, - llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, ) -> OneShotQAResponse: """Collects the streamed one shot answer responses into a single object""" qa_response = OneShotQAResponse() @@ -435,7 +270,6 @@ def get_search_answer( timeout=answer_generation_timeout, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, - llm_metrics_callback=llm_metrics_callback, ) answer = "" diff --git a/backend/danswer/one_shot_answer/factory.py b/backend/danswer/one_shot_answer/factory.py deleted file mode 100644 index 122ed6ac06..0000000000 --- a/backend/danswer/one_shot_answer/factory.py +++ /dev/null @@ -1,48 +0,0 @@ -from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE -from danswer.configs.chat_configs import QA_TIMEOUT -from danswer.db.models import Prompt -from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm -from danswer.one_shot_answer.interfaces import QAModel -from danswer.one_shot_answer.qa_block import QABlock -from danswer.one_shot_answer.qa_block import QAHandler -from danswer.one_shot_answer.qa_block import SingleMessageQAHandler -from danswer.one_shot_answer.qa_block import WeakLLMQAHandler -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -def get_question_answer_model( - prompt: Prompt | None, - api_key: str | None = None, - timeout: int = QA_TIMEOUT, - chain_of_thought: bool = False, - llm_version: str | None = None, - qa_model_version: str | None = QA_PROMPT_OVERRIDE, -) -> QAModel | None: - if chain_of_thought: - raise NotImplementedError("COT has been disabled") - - system_prompt = prompt.system_prompt if prompt is not None else None - task_prompt = prompt.task_prompt if prompt is not None else None - - try: - llm = get_default_llm( - api_key=api_key, - timeout=timeout, - gen_ai_model_version_override=llm_version, - ) - except GenAIDisabledException: - return None - - if qa_model_version == "weak": - qa_handler: QAHandler = WeakLLMQAHandler( - system_prompt=system_prompt, task_prompt=task_prompt - ) - else: - qa_handler = SingleMessageQAHandler( - system_prompt=system_prompt, task_prompt=task_prompt - ) - - return QABlock(llm=llm, qa_handler=qa_handler) diff --git a/backend/danswer/one_shot_answer/interfaces.py b/backend/danswer/one_shot_answer/interfaces.py deleted file mode 100644 index ca916d699d..0000000000 --- a/backend/danswer/one_shot_answer/interfaces.py +++ /dev/null @@ -1,26 +0,0 @@ -import abc -from collections.abc import Callable - -from danswer.chat.models import AnswerQuestionStreamReturn -from danswer.chat.models import LLMMetricsContainer -from danswer.indexing.models import InferenceChunk - - -class QAModel: - @abc.abstractmethod - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - raise NotImplementedError - - @abc.abstractmethod - def answer_question_stream( - self, - prompt: str, - llm_context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, - ) -> AnswerQuestionStreamReturn: - raise NotImplementedError diff --git a/backend/danswer/one_shot_answer/qa_block.py b/backend/danswer/one_shot_answer/qa_block.py deleted file mode 100644 index 68cb6e4a82..0000000000 --- a/backend/danswer/one_shot_answer/qa_block.py +++ /dev/null @@ -1,313 +0,0 @@ -import abc -import re -from collections.abc import Callable -from collections.abc import Iterator -from typing import cast - -from danswer.chat.models import AnswerQuestionStreamReturn -from danswer.chat.models import DanswerAnswer -from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerQuotes -from danswer.chat.models import LlmDoc -from danswer.chat.models import LLMMetricsContainer -from danswer.chat.models import StreamingError -from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.constants import DISABLED_GEN_AI_MSG -from danswer.indexing.models import InferenceChunk -from danswer.llm.interfaces import LLM -from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_default_llm_token_encode -from danswer.one_shot_answer.interfaces import QAModel -from danswer.one_shot_answer.qa_utils import process_answer -from danswer.one_shot_answer.qa_utils import process_model_tokens -from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK -from danswer.prompts.direct_qa_prompts import COT_PROMPT -from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK -from danswer.prompts.direct_qa_prompts import JSON_PROMPT -from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT -from danswer.prompts.direct_qa_prompts import ONE_SHOT_SYSTEM_PROMPT -from danswer.prompts.direct_qa_prompts import ONE_SHOT_TASK_PROMPT -from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT -from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT -from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT -from danswer.prompts.direct_qa_prompts import WEAK_MODEL_SYSTEM_PROMPT -from danswer.prompts.direct_qa_prompts import WEAK_MODEL_TASK_PROMPT -from danswer.prompts.prompt_utils import build_complete_context_str -from danswer.utils.logger import setup_logger -from danswer.utils.text_processing import clean_up_code_blocks -from danswer.utils.text_processing import escape_newlines - -logger = setup_logger() - - -class QAHandler(abc.ABC): - @property - @abc.abstractmethod - def is_json_output(self) -> bool: - """Does the model output a valid json with answer and quotes keys? Most flows with a - capable model should output a json. This hints to the model that the output is used - with a downstream system rather than freeform creative output. Most models should be - finetuned to recognize this.""" - raise NotImplementedError - - @abc.abstractmethod - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - raise NotImplementedError - - def process_llm_token_stream( - self, tokens: Iterator[str], context_chunks: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - yield from process_model_tokens( - tokens=tokens, - context_docs=context_chunks, - is_json_prompt=self.is_json_output, - ) - - -class WeakLLMQAHandler(QAHandler): - """Since Danswer supports a variety of LLMs, this less demanding prompt is provided - as an option to use with weaker LLMs such as small version, low float precision, quantized, - or distilled models. It only uses one context document and has very weak requirements of - output format. - """ - - def __init__( - self, - system_prompt: str | None, - task_prompt: str | None, - ) -> None: - if not system_prompt and not task_prompt: - self.system_prompt = WEAK_MODEL_SYSTEM_PROMPT - self.task_prompt = WEAK_MODEL_TASK_PROMPT - else: - self.system_prompt = system_prompt or "" - self.task_prompt = task_prompt or "" - - self.task_prompt = self.task_prompt.rstrip() - if self.task_prompt and self.task_prompt[0] != "\n": - self.task_prompt = "\n" + self.task_prompt - - @property - def is_json_output(self) -> bool: - return False - - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - context_block = "" - if context_chunks: - context_block = CONTEXT_BLOCK.format( - context_docs_str=context_chunks[0].content - ) - - prompt_str = WEAK_LLM_PROMPT.format( - system_prompt=self.system_prompt, - context_block=context_block, - task_prompt=self.task_prompt, - user_query=query, - ) - return prompt_str - - -class SingleMessageQAHandler(QAHandler): - def __init__( - self, - system_prompt: str | None, - task_prompt: str | None, - use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), - ) -> None: - self.use_language_hint = use_language_hint - if not system_prompt and not task_prompt: - self.system_prompt = ONE_SHOT_SYSTEM_PROMPT - self.task_prompt = ONE_SHOT_TASK_PROMPT - else: - self.system_prompt = system_prompt or "" - self.task_prompt = task_prompt or "" - - self.task_prompt = self.task_prompt.rstrip() - if self.task_prompt and self.task_prompt[0] != "\n": - self.task_prompt = "\n" + self.task_prompt - - @property - def is_json_output(self) -> bool: - return True - - def build_prompt( - self, query: str, history_str: str, context_chunks: list[InferenceChunk] - ) -> str: - context_block = "" - if context_chunks: - context_docs_str = build_complete_context_str( - cast(list[LlmDoc | InferenceChunk], context_chunks) - ) - context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str) - - history_block = "" - if history_str: - history_block = HISTORY_BLOCK.format(history_str=history_str) - - full_prompt = JSON_PROMPT.format( - system_prompt=self.system_prompt, - context_block=context_block, - history_block=history_block, - task_prompt=self.task_prompt, - user_query=query, - language_hint_or_none=LANGUAGE_HINT.strip() - if self.use_language_hint - else "", - ).strip() - return full_prompt - - -# This one isn't used, currently only streaming prompts are used -class SingleMessageScratchpadHandler(QAHandler): - def __init__( - self, - system_prompt: str | None, - task_prompt: str | None, - use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), - ) -> None: - self.use_language_hint = use_language_hint - if not system_prompt and not task_prompt: - self.system_prompt = ONE_SHOT_SYSTEM_PROMPT - self.task_prompt = ONE_SHOT_TASK_PROMPT - else: - self.system_prompt = system_prompt or "" - self.task_prompt = task_prompt or "" - - self.task_prompt = self.task_prompt.rstrip() - if self.task_prompt and self.task_prompt[0] != "\n": - self.task_prompt = "\n" + self.task_prompt - - @property - def is_json_output(self) -> bool: - return True - - def build_prompt( - self, query: str, history_str: str, context_chunks: list[InferenceChunk] - ) -> str: - context_docs_str = build_complete_context_str( - cast(list[LlmDoc | InferenceChunk], context_chunks) - ) - - # Outdated - prompt = COT_PROMPT.format( - context_docs_str=context_docs_str, - user_query=query, - language_hint_or_none=LANGUAGE_HINT.strip() - if self.use_language_hint - else "", - ).strip() - - return prompt - - def process_llm_output( - self, model_output: str, context_chunks: list[InferenceChunk] - ) -> tuple[DanswerAnswer, DanswerQuotes]: - logger.debug(model_output) - - model_clean = clean_up_code_blocks(model_output) - - match = re.search(r'{\s*"answer":', model_clean) - if not match: - return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) - - final_json = escape_newlines(model_clean[match.start() :]) - - return process_answer( - final_json, context_chunks, is_json_prompt=self.is_json_output - ) - - def process_llm_token_stream( - self, tokens: Iterator[str], context_chunks: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - # Can be supported but the parsing is more involved, not handling until needed - raise ValueError( - "This Scratchpad approach is not suitable for real time uses like streaming" - ) - - -def build_dummy_prompt( - system_prompt: str, task_prompt: str, retrieval_disabled: bool -) -> str: - if retrieval_disabled: - return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() - - return PARAMATERIZED_PROMPT.format( - context_docs_str="", - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() - - -def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]: - yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG) - - -class QABlock(QAModel): - def __init__(self, llm: LLM, qa_handler: QAHandler) -> None: - self._llm = llm - self._qa_handler = qa_handler - - def build_prompt( - self, - query: str, - history_str: str, - context_chunks: list[InferenceChunk], - ) -> str: - prompt = self._qa_handler.build_prompt( - query=query, history_str=history_str, context_chunks=context_chunks - ) - return prompt - - def answer_question_stream( - self, - prompt: str, - llm_context_docs: list[InferenceChunk], - metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, - ) -> AnswerQuestionStreamReturn: - tokens_stream = self._llm.stream(prompt) - - captured_tokens = [] - - try: - for answer_piece in self._qa_handler.process_llm_token_stream( - iter(tokens_stream), llm_context_docs - ): - if ( - isinstance(answer_piece, DanswerAnswerPiece) - and answer_piece.answer_piece - ): - captured_tokens.append(answer_piece.answer_piece) - yield answer_piece - - except Exception as e: - yield StreamingError(error=str(e)) - - if metrics_callback is not None: - prompt_tokens = check_number_of_tokens( - text=str(prompt), encode_fn=get_default_llm_token_encode() - ) - - response_tokens = check_number_of_tokens( - text="".join(captured_tokens), encode_fn=get_default_llm_token_encode() - ) - - metrics_callback( - LLMMetricsContainer( - prompt_tokens=prompt_tokens, response_tokens=response_tokens - ) - ) diff --git a/backend/danswer/one_shot_answer/qa_utils.py b/backend/danswer/one_shot_answer/qa_utils.py index 032d243459..e912a915e2 100644 --- a/backend/danswer/one_shot_answer/qa_utils.py +++ b/backend/danswer/one_shot_answer/qa_utils.py @@ -1,275 +1,14 @@ -import math -import re from collections.abc import Callable from collections.abc import Generator -from collections.abc import Iterator -from json.decoder import JSONDecodeError -from typing import Optional -from typing import Tuple -import regex - -from danswer.chat.models import DanswerAnswer -from danswer.chat.models import DanswerAnswerPiece -from danswer.chat.models import DanswerQuote -from danswer.chat.models import DanswerQuotes -from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.constants import MessageType -from danswer.indexing.models import InferenceChunk from danswer.llm.utils import get_default_llm_token_encode from danswer.one_shot_answer.models import ThreadMessage -from danswer.prompts.constants import ANSWER_PAT -from danswer.prompts.constants import QUOTE_PAT -from danswer.prompts.constants import UNCERTAINTY_PAT from danswer.utils.logger import setup_logger -from danswer.utils.text_processing import clean_model_quote -from danswer.utils.text_processing import clean_up_code_blocks -from danswer.utils.text_processing import extract_embedded_json -from danswer.utils.text_processing import shared_precompare_cleanup logger = setup_logger() -def _extract_answer_quotes_freeform( - answer_raw: str, -) -> Tuple[Optional[str], Optional[list[str]]]: - """Splits the model output into an Answer and 0 or more Quote sections. - Splits by the Quote pattern, if not exist then assume it's all answer and no quotes - """ - # If no answer section, don't care about the quote - if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()): - return None, None - - # Sometimes model regenerates the Answer: pattern despite it being provided in the prompt - if answer_raw.lower().startswith(ANSWER_PAT.lower()): - answer_raw = answer_raw[len(ANSWER_PAT) :] - - # Accept quote sections starting with the lower case version - answer_raw = answer_raw.replace( - f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}" - ) # Just in case model unreliable - - sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw) - sections_clean = [ - str(section).strip() for section in sections if str(section).strip() - ] - if not sections_clean: - return None, None - - answer = str(sections_clean[0]) - if len(sections) == 1: - return answer, None - return answer, sections_clean[1:] - - -def _extract_answer_quotes_json( - answer_dict: dict[str, str | list[str]] -) -> Tuple[Optional[str], Optional[list[str]]]: - answer_dict = {k.lower(): v for k, v in answer_dict.items()} - answer = str(answer_dict.get("answer")) - quotes = answer_dict.get("quotes") or answer_dict.get("quote") - if isinstance(quotes, str): - quotes = [quotes] - return answer, quotes - - -def _extract_answer_json(raw_model_output: str) -> dict: - try: - answer_json = extract_embedded_json(raw_model_output) - except (ValueError, JSONDecodeError): - # LLMs get confused when handling the list in the json. Sometimes it doesn't attend - # enough to the previous { token so it just ends the list of quotes and stops there - # here, we add logic to try to fix this LLM error. - answer_json = extract_embedded_json(raw_model_output + "}") - - if "answer" not in answer_json: - raise ValueError("Model did not output an answer as expected.") - - return answer_json - - -def separate_answer_quotes( - answer_raw: str, is_json_prompt: bool = False -) -> Tuple[Optional[str], Optional[list[str]]]: - """Takes in a raw model output and pulls out the answer and the quotes sections.""" - if is_json_prompt: - model_raw_json = _extract_answer_json(answer_raw) - return _extract_answer_quotes_json(model_raw_json) - - return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw)) - - -def match_quotes_to_docs( - quotes: list[str], - chunks: list[InferenceChunk], - max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT, - fuzzy_search: bool = False, - prefix_only_length: int = 100, -) -> DanswerQuotes: - danswer_quotes: list[DanswerQuote] = [] - for quote in quotes: - max_edits = math.ceil(float(len(quote)) * max_error_percent) - - for chunk in chunks: - if not chunk.source_links: - continue - - quote_clean = shared_precompare_cleanup( - clean_model_quote(quote, trim_length=prefix_only_length) - ) - chunk_clean = shared_precompare_cleanup(chunk.content) - - # Finding the offset of the quote in the plain text - if fuzzy_search: - re_search_str = ( - r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}" - ) - found = regex.search(re_search_str, chunk_clean) - if not found: - continue - offset = found.span()[0] - else: - if quote_clean not in chunk_clean: - continue - offset = chunk_clean.index(quote_clean) - - # Extracting the link from the offset - curr_link = None - for link_offset, link in chunk.source_links.items(): - # Should always find one because offset is at least 0 and there - # must be a 0 link_offset - if int(link_offset) <= offset: - curr_link = link - else: - break - - danswer_quotes.append( - DanswerQuote( - quote=quote, - document_id=chunk.document_id, - link=curr_link, - source_type=chunk.source_type, - semantic_identifier=chunk.semantic_identifier, - blurb=chunk.blurb, - ) - ) - break - - return DanswerQuotes(quotes=danswer_quotes) - - -def process_answer( - answer_raw: str, - chunks: list[InferenceChunk], - is_json_prompt: bool = True, -) -> tuple[DanswerAnswer, DanswerQuotes]: - """Used (1) in the non-streaming case to process the model output - into an Answer and Quotes AND (2) after the complete streaming response - has been received to process the model output into an Answer and Quotes.""" - answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt) - if answer == UNCERTAINTY_PAT or not answer: - if answer == UNCERTAINTY_PAT: - logger.debug("Answer matched UNCERTAINTY_PAT") - else: - logger.debug("No answer extracted from raw output") - return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) - - logger.info(f"Answer: {answer}") - if not quote_strings: - logger.debug("No quotes extracted from raw output") - return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[]) - logger.info(f"All quotes (including unmatched): {quote_strings}") - quotes = match_quotes_to_docs(quote_strings, chunks) - logger.debug(f"Final quotes: {quotes}") - - return DanswerAnswer(answer=answer), quotes - - -def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: - next_token = next_token.replace('\\"', "") - # If the previous character is an escape token, don't consider the first character of next_token - # This does not work if it's an escaped escape sign before the " but this is rare, not worth handling - if answer_so_far and answer_so_far[-1] == "\\": - next_token = next_token[1:] - if '"' in next_token: - return True - return False - - -def _extract_quotes_from_completed_token_stream( - model_output: str, context_chunks: list[InferenceChunk], is_json_prompt: bool = True -) -> DanswerQuotes: - answer, quotes = process_answer(model_output, context_chunks, is_json_prompt) - if answer: - logger.info(answer) - elif model_output: - logger.warning("Answer extraction from model output failed.") - - return quotes - - -def process_model_tokens( - tokens: Iterator[str], - context_docs: list[InferenceChunk], - is_json_prompt: bool = True, -) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: - """Used in the streaming case to process the model output - into an Answer and Quotes - - Yields Answer tokens back out in a dict for streaming to frontend - When Answer section ends, yields dict with answer_finished key - Collects all the tokens at the end to form the complete model output""" - quote_pat = f"\n{QUOTE_PAT}" - # Sometimes worse model outputs new line instead of : - quote_loose = f"\n{quote_pat[:-1]}\n" - # Sometime model outputs two newlines before quote section - quote_pat_full = f"\n{quote_pat}" - model_output: str = "" - found_answer_start = False if is_json_prompt else True - found_answer_end = False - hold_quote = "" - for token in tokens: - model_previous = model_output - model_output += token - - if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output): - # Note, if the token that completes the pattern has additional text, for example if the token is "? - # Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the - # event that the model outputs the UNCERTAINTY_PAT - found_answer_start = True - - # Prevent heavy cases of hallucinations where model is not even providing a json until later - if is_json_prompt and len(model_output) > 40: - logger.warning("LLM did not produce json as prompted") - found_answer_end = True - - continue - - if found_answer_start and not found_answer_end: - if is_json_prompt and _stream_json_answer_end(model_previous, token): - found_answer_end = True - yield DanswerAnswerPiece(answer_piece=None) - continue - elif not is_json_prompt: - if quote_pat in hold_quote + token or quote_loose in hold_quote + token: - found_answer_end = True - yield DanswerAnswerPiece(answer_piece=None) - continue - if hold_quote + token in quote_pat_full: - hold_quote += token - continue - yield DanswerAnswerPiece(answer_piece=hold_quote + token) - hold_quote = "" - - logger.debug(f"Raw Model QnA Output: {model_output}") - - yield _extract_quotes_from_completed_token_stream( - model_output=model_output, - context_chunks=context_docs, - is_json_prompt=is_json_prompt, - ) - - def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: """Mock streaming by generating the passed in model output, character by character""" for token in model_out: diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py index 972f510db9..5c590939b5 100644 --- a/backend/danswer/search/pipeline.py +++ b/backend/danswer/search/pipeline.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from collections.abc import Generator from typing import cast from sqlalchemy.orm import Session @@ -51,6 +52,11 @@ class SearchPipeline: self._reranked_docs: list[InferenceChunk] | None = None self._relevant_chunk_indicies: list[int] | None = None + # generator state + self._postprocessing_generator: Generator[ + list[InferenceChunk] | list[str], None, None + ] | None = None + """Pre-processing""" def _run_preprocessing(self) -> None: @@ -113,36 +119,38 @@ class SearchPipeline: """Post-Processing""" - def _run_postprocessing(self) -> None: - postprocessing_generator = search_postprocessing( - search_query=self.search_query, - retrieved_chunks=self.retrieved_docs, - rerank_metrics_callback=self.rerank_metrics_callback, - ) - self._reranked_docs = cast(list[InferenceChunk], next(postprocessing_generator)) - - relevant_chunk_ids = cast(list[str], next(postprocessing_generator)) - self._relevant_chunk_indicies = [ - ind - for ind, chunk in enumerate(self._reranked_docs) - if chunk.unique_id in relevant_chunk_ids - ] - @property def reranked_docs(self) -> list[InferenceChunk]: if self._reranked_docs is not None: return self._reranked_docs - self._run_postprocessing() - return cast(list[InferenceChunk], self._reranked_docs) + self._postprocessing_generator = search_postprocessing( + search_query=self.search_query, + retrieved_chunks=self.retrieved_docs, + rerank_metrics_callback=self.rerank_metrics_callback, + ) + self._reranked_docs = cast( + list[InferenceChunk], next(self._postprocessing_generator) + ) + return self._reranked_docs @property def relevant_chunk_indicies(self) -> list[int]: if self._relevant_chunk_indicies is not None: return self._relevant_chunk_indicies - self._run_postprocessing() - return cast(list[int], self._relevant_chunk_indicies) + # run first step of postprocessing generator if not already done + reranked_docs = self.reranked_docs + + relevant_chunk_ids = next( + cast(Generator[list[str], None, None], self._postprocessing_generator) + ) + self._relevant_chunk_indicies = [ + ind + for ind, chunk in enumerate(reranked_docs) + if chunk.unique_id in relevant_chunk_ids + ] + return self._relevant_chunk_indicies @property def chunk_relevance_list(self) -> list[bool]: diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 3dff76d96e..41aa3a3c7e 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -223,11 +223,13 @@ def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc: return LlmDoc( document_id=first_chunk.document_id, content="\n".join(chunk_texts), + blurb=first_chunk.blurb, semantic_identifier=first_chunk.semantic_identifier, source_type=first_chunk.source_type, metadata=first_chunk.metadata, updated_at=first_chunk.updated_at, link=first_chunk.source_links[0] if first_chunk.source_links else None, + source_links=first_chunk.source_links, ) diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index 8762f40b51..d75ff69480 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -14,8 +14,8 @@ from danswer.db.chat import update_persona_visibility from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.persona import create_update_persona +from danswer.llm.answering.prompts.utils import build_dummy_prompt from danswer.llm.utils import get_default_llm_version -from danswer.one_shot_answer.qa_block import build_dummy_prompt from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.features.persona.models import PromptTemplateResponse diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index a8076659c6..4fb98c5a15 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -6,7 +6,6 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.auth.users import current_user -from danswer.chat.chat_utils import compute_max_document_tokens from danswer.chat.chat_utils import create_chat_chain from danswer.chat.process_message import stream_chat_message from danswer.db.chat import create_chat_session @@ -25,6 +24,7 @@ from danswer.db.feedback import create_doc_retrieval_feedback from danswer.db.models import User from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index +from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens from danswer.secondary_llm_flows.chat_session_naming import ( get_renamed_conversation_name, ) diff --git a/backend/tests/regression/answer_quality/eval_direct_qa.py b/backend/tests/regression/answer_quality/eval_direct_qa.py index bd2f70010e..d32f275472 100644 --- a/backend/tests/regression/answer_quality/eval_direct_qa.py +++ b/backend/tests/regression/answer_quality/eval_direct_qa.py @@ -77,7 +77,6 @@ def get_answer_for_question( str | None, RetrievalMetricsContainer | None, RerankMetricsContainer | None, - LLMMetricsContainer | None, ]: filters = IndexFilters( source_type=None, @@ -103,7 +102,6 @@ def get_answer_for_question( retrieval_metrics = MetricsHander[RetrievalMetricsContainer]() rerank_metrics = MetricsHander[RerankMetricsContainer]() - llm_metrics = MetricsHander[LLMMetricsContainer]() answer = get_search_answer( query_req=new_message_request, @@ -116,14 +114,12 @@ def get_answer_for_question( bypass_acl=True, retrieval_metrics_callback=retrieval_metrics.record_metric, rerank_metrics_callback=rerank_metrics.record_metric, - llm_metrics_callback=llm_metrics.record_metric, ) return ( answer.answer, retrieval_metrics.metrics, rerank_metrics.metrics, - llm_metrics.metrics, ) @@ -221,7 +217,6 @@ if __name__ == "__main__": answer, retrieval_metrics, rerank_metrics, - llm_metrics, ) = get_answer_for_question(sample["question"], db_session) end_time = datetime.now() @@ -237,12 +232,6 @@ if __name__ == "__main__": else "\tFailed, either crashed or refused to answer." ) if not args.discard_metrics: - print("\nLLM Tokens Usage:") - if llm_metrics is None: - print("No LLM Metrics Available") - else: - _print_llm_metrics(llm_metrics) - print("\nRetrieval Metrics:") if retrieval_metrics is None: print("No Retrieval Metrics Available") diff --git a/backend/tests/regression/search_quality/eval_search.py b/backend/tests/regression/search_quality/eval_search.py index d40ae13480..5bf9406b41 100644 --- a/backend/tests/regression/search_quality/eval_search.py +++ b/backend/tests/regression/search_quality/eval_search.py @@ -7,9 +7,9 @@ from typing import TextIO from sqlalchemy.orm import Session -from danswer.chat.chat_utils import get_chunks_for_qa from danswer.db.engine import get_sqlalchemy_engine from danswer.indexing.models import InferenceChunk +from danswer.llm.answering.doc_pruning import reorder_docs from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SearchRequest @@ -95,16 +95,8 @@ def get_search_results( top_chunks = search_pipeline.reranked_docs llm_chunk_selection = search_pipeline.chunk_relevance_list - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, - token_limit=None, - ) - - llm_chunks = [top_chunks[i] for i in llm_chunks_indices] - return ( - llm_chunks, + reorder_docs(top_chunks, llm_chunk_selection), retrieval_metrics.metrics, rerank_metrics.metrics, ) diff --git a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py index b30d08b169..b7b30b63d2 100644 --- a/backend/tests/unit/danswer/direct_qa/test_qa_utils.py +++ b/backend/tests/unit/danswer/direct_qa/test_qa_utils.py @@ -3,8 +3,12 @@ import unittest from danswer.configs.constants import DocumentSource from danswer.indexing.models import InferenceChunk -from danswer.one_shot_answer.qa_utils import match_quotes_to_docs -from danswer.one_shot_answer.qa_utils import separate_answer_quotes +from danswer.llm.answering.stream_processing.quotes_processing import ( + match_quotes_to_docs, +) +from danswer.llm.answering.stream_processing.quotes_processing import ( + separate_answer_quotes, +) class TestQAPostprocessing(unittest.TestCase):