From e54ce779fd6197f1dc86d2a4b64c99de4e92a411 Mon Sep 17 00:00:00 2001 From: Weves Date: Fri, 2 Feb 2024 20:17:59 -0800 Subject: [PATCH] Enable selection of long documents --- backend/danswer/chat/chat_utils.py | 141 +++++++++++++++--- backend/danswer/chat/process_message.py | 88 +++++++++-- backend/danswer/configs/model_configs.py | 6 + backend/danswer/llm/utils.py | 18 ++- backend/danswer/one_shot_answer/qa_block.py | 6 +- backend/danswer/prompts/token_counts.py | 24 +++ backend/danswer/server/documents/document.py | 15 +- .../server/query_and_chat/chat_backend.py | 27 ++++ web/src/app/chat/Chat.tsx | 34 ++++- .../documentSidebar/ChatDocumentDisplay.tsx | 19 +-- .../chat/documentSidebar/DocumentSelector.tsx | 73 +++++++-- .../chat/documentSidebar/DocumentSidebar.tsx | 93 ++++++++---- web/src/app/chat/page.tsx | 2 +- web/src/app/chat/useDocumentSelection.ts | 69 +++++++++ web/src/components/HoverPopup.tsx | 15 +- web/tailwind.config.js | 1 + 16 files changed, 523 insertions(+), 108 deletions(-) create mode 100644 backend/danswer/prompts/token_counts.py create mode 100644 web/src/app/chat/useDocumentSelection.ts diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index 235cce9b3..fb2a32c23 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -1,6 +1,7 @@ import re from collections.abc import Callable from collections.abc import Iterator +from datetime import datetime from functools import lru_cache from typing import cast @@ -15,14 +16,19 @@ from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.chat_configs import STOP_STREAM_PAT +from danswer.configs.constants import DocumentSource from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF -from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.chat import get_chat_messages_by_session from danswer.db.models import ChatMessage +from danswer.db.models import Persona from danswer.db.models import Prompt from danswer.indexing.models import InferenceChunk from danswer.llm.utils import check_number_of_tokens +from danswer.llm.utils import get_llm_max_tokens from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT from danswer.prompts.chat_prompts import CHAT_USER_PROMPT from danswer.prompts.chat_prompts import CITATION_REMINDER @@ -33,6 +39,12 @@ from danswer.prompts.constants import CODE_BLOCK_PAT from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT from danswer.prompts.prompt_utils import get_current_llm_day_time +from danswer.prompts.token_counts import ( + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, +) +from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT +from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT +from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT # Maps connector enum string to a more natural language representation for the LLM # If not on the list, uses the original but slightly cleaned up, see below @@ -50,19 +62,39 @@ def clean_up_source(source_str: str) -> str: return source_str.replace("_", " ").title() -def build_context_str( +def build_doc_context_str( + semantic_identifier: str, + source_type: DocumentSource, + content: str, + ind: int, + include_metadata: bool = True, + updated_at: datetime | None = None, +) -> str: + context_str = "" + if include_metadata: + context_str += f"DOCUMENT {ind}: {semantic_identifier}\n" + context_str += f"Source: {clean_up_source(source_type)}\n" + if updated_at: + update_str = updated_at.strftime("%B %d, %Y %H:%M") + context_str += f"Updated: {update_str}\n" + context_str += f"{CODE_BLOCK_PAT.format(content.strip())}\n\n\n" + return context_str + + +def build_complete_context_str( context_docs: list[LlmDoc | InferenceChunk], include_metadata: bool = True, ) -> str: context_str = "" for ind, doc in enumerate(context_docs, start=1): - if include_metadata: - context_str += f"DOCUMENT {ind}: {doc.semantic_identifier}\n" - context_str += f"Source: {clean_up_source(doc.source_type)}\n" - if doc.updated_at: - update_str = doc.updated_at.strftime("%B %d, %Y %H:%M") - context_str += f"Updated: {update_str}\n" - context_str += f"{CODE_BLOCK_PAT.format(doc.content.strip())}\n\n\n" + context_str += build_doc_context_str( + semantic_identifier=doc.semantic_identifier, + source_type=doc.source_type, + content=doc.content, + updated_at=doc.updated_at, + ind=ind, + include_metadata=include_metadata, + ) return context_str.strip() @@ -71,7 +103,7 @@ def build_context_str( def build_chat_system_message( prompt: Prompt, context_exists: bool, - llm_tokenizer: Callable, + llm_tokenizer_encode_func: Callable, citation_line: str = REQUIRE_CITATION_STATEMENT, no_citation_line: str = NO_CITATION_STATEMENT, ) -> tuple[SystemMessage | None, int]: @@ -92,7 +124,7 @@ def build_chat_system_message( if not system_prompt: return None, 0 - token_count = len(llm_tokenizer(system_prompt)) + token_count = len(llm_tokenizer_encode_func(system_prompt)) system_msg = SystemMessage(content=system_prompt) return system_msg, token_count @@ -138,7 +170,7 @@ def build_chat_user_message( chat_message: ChatMessage, prompt: Prompt, context_docs: list[LlmDoc], - llm_tokenizer: Callable, + llm_tokenizer_encode_func: Callable, all_doc_useful: bool, user_prompt_template: str = CHAT_USER_PROMPT, context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT, @@ -156,11 +188,11 @@ def build_chat_user_message( else user_query ) user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer(user_prompt)) + token_count = len(llm_tokenizer_encode_func(user_prompt)) user_msg = HumanMessage(content=user_prompt) return user_msg, token_count - context_docs_str = build_context_str( + context_docs_str = build_complete_context_str( cast(list[LlmDoc | InferenceChunk], context_docs) ) optional_ignore = "" if all_doc_useful else ignore_str @@ -175,7 +207,7 @@ def build_chat_user_message( ) user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer(user_prompt)) + token_count = len(llm_tokenizer_encode_func(user_prompt)) user_msg = HumanMessage(content=user_prompt) return user_msg, token_count @@ -357,16 +389,17 @@ def combine_message_chain( return "\n\n".join(message_strs) -def find_last_index( - lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS -) -> int: +_PER_MESSAGE_TOKEN_BUFFER = 7 + + +def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: """From the back, find the index of the last element to include before the list exceeds the maximum""" running_sum = 0 last_ind = 0 for i in range(len(lst) - 1, -1, -1): - running_sum += lst[i] + running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER if running_sum > max_prompt_tokens: last_ind = i + 1 break @@ -382,14 +415,11 @@ def drop_messages_history_overflow( history_token_counts: list[int], final_msg: BaseMessage, final_msg_token_count: int, - max_allowed_tokens: int | None, + max_allowed_tokens: int, ) -> list[BaseMessage]: """As message history grows, messages need to be dropped starting from the furthest in the past. The System message should be kept if at all possible and the latest user input which is inserted in the prompt template must be included""" - if max_allowed_tokens is None: - max_allowed_tokens = GEN_AI_MAX_INPUT_TOKENS - if len(history_msgs) != len(history_token_counts): # This should never happen raise ValueError("Need exactly 1 token count per message for tracking overflow") @@ -508,3 +538,68 @@ def extract_citations_from_stream( yield DanswerAnswerPiece(answer_piece="[" + curr_segment) else: yield DanswerAnswerPiece(answer_piece=curr_segment) + + +def get_prompt_tokens(prompt: Prompt) -> int: + return ( + check_number_of_tokens(prompt.system_prompt) + + check_number_of_tokens(prompt.task_prompt) + + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT + + CITATION_STATEMENT_TOKEN_CNT + + CITATION_REMINDER_TOKEN_CNT + + (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0) + ) + + +# buffer just to be safe so that we don't overflow the token limit due to +# a small miscalculation +_MISC_BUFFER = 40 + + +def compute_max_document_tokens( + persona: Persona, actual_user_input: str | None = None +) -> int: + """Estimates the number of tokens available for context documents. Formula is roughly: + + ( + model_context_window - reserved_output_tokens - prompt_tokens + - (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe) + ) + + The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g. + if we're trying to determine if the user should be able to select another document) then we just set an + arbitrary "upper bound". + """ + llm_name = GEN_AI_MODEL_VERSION + if persona.llm_model_version_override: + llm_name = persona.llm_model_version_override + + # if we can't find a number of tokens, just assume some common default + model_full_context_window = get_llm_max_tokens(llm_name) or 4096 + if persona.prompts: + prompt_tokens = get_prompt_tokens(persona.prompts[0]) + else: + raise RuntimeError("Persona has no prompts - this should never happen") + user_input_tokens = ( + check_number_of_tokens(actual_user_input) + if actual_user_input is not None + else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS + ) + + return ( + model_full_context_window + - GEN_AI_MAX_OUTPUT_TOKENS + - prompt_tokens + - user_input_tokens + - _MISC_BUFFER + ) + + +def compute_max_llm_input_tokens(persona: Persona) -> int: + """Maximum tokens allows in the input to the LLM (of any type).""" + llm_name = GEN_AI_MODEL_VERSION + if persona.llm_model_version_override: + llm_name = persona.llm_model_version_override + + model_full_context_window = get_llm_max_tokens(llm_name) or 4096 + return model_full_context_window - GEN_AI_MAX_OUTPUT_TOKENS - _MISC_BUFFER diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 4bde92cbd..d24c3ca9d 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -7,6 +7,9 @@ from sqlalchemy.orm import Session from danswer.chat.chat_utils import build_chat_system_message from danswer.chat.chat_utils import build_chat_user_message +from danswer.chat.chat_utils import build_doc_context_str +from danswer.chat.chat_utils import compute_max_document_tokens +from danswer.chat.chat_utils import compute_max_llm_input_tokens from danswer.chat.chat_utils import create_chat_chain from danswer.chat.chat_utils import drop_messages_history_overflow from danswer.chat.chat_utils import extract_citations_from_stream @@ -42,8 +45,8 @@ from danswer.indexing.models import InferenceChunk from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.interfaces import LLM -from danswer.llm.utils import get_default_llm_token_encode -from danswer.llm.utils import get_llm_max_tokens +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import tokenizer_trim_content from danswer.llm.utils import translate_history_to_basemessages from danswer.search.models import OptionalSearchSetting from danswer.search.models import RetrievalDetails @@ -68,7 +71,7 @@ def generate_ai_chat_response( context_docs: list[LlmDoc], doc_id_to_rank_map: dict[str, int], llm: LLM | None, - llm_tokenizer: Callable, + llm_tokenizer_encode_func: Callable, all_doc_useful: bool, ) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]: if llm is None: @@ -88,7 +91,7 @@ def generate_ai_chat_response( system_message_or_none, system_tokens = build_chat_system_message( prompt=query_message.prompt, context_exists=context_exists, - llm_tokenizer=llm_tokenizer, + llm_tokenizer_encode_func=llm_tokenizer_encode_func, ) history_basemessages, history_token_counts = translate_history_to_basemessages( @@ -101,7 +104,7 @@ def generate_ai_chat_response( chat_message=query_message, prompt=query_message.prompt, context_docs=context_docs, - llm_tokenizer=llm_tokenizer, + llm_tokenizer_encode_func=llm_tokenizer_encode_func, all_doc_useful=all_doc_useful, ) @@ -112,9 +115,7 @@ def generate_ai_chat_response( history_token_counts=history_token_counts, final_msg=user_message, final_msg_token_count=user_tokens, - max_allowed_tokens=get_llm_max_tokens(persona.llm_model_version_override) - if persona.llm_model_version_override - else None, + max_allowed_tokens=compute_max_llm_input_tokens(persona), ) # Good Debug/Breakpoint @@ -195,7 +196,10 @@ def stream_chat_message( except GenAIDisabledException: llm = None - llm_tokenizer = get_default_llm_token_encode() + llm_tokenizer = get_default_llm_tokenizer() + llm_tokenizer_encode_func = cast( + Callable[[str], list[int]], llm_tokenizer.encode + ) embedding_model = get_current_db_embedding_model(db_session) document_index = get_default_document_index( @@ -223,7 +227,7 @@ def stream_chat_message( parent_message=parent_message, prompt_id=prompt_id, message=message_text, - token_count=len(llm_tokenizer(message_text)), + token_count=len(llm_tokenizer_encode_func(message_text)), message_type=MessageType.USER, db_session=db_session, commit=False, @@ -271,6 +275,66 @@ def stream_chat_message( doc_identifiers=identifier_tuples, document_index=document_index, ) + + # truncate the last document if it exceeds the token limit + max_document_tokens = compute_max_document_tokens( + persona, actual_user_input=message_text + ) + tokens_per_doc = [ + len( + llm_tokenizer_encode_func( + build_doc_context_str( + semantic_identifier=llm_doc.semantic_identifier, + source_type=llm_doc.source_type, + content=llm_doc.content, + updated_at=llm_doc.updated_at, + ind=ind, + ) + ) + ) + for ind, llm_doc in enumerate(llm_docs) + ] + final_doc_ind = None + total_tokens = 0 + for ind, tokens in enumerate(tokens_per_doc): + total_tokens += tokens + if total_tokens > max_document_tokens: + final_doc_ind = ind + break + if final_doc_ind is not None: + # only allow the final document to get truncated + # if more than that, then the user message is too long + if final_doc_ind != len(tokens_per_doc) - 1: + yield get_json_line( + StreamingError( + error="LLM context window exceeded. Please de-select some documents or shorten your query." + ).dict() + ) + return + + final_doc_desired_length = tokens_per_doc[final_doc_ind] - ( + total_tokens - max_document_tokens + ) + # 75 tokens is a reasonable over-estimate of the metadata and title + final_doc_content_length = final_doc_desired_length - 75 + # this could occur if we only have space for the title / metadata + # not ideal, but it's the most reasonable thing to do + # NOTE: the frontend prevents documents from being selected if + # less than 75 tokens are available to try and avoid this situation + # from occuring in the first place + if final_doc_content_length <= 0: + logger.error( + f"Final doc ({llm_docs[final_doc_ind].semantic_identifier}) content " + "length is less than 0. Removing this doc from the final prompt." + ) + llm_docs.pop() + else: + llm_docs[final_doc_ind].content = tokenizer_trim_content( + content=llm_docs[final_doc_ind].content, + desired_length=final_doc_content_length, + tokenizer=llm_tokenizer, + ) + doc_id_to_rank_map = map_document_id_order( cast(list[InferenceChunk | LlmDoc], llm_docs) ) @@ -423,7 +487,7 @@ def stream_chat_message( context_docs=llm_docs, doc_id_to_rank_map=doc_id_to_rank_map, llm=llm, - llm_tokenizer=llm_tokenizer, + llm_tokenizer_encode_func=llm_tokenizer_encode_func, all_doc_useful=reference_doc_ids is not None, ) @@ -467,7 +531,7 @@ def stream_chat_message( # Saving Gen AI answer and responding with message info gen_ai_response_message = partial_response( message=llm_output, - token_count=len(llm_tokenizer(llm_output)), + token_count=len(llm_tokenizer_encode_func(llm_output)), citations=db_citations, error=error, ) diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index b1b2725e3..11d6ce26d 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -104,4 +104,10 @@ GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000) # History for secondary LLM flows, not primary chat flow, generally we don't need to # include as much as possible as this just bumps up the cost unnecessarily GEN_AI_HISTORY_CUTOFF = int(0.5 * GEN_AI_MAX_INPUT_TOKENS) +# This is used when computing how much context space is available for documents +# ahead of time in order to let the user know if they can "select" more documents +# It represents a maximum "expected" number of input tokens from the latest user +# message. At query time, we don't actually enforce this - we will only throw an +# error if the total # of tokens exceeds the max input tokens. +GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS = 512 GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 37963da8d..a8c0b40ab 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -35,7 +35,7 @@ _LLM_TOKENIZER: Any = None _LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None -def get_default_llm_tokenizer() -> Any: +def get_default_llm_tokenizer() -> Encoding: """Currently only supports the OpenAI default tokenizer: tiktoken""" global _LLM_TOKENIZER if _LLM_TOKENIZER is None: @@ -56,16 +56,26 @@ def get_default_llm_token_encode() -> Callable[[str], Any]: return _LLM_TOKENIZER_ENCODE +def tokenizer_trim_content( + content: str, desired_length: int, tokenizer: Encoding +) -> str: + tokenizer = get_default_llm_tokenizer() + tokens = tokenizer.encode(content) + if len(tokens) > desired_length: + content = tokenizer.decode(tokens[:desired_length]) + return content + + def tokenizer_trim_chunks( chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE ) -> list[InferenceChunk]: tokenizer = get_default_llm_tokenizer() new_chunks = copy(chunks) for ind, chunk in enumerate(new_chunks): - tokens = tokenizer.encode(chunk.content) - if len(tokens) > max_chunk_toks: + new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer) + if len(new_content) != len(chunk.content): new_chunk = copy(chunk) - new_chunk.content = tokenizer.decode(tokens[:max_chunk_toks]) + new_chunk.content = new_content new_chunks[ind] = new_chunk return new_chunks diff --git a/backend/danswer/one_shot_answer/qa_block.py b/backend/danswer/one_shot_answer/qa_block.py index 3a9bfdf03..c7b702d26 100644 --- a/backend/danswer/one_shot_answer/qa_block.py +++ b/backend/danswer/one_shot_answer/qa_block.py @@ -4,7 +4,7 @@ from collections.abc import Callable from collections.abc import Iterator from typing import cast -from danswer.chat.chat_utils import build_context_str +from danswer.chat.chat_utils import build_complete_context_str from danswer.chat.models import AnswerQuestionStreamReturn from danswer.chat.models import DanswerAnswer from danswer.chat.models import DanswerAnswerPiece @@ -145,7 +145,7 @@ class SingleMessageQAHandler(QAHandler): ) -> str: context_block = "" if context_chunks: - context_docs_str = build_context_str( + context_docs_str = build_complete_context_str( cast(list[LlmDoc | InferenceChunk], context_chunks) ) context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str) @@ -194,7 +194,7 @@ class SingleMessageScratchpadHandler(QAHandler): def build_prompt( self, query: str, history_str: str, context_chunks: list[InferenceChunk] ) -> str: - context_docs_str = build_context_str( + context_docs_str = build_complete_context_str( cast(list[LlmDoc | InferenceChunk], context_chunks) ) diff --git a/backend/danswer/prompts/token_counts.py b/backend/danswer/prompts/token_counts.py new file mode 100644 index 000000000..35d082b8d --- /dev/null +++ b/backend/danswer/prompts/token_counts.py @@ -0,0 +1,24 @@ +from danswer.llm.utils import check_number_of_tokens +from danswer.prompts.chat_prompts import CHAT_USER_PROMPT +from danswer.prompts.chat_prompts import CITATION_REMINDER +from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT +from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT +from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT + + +# tokens outside of the actual persona's "user_prompt" that make up the end +# user message +CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT = check_number_of_tokens( + CHAT_USER_PROMPT.format( + context_docs_str="", + task_prompt="", + user_query="", + optional_ignore_statement=DEFAULT_IGNORE_STATEMENT, + ) +) + +CITATION_STATEMENT_TOKEN_CNT = check_number_of_tokens(REQUIRE_CITATION_STATEMENT) + +CITATION_REMINDER_TOKEN_CNT = check_number_of_tokens(CITATION_REMINDER) + +LANGUAGE_HINT_TOKEN_CNT = check_number_of_tokens(LANGUAGE_HINT) diff --git a/backend/danswer/server/documents/document.py b/backend/danswer/server/documents/document.py index 2778beaa0..05232fb2b 100644 --- a/backend/danswer/server/documents/document.py +++ b/backend/danswer/server/documents/document.py @@ -5,6 +5,7 @@ from fastapi import Query from sqlalchemy.orm import Session from danswer.auth.users import current_user +from danswer.chat.chat_utils import build_doc_context_str from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session from danswer.db.models import User @@ -47,12 +48,22 @@ def get_document_info( contents = [chunk.content for chunk in inference_chunks] - combined = "\n".join(contents) + combined_contents = "\n".join(contents) + # get actual document context used for LLM + first_chunk = inference_chunks[0] tokenizer_encode = get_default_llm_token_encode() + full_context_str = build_doc_context_str( + semantic_identifier=first_chunk.semantic_identifier, + source_type=first_chunk.source_type, + content=combined_contents, + updated_at=first_chunk.updated_at, + ind=0, + ) return DocumentInfo( - num_chunks=len(inference_chunks), num_tokens=len(tokenizer_encode(combined)) + num_chunks=len(inference_chunks), + num_tokens=len(tokenizer_encode(full_context_str)), ) diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 726c3342d..66c69fa87 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -2,9 +2,11 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi.responses import StreamingResponse +from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.auth.users import current_user +from danswer.chat.chat_utils import compute_max_document_tokens from danswer.chat.chat_utils import create_chat_chain from danswer.chat.process_message import stream_chat_message from danswer.db.chat import create_chat_session @@ -13,6 +15,7 @@ from danswer.db.chat import get_chat_message from danswer.db.chat import get_chat_messages_by_session from danswer.db.chat import get_chat_session_by_id from danswer.db.chat import get_chat_sessions_by_user +from danswer.db.chat import get_persona_by_id from danswer.db.chat import set_as_latest_chat_message from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import update_chat_session @@ -244,3 +247,27 @@ def create_search_feedback( document_index=document_index, db_session=db_session, ) + + +class MaxSelectedDocumentTokens(BaseModel): + max_tokens: int + + +@router.get("/max-selected-document-tokens") +def get_max_document_tokens( + persona_id: int, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> MaxSelectedDocumentTokens: + try: + persona = get_persona_by_id( + persona_id=persona_id, + user_id=user.id if user else None, + db_session=db_session, + ) + except ValueError: + raise HTTPException(status_code=404, detail="Persona not found") + + return MaxSelectedDocumentTokens( + max_tokens=compute_max_document_tokens(persona), + ) diff --git a/web/src/app/chat/Chat.tsx b/web/src/app/chat/Chat.tsx index 8bbd6592e..8192d2cff 100644 --- a/web/src/app/chat/Chat.tsx +++ b/web/src/app/chat/Chat.tsx @@ -40,8 +40,8 @@ import { ResizableSection } from "@/components/resizable/ResizableSection"; import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader"; import { ChatIntro } from "./ChatIntro"; import { HEADER_PADDING } from "@/lib/constants"; -import { getSourcesForPersona } from "@/lib/sources"; import { computeAvailableFilters } from "@/lib/filters"; +import { useDocumentSelection } from "./useDocumentSelection"; const MAX_INPUT_HEIGHT = 200; @@ -138,9 +138,6 @@ export const Chat = ({ selectedMessageForDocDisplay ) : { aiMessage: null }; - const [selectedDocuments, setSelectedDocuments] = useState( - [] - ); const [selectedPersona, setSelectedPersona] = useState( existingChatSessionPersonaId !== undefined @@ -165,6 +162,30 @@ export const Chat = ({ } }, [defaultSelectedPersonaId]); + const [ + selectedDocuments, + toggleDocumentSelection, + clearSelectedDocuments, + selectedDocumentTokens, + ] = useDocumentSelection(); + // just choose a conservative default, this will be updated in the + // background on initial load / on persona change + const [maxTokens, setMaxTokens] = useState(4096); + // fetch # of allowed document tokens for the selected Persona + useEffect(() => { + async function fetchMaxTokens() { + const response = await fetch( + `/api/chat/max-selected-document-tokens?persona_id=${livePersona.id}` + ); + if (response.ok) { + const maxTokens = (await response.json()).max_tokens as number; + setMaxTokens(maxTokens); + } + } + + fetchMaxTokens(); + }, [livePersona]); + const filterManager = useFilters(); const [finalAvailableSources, finalAvailableDocumentSets] = computeAvailableFilters({ @@ -725,7 +746,10 @@ export const Chat = ({ diff --git a/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx b/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx index f653cfe8d..a03f695f8 100644 --- a/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx +++ b/web/src/app/chat/documentSidebar/ChatDocumentDisplay.tsx @@ -4,7 +4,6 @@ import { PopupSpec } from "@/components/admin/connectors/Popup"; import { DocumentFeedbackBlock } from "@/components/search/DocumentFeedbackBlock"; import { DocumentUpdatedAtBadge } from "@/components/search/DocumentUpdatedAtBadge"; import { DanswerDocument } from "@/lib/search/interfaces"; -import { useState } from "react"; import { FiInfo, FiRadio } from "react-icons/fi"; import { DocumentSelector } from "./DocumentSelector"; import { @@ -19,6 +18,7 @@ interface DocumentDisplayProps { isSelected: boolean; handleSelect: (documentId: string) => void; setPopup: (popupSpec: PopupSpec | null) => void; + tokenLimitReached: boolean; } export function ChatDocumentDisplay({ @@ -28,27 +28,19 @@ export function ChatDocumentDisplay({ isSelected, handleSelect, setPopup, + tokenLimitReached, }: DocumentDisplayProps) { - const [isHovered, setIsHovered] = useState(false); - // Consider reintroducing null scored docs in the future if (document.score === null) { return null; } return ( -
{ - setIsHovered(true); - }} - onMouseLeave={() => setIsHovered(false)} - > -
+
+
diff --git a/web/src/app/chat/documentSidebar/DocumentSelector.tsx b/web/src/app/chat/documentSidebar/DocumentSelector.tsx index 29f7fe3b0..833c6a7ca 100644 --- a/web/src/app/chat/documentSidebar/DocumentSelector.tsx +++ b/web/src/app/chat/documentSidebar/DocumentSelector.tsx @@ -1,23 +1,66 @@ +import { HoverPopup } from "@/components/HoverPopup"; +import { useState } from "react"; + export function DocumentSelector({ isSelected, handleSelect, + isDisabled, }: { isSelected: boolean; handleSelect: () => void; + isDisabled?: boolean; }) { - return ( -
-

Select

- null} - /> -
- ); + const [popupDisabled, setPopupDisabled] = useState(false); + + function onClick() { + if (!isDisabled) { + setPopupDisabled(true); + handleSelect(); + // re-enable popup after 1 second so that we don't show the popup immediately upon the + // user de-selecting a document + setTimeout(() => { + setPopupDisabled(false); + }, 1000); + } + } + + function Main() { + return ( +
+

Select

+ null} + disabled={isDisabled} + /> +
+ ); + } + + if (isDisabled && !popupDisabled) { + return ( +
+ + LLM context limit reached 😔 If you want to chat with this + document, please de-select others to free up space. +
+ } + direction="left-top" + /> +
+ ); + } + + return Main(); } diff --git a/web/src/app/chat/documentSidebar/DocumentSidebar.tsx b/web/src/app/chat/documentSidebar/DocumentSidebar.tsx index c00b93f75..34660c4a3 100644 --- a/web/src/app/chat/documentSidebar/DocumentSidebar.tsx +++ b/web/src/app/chat/documentSidebar/DocumentSidebar.tsx @@ -2,12 +2,13 @@ import { DanswerDocument } from "@/lib/search/interfaces"; import { Text } from "@tremor/react"; import { ChatDocumentDisplay } from "./ChatDocumentDisplay"; import { usePopup } from "@/components/admin/connectors/Popup"; -import { FiFileText } from "react-icons/fi"; +import { FiAlertTriangle, FiFileText } from "react-icons/fi"; import { SelectedDocumentDisplay } from "./SelectedDocumentDisplay"; import { removeDuplicateDocs } from "@/lib/documentUtils"; import { BasicSelectable } from "@/components/BasicClickable"; import { Message, RetrievalType } from "../interfaces"; import { HEADER_PADDING } from "@/lib/constants"; +import { HoverPopup } from "@/components/HoverPopup"; function SectionHeader({ name, @@ -27,12 +28,18 @@ function SectionHeader({ export function DocumentSidebar({ selectedMessage, selectedDocuments, - setSelectedDocuments, + toggleDocumentSelection, + clearSelectedDocuments, + selectedDocumentTokens, + maxTokens, isLoading, }: { selectedMessage: Message | null; selectedDocuments: DanswerDocument[] | null; - setSelectedDocuments: (documents: DanswerDocument[]) => void; + toggleDocumentSelection: (document: DanswerDocument) => void; + clearSelectedDocuments: () => void; + selectedDocumentTokens: number; + maxTokens: number; isLoading: boolean; }) { const { popup, setPopup } = usePopup(); @@ -44,6 +51,13 @@ export function DocumentSidebar({ const currentDocuments = selectedMessage?.documents || null; const dedupedDocuments = removeDuplicateDocs(currentDocuments || []); + + // NOTE: do not allow selection if less than 75 tokens are left + // this is to prevent the case where they are able to select the doc + // but it basically is unused since it's truncated right at the very + // start of the document (since title + metadata + misc overhead) takes up + // space + const tokenLimitReached = selectedDocumentTokens > maxTokens - 75; return (
{currentDocuments ? ( -
+
{dedupedDocuments.length > 0 ? ( dedupedDocuments.map((document, ind) => ( @@ -93,21 +107,13 @@ export function DocumentSidebar({ document.document_id )} handleSelect={(documentId) => { - if (selectedDocumentIds.includes(documentId)) { - setSelectedDocuments( - selectedDocuments!.filter( - (document) => document.document_id !== documentId - ) - ); - } else { - setSelectedDocuments([ - ...selectedDocuments!, - currentDocuments.find( - (document) => document.document_id === documentId - )!, - ]); - } + toggleDocumentSelection( + dedupedDocuments.find( + (document) => document.document_id === documentId + )! + ); }} + tokenLimitReached={tokenLimitReached} />
)) @@ -132,15 +138,48 @@ export function DocumentSidebar({
-
+
+ + {tokenLimitReached && ( +
+
+ + } + popupContent={ + + Over LLM context length by:{" "} + {selectedDocumentTokens - maxTokens} tokens +
+
+ {selectedDocuments && selectedDocuments.length > 0 && ( + <> + Truncating: " + + { + selectedDocuments[selectedDocuments.length - 1] + .semantic_identifier + } + + " + + )} +
+ } + direction="left" + /> +
+
+ )}
{selectedDocuments && selectedDocuments.length > 0 && ( -
setSelectedDocuments([])} - > +
De-Select All
)} @@ -153,10 +192,10 @@ export function DocumentSidebar({ key={document.document_id} document={document} handleDeselect={(documentId) => { - setSelectedDocuments( - selectedDocuments!.filter( - (document) => document.document_id !== documentId - ) + toggleDocumentSelection( + dedupedDocuments.find( + (document) => document.document_id === documentId + )! ); }} /> diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index eb8ff8918..5099c7226 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -49,7 +49,7 @@ export default async function Page({ | AuthTypeMetadata | FullEmbeddingModelResponse | null - )[] = [null, null, null, null, null, null, null, null]; + )[] = [null, null, null, null, null, null, null, null, null]; try { results = await Promise.all(tasks); } catch (e) { diff --git a/web/src/app/chat/useDocumentSelection.ts b/web/src/app/chat/useDocumentSelection.ts new file mode 100644 index 000000000..df33f13c3 --- /dev/null +++ b/web/src/app/chat/useDocumentSelection.ts @@ -0,0 +1,69 @@ +import { DanswerDocument } from "@/lib/search/interfaces"; +import { useState } from "react"; + +interface DocumentInfo { + num_chunks: number; + num_tokens: number; +} + +async function fetchDocumentLength(documentId: string) { + const response = await fetch( + `/api/document/document-size-info?document_id=${documentId}` + ); + if (!response.ok) { + return 0; + } + const data = (await response.json()) as DocumentInfo; + return data.num_tokens; +} + +export function useDocumentSelection(): [ + DanswerDocument[], + (document: DanswerDocument) => void, + () => void, + number, +] { + const [selectedDocuments, setSelectedDocuments] = useState( + [] + ); + const [totalTokens, setTotalTokens] = useState(0); + const selectedDocumentIds = selectedDocuments.map( + (document) => document.document_id + ); + const documentIdToLength = new Map(); + + function toggleDocumentSelection(document: DanswerDocument) { + const documentId = document.document_id; + const isAdding = !selectedDocumentIds.includes(documentId); + if (!isAdding) { + setSelectedDocuments( + selectedDocuments.filter( + (document) => document.document_id !== documentId + ) + ); + } else { + setSelectedDocuments([...selectedDocuments, document]); + } + if (documentIdToLength.has(documentId)) { + const length = documentIdToLength.get(documentId)!; + setTotalTokens(isAdding ? totalTokens + length : totalTokens - length); + } else { + fetchDocumentLength(documentId).then((length) => { + documentIdToLength.set(documentId, length); + setTotalTokens(isAdding ? totalTokens + length : totalTokens - length); + }); + } + } + + function clearDocuments() { + setSelectedDocuments([]); + setTotalTokens(0); + } + + return [ + selectedDocuments, + toggleDocumentSelection, + clearDocuments, + totalTokens, + ]; +} diff --git a/web/src/components/HoverPopup.tsx b/web/src/components/HoverPopup.tsx index 25d68e29a..6fac81cd2 100644 --- a/web/src/components/HoverPopup.tsx +++ b/web/src/components/HoverPopup.tsx @@ -4,7 +4,7 @@ interface HoverPopupProps { mainContent: string | JSX.Element; popupContent: string | JSX.Element; classNameModifications?: string; - direction?: "left" | "bottom" | "top"; + direction?: "left" | "left-top" | "bottom" | "top"; style?: "basic" | "dark"; } @@ -18,9 +18,15 @@ export const HoverPopup = ({ const [hovered, setHovered] = useState(false); let popupDirectionClass; + let popupStyle = {}; switch (direction) { case "left": - popupDirectionClass = "top-0 left-0 transform translate-x-[-110%]"; + popupDirectionClass = "top-0 left-0 transform"; + popupStyle = { transform: "translateX(calc(-100% - 5px))" }; + break; + case "left-top": + popupDirectionClass = "bottom-0 left-0"; + popupStyle = { transform: "translate(calc(-100% - 5px), 0)" }; break; case "bottom": popupDirectionClass = "top-0 left-0 mt-6 pt-2"; @@ -39,7 +45,10 @@ export const HoverPopup = ({ onMouseLeave={() => setHovered(false)} > {hovered && ( -
+