diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py
index fb2a32c23..92bad7cb7 100644
--- a/backend/danswer/chat/chat_utils.py
+++ b/backend/danswer/chat/chat_utils.py
@@ -14,12 +14,9 @@ from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
-from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
-from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
-from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_chat_messages_by_session
@@ -28,7 +25,7 @@ from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens
-from danswer.llm.utils import get_llm_max_tokens
+from danswer.llm.utils import get_max_input_tokens
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import CITATION_REMINDER
@@ -239,7 +236,7 @@ def _get_usable_chunks(
def get_usable_chunks(
chunks: list[InferenceChunk],
- token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
+ token_limit: int,
offset: int = 0,
) -> list[InferenceChunk]:
offset_into_chunks = 0
@@ -261,7 +258,7 @@ def get_usable_chunks(
def get_chunks_for_qa(
chunks: list[InferenceChunk],
llm_chunk_selection: list[bool],
- token_limit: float | None = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
+ token_limit: int | None,
batch_offset: int = 0,
) -> list[int]:
"""
@@ -363,10 +360,10 @@ def create_chat_chain(
def combine_message_chain(
messages: list[ChatMessage],
- msg_limit: int | None = 10,
- token_limit: int | None = GEN_AI_HISTORY_CUTOFF,
+ token_limit: int,
+ msg_limit: int | None = None,
) -> str:
- """Used for secondary LLM flows that require the chat history"""
+ """Used for secondary LLM flows that require the chat history,"""
message_strs: list[str] = []
total_token_count = 0
@@ -376,10 +373,7 @@ def combine_message_chain(
for message in reversed(messages):
message_token_count = message.token_count
- if (
- token_limit is not None
- and total_token_count + message_token_count > token_limit
- ):
+ if total_token_count + message_token_count > token_limit:
break
role = message.message_type.value.upper()
@@ -557,7 +551,9 @@ _MISC_BUFFER = 40
def compute_max_document_tokens(
- persona: Persona, actual_user_input: str | None = None
+ persona: Persona,
+ actual_user_input: str | None = None,
+ max_llm_token_override: int | None = None,
) -> int:
"""Estimates the number of tokens available for context documents. Formula is roughly:
@@ -575,8 +571,13 @@ def compute_max_document_tokens(
llm_name = persona.llm_model_version_override
# if we can't find a number of tokens, just assume some common default
- model_full_context_window = get_llm_max_tokens(llm_name) or 4096
+ max_input_tokens = (
+ max_llm_token_override
+ if max_llm_token_override
+ else get_max_input_tokens(llm_name)
+ )
if persona.prompts:
+ # TODO this may not always be the first prompt
prompt_tokens = get_prompt_tokens(persona.prompts[0])
else:
raise RuntimeError("Persona has no prompts - this should never happen")
@@ -586,13 +587,7 @@ def compute_max_document_tokens(
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
)
- return (
- model_full_context_window
- - GEN_AI_MAX_OUTPUT_TOKENS
- - prompt_tokens
- - user_input_tokens
- - _MISC_BUFFER
- )
+ return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
def compute_max_llm_input_tokens(persona: Persona) -> int:
@@ -601,5 +596,5 @@ def compute_max_llm_input_tokens(persona: Persona) -> int:
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
- model_full_context_window = get_llm_max_tokens(llm_name) or 4096
- return model_full_context_window - GEN_AI_MAX_OUTPUT_TOKENS - _MISC_BUFFER
+ input_tokens = get_max_input_tokens(model_name=llm_name)
+ return input_tokens - _MISC_BUFFER
diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py
index b02c18cc4..d2e93c474 100644
--- a/backend/danswer/chat/load_yamls.py
+++ b/backend/danswer/chat/load_yamls.py
@@ -3,7 +3,7 @@ from typing import cast
import yaml
from sqlalchemy.orm import Session
-from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
+from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.chat import get_prompt_by_name
@@ -42,7 +42,7 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
def load_personas_from_yaml(
personas_yaml: str = PERSONAS_YAML,
- default_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
+ default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> None:
with open(personas_yaml, "r") as file:
data = yaml.safe_load(file)
diff --git a/backend/danswer/chat/personas.yaml b/backend/danswer/chat/personas.yaml
index b2c44ead5..1f358e4b1 100644
--- a/backend/danswer/chat/personas.yaml
+++ b/backend/danswer/chat/personas.yaml
@@ -13,9 +13,8 @@ personas:
- "Answer-Question"
# Default number of chunks to include as context, set to 0 to disable retrieval
# Remove the field to set to the system default number of chunks/tokens to pass to Gen AI
- # If selecting documents, user can bypass this up until NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
# Each chunk is 512 tokens long
- num_chunks: 5
+ num_chunks: 10
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
# if the chunk is useful or not towards the latest user query
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
@@ -46,7 +45,7 @@ personas:
extrapolate any answers for you.
prompts:
- "Summarize"
- num_chunks: 5
+ num_chunks: 10
llm_relevance_filter: true
llm_filter_extraction: true
recency_bias: "auto"
@@ -58,7 +57,7 @@ personas:
The least creative default assistant that only provides quotes from the documents.
prompts:
- "Paraphrase"
- num_chunks: 5
+ num_chunks: 10
llm_relevance_filter: true
llm_filter_extraction: true
recency_bias: "auto"
diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py
index d24c3ca9d..57be65679 100644
--- a/backend/danswer/chat/process_message.py
+++ b/backend/danswer/chat/process_message.py
@@ -22,10 +22,12 @@ from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
-from danswer.configs.chat_configs import CHUNK_SIZE
-from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
+from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
+from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.configs.constants import MessageType
+from danswer.configs.model_configs import CHUNK_SIZE
+from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
@@ -46,6 +48,7 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
+from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import tokenizer_trim_content
from danswer.llm.utils import translate_history_to_basemessages
from danswer.search.models import OptionalSearchSetting
@@ -156,8 +159,11 @@ def stream_chat_message(
user: User | None,
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
- default_num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
+ default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = CHUNK_SIZE,
+ # For flow with search, don't include as many chunks as possible since we need to leave space
+ # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
+ max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
) -> Iterator[str]:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -260,6 +266,10 @@ def stream_chat_message(
query_message=final_msg, history=history_msgs, llm=llm
)
+ max_document_tokens = compute_max_document_tokens(
+ persona=persona, actual_user_input=message_text
+ )
+
rephrased_query = None
if reference_doc_ids:
identifier_tuples = get_doc_query_identifiers_from_model(
@@ -277,9 +287,6 @@ def stream_chat_message(
)
# truncate the last document if it exceeds the token limit
- max_document_tokens = compute_max_document_tokens(
- persona, actual_user_input=message_text
- )
tokens_per_doc = [
len(
llm_tokenizer_encode_func(
@@ -431,10 +438,26 @@ def stream_chat_message(
if persona.num_chunks is not None
else default_num_chunks
)
+
+ llm_name = GEN_AI_MODEL_VERSION
+ if persona.llm_model_version_override:
+ llm_name = persona.llm_model_version_override
+
+ llm_max_input_tokens = get_max_input_tokens(llm_name)
+
+ llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens
+
+ chunk_token_limit = int(
+ min(
+ num_llm_chunks * default_chunk_size,
+ max_document_tokens,
+ llm_token_based_chunk_lim,
+ )
+ )
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
- token_limit=num_llm_chunks * default_chunk_size,
+ token_limit=chunk_token_limit,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]
diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py
index c5cc605aa..6fc2b6fb9 100644
--- a/backend/danswer/configs/chat_configs.py
+++ b/backend/danswer/configs/chat_configs.py
@@ -1,28 +1,18 @@
import os
-from danswer.configs.model_configs import CHUNK_SIZE
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15
-# We feed in document chunks until we reach this token limit.
-# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks may be
-# significantly smaller which could result in passing in more total chunks.
-# There is also a slight bit of overhead, not accounted for here such as separator patterns
-# between the docs, metadata for the docs, etc.
-# Finally, this is combined with the rest of the QA prompt, so don't set this too close to the
-# model token limit
-NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
- os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (CHUNK_SIZE * 5)
-)
-DEFAULT_NUM_CHUNKS_FED_TO_CHAT: float = (
- float(NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL) / CHUNK_SIZE
-)
-NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int(
- os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (CHUNK_SIZE * 3)
-)
+
+# May be less depending on model
+MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
+# For Chat, need to keep enough space for history and other prompt pieces
+# ~3k input, half for docs, half for chat history + prompts
+CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
+
# For selecting a different LLM question-answering prompt format
# Valid values: default, cot, weak
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
@@ -60,7 +50,7 @@ if os.environ.get("EDIT_KEYWORD_QUERY"):
else:
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
-HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.66)))
+HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
# Weighting factor between Title and Content of documents during search, 1 for completely
# Title based. Default heavily favors Content because Title is also included at the top of
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
diff --git a/backend/danswer/configs/danswerbot_configs.py b/backend/danswer/configs/danswerbot_configs.py
index 6fd610f99..484ba144b 100644
--- a/backend/danswer/configs/danswerbot_configs.py
+++ b/backend/danswer/configs/danswerbot_configs.py
@@ -7,6 +7,8 @@ DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
)
+# How much of the available input context can be used for thread context
+DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
# Number of docs to display in "Reference Documents"
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py
index 11d6ce26d..24691b7f0 100644
--- a/backend/danswer/configs/model_configs.py
+++ b/backend/danswer/configs/model_configs.py
@@ -78,7 +78,7 @@ INTENT_MODEL_VERSION = "danswer/intent-model"
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer
-GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"
+GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo-0125"
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
FAST_GEN_AI_MODEL_VERSION = (
@@ -96,14 +96,15 @@ GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
# LiteLLM custom_llm_provider
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
-
+# If the max tokens can't be found from the name, use this as the backup
+# This happens if user is configuring a different LLM to use
+GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 4096)
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
-# This next restriction is only used for chat ATM, used to expire old messages as needed
-GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
-# History for secondary LLM flows, not primary chat flow, generally we don't need to
-# include as much as possible as this just bumps up the cost unnecessarily
-GEN_AI_HISTORY_CUTOFF = int(0.5 * GEN_AI_MAX_INPUT_TOKENS)
+# Number of tokens from chat history to include at maximum
+# 3000 should be enough context regardless of use, no need to include as much as possible
+# as this drives up the cost unnecessarily
+GEN_AI_HISTORY_CUTOFF = 3000
# This is used when computing how much context space is available for documents
# ahead of time in order to let the user know if they can "select" more documents
# It represents a maximum "expected" number of input tokens from the latest user
diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py
index 2bd7c7e5b..e5b67285c 100644
--- a/backend/danswer/danswerbot/slack/handlers/handle_message.py
+++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py
@@ -11,14 +11,17 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
+from danswer.chat.chat_utils import compute_max_document_tokens
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
+from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
+from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
@@ -33,6 +36,8 @@ from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
+from danswer.llm.utils import check_number_of_tokens
+from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
@@ -98,6 +103,7 @@ def handle_message(
disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
+ thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated
@@ -215,11 +221,36 @@ def handle_message(
slack_usage_report(action=action, sender_id=sender_id, client=client)
+ max_document_tokens: int | None = None
+ max_history_tokens: int | None = None
+ if len(new_message_request.messages) > 1:
+ # In cases of threads, split the available tokens between docs and thread context
+ input_tokens = get_max_input_tokens(GEN_AI_MODEL_VERSION)
+ max_history_tokens = int(input_tokens * thread_context_percent)
+
+ remaining_tokens = input_tokens - max_history_tokens
+
+ query_text = new_message_request.messages[0].message
+ if persona:
+ max_document_tokens = compute_max_document_tokens(
+ persona=persona,
+ actual_user_input=query_text,
+ max_llm_token_override=remaining_tokens,
+ )
+ else:
+ max_document_tokens = (
+ remaining_tokens
+ - 512 # Needs to be more than any of the QA prompts
+ - check_number_of_tokens(query_text)
+ )
+
with Session(get_sqlalchemy_engine()) as db_session:
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=None,
+ max_document_tokens=max_document_tokens,
+ max_history_tokens=max_history_tokens,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
enable_reflexion=reflexion,
diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py
index 606040a00..19b1e4967 100644
--- a/backend/danswer/db/models.py
+++ b/backend/danswer/db/models.py
@@ -723,7 +723,6 @@ class Persona(Base):
Enum(SearchType), default=SearchType.HYBRID
)
# Number of chunks to pass to the LLM for generation.
- # If unspecified, uses the default DEFAULT_NUM_CHUNKS_FED_TO_CHAT set in the env variable
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
# Pass every chunk through LLM for evaluation, fairly expensive
# Can be turned off globally by admin, in which case, this setting is ignored
diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py
index f6478ca5b..bbf4ff0b6 100644
--- a/backend/danswer/db/slack_bot_config.py
+++ b/backend/danswer/db/slack_bot_config.py
@@ -3,7 +3,7 @@ from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
-from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
+from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.db.chat import upsert_persona
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.document_set import get_document_sets_by_ids
@@ -35,7 +35,7 @@ def create_slack_bot_persona(
channel_names: list[str],
document_set_ids: list[int],
existing_persona_id: int | None = None,
- num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
+ num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> Persona:
"""NOTE: does not commit changes"""
document_sets = list(
diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py
index a8c0b40ab..91da9160c 100644
--- a/backend/danswer/llm/utils.py
+++ b/backend/danswer/llm/utils.py
@@ -22,6 +22,9 @@ from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_API_KEY
+from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
+from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
+from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.models import ChatMessage
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
@@ -59,7 +62,6 @@ def get_default_llm_token_encode() -> Callable[[str], Any]:
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: Encoding
) -> str:
- tokenizer = get_default_llm_tokenizer()
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
@@ -201,9 +203,24 @@ def test_llm(llm: LLM) -> bool:
return False
-def get_llm_max_tokens(model_name: str) -> int | None:
+def get_llm_max_tokens(model_name: str | None = GEN_AI_MODEL_VERSION) -> int:
"""Best effort attempt to get the max tokens for the LLM"""
+ if not model_name:
+ return GEN_AI_MAX_TOKENS
+
try:
return get_max_tokens(model_name)
except Exception:
- return None
+ return GEN_AI_MAX_TOKENS
+
+
+def get_max_input_tokens(
+ model_name: str | None = GEN_AI_MODEL_VERSION,
+ output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
+) -> int:
+ input_toks = get_llm_max_tokens(model_name) - output_tokens
+
+ if input_toks <= 0:
+ raise RuntimeError("No tokens for input for the LLM given settings")
+
+ return input_toks
diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py
index 4a2cb493c..ff4f2cc00 100644
--- a/backend/danswer/one_shot_answer/answer_question.py
+++ b/backend/danswer/one_shot_answer/answer_question.py
@@ -5,6 +5,7 @@ from typing import cast
from sqlalchemy.orm import Session
+from danswer.chat.chat_utils import compute_max_document_tokens
from danswer.chat.chat_utils import get_chunks_for_qa
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContext
@@ -14,7 +15,7 @@ from danswer.chat.models import LLMMetricsContainer
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
-from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
+from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import CHUNK_SIZE
@@ -54,9 +55,14 @@ logger = setup_logger()
def stream_answer_objects(
query_req: DirectQARequest,
user: User | None,
+ # These need to be passed in because in Web UI one shot flow,
+ # we can have much more document as there is no history.
+ # For Slack flow, we need to save more tokens for the thread context
+ max_document_tokens: int | None,
+ max_history_tokens: int | None,
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
- default_num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
+ default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = CHUNK_SIZE,
timeout: int = QA_TIMEOUT,
bypass_acl: bool = False,
@@ -106,7 +112,9 @@ def stream_answer_objects(
chat_session_id=chat_session.id, db_session=db_session
)
- history_str = combine_message_thread(history)
+ history_str = combine_message_thread(
+ messages=history, max_tokens=max_history_tokens
+ )
rephrased_query = thread_based_query_rephrase(
user_query=query_msg.message,
@@ -174,10 +182,20 @@ def stream_answer_objects(
if chat_session.persona.num_chunks is not None
else default_num_chunks
)
+
+ chunk_token_limit = int(num_llm_chunks * default_chunk_size)
+ if max_document_tokens:
+ chunk_token_limit = min(chunk_token_limit, max_document_tokens)
+ else:
+ max_document_tokens = compute_max_document_tokens(
+ persona=chat_session.persona, actual_user_input=query_msg.message
+ )
+ chunk_token_limit = min(chunk_token_limit, max_document_tokens)
+
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
- token_limit=num_llm_chunks * default_chunk_size,
+ token_limit=chunk_token_limit,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
@@ -288,10 +306,16 @@ def stream_answer_objects(
def stream_search_answer(
query_req: DirectQARequest,
user: User | None,
+ max_document_tokens: int | None,
+ max_history_tokens: int | None,
db_session: Session,
) -> Iterator[str]:
objects = stream_answer_objects(
- query_req=query_req, user=user, db_session=db_session
+ query_req=query_req,
+ user=user,
+ max_document_tokens=max_document_tokens,
+ max_history_tokens=max_history_tokens,
+ db_session=db_session,
)
for obj in objects:
yield get_json_line(obj.dict())
@@ -300,6 +324,8 @@ def stream_search_answer(
def get_search_answer(
query_req: DirectQARequest,
user: User | None,
+ max_document_tokens: int | None,
+ max_history_tokens: int | None,
db_session: Session,
answer_generation_timeout: int = QA_TIMEOUT,
enable_reflexion: bool = False,
@@ -315,6 +341,8 @@ def get_search_answer(
results = stream_answer_objects(
query_req=query_req,
user=user,
+ max_document_tokens=max_document_tokens,
+ max_history_tokens=max_history_tokens,
db_session=db_session,
bypass_acl=bypass_acl,
timeout=answer_generation_timeout,
diff --git a/backend/danswer/one_shot_answer/qa_utils.py b/backend/danswer/one_shot_answer/qa_utils.py
index c1dd36889..032d24345 100644
--- a/backend/danswer/one_shot_answer/qa_utils.py
+++ b/backend/danswer/one_shot_answer/qa_utils.py
@@ -15,7 +15,6 @@ from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import MessageType
-from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import ThreadMessage
@@ -279,10 +278,13 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
def combine_message_thread(
messages: list[ThreadMessage],
- token_limit: int | None = GEN_AI_HISTORY_CUTOFF,
+ max_tokens: int | None,
llm_tokenizer: Callable | None = None,
) -> str:
"""Used to create a single combined message context from threads"""
+ if not messages:
+ return ""
+
message_strs: list[str] = []
total_token_count = 0
if llm_tokenizer is None:
@@ -304,8 +306,8 @@ def combine_message_thread(
message_token_count = len(llm_tokenizer(msg_str))
if (
- token_limit is not None
- and total_token_count + message_token_count > token_limit
+ max_tokens is not None
+ and total_token_count + message_token_count > max_tokens
):
break
diff --git a/backend/danswer/secondary_llm_flows/chat_session_naming.py b/backend/danswer/secondary_llm_flows/chat_session_naming.py
index e65c8dd36..aa604131b 100644
--- a/backend/danswer/secondary_llm_flows/chat_session_naming.py
+++ b/backend/danswer/secondary_llm_flows/chat_session_naming.py
@@ -1,4 +1,5 @@
from danswer.chat.chat_utils import combine_message_chain
+from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
@@ -31,7 +32,9 @@ def get_renamed_conversation_name(
# clear thing we can do
return full_history[0].message
- history_str = combine_message_chain(full_history)
+ history_str = combine_message_chain(
+ messages=full_history, token_limit=GEN_AI_HISTORY_CUTOFF
+ )
prompt_msgs = get_chat_rename_messages(history_str)
diff --git a/backend/danswer/secondary_llm_flows/choose_search.py b/backend/danswer/secondary_llm_flows/choose_search.py
index 626b10775..9e07bf647 100644
--- a/backend/danswer/secondary_llm_flows/choose_search.py
+++ b/backend/danswer/secondary_llm_flows/choose_search.py
@@ -4,6 +4,7 @@ from langchain.schema import SystemMessage
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
+from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
@@ -77,7 +78,9 @@ def check_if_need_search(
# as just a search engine
return True
- history_str = combine_message_chain(history)
+ history_str = combine_message_chain(
+ messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
+ )
prompt_msgs = _get_search_messages(
question=query_message.message, history_str=history_str
diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py
index 9a4946074..d0fb19b73 100644
--- a/backend/danswer/secondary_llm_flows/query_expansion.py
+++ b/backend/danswer/secondary_llm_flows/query_expansion.py
@@ -2,6 +2,7 @@ from collections.abc import Callable
from typing import cast
from danswer.chat.chat_utils import combine_message_chain
+from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
@@ -119,7 +120,9 @@ def history_based_query_rephrase(
if count_punctuation(user_query) >= punctuation_heuristic:
return user_query
- history_str = combine_message_chain(history)
+ history_str = combine_message_chain(
+ messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
+ )
prompt_msgs = get_contextual_rephrase_messages(
question=user_query, history_str=history_str
diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py
index 0c27ecaf3..0f0e540c6 100644
--- a/backend/danswer/server/query_and_chat/query_backend.py
+++ b/backend/danswer/server/query_and_chat/query_backend.py
@@ -153,6 +153,10 @@ def get_answer_with_quote(
query = query_request.messages[0].message
logger.info(f"Received query for one shot answer with quotes: {query}")
packets = stream_search_answer(
- query_req=query_request, user=user, db_session=db_session
+ query_req=query_request,
+ user=user,
+ max_document_tokens=None,
+ max_history_tokens=0,
+ db_session=db_session,
)
return StreamingResponse(packets, media_type="application/json")
diff --git a/backend/scripts/test-openapi-key.py b/backend/scripts/test-openapi-key.py
index 30f97571a..ba2e61dce 100644
--- a/backend/scripts/test-openapi-key.py
+++ b/backend/scripts/test-openapi-key.py
@@ -10,6 +10,7 @@ VALID_MODEL_LIST = [
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
+ "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
diff --git a/backend/tests/regression/answer_quality/eval_direct_qa.py b/backend/tests/regression/answer_quality/eval_direct_qa.py
index cfd921a6f..bd2f70010 100644
--- a/backend/tests/regression/answer_quality/eval_direct_qa.py
+++ b/backend/tests/regression/answer_quality/eval_direct_qa.py
@@ -108,6 +108,8 @@ def get_answer_for_question(
answer = get_search_answer(
query_req=new_message_request,
user=None,
+ max_document_tokens=None,
+ max_history_tokens=None,
db_session=db_session,
answer_generation_timeout=100,
enable_reflexion=False,
diff --git a/backend/tests/regression/answer_quality/relari.py b/backend/tests/regression/answer_quality/relari.py
index 7f60a8a58..a63226a3d 100644
--- a/backend/tests/regression/answer_quality/relari.py
+++ b/backend/tests/regression/answer_quality/relari.py
@@ -41,6 +41,8 @@ def get_answer_for_question(query: str, db_session: Session) -> OneShotQARespons
answer = get_search_answer(
query_req=new_message_request,
user=None,
+ max_document_tokens=None,
+ max_history_tokens=None,
db_session=db_session,
answer_generation_timeout=100,
enable_reflexion=False,
diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml
index 5e3bf1e04..a78784448 100644
--- a/deployment/docker_compose/docker-compose.dev.yml
+++ b/deployment/docker_compose/docker-compose.dev.yml
@@ -30,14 +30,15 @@ services:
- EMAIL_FROM=${EMAIL_FROM:-}
# Gen AI Settings
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
- - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
+ - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125}
+ - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
+ - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}
- - NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
+ - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-}
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
@@ -93,14 +94,15 @@ services:
environment:
# Gen AI Settings (Needed by DanswerBot)
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
- - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
+ - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125}
+ - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
+ - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}
- - NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
+ - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-}
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
diff --git a/deployment/kubernetes/env-configmap.yaml b/deployment/kubernetes/env-configmap.yaml
index 97f8354b0..11ec91934 100644
--- a/deployment/kubernetes/env-configmap.yaml
+++ b/deployment/kubernetes/env-configmap.yaml
@@ -14,14 +14,15 @@ data:
EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead
# Gen AI Settings
GEN_AI_MODEL_PROVIDER: "openai"
- GEN_AI_MODEL_VERSION: "gpt-3.5-turbo" # Use GPT-4 if you have it
- FAST_GEN_AI_MODEL_VERSION: "gpt-3.5-turbo"
+ GEN_AI_MODEL_VERSION: "gpt-3.5-turbo-0125" # Use GPT-4 if you have it
+ FAST_GEN_AI_MODEL_VERSION: "gpt-3.5-turbo-0125"
GEN_AI_API_KEY: ""
GEN_AI_API_ENDPOINT: ""
GEN_AI_API_VERSION: ""
GEN_AI_LLM_PROVIDER_TYPE: ""
+ GEN_AI_MAX_TOKENS: ""
QA_TIMEOUT: "60"
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL: ""
+ MAX_CHUNKS_FED_TO_CHAT: ""
DISABLE_LLM_FILTER_EXTRACTION: ""
DISABLE_LLM_CHUNK_FILTER: ""
DISABLE_LLM_CHOOSE_SEARCH: ""
diff --git a/web/src/app/admin/personas/PersonaEditor.tsx b/web/src/app/admin/personas/PersonaEditor.tsx
index 29ea52170..42106108c 100644
--- a/web/src/app/admin/personas/PersonaEditor.tsx
+++ b/web/src/app/admin/personas/PersonaEditor.tsx
@@ -86,7 +86,7 @@ export function PersonaEditor({
description: existingPersona?.description ?? "",
system_prompt: existingPrompt?.system_prompt ?? "",
task_prompt: existingPrompt?.task_prompt ?? "",
- disable_retrieval: (existingPersona?.num_chunks ?? 5) === 0,
+ disable_retrieval: (existingPersona?.num_chunks ?? 10) === 0,
document_set_ids:
existingPersona?.document_sets?.map(
(documentSet) => documentSet.id
@@ -148,7 +148,7 @@ export function PersonaEditor({
// to tell the backend to not fetch any documents
const numChunks = values.disable_retrieval
? 0
- : values.num_chunks || 5;
+ : values.num_chunks || 10;
let promptResponse;
let personaResponse;
@@ -414,7 +414,7 @@ export function PersonaEditor({
input length limit.
- If unspecified, will use 5 chunks.
+ If unspecified, will use 10 chunks.
}
onChange={(e) => {