diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index fb2a32c23..92bad7cb7 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -14,12 +14,9 @@ from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.chat_configs import STOP_STREAM_PAT from danswer.configs.constants import DocumentSource from danswer.configs.constants import IGNORE_FOR_QA -from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.chat import get_chat_messages_by_session @@ -28,7 +25,7 @@ from danswer.db.models import Persona from danswer.db.models import Prompt from danswer.indexing.models import InferenceChunk from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_llm_max_tokens +from danswer.llm.utils import get_max_input_tokens from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT from danswer.prompts.chat_prompts import CHAT_USER_PROMPT from danswer.prompts.chat_prompts import CITATION_REMINDER @@ -239,7 +236,7 @@ def _get_usable_chunks( def get_usable_chunks( chunks: list[InferenceChunk], - token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL, + token_limit: int, offset: int = 0, ) -> list[InferenceChunk]: offset_into_chunks = 0 @@ -261,7 +258,7 @@ def get_usable_chunks( def get_chunks_for_qa( chunks: list[InferenceChunk], llm_chunk_selection: list[bool], - token_limit: float | None = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL, + token_limit: int | None, batch_offset: int = 0, ) -> list[int]: """ @@ -363,10 +360,10 @@ def create_chat_chain( def combine_message_chain( messages: list[ChatMessage], - msg_limit: int | None = 10, - token_limit: int | None = GEN_AI_HISTORY_CUTOFF, + token_limit: int, + msg_limit: int | None = None, ) -> str: - """Used for secondary LLM flows that require the chat history""" + """Used for secondary LLM flows that require the chat history,""" message_strs: list[str] = [] total_token_count = 0 @@ -376,10 +373,7 @@ def combine_message_chain( for message in reversed(messages): message_token_count = message.token_count - if ( - token_limit is not None - and total_token_count + message_token_count > token_limit - ): + if total_token_count + message_token_count > token_limit: break role = message.message_type.value.upper() @@ -557,7 +551,9 @@ _MISC_BUFFER = 40 def compute_max_document_tokens( - persona: Persona, actual_user_input: str | None = None + persona: Persona, + actual_user_input: str | None = None, + max_llm_token_override: int | None = None, ) -> int: """Estimates the number of tokens available for context documents. Formula is roughly: @@ -575,8 +571,13 @@ def compute_max_document_tokens( llm_name = persona.llm_model_version_override # if we can't find a number of tokens, just assume some common default - model_full_context_window = get_llm_max_tokens(llm_name) or 4096 + max_input_tokens = ( + max_llm_token_override + if max_llm_token_override + else get_max_input_tokens(llm_name) + ) if persona.prompts: + # TODO this may not always be the first prompt prompt_tokens = get_prompt_tokens(persona.prompts[0]) else: raise RuntimeError("Persona has no prompts - this should never happen") @@ -586,13 +587,7 @@ def compute_max_document_tokens( else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS ) - return ( - model_full_context_window - - GEN_AI_MAX_OUTPUT_TOKENS - - prompt_tokens - - user_input_tokens - - _MISC_BUFFER - ) + return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER def compute_max_llm_input_tokens(persona: Persona) -> int: @@ -601,5 +596,5 @@ def compute_max_llm_input_tokens(persona: Persona) -> int: if persona.llm_model_version_override: llm_name = persona.llm_model_version_override - model_full_context_window = get_llm_max_tokens(llm_name) or 4096 - return model_full_context_window - GEN_AI_MAX_OUTPUT_TOKENS - _MISC_BUFFER + input_tokens = get_max_input_tokens(model_name=llm_name) + return input_tokens - _MISC_BUFFER diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index b02c18cc4..d2e93c474 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -3,7 +3,7 @@ from typing import cast import yaml from sqlalchemy.orm import Session -from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT +from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import PERSONAS_YAML from danswer.configs.chat_configs import PROMPTS_YAML from danswer.db.chat import get_prompt_by_name @@ -42,7 +42,7 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None: def load_personas_from_yaml( personas_yaml: str = PERSONAS_YAML, - default_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT, + default_chunks: float = MAX_CHUNKS_FED_TO_CHAT, ) -> None: with open(personas_yaml, "r") as file: data = yaml.safe_load(file) diff --git a/backend/danswer/chat/personas.yaml b/backend/danswer/chat/personas.yaml index b2c44ead5..1f358e4b1 100644 --- a/backend/danswer/chat/personas.yaml +++ b/backend/danswer/chat/personas.yaml @@ -13,9 +13,8 @@ personas: - "Answer-Question" # Default number of chunks to include as context, set to 0 to disable retrieval # Remove the field to set to the system default number of chunks/tokens to pass to Gen AI - # If selecting documents, user can bypass this up until NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL # Each chunk is 512 tokens long - num_chunks: 5 + num_chunks: 10 # Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine # if the chunk is useful or not towards the latest user query # This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable @@ -46,7 +45,7 @@ personas: extrapolate any answers for you. prompts: - "Summarize" - num_chunks: 5 + num_chunks: 10 llm_relevance_filter: true llm_filter_extraction: true recency_bias: "auto" @@ -58,7 +57,7 @@ personas: The least creative default assistant that only provides quotes from the documents. prompts: - "Paraphrase" - num_chunks: 5 + num_chunks: 10 llm_relevance_filter: true llm_filter_extraction: true recency_bias: "auto" diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index d24c3ca9d..57be65679 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -22,10 +22,12 @@ from danswer.chat.models import LlmDoc from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError -from danswer.configs.chat_configs import CHUNK_SIZE -from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT +from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE +from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.constants import DISABLED_GEN_AI_MSG from danswer.configs.constants import MessageType +from danswer.configs.model_configs import CHUNK_SIZE +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_chat_message @@ -46,6 +48,7 @@ from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import get_max_input_tokens from danswer.llm.utils import tokenizer_trim_content from danswer.llm.utils import translate_history_to_basemessages from danswer.search.models import OptionalSearchSetting @@ -156,8 +159,11 @@ def stream_chat_message( user: User | None, db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM - default_num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT, + default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, default_chunk_size: int = CHUNK_SIZE, + # For flow with search, don't include as many chunks as possible since we need to leave space + # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks + max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, ) -> Iterator[str]: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run @@ -260,6 +266,10 @@ def stream_chat_message( query_message=final_msg, history=history_msgs, llm=llm ) + max_document_tokens = compute_max_document_tokens( + persona=persona, actual_user_input=message_text + ) + rephrased_query = None if reference_doc_ids: identifier_tuples = get_doc_query_identifiers_from_model( @@ -277,9 +287,6 @@ def stream_chat_message( ) # truncate the last document if it exceeds the token limit - max_document_tokens = compute_max_document_tokens( - persona, actual_user_input=message_text - ) tokens_per_doc = [ len( llm_tokenizer_encode_func( @@ -431,10 +438,26 @@ def stream_chat_message( if persona.num_chunks is not None else default_num_chunks ) + + llm_name = GEN_AI_MODEL_VERSION + if persona.llm_model_version_override: + llm_name = persona.llm_model_version_override + + llm_max_input_tokens = get_max_input_tokens(llm_name) + + llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens + + chunk_token_limit = int( + min( + num_llm_chunks * default_chunk_size, + max_document_tokens, + llm_token_based_chunk_lim, + ) + ) llm_chunks_indices = get_chunks_for_qa( chunks=top_chunks, llm_chunk_selection=llm_chunk_selection, - token_limit=num_llm_chunks * default_chunk_size, + token_limit=chunk_token_limit, ) llm_chunks = [top_chunks[i] for i in llm_chunks_indices] llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks] diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py index c5cc605aa..6fc2b6fb9 100644 --- a/backend/danswer/configs/chat_configs.py +++ b/backend/danswer/configs/chat_configs.py @@ -1,28 +1,18 @@ import os -from danswer.configs.model_configs import CHUNK_SIZE PROMPTS_YAML = "./danswer/chat/prompts.yaml" PERSONAS_YAML = "./danswer/chat/personas.yaml" NUM_RETURNED_HITS = 50 NUM_RERANKED_RESULTS = 15 -# We feed in document chunks until we reach this token limit. -# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks may be -# significantly smaller which could result in passing in more total chunks. -# There is also a slight bit of overhead, not accounted for here such as separator patterns -# between the docs, metadata for the docs, etc. -# Finally, this is combined with the rest of the QA prompt, so don't set this too close to the -# model token limit -NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int( - os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (CHUNK_SIZE * 5) -) -DEFAULT_NUM_CHUNKS_FED_TO_CHAT: float = ( - float(NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL) / CHUNK_SIZE -) -NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int( - os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (CHUNK_SIZE * 3) -) + +# May be less depending on model +MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0) +# For Chat, need to keep enough space for history and other prompt pieces +# ~3k input, half for docs, half for chat history + prompts +CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072 + # For selecting a different LLM question-answering prompt format # Valid values: default, cot, weak QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None @@ -60,7 +50,7 @@ if os.environ.get("EDIT_KEYWORD_QUERY"): else: EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL") # Weighting factor between Vector and Keyword Search, 1 for completely vector search -HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.66))) +HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62))) # Weighting factor between Title and Content of documents during search, 1 for completely # Title based. Default heavily favors Content because Title is also included at the top of # Content. This is to avoid cases where the Content is very relevant but it may not be clear diff --git a/backend/danswer/configs/danswerbot_configs.py b/backend/danswer/configs/danswerbot_configs.py index 6fd610f99..484ba144b 100644 --- a/backend/danswer/configs/danswerbot_configs.py +++ b/backend/danswer/configs/danswerbot_configs.py @@ -7,6 +7,8 @@ DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5")) DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int( os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90") ) +# How much of the available input context can be used for thread context +DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072 # Number of docs to display in "Reference Documents" DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int( os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5") diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 11d6ce26d..24691b7f0 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -78,7 +78,7 @@ INTENT_MODEL_VERSION = "danswer/intent-model" # Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai" # If using Azure, it's the engine name, for example: Danswer -GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo" +GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo-0125" # For secondary flows like extracting filters or deciding if a chunk is useful, we don't need # as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper FAST_GEN_AI_MODEL_VERSION = ( @@ -96,14 +96,15 @@ GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None # LiteLLM custom_llm_provider GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None - +# If the max tokens can't be found from the name, use this as the backup +# This happens if user is configuring a different LLM to use +GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 4096) # Set this to be enough for an answer + quotes. Also used for Chat GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024) -# This next restriction is only used for chat ATM, used to expire old messages as needed -GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000) -# History for secondary LLM flows, not primary chat flow, generally we don't need to -# include as much as possible as this just bumps up the cost unnecessarily -GEN_AI_HISTORY_CUTOFF = int(0.5 * GEN_AI_MAX_INPUT_TOKENS) +# Number of tokens from chat history to include at maximum +# 3000 should be enough context regardless of use, no need to include as much as possible +# as this drives up the cost unnecessarily +GEN_AI_HISTORY_CUTOFF = 3000 # This is used when computing how much context space is available for documents # ahead of time in order to let the user know if they can "select" more documents # It represents a maximum "expected" number of input tokens from the latest user diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index 2bd7c7e5b..e5b67285c 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -11,14 +11,17 @@ from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from sqlalchemy.orm import Session +from danswer.chat.chat_utils import compute_max_document_tokens from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES +from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.danswerbot.slack.blocks import build_documents_blocks from danswer.danswerbot.slack.blocks import build_follow_up_block from danswer.danswerbot.slack.blocks import build_qa_response_blocks @@ -33,6 +36,8 @@ from danswer.danswerbot.slack.utils import SlackRateLimiter from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import SlackBotConfig +from danswer.llm.utils import check_number_of_tokens +from danswer.llm.utils import get_max_input_tokens from danswer.one_shot_answer.answer_question import get_search_answer from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import OneShotQAResponse @@ -98,6 +103,7 @@ def handle_message( disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT, reflexion: bool = ENABLE_DANSWERBOT_REFLEXION, disable_cot: bool = DANSWER_BOT_DISABLE_COT, + thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE, ) -> bool: """Potentially respond to the user message depending on filters and if an answer was generated @@ -215,11 +221,36 @@ def handle_message( slack_usage_report(action=action, sender_id=sender_id, client=client) + max_document_tokens: int | None = None + max_history_tokens: int | None = None + if len(new_message_request.messages) > 1: + # In cases of threads, split the available tokens between docs and thread context + input_tokens = get_max_input_tokens(GEN_AI_MODEL_VERSION) + max_history_tokens = int(input_tokens * thread_context_percent) + + remaining_tokens = input_tokens - max_history_tokens + + query_text = new_message_request.messages[0].message + if persona: + max_document_tokens = compute_max_document_tokens( + persona=persona, + actual_user_input=query_text, + max_llm_token_override=remaining_tokens, + ) + else: + max_document_tokens = ( + remaining_tokens + - 512 # Needs to be more than any of the QA prompts + - check_number_of_tokens(query_text) + ) + with Session(get_sqlalchemy_engine()) as db_session: # This also handles creating the query event in postgres answer = get_search_answer( query_req=new_message_request, user=None, + max_document_tokens=max_document_tokens, + max_history_tokens=max_history_tokens, db_session=db_session, answer_generation_timeout=answer_generation_timeout, enable_reflexion=reflexion, diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 606040a00..19b1e4967 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -723,7 +723,6 @@ class Persona(Base): Enum(SearchType), default=SearchType.HYBRID ) # Number of chunks to pass to the LLM for generation. - # If unspecified, uses the default DEFAULT_NUM_CHUNKS_FED_TO_CHAT set in the env variable num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True) # Pass every chunk through LLM for evaluation, fairly expensive # Can be turned off globally by admin, in which case, this setting is ignored diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index f6478ca5b..bbf4ff0b6 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from sqlalchemy import select from sqlalchemy.orm import Session -from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT +from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.db.chat import upsert_persona from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX from danswer.db.document_set import get_document_sets_by_ids @@ -35,7 +35,7 @@ def create_slack_bot_persona( channel_names: list[str], document_set_ids: list[int], existing_persona_id: int | None = None, - num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT, + num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, ) -> Persona: """NOTE: does not commit changes""" document_sets = list( diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index a8c0b40ab..91da9160c 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -22,6 +22,9 @@ from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.configs.constants import MessageType from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import GEN_AI_API_KEY +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MAX_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.models import ChatMessage from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError @@ -59,7 +62,6 @@ def get_default_llm_token_encode() -> Callable[[str], Any]: def tokenizer_trim_content( content: str, desired_length: int, tokenizer: Encoding ) -> str: - tokenizer = get_default_llm_tokenizer() tokens = tokenizer.encode(content) if len(tokens) > desired_length: content = tokenizer.decode(tokens[:desired_length]) @@ -201,9 +203,24 @@ def test_llm(llm: LLM) -> bool: return False -def get_llm_max_tokens(model_name: str) -> int | None: +def get_llm_max_tokens(model_name: str | None = GEN_AI_MODEL_VERSION) -> int: """Best effort attempt to get the max tokens for the LLM""" + if not model_name: + return GEN_AI_MAX_TOKENS + try: return get_max_tokens(model_name) except Exception: - return None + return GEN_AI_MAX_TOKENS + + +def get_max_input_tokens( + model_name: str | None = GEN_AI_MODEL_VERSION, + output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, +) -> int: + input_toks = get_llm_max_tokens(model_name) - output_tokens + + if input_toks <= 0: + raise RuntimeError("No tokens for input for the LLM given settings") + + return input_toks diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 4a2cb493c..ff4f2cc00 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -5,6 +5,7 @@ from typing import cast from sqlalchemy.orm import Session +from danswer.chat.chat_utils import compute_max_document_tokens from danswer.chat.chat_utils import get_chunks_for_qa from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import DanswerContext @@ -14,7 +15,7 @@ from danswer.chat.models import LLMMetricsContainer from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError -from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT +from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.constants import MessageType from danswer.configs.model_configs import CHUNK_SIZE @@ -54,9 +55,14 @@ logger = setup_logger() def stream_answer_objects( query_req: DirectQARequest, user: User | None, + # These need to be passed in because in Web UI one shot flow, + # we can have much more document as there is no history. + # For Slack flow, we need to save more tokens for the thread context + max_document_tokens: int | None, + max_history_tokens: int | None, db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM - default_num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT, + default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, default_chunk_size: int = CHUNK_SIZE, timeout: int = QA_TIMEOUT, bypass_acl: bool = False, @@ -106,7 +112,9 @@ def stream_answer_objects( chat_session_id=chat_session.id, db_session=db_session ) - history_str = combine_message_thread(history) + history_str = combine_message_thread( + messages=history, max_tokens=max_history_tokens + ) rephrased_query = thread_based_query_rephrase( user_query=query_msg.message, @@ -174,10 +182,20 @@ def stream_answer_objects( if chat_session.persona.num_chunks is not None else default_num_chunks ) + + chunk_token_limit = int(num_llm_chunks * default_chunk_size) + if max_document_tokens: + chunk_token_limit = min(chunk_token_limit, max_document_tokens) + else: + max_document_tokens = compute_max_document_tokens( + persona=chat_session.persona, actual_user_input=query_msg.message + ) + chunk_token_limit = min(chunk_token_limit, max_document_tokens) + llm_chunks_indices = get_chunks_for_qa( chunks=top_chunks, llm_chunk_selection=llm_chunk_selection, - token_limit=num_llm_chunks * default_chunk_size, + token_limit=chunk_token_limit, ) llm_chunks = [top_chunks[i] for i in llm_chunks_indices] @@ -288,10 +306,16 @@ def stream_answer_objects( def stream_search_answer( query_req: DirectQARequest, user: User | None, + max_document_tokens: int | None, + max_history_tokens: int | None, db_session: Session, ) -> Iterator[str]: objects = stream_answer_objects( - query_req=query_req, user=user, db_session=db_session + query_req=query_req, + user=user, + max_document_tokens=max_document_tokens, + max_history_tokens=max_history_tokens, + db_session=db_session, ) for obj in objects: yield get_json_line(obj.dict()) @@ -300,6 +324,8 @@ def stream_search_answer( def get_search_answer( query_req: DirectQARequest, user: User | None, + max_document_tokens: int | None, + max_history_tokens: int | None, db_session: Session, answer_generation_timeout: int = QA_TIMEOUT, enable_reflexion: bool = False, @@ -315,6 +341,8 @@ def get_search_answer( results = stream_answer_objects( query_req=query_req, user=user, + max_document_tokens=max_document_tokens, + max_history_tokens=max_history_tokens, db_session=db_session, bypass_acl=bypass_acl, timeout=answer_generation_timeout, diff --git a/backend/danswer/one_shot_answer/qa_utils.py b/backend/danswer/one_shot_answer/qa_utils.py index c1dd36889..032d24345 100644 --- a/backend/danswer/one_shot_answer/qa_utils.py +++ b/backend/danswer/one_shot_answer/qa_utils.py @@ -15,7 +15,6 @@ from danswer.chat.models import DanswerQuote from danswer.chat.models import DanswerQuotes from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.constants import MessageType -from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.indexing.models import InferenceChunk from danswer.llm.utils import get_default_llm_token_encode from danswer.one_shot_answer.models import ThreadMessage @@ -279,10 +278,13 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: def combine_message_thread( messages: list[ThreadMessage], - token_limit: int | None = GEN_AI_HISTORY_CUTOFF, + max_tokens: int | None, llm_tokenizer: Callable | None = None, ) -> str: """Used to create a single combined message context from threads""" + if not messages: + return "" + message_strs: list[str] = [] total_token_count = 0 if llm_tokenizer is None: @@ -304,8 +306,8 @@ def combine_message_thread( message_token_count = len(llm_tokenizer(msg_str)) if ( - token_limit is not None - and total_token_count + message_token_count > token_limit + max_tokens is not None + and total_token_count + message_token_count > max_tokens ): break diff --git a/backend/danswer/secondary_llm_flows/chat_session_naming.py b/backend/danswer/secondary_llm_flows/chat_session_naming.py index e65c8dd36..aa604131b 100644 --- a/backend/danswer/secondary_llm_flows/chat_session_naming.py +++ b/backend/danswer/secondary_llm_flows/chat_session_naming.py @@ -1,4 +1,5 @@ from danswer.chat.chat_utils import combine_message_chain +from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm @@ -31,7 +32,9 @@ def get_renamed_conversation_name( # clear thing we can do return full_history[0].message - history_str = combine_message_chain(full_history) + history_str = combine_message_chain( + messages=full_history, token_limit=GEN_AI_HISTORY_CUTOFF + ) prompt_msgs = get_chat_rename_messages(history_str) diff --git a/backend/danswer/secondary_llm_flows/choose_search.py b/backend/danswer/secondary_llm_flows/choose_search.py index 626b10775..9e07bf647 100644 --- a/backend/danswer/secondary_llm_flows/choose_search.py +++ b/backend/danswer/secondary_llm_flows/choose_search.py @@ -4,6 +4,7 @@ from langchain.schema import SystemMessage from danswer.chat.chat_utils import combine_message_chain from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH +from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm @@ -77,7 +78,9 @@ def check_if_need_search( # as just a search engine return True - history_str = combine_message_chain(history) + history_str = combine_message_chain( + messages=history, token_limit=GEN_AI_HISTORY_CUTOFF + ) prompt_msgs = _get_search_messages( question=query_message.message, history_str=history_str diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py index 9a4946074..d0fb19b73 100644 --- a/backend/danswer/secondary_llm_flows/query_expansion.py +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -2,6 +2,7 @@ from collections.abc import Callable from typing import cast from danswer.chat.chat_utils import combine_message_chain +from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm @@ -119,7 +120,9 @@ def history_based_query_rephrase( if count_punctuation(user_query) >= punctuation_heuristic: return user_query - history_str = combine_message_chain(history) + history_str = combine_message_chain( + messages=history, token_limit=GEN_AI_HISTORY_CUTOFF + ) prompt_msgs = get_contextual_rephrase_messages( question=user_query, history_str=history_str diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 0c27ecaf3..0f0e540c6 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -153,6 +153,10 @@ def get_answer_with_quote( query = query_request.messages[0].message logger.info(f"Received query for one shot answer with quotes: {query}") packets = stream_search_answer( - query_req=query_request, user=user, db_session=db_session + query_req=query_request, + user=user, + max_document_tokens=None, + max_history_tokens=0, + db_session=db_session, ) return StreamingResponse(packets, media_type="application/json") diff --git a/backend/scripts/test-openapi-key.py b/backend/scripts/test-openapi-key.py index 30f97571a..ba2e61dce 100644 --- a/backend/scripts/test-openapi-key.py +++ b/backend/scripts/test-openapi-key.py @@ -10,6 +10,7 @@ VALID_MODEL_LIST = [ "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", + "gpt-3.5-turbo-0125", "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", diff --git a/backend/tests/regression/answer_quality/eval_direct_qa.py b/backend/tests/regression/answer_quality/eval_direct_qa.py index cfd921a6f..bd2f70010 100644 --- a/backend/tests/regression/answer_quality/eval_direct_qa.py +++ b/backend/tests/regression/answer_quality/eval_direct_qa.py @@ -108,6 +108,8 @@ def get_answer_for_question( answer = get_search_answer( query_req=new_message_request, user=None, + max_document_tokens=None, + max_history_tokens=None, db_session=db_session, answer_generation_timeout=100, enable_reflexion=False, diff --git a/backend/tests/regression/answer_quality/relari.py b/backend/tests/regression/answer_quality/relari.py index 7f60a8a58..a63226a3d 100644 --- a/backend/tests/regression/answer_quality/relari.py +++ b/backend/tests/regression/answer_quality/relari.py @@ -41,6 +41,8 @@ def get_answer_for_question(query: str, db_session: Session) -> OneShotQARespons answer = get_search_answer( query_req=new_message_request, user=None, + max_document_tokens=None, + max_history_tokens=None, db_session=db_session, answer_generation_timeout=100, enable_reflexion=False, diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 5e3bf1e04..a78784448 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -30,14 +30,15 @@ services: - EMAIL_FROM=${EMAIL_FROM:-} # Gen AI Settings - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai} - - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} + - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125} + - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} + - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - - NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-} + - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} @@ -93,14 +94,15 @@ services: environment: # Gen AI Settings (Needed by DanswerBot) - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai} - - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} + - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125} + - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-} + - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} - - NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-} + - MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-} - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} diff --git a/deployment/kubernetes/env-configmap.yaml b/deployment/kubernetes/env-configmap.yaml index 97f8354b0..11ec91934 100644 --- a/deployment/kubernetes/env-configmap.yaml +++ b/deployment/kubernetes/env-configmap.yaml @@ -14,14 +14,15 @@ data: EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead # Gen AI Settings GEN_AI_MODEL_PROVIDER: "openai" - GEN_AI_MODEL_VERSION: "gpt-3.5-turbo" # Use GPT-4 if you have it - FAST_GEN_AI_MODEL_VERSION: "gpt-3.5-turbo" + GEN_AI_MODEL_VERSION: "gpt-3.5-turbo-0125" # Use GPT-4 if you have it + FAST_GEN_AI_MODEL_VERSION: "gpt-3.5-turbo-0125" GEN_AI_API_KEY: "" GEN_AI_API_ENDPOINT: "" GEN_AI_API_VERSION: "" GEN_AI_LLM_PROVIDER_TYPE: "" + GEN_AI_MAX_TOKENS: "" QA_TIMEOUT: "60" - NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL: "" + MAX_CHUNKS_FED_TO_CHAT: "" DISABLE_LLM_FILTER_EXTRACTION: "" DISABLE_LLM_CHUNK_FILTER: "" DISABLE_LLM_CHOOSE_SEARCH: "" diff --git a/web/src/app/admin/personas/PersonaEditor.tsx b/web/src/app/admin/personas/PersonaEditor.tsx index 29ea52170..42106108c 100644 --- a/web/src/app/admin/personas/PersonaEditor.tsx +++ b/web/src/app/admin/personas/PersonaEditor.tsx @@ -86,7 +86,7 @@ export function PersonaEditor({ description: existingPersona?.description ?? "", system_prompt: existingPrompt?.system_prompt ?? "", task_prompt: existingPrompt?.task_prompt ?? "", - disable_retrieval: (existingPersona?.num_chunks ?? 5) === 0, + disable_retrieval: (existingPersona?.num_chunks ?? 10) === 0, document_set_ids: existingPersona?.document_sets?.map( (documentSet) => documentSet.id @@ -148,7 +148,7 @@ export function PersonaEditor({ // to tell the backend to not fetch any documents const numChunks = values.disable_retrieval ? 0 - : values.num_chunks || 5; + : values.num_chunks || 10; let promptResponse; let personaResponse; @@ -414,7 +414,7 @@ export function PersonaEditor({ input length limit.

- If unspecified, will use 5 chunks. + If unspecified, will use 10 chunks. } onChange={(e) => {