Enable selection of long documents

This commit is contained in:
Weves 2024-02-02 20:17:59 -08:00 committed by Chris Weaver
parent 4c9709ae4a
commit e54ce779fd
16 changed files with 523 additions and 108 deletions

View File

@ -1,6 +1,7 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from functools import lru_cache
from typing import cast
@ -15,14 +16,19 @@ from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_llm_max_tokens
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import CITATION_REMINDER
@ -33,6 +39,12 @@ from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.prompt_utils import get_current_llm_day_time
from danswer.prompts.token_counts import (
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
)
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
# Maps connector enum string to a more natural language representation for the LLM
# If not on the list, uses the original but slightly cleaned up, see below
@ -50,19 +62,39 @@ def clean_up_source(source_str: str) -> str:
return source_str.replace("_", " ").title()
def build_context_str(
def build_doc_context_str(
semantic_identifier: str,
source_type: DocumentSource,
content: str,
ind: int,
include_metadata: bool = True,
updated_at: datetime | None = None,
) -> str:
context_str = ""
if include_metadata:
context_str += f"DOCUMENT {ind}: {semantic_identifier}\n"
context_str += f"Source: {clean_up_source(source_type)}\n"
if updated_at:
update_str = updated_at.strftime("%B %d, %Y %H:%M")
context_str += f"Updated: {update_str}\n"
context_str += f"{CODE_BLOCK_PAT.format(content.strip())}\n\n\n"
return context_str
def build_complete_context_str(
context_docs: list[LlmDoc | InferenceChunk],
include_metadata: bool = True,
) -> str:
context_str = ""
for ind, doc in enumerate(context_docs, start=1):
if include_metadata:
context_str += f"DOCUMENT {ind}: {doc.semantic_identifier}\n"
context_str += f"Source: {clean_up_source(doc.source_type)}\n"
if doc.updated_at:
update_str = doc.updated_at.strftime("%B %d, %Y %H:%M")
context_str += f"Updated: {update_str}\n"
context_str += f"{CODE_BLOCK_PAT.format(doc.content.strip())}\n\n\n"
context_str += build_doc_context_str(
semantic_identifier=doc.semantic_identifier,
source_type=doc.source_type,
content=doc.content,
updated_at=doc.updated_at,
ind=ind,
include_metadata=include_metadata,
)
return context_str.strip()
@ -71,7 +103,7 @@ def build_context_str(
def build_chat_system_message(
prompt: Prompt,
context_exists: bool,
llm_tokenizer: Callable,
llm_tokenizer_encode_func: Callable,
citation_line: str = REQUIRE_CITATION_STATEMENT,
no_citation_line: str = NO_CITATION_STATEMENT,
) -> tuple[SystemMessage | None, int]:
@ -92,7 +124,7 @@ def build_chat_system_message(
if not system_prompt:
return None, 0
token_count = len(llm_tokenizer(system_prompt))
token_count = len(llm_tokenizer_encode_func(system_prompt))
system_msg = SystemMessage(content=system_prompt)
return system_msg, token_count
@ -138,7 +170,7 @@ def build_chat_user_message(
chat_message: ChatMessage,
prompt: Prompt,
context_docs: list[LlmDoc],
llm_tokenizer: Callable,
llm_tokenizer_encode_func: Callable,
all_doc_useful: bool,
user_prompt_template: str = CHAT_USER_PROMPT,
context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT,
@ -156,11 +188,11 @@ def build_chat_user_message(
else user_query
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer(user_prompt))
token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt)
return user_msg, token_count
context_docs_str = build_context_str(
context_docs_str = build_complete_context_str(
cast(list[LlmDoc | InferenceChunk], context_docs)
)
optional_ignore = "" if all_doc_useful else ignore_str
@ -175,7 +207,7 @@ def build_chat_user_message(
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer(user_prompt))
token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt)
return user_msg, token_count
@ -357,16 +389,17 @@ def combine_message_chain(
return "\n\n".join(message_strs)
def find_last_index(
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
) -> int:
_PER_MESSAGE_TOKEN_BUFFER = 7
def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
"""From the back, find the index of the last element to include
before the list exceeds the maximum"""
running_sum = 0
last_ind = 0
for i in range(len(lst) - 1, -1, -1):
running_sum += lst[i]
running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER
if running_sum > max_prompt_tokens:
last_ind = i + 1
break
@ -382,14 +415,11 @@ def drop_messages_history_overflow(
history_token_counts: list[int],
final_msg: BaseMessage,
final_msg_token_count: int,
max_allowed_tokens: int | None,
max_allowed_tokens: int,
) -> list[BaseMessage]:
"""As message history grows, messages need to be dropped starting from the furthest in the past.
The System message should be kept if at all possible and the latest user input which is inserted in the
prompt template must be included"""
if max_allowed_tokens is None:
max_allowed_tokens = GEN_AI_MAX_INPUT_TOKENS
if len(history_msgs) != len(history_token_counts):
# This should never happen
raise ValueError("Need exactly 1 token count per message for tracking overflow")
@ -508,3 +538,68 @@ def extract_citations_from_stream(
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
else:
yield DanswerAnswerPiece(answer_piece=curr_segment)
def get_prompt_tokens(prompt: Prompt) -> int:
return (
check_number_of_tokens(prompt.system_prompt)
+ check_number_of_tokens(prompt.task_prompt)
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
+ CITATION_STATEMENT_TOKEN_CNT
+ CITATION_REMINDER_TOKEN_CNT
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
)
# buffer just to be safe so that we don't overflow the token limit due to
# a small miscalculation
_MISC_BUFFER = 40
def compute_max_document_tokens(
persona: Persona, actual_user_input: str | None = None
) -> int:
"""Estimates the number of tokens available for context documents. Formula is roughly:
(
model_context_window - reserved_output_tokens - prompt_tokens
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
)
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
if we're trying to determine if the user should be able to select another document) then we just set an
arbitrary "upper bound".
"""
llm_name = GEN_AI_MODEL_VERSION
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
# if we can't find a number of tokens, just assume some common default
model_full_context_window = get_llm_max_tokens(llm_name) or 4096
if persona.prompts:
prompt_tokens = get_prompt_tokens(persona.prompts[0])
else:
raise RuntimeError("Persona has no prompts - this should never happen")
user_input_tokens = (
check_number_of_tokens(actual_user_input)
if actual_user_input is not None
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
)
return (
model_full_context_window
- GEN_AI_MAX_OUTPUT_TOKENS
- prompt_tokens
- user_input_tokens
- _MISC_BUFFER
)
def compute_max_llm_input_tokens(persona: Persona) -> int:
"""Maximum tokens allows in the input to the LLM (of any type)."""
llm_name = GEN_AI_MODEL_VERSION
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
model_full_context_window = get_llm_max_tokens(llm_name) or 4096
return model_full_context_window - GEN_AI_MAX_OUTPUT_TOKENS - _MISC_BUFFER

View File

@ -7,6 +7,9 @@ from sqlalchemy.orm import Session
from danswer.chat.chat_utils import build_chat_system_message
from danswer.chat.chat_utils import build_chat_user_message
from danswer.chat.chat_utils import build_doc_context_str
from danswer.chat.chat_utils import compute_max_document_tokens
from danswer.chat.chat_utils import compute_max_llm_input_tokens
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import drop_messages_history_overflow
from danswer.chat.chat_utils import extract_citations_from_stream
@ -42,8 +45,8 @@ from danswer.indexing.models import InferenceChunk
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_token_encode
from danswer.llm.utils import get_llm_max_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
from danswer.llm.utils import translate_history_to_basemessages
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
@ -68,7 +71,7 @@ def generate_ai_chat_response(
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
llm: LLM | None,
llm_tokenizer: Callable,
llm_tokenizer_encode_func: Callable,
all_doc_useful: bool,
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
if llm is None:
@ -88,7 +91,7 @@ def generate_ai_chat_response(
system_message_or_none, system_tokens = build_chat_system_message(
prompt=query_message.prompt,
context_exists=context_exists,
llm_tokenizer=llm_tokenizer,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
)
history_basemessages, history_token_counts = translate_history_to_basemessages(
@ -101,7 +104,7 @@ def generate_ai_chat_response(
chat_message=query_message,
prompt=query_message.prompt,
context_docs=context_docs,
llm_tokenizer=llm_tokenizer,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
all_doc_useful=all_doc_useful,
)
@ -112,9 +115,7 @@ def generate_ai_chat_response(
history_token_counts=history_token_counts,
final_msg=user_message,
final_msg_token_count=user_tokens,
max_allowed_tokens=get_llm_max_tokens(persona.llm_model_version_override)
if persona.llm_model_version_override
else None,
max_allowed_tokens=compute_max_llm_input_tokens(persona),
)
# Good Debug/Breakpoint
@ -195,7 +196,10 @@ def stream_chat_message(
except GenAIDisabledException:
llm = None
llm_tokenizer = get_default_llm_token_encode()
llm_tokenizer = get_default_llm_tokenizer()
llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
embedding_model = get_current_db_embedding_model(db_session)
document_index = get_default_document_index(
@ -223,7 +227,7 @@ def stream_chat_message(
parent_message=parent_message,
prompt_id=prompt_id,
message=message_text,
token_count=len(llm_tokenizer(message_text)),
token_count=len(llm_tokenizer_encode_func(message_text)),
message_type=MessageType.USER,
db_session=db_session,
commit=False,
@ -271,6 +275,66 @@ def stream_chat_message(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
# truncate the last document if it exceeds the token limit
max_document_tokens = compute_max_document_tokens(
persona, actual_user_input=message_text
)
tokens_per_doc = [
len(
llm_tokenizer_encode_func(
build_doc_context_str(
semantic_identifier=llm_doc.semantic_identifier,
source_type=llm_doc.source_type,
content=llm_doc.content,
updated_at=llm_doc.updated_at,
ind=ind,
)
)
)
for ind, llm_doc in enumerate(llm_docs)
]
final_doc_ind = None
total_tokens = 0
for ind, tokens in enumerate(tokens_per_doc):
total_tokens += tokens
if total_tokens > max_document_tokens:
final_doc_ind = ind
break
if final_doc_ind is not None:
# only allow the final document to get truncated
# if more than that, then the user message is too long
if final_doc_ind != len(tokens_per_doc) - 1:
yield get_json_line(
StreamingError(
error="LLM context window exceeded. Please de-select some documents or shorten your query."
).dict()
)
return
final_doc_desired_length = tokens_per_doc[final_doc_ind] - (
total_tokens - max_document_tokens
)
# 75 tokens is a reasonable over-estimate of the metadata and title
final_doc_content_length = final_doc_desired_length - 75
# this could occur if we only have space for the title / metadata
# not ideal, but it's the most reasonable thing to do
# NOTE: the frontend prevents documents from being selected if
# less than 75 tokens are available to try and avoid this situation
# from occuring in the first place
if final_doc_content_length <= 0:
logger.error(
f"Final doc ({llm_docs[final_doc_ind].semantic_identifier}) content "
"length is less than 0. Removing this doc from the final prompt."
)
llm_docs.pop()
else:
llm_docs[final_doc_ind].content = tokenizer_trim_content(
content=llm_docs[final_doc_ind].content,
desired_length=final_doc_content_length,
tokenizer=llm_tokenizer,
)
doc_id_to_rank_map = map_document_id_order(
cast(list[InferenceChunk | LlmDoc], llm_docs)
)
@ -423,7 +487,7 @@ def stream_chat_message(
context_docs=llm_docs,
doc_id_to_rank_map=doc_id_to_rank_map,
llm=llm,
llm_tokenizer=llm_tokenizer,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
all_doc_useful=reference_doc_ids is not None,
)
@ -467,7 +531,7 @@ def stream_chat_message(
# Saving Gen AI answer and responding with message info
gen_ai_response_message = partial_response(
message=llm_output,
token_count=len(llm_tokenizer(llm_output)),
token_count=len(llm_tokenizer_encode_func(llm_output)),
citations=db_citations,
error=error,
)

View File

@ -104,4 +104,10 @@ GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
# History for secondary LLM flows, not primary chat flow, generally we don't need to
# include as much as possible as this just bumps up the cost unnecessarily
GEN_AI_HISTORY_CUTOFF = int(0.5 * GEN_AI_MAX_INPUT_TOKENS)
# This is used when computing how much context space is available for documents
# ahead of time in order to let the user know if they can "select" more documents
# It represents a maximum "expected" number of input tokens from the latest user
# message. At query time, we don't actually enforce this - we will only throw an
# error if the total # of tokens exceeds the max input tokens.
GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS = 512
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)

View File

@ -35,7 +35,7 @@ _LLM_TOKENIZER: Any = None
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
def get_default_llm_tokenizer() -> Any:
def get_default_llm_tokenizer() -> Encoding:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
if _LLM_TOKENIZER is None:
@ -56,16 +56,26 @@ def get_default_llm_token_encode() -> Callable[[str], Any]:
return _LLM_TOKENIZER_ENCODE
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: Encoding
) -> str:
tokenizer = get_default_llm_tokenizer()
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
) -> list[InferenceChunk]:
tokenizer = get_default_llm_tokenizer()
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
tokens = tokenizer.encode(chunk.content)
if len(tokens) > max_chunk_toks:
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = tokenizer.decode(tokens[:max_chunk_toks])
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks

View File

@ -4,7 +4,7 @@ from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from danswer.chat.chat_utils import build_context_str
from danswer.chat.chat_utils import build_complete_context_str
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
@ -145,7 +145,7 @@ class SingleMessageQAHandler(QAHandler):
) -> str:
context_block = ""
if context_chunks:
context_docs_str = build_context_str(
context_docs_str = build_complete_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks)
)
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
@ -194,7 +194,7 @@ class SingleMessageScratchpadHandler(QAHandler):
def build_prompt(
self, query: str, history_str: str, context_chunks: list[InferenceChunk]
) -> str:
context_docs_str = build_context_str(
context_docs_str = build_complete_context_str(
cast(list[LlmDoc | InferenceChunk], context_chunks)
)

View File

@ -0,0 +1,24 @@
from danswer.llm.utils import check_number_of_tokens
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import CITATION_REMINDER
from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
# tokens outside of the actual persona's "user_prompt" that make up the end
# user message
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT = check_number_of_tokens(
CHAT_USER_PROMPT.format(
context_docs_str="",
task_prompt="",
user_query="",
optional_ignore_statement=DEFAULT_IGNORE_STATEMENT,
)
)
CITATION_STATEMENT_TOKEN_CNT = check_number_of_tokens(REQUIRE_CITATION_STATEMENT)
CITATION_REMINDER_TOKEN_CNT = check_number_of_tokens(CITATION_REMINDER)
LANGUAGE_HINT_TOKEN_CNT = check_number_of_tokens(LANGUAGE_HINT)

View File

@ -5,6 +5,7 @@ from fastapi import Query
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.chat_utils import build_doc_context_str
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.models import User
@ -47,12 +48,22 @@ def get_document_info(
contents = [chunk.content for chunk in inference_chunks]
combined = "\n".join(contents)
combined_contents = "\n".join(contents)
# get actual document context used for LLM
first_chunk = inference_chunks[0]
tokenizer_encode = get_default_llm_token_encode()
full_context_str = build_doc_context_str(
semantic_identifier=first_chunk.semantic_identifier,
source_type=first_chunk.source_type,
content=combined_contents,
updated_at=first_chunk.updated_at,
ind=0,
)
return DocumentInfo(
num_chunks=len(inference_chunks), num_tokens=len(tokenizer_encode(combined))
num_chunks=len(inference_chunks),
num_tokens=len(tokenizer_encode(full_context_str)),
)

View File

@ -2,9 +2,11 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.chat_utils import compute_max_document_tokens
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.process_message import stream_chat_message
from danswer.db.chat import create_chat_session
@ -13,6 +15,7 @@ from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_chat_sessions_by_user
from danswer.db.chat import get_persona_by_id
from danswer.db.chat import set_as_latest_chat_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import update_chat_session
@ -244,3 +247,27 @@ def create_search_feedback(
document_index=document_index,
db_session=db_session,
)
class MaxSelectedDocumentTokens(BaseModel):
max_tokens: int
@router.get("/max-selected-document-tokens")
def get_max_document_tokens(
persona_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> MaxSelectedDocumentTokens:
try:
persona = get_persona_by_id(
persona_id=persona_id,
user_id=user.id if user else None,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Persona not found")
return MaxSelectedDocumentTokens(
max_tokens=compute_max_document_tokens(persona),
)

View File

@ -40,8 +40,8 @@ import { ResizableSection } from "@/components/resizable/ResizableSection";
import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader";
import { ChatIntro } from "./ChatIntro";
import { HEADER_PADDING } from "@/lib/constants";
import { getSourcesForPersona } from "@/lib/sources";
import { computeAvailableFilters } from "@/lib/filters";
import { useDocumentSelection } from "./useDocumentSelection";
const MAX_INPUT_HEIGHT = 200;
@ -138,9 +138,6 @@ export const Chat = ({
selectedMessageForDocDisplay
)
: { aiMessage: null };
const [selectedDocuments, setSelectedDocuments] = useState<DanswerDocument[]>(
[]
);
const [selectedPersona, setSelectedPersona] = useState<Persona | undefined>(
existingChatSessionPersonaId !== undefined
@ -165,6 +162,30 @@ export const Chat = ({
}
}, [defaultSelectedPersonaId]);
const [
selectedDocuments,
toggleDocumentSelection,
clearSelectedDocuments,
selectedDocumentTokens,
] = useDocumentSelection();
// just choose a conservative default, this will be updated in the
// background on initial load / on persona change
const [maxTokens, setMaxTokens] = useState<number>(4096);
// fetch # of allowed document tokens for the selected Persona
useEffect(() => {
async function fetchMaxTokens() {
const response = await fetch(
`/api/chat/max-selected-document-tokens?persona_id=${livePersona.id}`
);
if (response.ok) {
const maxTokens = (await response.json()).max_tokens as number;
setMaxTokens(maxTokens);
}
}
fetchMaxTokens();
}, [livePersona]);
const filterManager = useFilters();
const [finalAvailableSources, finalAvailableDocumentSets] =
computeAvailableFilters({
@ -725,7 +746,10 @@ export const Chat = ({
<DocumentSidebar
selectedMessage={aiMessage}
selectedDocuments={selectedDocuments}
setSelectedDocuments={setSelectedDocuments}
toggleDocumentSelection={toggleDocumentSelection}
clearSelectedDocuments={clearSelectedDocuments}
selectedDocumentTokens={selectedDocumentTokens}
maxTokens={maxTokens}
isLoading={isFetchingChatMessages}
/>
</ResizableSection>

View File

@ -4,7 +4,6 @@ import { PopupSpec } from "@/components/admin/connectors/Popup";
import { DocumentFeedbackBlock } from "@/components/search/DocumentFeedbackBlock";
import { DocumentUpdatedAtBadge } from "@/components/search/DocumentUpdatedAtBadge";
import { DanswerDocument } from "@/lib/search/interfaces";
import { useState } from "react";
import { FiInfo, FiRadio } from "react-icons/fi";
import { DocumentSelector } from "./DocumentSelector";
import {
@ -19,6 +18,7 @@ interface DocumentDisplayProps {
isSelected: boolean;
handleSelect: (documentId: string) => void;
setPopup: (popupSpec: PopupSpec | null) => void;
tokenLimitReached: boolean;
}
export function ChatDocumentDisplay({
@ -28,27 +28,19 @@ export function ChatDocumentDisplay({
isSelected,
handleSelect,
setPopup,
tokenLimitReached,
}: DocumentDisplayProps) {
const [isHovered, setIsHovered] = useState(false);
// Consider reintroducing null scored docs in the future
if (document.score === null) {
return null;
}
return (
<div
key={document.semantic_identifier}
className="text-sm px-3"
onMouseEnter={() => {
setIsHovered(true);
}}
onMouseLeave={() => setIsHovered(false)}
>
<div className="flex relative w-full overflow-x-hidden">
<div key={document.semantic_identifier} className="text-sm px-3">
<div className="flex relative w-full overflow-y-visible">
<a
className={
"rounded-lg flex font-bold flex-shrink truncate " +
"rounded-lg flex font-bold flex-shrink truncate " +
(document.link ? "" : "pointer-events-none")
}
href={document.link}
@ -102,6 +94,7 @@ export function ChatDocumentDisplay({
<DocumentSelector
isSelected={isSelected}
handleSelect={() => handleSelect(document.document_id)}
isDisabled={tokenLimitReached && !isSelected}
/>
</div>
<div>

View File

@ -1,23 +1,66 @@
import { HoverPopup } from "@/components/HoverPopup";
import { useState } from "react";
export function DocumentSelector({
isSelected,
handleSelect,
isDisabled,
}: {
isSelected: boolean;
handleSelect: () => void;
isDisabled?: boolean;
}) {
return (
<div
className="ml-auto flex cursor-pointer select-none"
onClick={handleSelect}
>
<p className="mr-2 my-auto">Select</p>
<input
className="my-auto"
type="checkbox"
checked={isSelected}
// dummy function to prevent warning
onChange={() => null}
/>
</div>
);
const [popupDisabled, setPopupDisabled] = useState(false);
function onClick() {
if (!isDisabled) {
setPopupDisabled(true);
handleSelect();
// re-enable popup after 1 second so that we don't show the popup immediately upon the
// user de-selecting a document
setTimeout(() => {
setPopupDisabled(false);
}, 1000);
}
}
function Main() {
return (
<div
className={
"ml-auto flex select-none " + (!isDisabled ? " cursor-pointer" : "")
}
onClick={onClick}
>
<p className="mr-2 my-auto">Select</p>
<input
className="my-auto"
type="checkbox"
checked={isSelected}
// dummy function to prevent warning
onChange={() => null}
disabled={isDisabled}
/>
</div>
);
}
if (isDisabled && !popupDisabled) {
return (
<div className="ml-auto">
<HoverPopup
mainContent={Main()}
popupContent={
<div className="w-48">
LLM context limit reached 😔 If you want to chat with this
document, please de-select others to free up space.
</div>
}
direction="left-top"
/>
</div>
);
}
return Main();
}

View File

@ -2,12 +2,13 @@ import { DanswerDocument } from "@/lib/search/interfaces";
import { Text } from "@tremor/react";
import { ChatDocumentDisplay } from "./ChatDocumentDisplay";
import { usePopup } from "@/components/admin/connectors/Popup";
import { FiFileText } from "react-icons/fi";
import { FiAlertTriangle, FiFileText } from "react-icons/fi";
import { SelectedDocumentDisplay } from "./SelectedDocumentDisplay";
import { removeDuplicateDocs } from "@/lib/documentUtils";
import { BasicSelectable } from "@/components/BasicClickable";
import { Message, RetrievalType } from "../interfaces";
import { HEADER_PADDING } from "@/lib/constants";
import { HoverPopup } from "@/components/HoverPopup";
function SectionHeader({
name,
@ -27,12 +28,18 @@ function SectionHeader({
export function DocumentSidebar({
selectedMessage,
selectedDocuments,
setSelectedDocuments,
toggleDocumentSelection,
clearSelectedDocuments,
selectedDocumentTokens,
maxTokens,
isLoading,
}: {
selectedMessage: Message | null;
selectedDocuments: DanswerDocument[] | null;
setSelectedDocuments: (documents: DanswerDocument[]) => void;
toggleDocumentSelection: (document: DanswerDocument) => void;
clearSelectedDocuments: () => void;
selectedDocumentTokens: number;
maxTokens: number;
isLoading: boolean;
}) {
const { popup, setPopup } = usePopup();
@ -44,6 +51,13 @@ export function DocumentSidebar({
const currentDocuments = selectedMessage?.documents || null;
const dedupedDocuments = removeDuplicateDocs(currentDocuments || []);
// NOTE: do not allow selection if less than 75 tokens are left
// this is to prevent the case where they are able to select the doc
// but it basically is unused since it's truncated right at the very
// start of the document (since title + metadata + misc overhead) takes up
// space
const tokenLimitReached = selectedDocumentTokens > maxTokens - 75;
return (
<div
className={`
@ -72,7 +86,7 @@ export function DocumentSidebar({
</div>
{currentDocuments ? (
<div className="overflow-y-auto dark-scrollbar overflow-x-hidden flex flex-col">
<div className="overflow-y-auto dark-scrollbar flex flex-col">
<div>
{dedupedDocuments.length > 0 ? (
dedupedDocuments.map((document, ind) => (
@ -93,21 +107,13 @@ export function DocumentSidebar({
document.document_id
)}
handleSelect={(documentId) => {
if (selectedDocumentIds.includes(documentId)) {
setSelectedDocuments(
selectedDocuments!.filter(
(document) => document.document_id !== documentId
)
);
} else {
setSelectedDocuments([
...selectedDocuments!,
currentDocuments.find(
(document) => document.document_id === documentId
)!,
]);
}
toggleDocumentSelection(
dedupedDocuments.find(
(document) => document.document_id === documentId
)!
);
}}
tokenLimitReached={tokenLimitReached}
/>
</div>
))
@ -132,15 +138,48 @@ export function DocumentSidebar({
<div className="text-sm mb-4 border-t border-border pt-4 overflow-y-hidden flex flex-col">
<div className="flex border-b border-border px-3">
<div>
<div className="flex">
<SectionHeader name="Selected Documents" icon={FiFileText} />
{tokenLimitReached && (
<div className="ml-2 my-auto">
<div className="mb-2">
<HoverPopup
mainContent={
<FiAlertTriangle
className="text-alert my-auto"
size="16"
/>
}
popupContent={
<Text className="w-40">
Over LLM context length by:{" "}
<i>{selectedDocumentTokens - maxTokens} tokens</i>
<br />
<br />
{selectedDocuments && selectedDocuments.length > 0 && (
<>
Truncating: &quot;
<i>
{
selectedDocuments[selectedDocuments.length - 1]
.semantic_identifier
}
</i>
&quot;
</>
)}
</Text>
}
direction="left"
/>
</div>
</div>
)}
</div>
{selectedDocuments && selectedDocuments.length > 0 && (
<div
className="ml-auto my-auto"
onClick={() => setSelectedDocuments([])}
>
<div className="ml-auto my-auto" onClick={clearSelectedDocuments}>
<BasicSelectable selected={false}>De-Select All</BasicSelectable>
</div>
)}
@ -153,10 +192,10 @@ export function DocumentSidebar({
key={document.document_id}
document={document}
handleDeselect={(documentId) => {
setSelectedDocuments(
selectedDocuments!.filter(
(document) => document.document_id !== documentId
)
toggleDocumentSelection(
dedupedDocuments.find(
(document) => document.document_id === documentId
)!
);
}}
/>

View File

@ -49,7 +49,7 @@ export default async function Page({
| AuthTypeMetadata
| FullEmbeddingModelResponse
| null
)[] = [null, null, null, null, null, null, null, null];
)[] = [null, null, null, null, null, null, null, null, null];
try {
results = await Promise.all(tasks);
} catch (e) {

View File

@ -0,0 +1,69 @@
import { DanswerDocument } from "@/lib/search/interfaces";
import { useState } from "react";
interface DocumentInfo {
num_chunks: number;
num_tokens: number;
}
async function fetchDocumentLength(documentId: string) {
const response = await fetch(
`/api/document/document-size-info?document_id=${documentId}`
);
if (!response.ok) {
return 0;
}
const data = (await response.json()) as DocumentInfo;
return data.num_tokens;
}
export function useDocumentSelection(): [
DanswerDocument[],
(document: DanswerDocument) => void,
() => void,
number,
] {
const [selectedDocuments, setSelectedDocuments] = useState<DanswerDocument[]>(
[]
);
const [totalTokens, setTotalTokens] = useState(0);
const selectedDocumentIds = selectedDocuments.map(
(document) => document.document_id
);
const documentIdToLength = new Map<string, number>();
function toggleDocumentSelection(document: DanswerDocument) {
const documentId = document.document_id;
const isAdding = !selectedDocumentIds.includes(documentId);
if (!isAdding) {
setSelectedDocuments(
selectedDocuments.filter(
(document) => document.document_id !== documentId
)
);
} else {
setSelectedDocuments([...selectedDocuments, document]);
}
if (documentIdToLength.has(documentId)) {
const length = documentIdToLength.get(documentId)!;
setTotalTokens(isAdding ? totalTokens + length : totalTokens - length);
} else {
fetchDocumentLength(documentId).then((length) => {
documentIdToLength.set(documentId, length);
setTotalTokens(isAdding ? totalTokens + length : totalTokens - length);
});
}
}
function clearDocuments() {
setSelectedDocuments([]);
setTotalTokens(0);
}
return [
selectedDocuments,
toggleDocumentSelection,
clearDocuments,
totalTokens,
];
}

View File

@ -4,7 +4,7 @@ interface HoverPopupProps {
mainContent: string | JSX.Element;
popupContent: string | JSX.Element;
classNameModifications?: string;
direction?: "left" | "bottom" | "top";
direction?: "left" | "left-top" | "bottom" | "top";
style?: "basic" | "dark";
}
@ -18,9 +18,15 @@ export const HoverPopup = ({
const [hovered, setHovered] = useState(false);
let popupDirectionClass;
let popupStyle = {};
switch (direction) {
case "left":
popupDirectionClass = "top-0 left-0 transform translate-x-[-110%]";
popupDirectionClass = "top-0 left-0 transform";
popupStyle = { transform: "translateX(calc(-100% - 5px))" };
break;
case "left-top":
popupDirectionClass = "bottom-0 left-0";
popupStyle = { transform: "translate(calc(-100% - 5px), 0)" };
break;
case "bottom":
popupDirectionClass = "top-0 left-0 mt-6 pt-2";
@ -39,7 +45,10 @@ export const HoverPopup = ({
onMouseLeave={() => setHovered(false)}
>
{hovered && (
<div className={`absolute ${popupDirectionClass} z-30`}>
<div
className={`absolute ${popupDirectionClass} z-30`}
style={popupStyle}
>
<div
className={
`px-3 py-2 rounded bg-background border border-border` +

View File

@ -59,6 +59,7 @@ module.exports = {
},
error: "#ef4444", // red-500
success: "#059669", // emerald-600
alert: "#f59e0b", // amber-600
user: "#fb7185", // yellow-400
ai: "#60a5fa", // blue-400
// light mode