mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-31 02:01:16 +02:00
Move some util functions around (#883)
This commit is contained in:
parent
2e9af3086a
commit
e5035b8992
@ -1,16 +1,22 @@
|
|||||||
|
import re
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Iterator
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
from langchain.schema.messages import BaseMessage
|
||||||
from langchain.schema.messages import HumanMessage
|
from langchain.schema.messages import HumanMessage
|
||||||
from langchain.schema.messages import SystemMessage
|
from langchain.schema.messages import SystemMessage
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.chat.models import CitationInfo
|
||||||
|
from danswer.chat.models import DanswerAnswerPiece
|
||||||
from danswer.chat.models import LlmDoc
|
from danswer.chat.models import LlmDoc
|
||||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||||
from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||||
from danswer.configs.constants import IGNORE_FOR_QA
|
from danswer.configs.constants import IGNORE_FOR_QA
|
||||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
|
||||||
from danswer.db.chat import get_chat_messages_by_session
|
from danswer.db.chat import get_chat_messages_by_session
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
from danswer.db.models import Prompt
|
from danswer.db.models import Prompt
|
||||||
@ -347,3 +353,127 @@ def combine_message_chain(
|
|||||||
total_token_count += message_token_count
|
total_token_count += message_token_count
|
||||||
|
|
||||||
return "\n\n".join(message_strs)
|
return "\n\n".join(message_strs)
|
||||||
|
|
||||||
|
|
||||||
|
def find_last_index(
|
||||||
|
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
|
||||||
|
) -> int:
|
||||||
|
"""From the back, find the index of the last element to include
|
||||||
|
before the list exceeds the maximum"""
|
||||||
|
running_sum = 0
|
||||||
|
|
||||||
|
last_ind = 0
|
||||||
|
for i in range(len(lst) - 1, -1, -1):
|
||||||
|
running_sum += lst[i]
|
||||||
|
if running_sum > max_prompt_tokens:
|
||||||
|
last_ind = i + 1
|
||||||
|
break
|
||||||
|
if last_ind >= len(lst):
|
||||||
|
raise ValueError("Last message alone is too large!")
|
||||||
|
return last_ind
|
||||||
|
|
||||||
|
|
||||||
|
def drop_messages_history_overflow(
|
||||||
|
system_msg: BaseMessage | None,
|
||||||
|
system_token_count: int,
|
||||||
|
history_msgs: list[BaseMessage],
|
||||||
|
history_token_counts: list[int],
|
||||||
|
final_msg: BaseMessage,
|
||||||
|
final_msg_token_count: int,
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||||
|
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||||
|
prompt template must be included"""
|
||||||
|
|
||||||
|
if len(history_msgs) != len(history_token_counts):
|
||||||
|
# This should never happen
|
||||||
|
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
||||||
|
|
||||||
|
prompt: list[BaseMessage] = []
|
||||||
|
|
||||||
|
# Start dropping from the history if necessary
|
||||||
|
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||||
|
ind_prev_msg_start = find_last_index(all_tokens)
|
||||||
|
|
||||||
|
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||||
|
prompt.append(system_msg)
|
||||||
|
|
||||||
|
prompt.extend(history_msgs[ind_prev_msg_start:])
|
||||||
|
|
||||||
|
prompt.append(final_msg)
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def extract_citations_from_stream(
|
||||||
|
tokens: Iterator[str],
|
||||||
|
context_docs: list[LlmDoc],
|
||||||
|
doc_id_to_rank_map: dict[str, int],
|
||||||
|
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||||
|
max_citation_num = len(context_docs)
|
||||||
|
curr_segment = ""
|
||||||
|
prepend_bracket = False
|
||||||
|
cited_inds = set()
|
||||||
|
for token in tokens:
|
||||||
|
# Special case of [1][ where ][ is a single token
|
||||||
|
# This is where the model attempts to do consecutive citations like [1][2]
|
||||||
|
if prepend_bracket:
|
||||||
|
curr_segment += "[" + curr_segment
|
||||||
|
prepend_bracket = False
|
||||||
|
|
||||||
|
curr_segment += token
|
||||||
|
|
||||||
|
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||||
|
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||||
|
|
||||||
|
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||||
|
citation_found = re.search(citation_pattern, curr_segment)
|
||||||
|
|
||||||
|
if citation_found:
|
||||||
|
numerical_value = int(citation_found.group(1))
|
||||||
|
if 1 <= numerical_value <= max_citation_num:
|
||||||
|
context_llm_doc = context_docs[
|
||||||
|
numerical_value - 1
|
||||||
|
] # remove 1 index offset
|
||||||
|
|
||||||
|
link = context_llm_doc.link
|
||||||
|
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
|
||||||
|
|
||||||
|
# Use the citation number for the document's rank in
|
||||||
|
# the search (or selected docs) results
|
||||||
|
curr_segment = re.sub(
|
||||||
|
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_citation_num not in cited_inds:
|
||||||
|
cited_inds.add(target_citation_num)
|
||||||
|
yield CitationInfo(
|
||||||
|
citation_num=target_citation_num,
|
||||||
|
document_id=context_llm_doc.document_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if link:
|
||||||
|
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||||
|
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||||
|
|
||||||
|
# In case there's another open bracket like [1][, don't want to match this
|
||||||
|
possible_citation_found = None
|
||||||
|
|
||||||
|
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||||
|
# citation that needs to be replaced with a link
|
||||||
|
if possible_citation_found:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Special case with back to back citations [1][2]
|
||||||
|
if curr_segment and curr_segment[-1] == "[":
|
||||||
|
curr_segment = curr_segment[:-1]
|
||||||
|
prepend_bracket = True
|
||||||
|
|
||||||
|
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||||
|
curr_segment = ""
|
||||||
|
|
||||||
|
if curr_segment:
|
||||||
|
if prepend_bracket:
|
||||||
|
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
||||||
|
else:
|
||||||
|
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
import re
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from langchain.schema.messages import BaseMessage
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.chat.chat_utils import build_chat_system_message
|
from danswer.chat.chat_utils import build_chat_system_message
|
||||||
from danswer.chat.chat_utils import build_chat_user_message
|
from danswer.chat.chat_utils import build_chat_user_message
|
||||||
from danswer.chat.chat_utils import create_chat_chain
|
from danswer.chat.chat_utils import create_chat_chain
|
||||||
|
from danswer.chat.chat_utils import drop_messages_history_overflow
|
||||||
|
from danswer.chat.chat_utils import extract_citations_from_stream
|
||||||
from danswer.chat.chat_utils import get_chunks_for_qa
|
from danswer.chat.chat_utils import get_chunks_for_qa
|
||||||
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
|
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
|
||||||
from danswer.chat.chat_utils import map_document_id_order
|
from danswer.chat.chat_utils import map_document_id_order
|
||||||
@ -22,7 +22,6 @@ from danswer.chat.models import StreamingError
|
|||||||
from danswer.configs.chat_configs import CHUNK_SIZE
|
from danswer.configs.chat_configs import CHUNK_SIZE
|
||||||
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
|
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
|
|
||||||
from danswer.db.chat import create_db_search_doc
|
from danswer.db.chat import create_db_search_doc
|
||||||
from danswer.db.chat import create_new_chat_message
|
from danswer.db.chat import create_new_chat_message
|
||||||
from danswer.db.chat import get_chat_message
|
from danswer.db.chat import get_chat_message
|
||||||
@ -57,130 +56,6 @@ from danswer.utils.timing import log_generator_function_time
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def _find_last_index(
|
|
||||||
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
|
|
||||||
) -> int:
|
|
||||||
"""From the back, find the index of the last element to include
|
|
||||||
before the list exceeds the maximum"""
|
|
||||||
running_sum = 0
|
|
||||||
|
|
||||||
last_ind = 0
|
|
||||||
for i in range(len(lst) - 1, -1, -1):
|
|
||||||
running_sum += lst[i]
|
|
||||||
if running_sum > max_prompt_tokens:
|
|
||||||
last_ind = i + 1
|
|
||||||
break
|
|
||||||
if last_ind >= len(lst):
|
|
||||||
raise ValueError("Last message alone is too large!")
|
|
||||||
return last_ind
|
|
||||||
|
|
||||||
|
|
||||||
def _drop_messages_history_overflow(
|
|
||||||
system_msg: BaseMessage | None,
|
|
||||||
system_token_count: int,
|
|
||||||
history_msgs: list[BaseMessage],
|
|
||||||
history_token_counts: list[int],
|
|
||||||
final_msg: BaseMessage,
|
|
||||||
final_msg_token_count: int,
|
|
||||||
) -> list[BaseMessage]:
|
|
||||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
|
||||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
|
||||||
prompt template must be included"""
|
|
||||||
|
|
||||||
if len(history_msgs) != len(history_token_counts):
|
|
||||||
# This should never happen
|
|
||||||
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
|
||||||
|
|
||||||
prompt: list[BaseMessage] = []
|
|
||||||
|
|
||||||
# Start dropping from the history if necessary
|
|
||||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
|
||||||
ind_prev_msg_start = _find_last_index(all_tokens)
|
|
||||||
|
|
||||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
|
||||||
prompt.append(system_msg)
|
|
||||||
|
|
||||||
prompt.extend(history_msgs[ind_prev_msg_start:])
|
|
||||||
|
|
||||||
prompt.append(final_msg)
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def extract_citations_from_stream(
|
|
||||||
tokens: Iterator[str],
|
|
||||||
context_docs: list[LlmDoc],
|
|
||||||
doc_id_to_rank_map: dict[str, int],
|
|
||||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
|
||||||
max_citation_num = len(context_docs)
|
|
||||||
curr_segment = ""
|
|
||||||
prepend_bracket = False
|
|
||||||
cited_inds = set()
|
|
||||||
for token in tokens:
|
|
||||||
# Special case of [1][ where ][ is a single token
|
|
||||||
# This is where the model attempts to do consecutive citations like [1][2]
|
|
||||||
if prepend_bracket:
|
|
||||||
curr_segment += "[" + curr_segment
|
|
||||||
prepend_bracket = False
|
|
||||||
|
|
||||||
curr_segment += token
|
|
||||||
|
|
||||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
|
||||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
|
||||||
|
|
||||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
|
||||||
citation_found = re.search(citation_pattern, curr_segment)
|
|
||||||
|
|
||||||
if citation_found:
|
|
||||||
numerical_value = int(citation_found.group(1))
|
|
||||||
if 1 <= numerical_value <= max_citation_num:
|
|
||||||
context_llm_doc = context_docs[
|
|
||||||
numerical_value - 1
|
|
||||||
] # remove 1 index offset
|
|
||||||
|
|
||||||
link = context_llm_doc.link
|
|
||||||
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
|
|
||||||
|
|
||||||
# Use the citation number for the document's rank in
|
|
||||||
# the search (or selected docs) results
|
|
||||||
curr_segment = re.sub(
|
|
||||||
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
|
|
||||||
)
|
|
||||||
|
|
||||||
if target_citation_num not in cited_inds:
|
|
||||||
cited_inds.add(target_citation_num)
|
|
||||||
yield CitationInfo(
|
|
||||||
citation_num=target_citation_num,
|
|
||||||
document_id=context_llm_doc.document_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if link:
|
|
||||||
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
|
||||||
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
|
||||||
|
|
||||||
# In case there's another open bracket like [1][, don't want to match this
|
|
||||||
possible_citation_found = None
|
|
||||||
|
|
||||||
# if we see "[", but haven't seen the right side, hold back - this may be a
|
|
||||||
# citation that needs to be replaced with a link
|
|
||||||
if possible_citation_found:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Special case with back to back citations [1][2]
|
|
||||||
if curr_segment and curr_segment[-1] == "[":
|
|
||||||
curr_segment = curr_segment[:-1]
|
|
||||||
prepend_bracket = True
|
|
||||||
|
|
||||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
|
||||||
curr_segment = ""
|
|
||||||
|
|
||||||
if curr_segment:
|
|
||||||
if prepend_bracket:
|
|
||||||
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
|
||||||
else:
|
|
||||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_ai_chat_response(
|
def generate_ai_chat_response(
|
||||||
query_message: ChatMessage,
|
query_message: ChatMessage,
|
||||||
history: list[ChatMessage],
|
history: list[ChatMessage],
|
||||||
@ -216,7 +91,7 @@ def generate_ai_chat_response(
|
|||||||
all_doc_useful=all_doc_useful,
|
all_doc_useful=all_doc_useful,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = _drop_messages_history_overflow(
|
prompt = drop_messages_history_overflow(
|
||||||
system_msg=system_message_or_none,
|
system_msg=system_message_or_none,
|
||||||
system_token_count=system_tokens,
|
system_token_count=system_tokens,
|
||||||
history_msgs=history_basemessages,
|
history_msgs=history_basemessages,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user