Move some util functions around (#883)

This commit is contained in:
Yuhong Sun 2023-12-26 00:38:29 -08:00 committed by GitHub
parent 2e9af3086a
commit e5035b8992
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 133 additions and 128 deletions

View File

@ -1,16 +1,22 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from functools import lru_cache
from typing import cast
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.db.models import Prompt
@ -347,3 +353,127 @@ def combine_message_chain(
total_token_count += message_token_count
return "\n\n".join(message_strs)
def find_last_index(
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
) -> int:
"""From the back, find the index of the last element to include
before the list exceeds the maximum"""
running_sum = 0
last_ind = 0
for i in range(len(lst) - 1, -1, -1):
running_sum += lst[i]
if running_sum > max_prompt_tokens:
last_ind = i + 1
break
if last_ind >= len(lst):
raise ValueError("Last message alone is too large!")
return last_ind
def drop_messages_history_overflow(
system_msg: BaseMessage | None,
system_token_count: int,
history_msgs: list[BaseMessage],
history_token_counts: list[int],
final_msg: BaseMessage,
final_msg_token_count: int,
) -> list[BaseMessage]:
"""As message history grows, messages need to be dropped starting from the furthest in the past.
The System message should be kept if at all possible and the latest user input which is inserted in the
prompt template must be included"""
if len(history_msgs) != len(history_token_counts):
# This should never happen
raise ValueError("Need exactly 1 token count per message for tracking overflow")
prompt: list[BaseMessage] = []
# Start dropping from the history if necessary
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
ind_prev_msg_start = find_last_index(all_tokens)
if system_msg and ind_prev_msg_start <= len(history_msgs):
prompt.append(system_msg)
prompt.extend(history_msgs[ind_prev_msg_start:])
prompt.append(final_msg)
return prompt
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
max_citation_num = len(context_docs)
curr_segment = ""
prepend_bracket = False
cited_inds = set()
for token in tokens:
# Special case of [1][ where ][ is a single token
# This is where the model attempts to do consecutive citations like [1][2]
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
curr_segment += token
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
if citation_found:
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[
numerical_value - 1
] # remove 1 index offset
link = context_llm_doc.link
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
# Use the citation number for the document's rank in
# the search (or selected docs) results
curr_segment = re.sub(
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
)
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
if link:
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
# In case there's another open bracket like [1][, don't want to match this
possible_citation_found = None
# if we see "[", but haven't seen the right side, hold back - this may be a
# citation that needs to be replaced with a link
if possible_citation_found:
continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = ""
if curr_segment:
if prepend_bracket:
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
else:
yield DanswerAnswerPiece(answer_piece=curr_segment)

View File

@ -1,15 +1,15 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from functools import partial
from typing import cast
from langchain.schema.messages import BaseMessage
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import build_chat_system_message
from danswer.chat.chat_utils import build_chat_user_message
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import drop_messages_history_overflow
from danswer.chat.chat_utils import extract_citations_from_stream
from danswer.chat.chat_utils import get_chunks_for_qa
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
from danswer.chat.chat_utils import map_document_id_order
@ -22,7 +22,6 @@ from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import CHUNK_SIZE
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
@ -57,130 +56,6 @@ from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
def _find_last_index(
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
) -> int:
"""From the back, find the index of the last element to include
before the list exceeds the maximum"""
running_sum = 0
last_ind = 0
for i in range(len(lst) - 1, -1, -1):
running_sum += lst[i]
if running_sum > max_prompt_tokens:
last_ind = i + 1
break
if last_ind >= len(lst):
raise ValueError("Last message alone is too large!")
return last_ind
def _drop_messages_history_overflow(
system_msg: BaseMessage | None,
system_token_count: int,
history_msgs: list[BaseMessage],
history_token_counts: list[int],
final_msg: BaseMessage,
final_msg_token_count: int,
) -> list[BaseMessage]:
"""As message history grows, messages need to be dropped starting from the furthest in the past.
The System message should be kept if at all possible and the latest user input which is inserted in the
prompt template must be included"""
if len(history_msgs) != len(history_token_counts):
# This should never happen
raise ValueError("Need exactly 1 token count per message for tracking overflow")
prompt: list[BaseMessage] = []
# Start dropping from the history if necessary
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
ind_prev_msg_start = _find_last_index(all_tokens)
if system_msg and ind_prev_msg_start <= len(history_msgs):
prompt.append(system_msg)
prompt.extend(history_msgs[ind_prev_msg_start:])
prompt.append(final_msg)
return prompt
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
max_citation_num = len(context_docs)
curr_segment = ""
prepend_bracket = False
cited_inds = set()
for token in tokens:
# Special case of [1][ where ][ is a single token
# This is where the model attempts to do consecutive citations like [1][2]
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
curr_segment += token
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
if citation_found:
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[
numerical_value - 1
] # remove 1 index offset
link = context_llm_doc.link
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
# Use the citation number for the document's rank in
# the search (or selected docs) results
curr_segment = re.sub(
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
)
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
if link:
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
# In case there's another open bracket like [1][, don't want to match this
possible_citation_found = None
# if we see "[", but haven't seen the right side, hold back - this may be a
# citation that needs to be replaced with a link
if possible_citation_found:
continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = ""
if curr_segment:
if prepend_bracket:
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
else:
yield DanswerAnswerPiece(answer_piece=curr_segment)
def generate_ai_chat_response(
query_message: ChatMessage,
history: list[ChatMessage],
@ -216,7 +91,7 @@ def generate_ai_chat_response(
all_doc_useful=all_doc_useful,
)
prompt = _drop_messages_history_overflow(
prompt = drop_messages_history_overflow(
system_msg=system_message_or_none,
system_token_count=system_tokens,
history_msgs=history_basemessages,