mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-09 20:55:06 +02:00
Enable selection of long documents
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
from datetime import datetime
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@@ -15,14 +16,19 @@ from danswer.chat.models import LlmDoc
|
|||||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||||
from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||||
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.configs.constants import IGNORE_FOR_QA
|
from danswer.configs.constants import IGNORE_FOR_QA
|
||||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
|
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
|
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||||
from danswer.db.chat import get_chat_messages_by_session
|
from danswer.db.chat import get_chat_messages_by_session
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
|
from danswer.db.models import Persona
|
||||||
from danswer.db.models import Prompt
|
from danswer.db.models import Prompt
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.llm.utils import check_number_of_tokens
|
from danswer.llm.utils import check_number_of_tokens
|
||||||
|
from danswer.llm.utils import get_llm_max_tokens
|
||||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||||
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
|
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
|
||||||
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||||
@@ -33,6 +39,12 @@ from danswer.prompts.constants import CODE_BLOCK_PAT
|
|||||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||||
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
||||||
|
from danswer.prompts.token_counts import (
|
||||||
|
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
|
||||||
|
)
|
||||||
|
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||||
|
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||||
|
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||||
|
|
||||||
# Maps connector enum string to a more natural language representation for the LLM
|
# Maps connector enum string to a more natural language representation for the LLM
|
||||||
# If not on the list, uses the original but slightly cleaned up, see below
|
# If not on the list, uses the original but slightly cleaned up, see below
|
||||||
@@ -50,19 +62,39 @@ def clean_up_source(source_str: str) -> str:
|
|||||||
return source_str.replace("_", " ").title()
|
return source_str.replace("_", " ").title()
|
||||||
|
|
||||||
|
|
||||||
def build_context_str(
|
def build_doc_context_str(
|
||||||
|
semantic_identifier: str,
|
||||||
|
source_type: DocumentSource,
|
||||||
|
content: str,
|
||||||
|
ind: int,
|
||||||
|
include_metadata: bool = True,
|
||||||
|
updated_at: datetime | None = None,
|
||||||
|
) -> str:
|
||||||
|
context_str = ""
|
||||||
|
if include_metadata:
|
||||||
|
context_str += f"DOCUMENT {ind}: {semantic_identifier}\n"
|
||||||
|
context_str += f"Source: {clean_up_source(source_type)}\n"
|
||||||
|
if updated_at:
|
||||||
|
update_str = updated_at.strftime("%B %d, %Y %H:%M")
|
||||||
|
context_str += f"Updated: {update_str}\n"
|
||||||
|
context_str += f"{CODE_BLOCK_PAT.format(content.strip())}\n\n\n"
|
||||||
|
return context_str
|
||||||
|
|
||||||
|
|
||||||
|
def build_complete_context_str(
|
||||||
context_docs: list[LlmDoc | InferenceChunk],
|
context_docs: list[LlmDoc | InferenceChunk],
|
||||||
include_metadata: bool = True,
|
include_metadata: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
context_str = ""
|
context_str = ""
|
||||||
for ind, doc in enumerate(context_docs, start=1):
|
for ind, doc in enumerate(context_docs, start=1):
|
||||||
if include_metadata:
|
context_str += build_doc_context_str(
|
||||||
context_str += f"DOCUMENT {ind}: {doc.semantic_identifier}\n"
|
semantic_identifier=doc.semantic_identifier,
|
||||||
context_str += f"Source: {clean_up_source(doc.source_type)}\n"
|
source_type=doc.source_type,
|
||||||
if doc.updated_at:
|
content=doc.content,
|
||||||
update_str = doc.updated_at.strftime("%B %d, %Y %H:%M")
|
updated_at=doc.updated_at,
|
||||||
context_str += f"Updated: {update_str}\n"
|
ind=ind,
|
||||||
context_str += f"{CODE_BLOCK_PAT.format(doc.content.strip())}\n\n\n"
|
include_metadata=include_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
return context_str.strip()
|
return context_str.strip()
|
||||||
|
|
||||||
@@ -71,7 +103,7 @@ def build_context_str(
|
|||||||
def build_chat_system_message(
|
def build_chat_system_message(
|
||||||
prompt: Prompt,
|
prompt: Prompt,
|
||||||
context_exists: bool,
|
context_exists: bool,
|
||||||
llm_tokenizer: Callable,
|
llm_tokenizer_encode_func: Callable,
|
||||||
citation_line: str = REQUIRE_CITATION_STATEMENT,
|
citation_line: str = REQUIRE_CITATION_STATEMENT,
|
||||||
no_citation_line: str = NO_CITATION_STATEMENT,
|
no_citation_line: str = NO_CITATION_STATEMENT,
|
||||||
) -> tuple[SystemMessage | None, int]:
|
) -> tuple[SystemMessage | None, int]:
|
||||||
@@ -92,7 +124,7 @@ def build_chat_system_message(
|
|||||||
if not system_prompt:
|
if not system_prompt:
|
||||||
return None, 0
|
return None, 0
|
||||||
|
|
||||||
token_count = len(llm_tokenizer(system_prompt))
|
token_count = len(llm_tokenizer_encode_func(system_prompt))
|
||||||
system_msg = SystemMessage(content=system_prompt)
|
system_msg = SystemMessage(content=system_prompt)
|
||||||
|
|
||||||
return system_msg, token_count
|
return system_msg, token_count
|
||||||
@@ -138,7 +170,7 @@ def build_chat_user_message(
|
|||||||
chat_message: ChatMessage,
|
chat_message: ChatMessage,
|
||||||
prompt: Prompt,
|
prompt: Prompt,
|
||||||
context_docs: list[LlmDoc],
|
context_docs: list[LlmDoc],
|
||||||
llm_tokenizer: Callable,
|
llm_tokenizer_encode_func: Callable,
|
||||||
all_doc_useful: bool,
|
all_doc_useful: bool,
|
||||||
user_prompt_template: str = CHAT_USER_PROMPT,
|
user_prompt_template: str = CHAT_USER_PROMPT,
|
||||||
context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT,
|
context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT,
|
||||||
@@ -156,11 +188,11 @@ def build_chat_user_message(
|
|||||||
else user_query
|
else user_query
|
||||||
)
|
)
|
||||||
user_prompt = user_prompt.strip()
|
user_prompt = user_prompt.strip()
|
||||||
token_count = len(llm_tokenizer(user_prompt))
|
token_count = len(llm_tokenizer_encode_func(user_prompt))
|
||||||
user_msg = HumanMessage(content=user_prompt)
|
user_msg = HumanMessage(content=user_prompt)
|
||||||
return user_msg, token_count
|
return user_msg, token_count
|
||||||
|
|
||||||
context_docs_str = build_context_str(
|
context_docs_str = build_complete_context_str(
|
||||||
cast(list[LlmDoc | InferenceChunk], context_docs)
|
cast(list[LlmDoc | InferenceChunk], context_docs)
|
||||||
)
|
)
|
||||||
optional_ignore = "" if all_doc_useful else ignore_str
|
optional_ignore = "" if all_doc_useful else ignore_str
|
||||||
@@ -175,7 +207,7 @@ def build_chat_user_message(
|
|||||||
)
|
)
|
||||||
|
|
||||||
user_prompt = user_prompt.strip()
|
user_prompt = user_prompt.strip()
|
||||||
token_count = len(llm_tokenizer(user_prompt))
|
token_count = len(llm_tokenizer_encode_func(user_prompt))
|
||||||
user_msg = HumanMessage(content=user_prompt)
|
user_msg = HumanMessage(content=user_prompt)
|
||||||
|
|
||||||
return user_msg, token_count
|
return user_msg, token_count
|
||||||
@@ -357,16 +389,17 @@ def combine_message_chain(
|
|||||||
return "\n\n".join(message_strs)
|
return "\n\n".join(message_strs)
|
||||||
|
|
||||||
|
|
||||||
def find_last_index(
|
_PER_MESSAGE_TOKEN_BUFFER = 7
|
||||||
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
|
|
||||||
) -> int:
|
|
||||||
|
def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
|
||||||
"""From the back, find the index of the last element to include
|
"""From the back, find the index of the last element to include
|
||||||
before the list exceeds the maximum"""
|
before the list exceeds the maximum"""
|
||||||
running_sum = 0
|
running_sum = 0
|
||||||
|
|
||||||
last_ind = 0
|
last_ind = 0
|
||||||
for i in range(len(lst) - 1, -1, -1):
|
for i in range(len(lst) - 1, -1, -1):
|
||||||
running_sum += lst[i]
|
running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER
|
||||||
if running_sum > max_prompt_tokens:
|
if running_sum > max_prompt_tokens:
|
||||||
last_ind = i + 1
|
last_ind = i + 1
|
||||||
break
|
break
|
||||||
@@ -382,14 +415,11 @@ def drop_messages_history_overflow(
|
|||||||
history_token_counts: list[int],
|
history_token_counts: list[int],
|
||||||
final_msg: BaseMessage,
|
final_msg: BaseMessage,
|
||||||
final_msg_token_count: int,
|
final_msg_token_count: int,
|
||||||
max_allowed_tokens: int | None,
|
max_allowed_tokens: int,
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||||
prompt template must be included"""
|
prompt template must be included"""
|
||||||
if max_allowed_tokens is None:
|
|
||||||
max_allowed_tokens = GEN_AI_MAX_INPUT_TOKENS
|
|
||||||
|
|
||||||
if len(history_msgs) != len(history_token_counts):
|
if len(history_msgs) != len(history_token_counts):
|
||||||
# This should never happen
|
# This should never happen
|
||||||
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
raise ValueError("Need exactly 1 token count per message for tracking overflow")
|
||||||
@@ -508,3 +538,68 @@ def extract_citations_from_stream(
|
|||||||
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
||||||
else:
|
else:
|
||||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_tokens(prompt: Prompt) -> int:
|
||||||
|
return (
|
||||||
|
check_number_of_tokens(prompt.system_prompt)
|
||||||
|
+ check_number_of_tokens(prompt.task_prompt)
|
||||||
|
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
|
||||||
|
+ CITATION_STATEMENT_TOKEN_CNT
|
||||||
|
+ CITATION_REMINDER_TOKEN_CNT
|
||||||
|
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# buffer just to be safe so that we don't overflow the token limit due to
|
||||||
|
# a small miscalculation
|
||||||
|
_MISC_BUFFER = 40
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_document_tokens(
|
||||||
|
persona: Persona, actual_user_input: str | None = None
|
||||||
|
) -> int:
|
||||||
|
"""Estimates the number of tokens available for context documents. Formula is roughly:
|
||||||
|
|
||||||
|
(
|
||||||
|
model_context_window - reserved_output_tokens - prompt_tokens
|
||||||
|
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
|
||||||
|
)
|
||||||
|
|
||||||
|
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
|
||||||
|
if we're trying to determine if the user should be able to select another document) then we just set an
|
||||||
|
arbitrary "upper bound".
|
||||||
|
"""
|
||||||
|
llm_name = GEN_AI_MODEL_VERSION
|
||||||
|
if persona.llm_model_version_override:
|
||||||
|
llm_name = persona.llm_model_version_override
|
||||||
|
|
||||||
|
# if we can't find a number of tokens, just assume some common default
|
||||||
|
model_full_context_window = get_llm_max_tokens(llm_name) or 4096
|
||||||
|
if persona.prompts:
|
||||||
|
prompt_tokens = get_prompt_tokens(persona.prompts[0])
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Persona has no prompts - this should never happen")
|
||||||
|
user_input_tokens = (
|
||||||
|
check_number_of_tokens(actual_user_input)
|
||||||
|
if actual_user_input is not None
|
||||||
|
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
model_full_context_window
|
||||||
|
- GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
|
- prompt_tokens
|
||||||
|
- user_input_tokens
|
||||||
|
- _MISC_BUFFER
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_max_llm_input_tokens(persona: Persona) -> int:
|
||||||
|
"""Maximum tokens allows in the input to the LLM (of any type)."""
|
||||||
|
llm_name = GEN_AI_MODEL_VERSION
|
||||||
|
if persona.llm_model_version_override:
|
||||||
|
llm_name = persona.llm_model_version_override
|
||||||
|
|
||||||
|
model_full_context_window = get_llm_max_tokens(llm_name) or 4096
|
||||||
|
return model_full_context_window - GEN_AI_MAX_OUTPUT_TOKENS - _MISC_BUFFER
|
||||||
|
@@ -7,6 +7,9 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.chat.chat_utils import build_chat_system_message
|
from danswer.chat.chat_utils import build_chat_system_message
|
||||||
from danswer.chat.chat_utils import build_chat_user_message
|
from danswer.chat.chat_utils import build_chat_user_message
|
||||||
|
from danswer.chat.chat_utils import build_doc_context_str
|
||||||
|
from danswer.chat.chat_utils import compute_max_document_tokens
|
||||||
|
from danswer.chat.chat_utils import compute_max_llm_input_tokens
|
||||||
from danswer.chat.chat_utils import create_chat_chain
|
from danswer.chat.chat_utils import create_chat_chain
|
||||||
from danswer.chat.chat_utils import drop_messages_history_overflow
|
from danswer.chat.chat_utils import drop_messages_history_overflow
|
||||||
from danswer.chat.chat_utils import extract_citations_from_stream
|
from danswer.chat.chat_utils import extract_citations_from_stream
|
||||||
@@ -42,8 +45,8 @@ from danswer.indexing.models import InferenceChunk
|
|||||||
from danswer.llm.exceptions import GenAIDisabledException
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import get_default_llm_token_encode
|
from danswer.llm.utils import get_default_llm_tokenizer
|
||||||
from danswer.llm.utils import get_llm_max_tokens
|
from danswer.llm.utils import tokenizer_trim_content
|
||||||
from danswer.llm.utils import translate_history_to_basemessages
|
from danswer.llm.utils import translate_history_to_basemessages
|
||||||
from danswer.search.models import OptionalSearchSetting
|
from danswer.search.models import OptionalSearchSetting
|
||||||
from danswer.search.models import RetrievalDetails
|
from danswer.search.models import RetrievalDetails
|
||||||
@@ -68,7 +71,7 @@ def generate_ai_chat_response(
|
|||||||
context_docs: list[LlmDoc],
|
context_docs: list[LlmDoc],
|
||||||
doc_id_to_rank_map: dict[str, int],
|
doc_id_to_rank_map: dict[str, int],
|
||||||
llm: LLM | None,
|
llm: LLM | None,
|
||||||
llm_tokenizer: Callable,
|
llm_tokenizer_encode_func: Callable,
|
||||||
all_doc_useful: bool,
|
all_doc_useful: bool,
|
||||||
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
|
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
|
||||||
if llm is None:
|
if llm is None:
|
||||||
@@ -88,7 +91,7 @@ def generate_ai_chat_response(
|
|||||||
system_message_or_none, system_tokens = build_chat_system_message(
|
system_message_or_none, system_tokens = build_chat_system_message(
|
||||||
prompt=query_message.prompt,
|
prompt=query_message.prompt,
|
||||||
context_exists=context_exists,
|
context_exists=context_exists,
|
||||||
llm_tokenizer=llm_tokenizer,
|
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
history_basemessages, history_token_counts = translate_history_to_basemessages(
|
history_basemessages, history_token_counts = translate_history_to_basemessages(
|
||||||
@@ -101,7 +104,7 @@ def generate_ai_chat_response(
|
|||||||
chat_message=query_message,
|
chat_message=query_message,
|
||||||
prompt=query_message.prompt,
|
prompt=query_message.prompt,
|
||||||
context_docs=context_docs,
|
context_docs=context_docs,
|
||||||
llm_tokenizer=llm_tokenizer,
|
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||||
all_doc_useful=all_doc_useful,
|
all_doc_useful=all_doc_useful,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -112,9 +115,7 @@ def generate_ai_chat_response(
|
|||||||
history_token_counts=history_token_counts,
|
history_token_counts=history_token_counts,
|
||||||
final_msg=user_message,
|
final_msg=user_message,
|
||||||
final_msg_token_count=user_tokens,
|
final_msg_token_count=user_tokens,
|
||||||
max_allowed_tokens=get_llm_max_tokens(persona.llm_model_version_override)
|
max_allowed_tokens=compute_max_llm_input_tokens(persona),
|
||||||
if persona.llm_model_version_override
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Good Debug/Breakpoint
|
# Good Debug/Breakpoint
|
||||||
@@ -195,7 +196,10 @@ def stream_chat_message(
|
|||||||
except GenAIDisabledException:
|
except GenAIDisabledException:
|
||||||
llm = None
|
llm = None
|
||||||
|
|
||||||
llm_tokenizer = get_default_llm_token_encode()
|
llm_tokenizer = get_default_llm_tokenizer()
|
||||||
|
llm_tokenizer_encode_func = cast(
|
||||||
|
Callable[[str], list[int]], llm_tokenizer.encode
|
||||||
|
)
|
||||||
|
|
||||||
embedding_model = get_current_db_embedding_model(db_session)
|
embedding_model = get_current_db_embedding_model(db_session)
|
||||||
document_index = get_default_document_index(
|
document_index = get_default_document_index(
|
||||||
@@ -223,7 +227,7 @@ def stream_chat_message(
|
|||||||
parent_message=parent_message,
|
parent_message=parent_message,
|
||||||
prompt_id=prompt_id,
|
prompt_id=prompt_id,
|
||||||
message=message_text,
|
message=message_text,
|
||||||
token_count=len(llm_tokenizer(message_text)),
|
token_count=len(llm_tokenizer_encode_func(message_text)),
|
||||||
message_type=MessageType.USER,
|
message_type=MessageType.USER,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
commit=False,
|
commit=False,
|
||||||
@@ -271,6 +275,66 @@ def stream_chat_message(
|
|||||||
doc_identifiers=identifier_tuples,
|
doc_identifiers=identifier_tuples,
|
||||||
document_index=document_index,
|
document_index=document_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# truncate the last document if it exceeds the token limit
|
||||||
|
max_document_tokens = compute_max_document_tokens(
|
||||||
|
persona, actual_user_input=message_text
|
||||||
|
)
|
||||||
|
tokens_per_doc = [
|
||||||
|
len(
|
||||||
|
llm_tokenizer_encode_func(
|
||||||
|
build_doc_context_str(
|
||||||
|
semantic_identifier=llm_doc.semantic_identifier,
|
||||||
|
source_type=llm_doc.source_type,
|
||||||
|
content=llm_doc.content,
|
||||||
|
updated_at=llm_doc.updated_at,
|
||||||
|
ind=ind,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for ind, llm_doc in enumerate(llm_docs)
|
||||||
|
]
|
||||||
|
final_doc_ind = None
|
||||||
|
total_tokens = 0
|
||||||
|
for ind, tokens in enumerate(tokens_per_doc):
|
||||||
|
total_tokens += tokens
|
||||||
|
if total_tokens > max_document_tokens:
|
||||||
|
final_doc_ind = ind
|
||||||
|
break
|
||||||
|
if final_doc_ind is not None:
|
||||||
|
# only allow the final document to get truncated
|
||||||
|
# if more than that, then the user message is too long
|
||||||
|
if final_doc_ind != len(tokens_per_doc) - 1:
|
||||||
|
yield get_json_line(
|
||||||
|
StreamingError(
|
||||||
|
error="LLM context window exceeded. Please de-select some documents or shorten your query."
|
||||||
|
).dict()
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
final_doc_desired_length = tokens_per_doc[final_doc_ind] - (
|
||||||
|
total_tokens - max_document_tokens
|
||||||
|
)
|
||||||
|
# 75 tokens is a reasonable over-estimate of the metadata and title
|
||||||
|
final_doc_content_length = final_doc_desired_length - 75
|
||||||
|
# this could occur if we only have space for the title / metadata
|
||||||
|
# not ideal, but it's the most reasonable thing to do
|
||||||
|
# NOTE: the frontend prevents documents from being selected if
|
||||||
|
# less than 75 tokens are available to try and avoid this situation
|
||||||
|
# from occuring in the first place
|
||||||
|
if final_doc_content_length <= 0:
|
||||||
|
logger.error(
|
||||||
|
f"Final doc ({llm_docs[final_doc_ind].semantic_identifier}) content "
|
||||||
|
"length is less than 0. Removing this doc from the final prompt."
|
||||||
|
)
|
||||||
|
llm_docs.pop()
|
||||||
|
else:
|
||||||
|
llm_docs[final_doc_ind].content = tokenizer_trim_content(
|
||||||
|
content=llm_docs[final_doc_ind].content,
|
||||||
|
desired_length=final_doc_content_length,
|
||||||
|
tokenizer=llm_tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
doc_id_to_rank_map = map_document_id_order(
|
doc_id_to_rank_map = map_document_id_order(
|
||||||
cast(list[InferenceChunk | LlmDoc], llm_docs)
|
cast(list[InferenceChunk | LlmDoc], llm_docs)
|
||||||
)
|
)
|
||||||
@@ -423,7 +487,7 @@ def stream_chat_message(
|
|||||||
context_docs=llm_docs,
|
context_docs=llm_docs,
|
||||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
llm_tokenizer=llm_tokenizer,
|
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||||
all_doc_useful=reference_doc_ids is not None,
|
all_doc_useful=reference_doc_ids is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -467,7 +531,7 @@ def stream_chat_message(
|
|||||||
# Saving Gen AI answer and responding with message info
|
# Saving Gen AI answer and responding with message info
|
||||||
gen_ai_response_message = partial_response(
|
gen_ai_response_message = partial_response(
|
||||||
message=llm_output,
|
message=llm_output,
|
||||||
token_count=len(llm_tokenizer(llm_output)),
|
token_count=len(llm_tokenizer_encode_func(llm_output)),
|
||||||
citations=db_citations,
|
citations=db_citations,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
|
@@ -104,4 +104,10 @@ GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
|
|||||||
# History for secondary LLM flows, not primary chat flow, generally we don't need to
|
# History for secondary LLM flows, not primary chat flow, generally we don't need to
|
||||||
# include as much as possible as this just bumps up the cost unnecessarily
|
# include as much as possible as this just bumps up the cost unnecessarily
|
||||||
GEN_AI_HISTORY_CUTOFF = int(0.5 * GEN_AI_MAX_INPUT_TOKENS)
|
GEN_AI_HISTORY_CUTOFF = int(0.5 * GEN_AI_MAX_INPUT_TOKENS)
|
||||||
|
# This is used when computing how much context space is available for documents
|
||||||
|
# ahead of time in order to let the user know if they can "select" more documents
|
||||||
|
# It represents a maximum "expected" number of input tokens from the latest user
|
||||||
|
# message. At query time, we don't actually enforce this - we will only throw an
|
||||||
|
# error if the total # of tokens exceeds the max input tokens.
|
||||||
|
GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS = 512
|
||||||
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
|
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)
|
||||||
|
@@ -35,7 +35,7 @@ _LLM_TOKENIZER: Any = None
|
|||||||
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
|
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_default_llm_tokenizer() -> Any:
|
def get_default_llm_tokenizer() -> Encoding:
|
||||||
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||||
global _LLM_TOKENIZER
|
global _LLM_TOKENIZER
|
||||||
if _LLM_TOKENIZER is None:
|
if _LLM_TOKENIZER is None:
|
||||||
@@ -56,16 +56,26 @@ def get_default_llm_token_encode() -> Callable[[str], Any]:
|
|||||||
return _LLM_TOKENIZER_ENCODE
|
return _LLM_TOKENIZER_ENCODE
|
||||||
|
|
||||||
|
|
||||||
|
def tokenizer_trim_content(
|
||||||
|
content: str, desired_length: int, tokenizer: Encoding
|
||||||
|
) -> str:
|
||||||
|
tokenizer = get_default_llm_tokenizer()
|
||||||
|
tokens = tokenizer.encode(content)
|
||||||
|
if len(tokens) > desired_length:
|
||||||
|
content = tokenizer.decode(tokens[:desired_length])
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
def tokenizer_trim_chunks(
|
def tokenizer_trim_chunks(
|
||||||
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
|
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
tokenizer = get_default_llm_tokenizer()
|
tokenizer = get_default_llm_tokenizer()
|
||||||
new_chunks = copy(chunks)
|
new_chunks = copy(chunks)
|
||||||
for ind, chunk in enumerate(new_chunks):
|
for ind, chunk in enumerate(new_chunks):
|
||||||
tokens = tokenizer.encode(chunk.content)
|
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
||||||
if len(tokens) > max_chunk_toks:
|
if len(new_content) != len(chunk.content):
|
||||||
new_chunk = copy(chunk)
|
new_chunk = copy(chunk)
|
||||||
new_chunk.content = tokenizer.decode(tokens[:max_chunk_toks])
|
new_chunk.content = new_content
|
||||||
new_chunks[ind] = new_chunk
|
new_chunks[ind] = new_chunk
|
||||||
return new_chunks
|
return new_chunks
|
||||||
|
|
||||||
|
@@ -4,7 +4,7 @@ from collections.abc import Callable
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from danswer.chat.chat_utils import build_context_str
|
from danswer.chat.chat_utils import build_complete_context_str
|
||||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||||
from danswer.chat.models import DanswerAnswer
|
from danswer.chat.models import DanswerAnswer
|
||||||
from danswer.chat.models import DanswerAnswerPiece
|
from danswer.chat.models import DanswerAnswerPiece
|
||||||
@@ -145,7 +145,7 @@ class SingleMessageQAHandler(QAHandler):
|
|||||||
) -> str:
|
) -> str:
|
||||||
context_block = ""
|
context_block = ""
|
||||||
if context_chunks:
|
if context_chunks:
|
||||||
context_docs_str = build_context_str(
|
context_docs_str = build_complete_context_str(
|
||||||
cast(list[LlmDoc | InferenceChunk], context_chunks)
|
cast(list[LlmDoc | InferenceChunk], context_chunks)
|
||||||
)
|
)
|
||||||
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
|
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
|
||||||
@@ -194,7 +194,7 @@ class SingleMessageScratchpadHandler(QAHandler):
|
|||||||
def build_prompt(
|
def build_prompt(
|
||||||
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
|
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
|
||||||
) -> str:
|
) -> str:
|
||||||
context_docs_str = build_context_str(
|
context_docs_str = build_complete_context_str(
|
||||||
cast(list[LlmDoc | InferenceChunk], context_chunks)
|
cast(list[LlmDoc | InferenceChunk], context_chunks)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
24
backend/danswer/prompts/token_counts.py
Normal file
24
backend/danswer/prompts/token_counts.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from danswer.llm.utils import check_number_of_tokens
|
||||||
|
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
|
||||||
|
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||||
|
from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
|
||||||
|
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||||
|
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||||
|
|
||||||
|
|
||||||
|
# tokens outside of the actual persona's "user_prompt" that make up the end
|
||||||
|
# user message
|
||||||
|
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT = check_number_of_tokens(
|
||||||
|
CHAT_USER_PROMPT.format(
|
||||||
|
context_docs_str="",
|
||||||
|
task_prompt="",
|
||||||
|
user_query="",
|
||||||
|
optional_ignore_statement=DEFAULT_IGNORE_STATEMENT,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
CITATION_STATEMENT_TOKEN_CNT = check_number_of_tokens(REQUIRE_CITATION_STATEMENT)
|
||||||
|
|
||||||
|
CITATION_REMINDER_TOKEN_CNT = check_number_of_tokens(CITATION_REMINDER)
|
||||||
|
|
||||||
|
LANGUAGE_HINT_TOKEN_CNT = check_number_of_tokens(LANGUAGE_HINT)
|
@@ -5,6 +5,7 @@ from fastapi import Query
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
|
from danswer.chat.chat_utils import build_doc_context_str
|
||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
@@ -47,12 +48,22 @@ def get_document_info(
|
|||||||
|
|
||||||
contents = [chunk.content for chunk in inference_chunks]
|
contents = [chunk.content for chunk in inference_chunks]
|
||||||
|
|
||||||
combined = "\n".join(contents)
|
combined_contents = "\n".join(contents)
|
||||||
|
|
||||||
|
# get actual document context used for LLM
|
||||||
|
first_chunk = inference_chunks[0]
|
||||||
tokenizer_encode = get_default_llm_token_encode()
|
tokenizer_encode = get_default_llm_token_encode()
|
||||||
|
full_context_str = build_doc_context_str(
|
||||||
|
semantic_identifier=first_chunk.semantic_identifier,
|
||||||
|
source_type=first_chunk.source_type,
|
||||||
|
content=combined_contents,
|
||||||
|
updated_at=first_chunk.updated_at,
|
||||||
|
ind=0,
|
||||||
|
)
|
||||||
|
|
||||||
return DocumentInfo(
|
return DocumentInfo(
|
||||||
num_chunks=len(inference_chunks), num_tokens=len(tokenizer_encode(combined))
|
num_chunks=len(inference_chunks),
|
||||||
|
num_tokens=len(tokenizer_encode(full_context_str)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -2,9 +2,11 @@ from fastapi import APIRouter
|
|||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
|
from danswer.chat.chat_utils import compute_max_document_tokens
|
||||||
from danswer.chat.chat_utils import create_chat_chain
|
from danswer.chat.chat_utils import create_chat_chain
|
||||||
from danswer.chat.process_message import stream_chat_message
|
from danswer.chat.process_message import stream_chat_message
|
||||||
from danswer.db.chat import create_chat_session
|
from danswer.db.chat import create_chat_session
|
||||||
@@ -13,6 +15,7 @@ from danswer.db.chat import get_chat_message
|
|||||||
from danswer.db.chat import get_chat_messages_by_session
|
from danswer.db.chat import get_chat_messages_by_session
|
||||||
from danswer.db.chat import get_chat_session_by_id
|
from danswer.db.chat import get_chat_session_by_id
|
||||||
from danswer.db.chat import get_chat_sessions_by_user
|
from danswer.db.chat import get_chat_sessions_by_user
|
||||||
|
from danswer.db.chat import get_persona_by_id
|
||||||
from danswer.db.chat import set_as_latest_chat_message
|
from danswer.db.chat import set_as_latest_chat_message
|
||||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||||
from danswer.db.chat import update_chat_session
|
from danswer.db.chat import update_chat_session
|
||||||
@@ -244,3 +247,27 @@ def create_search_feedback(
|
|||||||
document_index=document_index,
|
document_index=document_index,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MaxSelectedDocumentTokens(BaseModel):
|
||||||
|
max_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/max-selected-document-tokens")
|
||||||
|
def get_max_document_tokens(
|
||||||
|
persona_id: int,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> MaxSelectedDocumentTokens:
|
||||||
|
try:
|
||||||
|
persona = get_persona_by_id(
|
||||||
|
persona_id=persona_id,
|
||||||
|
user_id=user.id if user else None,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail="Persona not found")
|
||||||
|
|
||||||
|
return MaxSelectedDocumentTokens(
|
||||||
|
max_tokens=compute_max_document_tokens(persona),
|
||||||
|
)
|
||||||
|
@@ -40,8 +40,8 @@ import { ResizableSection } from "@/components/resizable/ResizableSection";
|
|||||||
import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader";
|
import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader";
|
||||||
import { ChatIntro } from "./ChatIntro";
|
import { ChatIntro } from "./ChatIntro";
|
||||||
import { HEADER_PADDING } from "@/lib/constants";
|
import { HEADER_PADDING } from "@/lib/constants";
|
||||||
import { getSourcesForPersona } from "@/lib/sources";
|
|
||||||
import { computeAvailableFilters } from "@/lib/filters";
|
import { computeAvailableFilters } from "@/lib/filters";
|
||||||
|
import { useDocumentSelection } from "./useDocumentSelection";
|
||||||
|
|
||||||
const MAX_INPUT_HEIGHT = 200;
|
const MAX_INPUT_HEIGHT = 200;
|
||||||
|
|
||||||
@@ -138,9 +138,6 @@ export const Chat = ({
|
|||||||
selectedMessageForDocDisplay
|
selectedMessageForDocDisplay
|
||||||
)
|
)
|
||||||
: { aiMessage: null };
|
: { aiMessage: null };
|
||||||
const [selectedDocuments, setSelectedDocuments] = useState<DanswerDocument[]>(
|
|
||||||
[]
|
|
||||||
);
|
|
||||||
|
|
||||||
const [selectedPersona, setSelectedPersona] = useState<Persona | undefined>(
|
const [selectedPersona, setSelectedPersona] = useState<Persona | undefined>(
|
||||||
existingChatSessionPersonaId !== undefined
|
existingChatSessionPersonaId !== undefined
|
||||||
@@ -165,6 +162,30 @@ export const Chat = ({
|
|||||||
}
|
}
|
||||||
}, [defaultSelectedPersonaId]);
|
}, [defaultSelectedPersonaId]);
|
||||||
|
|
||||||
|
const [
|
||||||
|
selectedDocuments,
|
||||||
|
toggleDocumentSelection,
|
||||||
|
clearSelectedDocuments,
|
||||||
|
selectedDocumentTokens,
|
||||||
|
] = useDocumentSelection();
|
||||||
|
// just choose a conservative default, this will be updated in the
|
||||||
|
// background on initial load / on persona change
|
||||||
|
const [maxTokens, setMaxTokens] = useState<number>(4096);
|
||||||
|
// fetch # of allowed document tokens for the selected Persona
|
||||||
|
useEffect(() => {
|
||||||
|
async function fetchMaxTokens() {
|
||||||
|
const response = await fetch(
|
||||||
|
`/api/chat/max-selected-document-tokens?persona_id=${livePersona.id}`
|
||||||
|
);
|
||||||
|
if (response.ok) {
|
||||||
|
const maxTokens = (await response.json()).max_tokens as number;
|
||||||
|
setMaxTokens(maxTokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fetchMaxTokens();
|
||||||
|
}, [livePersona]);
|
||||||
|
|
||||||
const filterManager = useFilters();
|
const filterManager = useFilters();
|
||||||
const [finalAvailableSources, finalAvailableDocumentSets] =
|
const [finalAvailableSources, finalAvailableDocumentSets] =
|
||||||
computeAvailableFilters({
|
computeAvailableFilters({
|
||||||
@@ -725,7 +746,10 @@ export const Chat = ({
|
|||||||
<DocumentSidebar
|
<DocumentSidebar
|
||||||
selectedMessage={aiMessage}
|
selectedMessage={aiMessage}
|
||||||
selectedDocuments={selectedDocuments}
|
selectedDocuments={selectedDocuments}
|
||||||
setSelectedDocuments={setSelectedDocuments}
|
toggleDocumentSelection={toggleDocumentSelection}
|
||||||
|
clearSelectedDocuments={clearSelectedDocuments}
|
||||||
|
selectedDocumentTokens={selectedDocumentTokens}
|
||||||
|
maxTokens={maxTokens}
|
||||||
isLoading={isFetchingChatMessages}
|
isLoading={isFetchingChatMessages}
|
||||||
/>
|
/>
|
||||||
</ResizableSection>
|
</ResizableSection>
|
||||||
|
@@ -4,7 +4,6 @@ import { PopupSpec } from "@/components/admin/connectors/Popup";
|
|||||||
import { DocumentFeedbackBlock } from "@/components/search/DocumentFeedbackBlock";
|
import { DocumentFeedbackBlock } from "@/components/search/DocumentFeedbackBlock";
|
||||||
import { DocumentUpdatedAtBadge } from "@/components/search/DocumentUpdatedAtBadge";
|
import { DocumentUpdatedAtBadge } from "@/components/search/DocumentUpdatedAtBadge";
|
||||||
import { DanswerDocument } from "@/lib/search/interfaces";
|
import { DanswerDocument } from "@/lib/search/interfaces";
|
||||||
import { useState } from "react";
|
|
||||||
import { FiInfo, FiRadio } from "react-icons/fi";
|
import { FiInfo, FiRadio } from "react-icons/fi";
|
||||||
import { DocumentSelector } from "./DocumentSelector";
|
import { DocumentSelector } from "./DocumentSelector";
|
||||||
import {
|
import {
|
||||||
@@ -19,6 +18,7 @@ interface DocumentDisplayProps {
|
|||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
handleSelect: (documentId: string) => void;
|
handleSelect: (documentId: string) => void;
|
||||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||||
|
tokenLimitReached: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatDocumentDisplay({
|
export function ChatDocumentDisplay({
|
||||||
@@ -28,24 +28,16 @@ export function ChatDocumentDisplay({
|
|||||||
isSelected,
|
isSelected,
|
||||||
handleSelect,
|
handleSelect,
|
||||||
setPopup,
|
setPopup,
|
||||||
|
tokenLimitReached,
|
||||||
}: DocumentDisplayProps) {
|
}: DocumentDisplayProps) {
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
|
||||||
|
|
||||||
// Consider reintroducing null scored docs in the future
|
// Consider reintroducing null scored docs in the future
|
||||||
if (document.score === null) {
|
if (document.score === null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div key={document.semantic_identifier} className="text-sm px-3">
|
||||||
key={document.semantic_identifier}
|
<div className="flex relative w-full overflow-y-visible">
|
||||||
className="text-sm px-3"
|
|
||||||
onMouseEnter={() => {
|
|
||||||
setIsHovered(true);
|
|
||||||
}}
|
|
||||||
onMouseLeave={() => setIsHovered(false)}
|
|
||||||
>
|
|
||||||
<div className="flex relative w-full overflow-x-hidden">
|
|
||||||
<a
|
<a
|
||||||
className={
|
className={
|
||||||
"rounded-lg flex font-bold flex-shrink truncate " +
|
"rounded-lg flex font-bold flex-shrink truncate " +
|
||||||
@@ -102,6 +94,7 @@ export function ChatDocumentDisplay({
|
|||||||
<DocumentSelector
|
<DocumentSelector
|
||||||
isSelected={isSelected}
|
isSelected={isSelected}
|
||||||
handleSelect={() => handleSelect(document.document_id)}
|
handleSelect={() => handleSelect(document.document_id)}
|
||||||
|
isDisabled={tokenLimitReached && !isSelected}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
|
@@ -1,14 +1,36 @@
|
|||||||
|
import { HoverPopup } from "@/components/HoverPopup";
|
||||||
|
import { useState } from "react";
|
||||||
|
|
||||||
export function DocumentSelector({
|
export function DocumentSelector({
|
||||||
isSelected,
|
isSelected,
|
||||||
handleSelect,
|
handleSelect,
|
||||||
|
isDisabled,
|
||||||
}: {
|
}: {
|
||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
handleSelect: () => void;
|
handleSelect: () => void;
|
||||||
|
isDisabled?: boolean;
|
||||||
}) {
|
}) {
|
||||||
|
const [popupDisabled, setPopupDisabled] = useState(false);
|
||||||
|
|
||||||
|
function onClick() {
|
||||||
|
if (!isDisabled) {
|
||||||
|
setPopupDisabled(true);
|
||||||
|
handleSelect();
|
||||||
|
// re-enable popup after 1 second so that we don't show the popup immediately upon the
|
||||||
|
// user de-selecting a document
|
||||||
|
setTimeout(() => {
|
||||||
|
setPopupDisabled(false);
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function Main() {
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className="ml-auto flex cursor-pointer select-none"
|
className={
|
||||||
onClick={handleSelect}
|
"ml-auto flex select-none " + (!isDisabled ? " cursor-pointer" : "")
|
||||||
|
}
|
||||||
|
onClick={onClick}
|
||||||
>
|
>
|
||||||
<p className="mr-2 my-auto">Select</p>
|
<p className="mr-2 my-auto">Select</p>
|
||||||
<input
|
<input
|
||||||
@@ -17,7 +39,28 @@ export function DocumentSelector({
|
|||||||
checked={isSelected}
|
checked={isSelected}
|
||||||
// dummy function to prevent warning
|
// dummy function to prevent warning
|
||||||
onChange={() => null}
|
onChange={() => null}
|
||||||
|
disabled={isDisabled}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isDisabled && !popupDisabled) {
|
||||||
|
return (
|
||||||
|
<div className="ml-auto">
|
||||||
|
<HoverPopup
|
||||||
|
mainContent={Main()}
|
||||||
|
popupContent={
|
||||||
|
<div className="w-48">
|
||||||
|
LLM context limit reached 😔 If you want to chat with this
|
||||||
|
document, please de-select others to free up space.
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
direction="left-top"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Main();
|
||||||
|
}
|
||||||
|
@@ -2,12 +2,13 @@ import { DanswerDocument } from "@/lib/search/interfaces";
|
|||||||
import { Text } from "@tremor/react";
|
import { Text } from "@tremor/react";
|
||||||
import { ChatDocumentDisplay } from "./ChatDocumentDisplay";
|
import { ChatDocumentDisplay } from "./ChatDocumentDisplay";
|
||||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||||
import { FiFileText } from "react-icons/fi";
|
import { FiAlertTriangle, FiFileText } from "react-icons/fi";
|
||||||
import { SelectedDocumentDisplay } from "./SelectedDocumentDisplay";
|
import { SelectedDocumentDisplay } from "./SelectedDocumentDisplay";
|
||||||
import { removeDuplicateDocs } from "@/lib/documentUtils";
|
import { removeDuplicateDocs } from "@/lib/documentUtils";
|
||||||
import { BasicSelectable } from "@/components/BasicClickable";
|
import { BasicSelectable } from "@/components/BasicClickable";
|
||||||
import { Message, RetrievalType } from "../interfaces";
|
import { Message, RetrievalType } from "../interfaces";
|
||||||
import { HEADER_PADDING } from "@/lib/constants";
|
import { HEADER_PADDING } from "@/lib/constants";
|
||||||
|
import { HoverPopup } from "@/components/HoverPopup";
|
||||||
|
|
||||||
function SectionHeader({
|
function SectionHeader({
|
||||||
name,
|
name,
|
||||||
@@ -27,12 +28,18 @@ function SectionHeader({
|
|||||||
export function DocumentSidebar({
|
export function DocumentSidebar({
|
||||||
selectedMessage,
|
selectedMessage,
|
||||||
selectedDocuments,
|
selectedDocuments,
|
||||||
setSelectedDocuments,
|
toggleDocumentSelection,
|
||||||
|
clearSelectedDocuments,
|
||||||
|
selectedDocumentTokens,
|
||||||
|
maxTokens,
|
||||||
isLoading,
|
isLoading,
|
||||||
}: {
|
}: {
|
||||||
selectedMessage: Message | null;
|
selectedMessage: Message | null;
|
||||||
selectedDocuments: DanswerDocument[] | null;
|
selectedDocuments: DanswerDocument[] | null;
|
||||||
setSelectedDocuments: (documents: DanswerDocument[]) => void;
|
toggleDocumentSelection: (document: DanswerDocument) => void;
|
||||||
|
clearSelectedDocuments: () => void;
|
||||||
|
selectedDocumentTokens: number;
|
||||||
|
maxTokens: number;
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
}) {
|
}) {
|
||||||
const { popup, setPopup } = usePopup();
|
const { popup, setPopup } = usePopup();
|
||||||
@@ -44,6 +51,13 @@ export function DocumentSidebar({
|
|||||||
|
|
||||||
const currentDocuments = selectedMessage?.documents || null;
|
const currentDocuments = selectedMessage?.documents || null;
|
||||||
const dedupedDocuments = removeDuplicateDocs(currentDocuments || []);
|
const dedupedDocuments = removeDuplicateDocs(currentDocuments || []);
|
||||||
|
|
||||||
|
// NOTE: do not allow selection if less than 75 tokens are left
|
||||||
|
// this is to prevent the case where they are able to select the doc
|
||||||
|
// but it basically is unused since it's truncated right at the very
|
||||||
|
// start of the document (since title + metadata + misc overhead) takes up
|
||||||
|
// space
|
||||||
|
const tokenLimitReached = selectedDocumentTokens > maxTokens - 75;
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={`
|
className={`
|
||||||
@@ -72,7 +86,7 @@ export function DocumentSidebar({
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{currentDocuments ? (
|
{currentDocuments ? (
|
||||||
<div className="overflow-y-auto dark-scrollbar overflow-x-hidden flex flex-col">
|
<div className="overflow-y-auto dark-scrollbar flex flex-col">
|
||||||
<div>
|
<div>
|
||||||
{dedupedDocuments.length > 0 ? (
|
{dedupedDocuments.length > 0 ? (
|
||||||
dedupedDocuments.map((document, ind) => (
|
dedupedDocuments.map((document, ind) => (
|
||||||
@@ -93,21 +107,13 @@ export function DocumentSidebar({
|
|||||||
document.document_id
|
document.document_id
|
||||||
)}
|
)}
|
||||||
handleSelect={(documentId) => {
|
handleSelect={(documentId) => {
|
||||||
if (selectedDocumentIds.includes(documentId)) {
|
toggleDocumentSelection(
|
||||||
setSelectedDocuments(
|
dedupedDocuments.find(
|
||||||
selectedDocuments!.filter(
|
|
||||||
(document) => document.document_id !== documentId
|
|
||||||
)
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
setSelectedDocuments([
|
|
||||||
...selectedDocuments!,
|
|
||||||
currentDocuments.find(
|
|
||||||
(document) => document.document_id === documentId
|
(document) => document.document_id === documentId
|
||||||
)!,
|
)!
|
||||||
]);
|
);
|
||||||
}
|
|
||||||
}}
|
}}
|
||||||
|
tokenLimitReached={tokenLimitReached}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
))
|
))
|
||||||
@@ -132,15 +138,48 @@ export function DocumentSidebar({
|
|||||||
|
|
||||||
<div className="text-sm mb-4 border-t border-border pt-4 overflow-y-hidden flex flex-col">
|
<div className="text-sm mb-4 border-t border-border pt-4 overflow-y-hidden flex flex-col">
|
||||||
<div className="flex border-b border-border px-3">
|
<div className="flex border-b border-border px-3">
|
||||||
<div>
|
<div className="flex">
|
||||||
<SectionHeader name="Selected Documents" icon={FiFileText} />
|
<SectionHeader name="Selected Documents" icon={FiFileText} />
|
||||||
|
|
||||||
|
{tokenLimitReached && (
|
||||||
|
<div className="ml-2 my-auto">
|
||||||
|
<div className="mb-2">
|
||||||
|
<HoverPopup
|
||||||
|
mainContent={
|
||||||
|
<FiAlertTriangle
|
||||||
|
className="text-alert my-auto"
|
||||||
|
size="16"
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
popupContent={
|
||||||
|
<Text className="w-40">
|
||||||
|
Over LLM context length by:{" "}
|
||||||
|
<i>{selectedDocumentTokens - maxTokens} tokens</i>
|
||||||
|
<br />
|
||||||
|
<br />
|
||||||
|
{selectedDocuments && selectedDocuments.length > 0 && (
|
||||||
|
<>
|
||||||
|
Truncating: "
|
||||||
|
<i>
|
||||||
|
{
|
||||||
|
selectedDocuments[selectedDocuments.length - 1]
|
||||||
|
.semantic_identifier
|
||||||
|
}
|
||||||
|
</i>
|
||||||
|
"
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</Text>
|
||||||
|
}
|
||||||
|
direction="left"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{selectedDocuments && selectedDocuments.length > 0 && (
|
{selectedDocuments && selectedDocuments.length > 0 && (
|
||||||
<div
|
<div className="ml-auto my-auto" onClick={clearSelectedDocuments}>
|
||||||
className="ml-auto my-auto"
|
|
||||||
onClick={() => setSelectedDocuments([])}
|
|
||||||
>
|
|
||||||
<BasicSelectable selected={false}>De-Select All</BasicSelectable>
|
<BasicSelectable selected={false}>De-Select All</BasicSelectable>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -153,10 +192,10 @@ export function DocumentSidebar({
|
|||||||
key={document.document_id}
|
key={document.document_id}
|
||||||
document={document}
|
document={document}
|
||||||
handleDeselect={(documentId) => {
|
handleDeselect={(documentId) => {
|
||||||
setSelectedDocuments(
|
toggleDocumentSelection(
|
||||||
selectedDocuments!.filter(
|
dedupedDocuments.find(
|
||||||
(document) => document.document_id !== documentId
|
(document) => document.document_id === documentId
|
||||||
)
|
)!
|
||||||
);
|
);
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
@@ -49,7 +49,7 @@ export default async function Page({
|
|||||||
| AuthTypeMetadata
|
| AuthTypeMetadata
|
||||||
| FullEmbeddingModelResponse
|
| FullEmbeddingModelResponse
|
||||||
| null
|
| null
|
||||||
)[] = [null, null, null, null, null, null, null, null];
|
)[] = [null, null, null, null, null, null, null, null, null];
|
||||||
try {
|
try {
|
||||||
results = await Promise.all(tasks);
|
results = await Promise.all(tasks);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
|
69
web/src/app/chat/useDocumentSelection.ts
Normal file
69
web/src/app/chat/useDocumentSelection.ts
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import { DanswerDocument } from "@/lib/search/interfaces";
|
||||||
|
import { useState } from "react";
|
||||||
|
|
||||||
|
interface DocumentInfo {
|
||||||
|
num_chunks: number;
|
||||||
|
num_tokens: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function fetchDocumentLength(documentId: string) {
|
||||||
|
const response = await fetch(
|
||||||
|
`/api/document/document-size-info?document_id=${documentId}`
|
||||||
|
);
|
||||||
|
if (!response.ok) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
const data = (await response.json()) as DocumentInfo;
|
||||||
|
return data.num_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useDocumentSelection(): [
|
||||||
|
DanswerDocument[],
|
||||||
|
(document: DanswerDocument) => void,
|
||||||
|
() => void,
|
||||||
|
number,
|
||||||
|
] {
|
||||||
|
const [selectedDocuments, setSelectedDocuments] = useState<DanswerDocument[]>(
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
const [totalTokens, setTotalTokens] = useState(0);
|
||||||
|
const selectedDocumentIds = selectedDocuments.map(
|
||||||
|
(document) => document.document_id
|
||||||
|
);
|
||||||
|
const documentIdToLength = new Map<string, number>();
|
||||||
|
|
||||||
|
function toggleDocumentSelection(document: DanswerDocument) {
|
||||||
|
const documentId = document.document_id;
|
||||||
|
const isAdding = !selectedDocumentIds.includes(documentId);
|
||||||
|
if (!isAdding) {
|
||||||
|
setSelectedDocuments(
|
||||||
|
selectedDocuments.filter(
|
||||||
|
(document) => document.document_id !== documentId
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
setSelectedDocuments([...selectedDocuments, document]);
|
||||||
|
}
|
||||||
|
if (documentIdToLength.has(documentId)) {
|
||||||
|
const length = documentIdToLength.get(documentId)!;
|
||||||
|
setTotalTokens(isAdding ? totalTokens + length : totalTokens - length);
|
||||||
|
} else {
|
||||||
|
fetchDocumentLength(documentId).then((length) => {
|
||||||
|
documentIdToLength.set(documentId, length);
|
||||||
|
setTotalTokens(isAdding ? totalTokens + length : totalTokens - length);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function clearDocuments() {
|
||||||
|
setSelectedDocuments([]);
|
||||||
|
setTotalTokens(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return [
|
||||||
|
selectedDocuments,
|
||||||
|
toggleDocumentSelection,
|
||||||
|
clearDocuments,
|
||||||
|
totalTokens,
|
||||||
|
];
|
||||||
|
}
|
@@ -4,7 +4,7 @@ interface HoverPopupProps {
|
|||||||
mainContent: string | JSX.Element;
|
mainContent: string | JSX.Element;
|
||||||
popupContent: string | JSX.Element;
|
popupContent: string | JSX.Element;
|
||||||
classNameModifications?: string;
|
classNameModifications?: string;
|
||||||
direction?: "left" | "bottom" | "top";
|
direction?: "left" | "left-top" | "bottom" | "top";
|
||||||
style?: "basic" | "dark";
|
style?: "basic" | "dark";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -18,9 +18,15 @@ export const HoverPopup = ({
|
|||||||
const [hovered, setHovered] = useState(false);
|
const [hovered, setHovered] = useState(false);
|
||||||
|
|
||||||
let popupDirectionClass;
|
let popupDirectionClass;
|
||||||
|
let popupStyle = {};
|
||||||
switch (direction) {
|
switch (direction) {
|
||||||
case "left":
|
case "left":
|
||||||
popupDirectionClass = "top-0 left-0 transform translate-x-[-110%]";
|
popupDirectionClass = "top-0 left-0 transform";
|
||||||
|
popupStyle = { transform: "translateX(calc(-100% - 5px))" };
|
||||||
|
break;
|
||||||
|
case "left-top":
|
||||||
|
popupDirectionClass = "bottom-0 left-0";
|
||||||
|
popupStyle = { transform: "translate(calc(-100% - 5px), 0)" };
|
||||||
break;
|
break;
|
||||||
case "bottom":
|
case "bottom":
|
||||||
popupDirectionClass = "top-0 left-0 mt-6 pt-2";
|
popupDirectionClass = "top-0 left-0 mt-6 pt-2";
|
||||||
@@ -39,7 +45,10 @@ export const HoverPopup = ({
|
|||||||
onMouseLeave={() => setHovered(false)}
|
onMouseLeave={() => setHovered(false)}
|
||||||
>
|
>
|
||||||
{hovered && (
|
{hovered && (
|
||||||
<div className={`absolute ${popupDirectionClass} z-30`}>
|
<div
|
||||||
|
className={`absolute ${popupDirectionClass} z-30`}
|
||||||
|
style={popupStyle}
|
||||||
|
>
|
||||||
<div
|
<div
|
||||||
className={
|
className={
|
||||||
`px-3 py-2 rounded bg-background border border-border` +
|
`px-3 py-2 rounded bg-background border border-border` +
|
||||||
|
@@ -59,6 +59,7 @@ module.exports = {
|
|||||||
},
|
},
|
||||||
error: "#ef4444", // red-500
|
error: "#ef4444", // red-500
|
||||||
success: "#059669", // emerald-600
|
success: "#059669", // emerald-600
|
||||||
|
alert: "#f59e0b", // amber-600
|
||||||
user: "#fb7185", // yellow-400
|
user: "#fb7185", // yellow-400
|
||||||
ai: "#60a5fa", // blue-400
|
ai: "#60a5fa", // blue-400
|
||||||
// light mode
|
// light mode
|
||||||
|
Reference in New Issue
Block a user