From 1a1c91a7d9d06c6dcbc68c0ba94559a41d2be133 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 8 Feb 2024 15:15:58 -0800 Subject: [PATCH] Support Detection of LLM Max Context for Non OpenAI Models (#1060) --- backend/danswer/chat/chat_utils.py | 2 +- backend/danswer/chat/process_message.py | 2 +- backend/danswer/configs/model_configs.py | 5 ++-- .../slack/handlers/handle_message.py | 6 ++++- backend/danswer/llm/utils.py | 24 +++++++++++++++---- 5 files changed, 28 insertions(+), 11 deletions(-) diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index 92bad7cb7..efd0924e4 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -574,7 +574,7 @@ def compute_max_document_tokens( max_input_tokens = ( max_llm_token_override if max_llm_token_override - else get_max_input_tokens(llm_name) + else get_max_input_tokens(model_name=llm_name) ) if persona.prompts: # TODO this may not always be the first prompt diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 57be65679..4885ff877 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -443,7 +443,7 @@ def stream_chat_message( if persona.llm_model_version_override: llm_name = persona.llm_model_version_override - llm_max_input_tokens = get_max_input_tokens(llm_name) + llm_max_input_tokens = get_max_input_tokens(model_name=llm_name) llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 24691b7f0..62563b111 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -96,9 +96,8 @@ GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None # LiteLLM custom_llm_provider GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None -# If the max tokens can't be found from the name, use this as the backup -# This happens if user is configuring a different LLM to use -GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 4096) +# Override the auto-detection of LLM max context length +GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None # Set this to be enough for an answer + quotes. Also used for Chat GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024) # Number of tokens from chat history to include at maximum diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index e5b67285c..2bd9759ff 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -224,8 +224,12 @@ def handle_message( max_document_tokens: int | None = None max_history_tokens: int | None = None if len(new_message_request.messages) > 1: + llm_name = GEN_AI_MODEL_VERSION + if persona and persona.llm_model_version_override: + llm_name = persona.llm_model_version_override + # In cases of threads, split the available tokens between docs and thread context - input_tokens = get_max_input_tokens(GEN_AI_MODEL_VERSION) + input_tokens = get_max_input_tokens(model_name=llm_name) max_history_tokens = int(input_tokens * thread_context_percent) remaining_tokens = input_tokens - max_history_tokens diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 91da9160c..379ccfbae 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -24,6 +24,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.models import ChatMessage from danswer.dynamic_configs import get_dynamic_config_store @@ -203,22 +204,35 @@ def test_llm(llm: LLM) -> bool: return False -def get_llm_max_tokens(model_name: str | None = GEN_AI_MODEL_VERSION) -> int: +def get_llm_max_tokens( + model_name: str | None = GEN_AI_MODEL_VERSION, + model_provider: str = GEN_AI_MODEL_PROVIDER, +) -> int: """Best effort attempt to get the max tokens for the LLM""" - if not model_name: + if GEN_AI_MAX_TOKENS: + # This is an override, so always return this return GEN_AI_MAX_TOKENS + if not model_name: + return 4096 + try: - return get_max_tokens(model_name) + if model_provider == "openai": + return get_max_tokens(model_name) + return get_max_tokens("/".join([model_provider, model_name])) except Exception: - return GEN_AI_MAX_TOKENS + return 4096 def get_max_input_tokens( model_name: str | None = GEN_AI_MODEL_VERSION, + model_provider: str = GEN_AI_MODEL_PROVIDER, output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, ) -> int: - input_toks = get_llm_max_tokens(model_name) - output_tokens + input_toks = ( + get_llm_max_tokens(model_name=model_name, model_provider=model_provider) + - output_tokens + ) if input_toks <= 0: raise RuntimeError("No tokens for input for the LLM given settings")