From e5035b8992603b6e1e3bb7eefdc63935c4645d10 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Tue, 26 Dec 2023 00:38:29 -0800 Subject: [PATCH] Move some util functions around (#883) --- backend/danswer/chat/chat_utils.py | 130 +++++++++++++++++++++++ backend/danswer/chat/process_message.py | 131 +----------------------- 2 files changed, 133 insertions(+), 128 deletions(-) diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index 4a8a539fa5..efc8323f4d 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -1,16 +1,22 @@ +import re from collections.abc import Callable +from collections.abc import Iterator from functools import lru_cache from typing import cast +from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage from sqlalchemy.orm import Session +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF +from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS from danswer.db.chat import get_chat_messages_by_session from danswer.db.models import ChatMessage from danswer.db.models import Prompt @@ -347,3 +353,127 @@ def combine_message_chain( total_token_count += message_token_count return "\n\n".join(message_strs) + + +def find_last_index( + lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS +) -> int: + """From the back, find the index of the last element to include + before the list exceeds the maximum""" + running_sum = 0 + + last_ind = 0 + for i in range(len(lst) - 1, -1, -1): + running_sum += lst[i] + if running_sum > max_prompt_tokens: + last_ind = i + 1 + break + if last_ind >= len(lst): + raise ValueError("Last message alone is too large!") + return last_ind + + +def drop_messages_history_overflow( + system_msg: BaseMessage | None, + system_token_count: int, + history_msgs: list[BaseMessage], + history_token_counts: list[int], + final_msg: BaseMessage, + final_msg_token_count: int, +) -> list[BaseMessage]: + """As message history grows, messages need to be dropped starting from the furthest in the past. + The System message should be kept if at all possible and the latest user input which is inserted in the + prompt template must be included""" + + if len(history_msgs) != len(history_token_counts): + # This should never happen + raise ValueError("Need exactly 1 token count per message for tracking overflow") + + prompt: list[BaseMessage] = [] + + # Start dropping from the history if necessary + all_tokens = history_token_counts + [system_token_count, final_msg_token_count] + ind_prev_msg_start = find_last_index(all_tokens) + + if system_msg and ind_prev_msg_start <= len(history_msgs): + prompt.append(system_msg) + + prompt.extend(history_msgs[ind_prev_msg_start:]) + + prompt.append(final_msg) + + return prompt + + +def extract_citations_from_stream( + tokens: Iterator[str], + context_docs: list[LlmDoc], + doc_id_to_rank_map: dict[str, int], +) -> Iterator[DanswerAnswerPiece | CitationInfo]: + max_citation_num = len(context_docs) + curr_segment = "" + prepend_bracket = False + cited_inds = set() + for token in tokens: + # Special case of [1][ where ][ is a single token + # This is where the model attempts to do consecutive citations like [1][2] + if prepend_bracket: + curr_segment += "[" + curr_segment + prepend_bracket = False + + curr_segment += token + + possible_citation_pattern = r"(\[\d*$)" # [1, [, etc + possible_citation_found = re.search(possible_citation_pattern, curr_segment) + + citation_pattern = r"\[(\d+)\]" # [1], [2] etc + citation_found = re.search(citation_pattern, curr_segment) + + if citation_found: + numerical_value = int(citation_found.group(1)) + if 1 <= numerical_value <= max_citation_num: + context_llm_doc = context_docs[ + numerical_value - 1 + ] # remove 1 index offset + + link = context_llm_doc.link + target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id] + + # Use the citation number for the document's rank in + # the search (or selected docs) results + curr_segment = re.sub( + rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment + ) + + if target_citation_num not in cited_inds: + cited_inds.add(target_citation_num) + yield CitationInfo( + citation_num=target_citation_num, + document_id=context_llm_doc.document_id, + ) + + if link: + curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) + curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) + + # In case there's another open bracket like [1][, don't want to match this + possible_citation_found = None + + # if we see "[", but haven't seen the right side, hold back - this may be a + # citation that needs to be replaced with a link + if possible_citation_found: + continue + + # Special case with back to back citations [1][2] + if curr_segment and curr_segment[-1] == "[": + curr_segment = curr_segment[:-1] + prepend_bracket = True + + yield DanswerAnswerPiece(answer_piece=curr_segment) + curr_segment = "" + + if curr_segment: + if prepend_bracket: + yield DanswerAnswerPiece(answer_piece="[" + curr_segment) + else: + yield DanswerAnswerPiece(answer_piece=curr_segment) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index e7bf455c96..8f53467cab 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -1,15 +1,15 @@ -import re from collections.abc import Callable from collections.abc import Iterator from functools import partial from typing import cast -from langchain.schema.messages import BaseMessage from sqlalchemy.orm import Session from danswer.chat.chat_utils import build_chat_system_message from danswer.chat.chat_utils import build_chat_user_message from danswer.chat.chat_utils import create_chat_chain +from danswer.chat.chat_utils import drop_messages_history_overflow +from danswer.chat.chat_utils import extract_citations_from_stream from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.chat_utils import llm_doc_from_inference_chunk from danswer.chat.chat_utils import map_document_id_order @@ -22,7 +22,6 @@ from danswer.chat.models import StreamingError from danswer.configs.chat_configs import CHUNK_SIZE from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT from danswer.configs.constants import MessageType -from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_chat_message @@ -57,130 +56,6 @@ from danswer.utils.timing import log_generator_function_time logger = setup_logger() -def _find_last_index( - lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS -) -> int: - """From the back, find the index of the last element to include - before the list exceeds the maximum""" - running_sum = 0 - - last_ind = 0 - for i in range(len(lst) - 1, -1, -1): - running_sum += lst[i] - if running_sum > max_prompt_tokens: - last_ind = i + 1 - break - if last_ind >= len(lst): - raise ValueError("Last message alone is too large!") - return last_ind - - -def _drop_messages_history_overflow( - system_msg: BaseMessage | None, - system_token_count: int, - history_msgs: list[BaseMessage], - history_token_counts: list[int], - final_msg: BaseMessage, - final_msg_token_count: int, -) -> list[BaseMessage]: - """As message history grows, messages need to be dropped starting from the furthest in the past. - The System message should be kept if at all possible and the latest user input which is inserted in the - prompt template must be included""" - - if len(history_msgs) != len(history_token_counts): - # This should never happen - raise ValueError("Need exactly 1 token count per message for tracking overflow") - - prompt: list[BaseMessage] = [] - - # Start dropping from the history if necessary - all_tokens = history_token_counts + [system_token_count, final_msg_token_count] - ind_prev_msg_start = _find_last_index(all_tokens) - - if system_msg and ind_prev_msg_start <= len(history_msgs): - prompt.append(system_msg) - - prompt.extend(history_msgs[ind_prev_msg_start:]) - - prompt.append(final_msg) - - return prompt - - -def extract_citations_from_stream( - tokens: Iterator[str], - context_docs: list[LlmDoc], - doc_id_to_rank_map: dict[str, int], -) -> Iterator[DanswerAnswerPiece | CitationInfo]: - max_citation_num = len(context_docs) - curr_segment = "" - prepend_bracket = False - cited_inds = set() - for token in tokens: - # Special case of [1][ where ][ is a single token - # This is where the model attempts to do consecutive citations like [1][2] - if prepend_bracket: - curr_segment += "[" + curr_segment - prepend_bracket = False - - curr_segment += token - - possible_citation_pattern = r"(\[\d*$)" # [1, [, etc - possible_citation_found = re.search(possible_citation_pattern, curr_segment) - - citation_pattern = r"\[(\d+)\]" # [1], [2] etc - citation_found = re.search(citation_pattern, curr_segment) - - if citation_found: - numerical_value = int(citation_found.group(1)) - if 1 <= numerical_value <= max_citation_num: - context_llm_doc = context_docs[ - numerical_value - 1 - ] # remove 1 index offset - - link = context_llm_doc.link - target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id] - - # Use the citation number for the document's rank in - # the search (or selected docs) results - curr_segment = re.sub( - rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment - ) - - if target_citation_num not in cited_inds: - cited_inds.add(target_citation_num) - yield CitationInfo( - citation_num=target_citation_num, - document_id=context_llm_doc.document_id, - ) - - if link: - curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) - curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1) - - # In case there's another open bracket like [1][, don't want to match this - possible_citation_found = None - - # if we see "[", but haven't seen the right side, hold back - this may be a - # citation that needs to be replaced with a link - if possible_citation_found: - continue - - # Special case with back to back citations [1][2] - if curr_segment and curr_segment[-1] == "[": - curr_segment = curr_segment[:-1] - prepend_bracket = True - - yield DanswerAnswerPiece(answer_piece=curr_segment) - curr_segment = "" - - if curr_segment: - if prepend_bracket: - yield DanswerAnswerPiece(answer_piece="[" + curr_segment) - else: - yield DanswerAnswerPiece(answer_piece=curr_segment) - - def generate_ai_chat_response( query_message: ChatMessage, history: list[ChatMessage], @@ -216,7 +91,7 @@ def generate_ai_chat_response( all_doc_useful=all_doc_useful, ) - prompt = _drop_messages_history_overflow( + prompt = drop_messages_history_overflow( system_msg=system_message_or_none, system_token_count=system_tokens, history_msgs=history_basemessages,