Support Detection of LLM Max Context for Non OpenAI Models (#1060)

This commit is contained in:
Yuhong Sun 2024-02-08 15:15:58 -08:00 committed by GitHub
parent cd8d8def1e
commit 1a1c91a7d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 28 additions and 11 deletions

View File

@ -574,7 +574,7 @@ def compute_max_document_tokens(
max_input_tokens = (
max_llm_token_override
if max_llm_token_override
else get_max_input_tokens(llm_name)
else get_max_input_tokens(model_name=llm_name)
)
if persona.prompts:
# TODO this may not always be the first prompt

View File

@ -443,7 +443,7 @@ def stream_chat_message(
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
llm_max_input_tokens = get_max_input_tokens(llm_name)
llm_max_input_tokens = get_max_input_tokens(model_name=llm_name)
llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens

View File

@ -96,9 +96,8 @@ GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
# LiteLLM custom_llm_provider
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
# If the max tokens can't be found from the name, use this as the backup
# This happens if user is configuring a different LLM to use
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 4096)
# Override the auto-detection of LLM max context length
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
# Number of tokens from chat history to include at maximum

View File

@ -224,8 +224,12 @@ def handle_message(
max_document_tokens: int | None = None
max_history_tokens: int | None = None
if len(new_message_request.messages) > 1:
llm_name = GEN_AI_MODEL_VERSION
if persona and persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(GEN_AI_MODEL_VERSION)
input_tokens = get_max_input_tokens(model_name=llm_name)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens

View File

@ -24,6 +24,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.models import ChatMessage
from danswer.dynamic_configs import get_dynamic_config_store
@ -203,22 +204,35 @@ def test_llm(llm: LLM) -> bool:
return False
def get_llm_max_tokens(model_name: str | None = GEN_AI_MODEL_VERSION) -> int:
def get_llm_max_tokens(
model_name: str | None = GEN_AI_MODEL_VERSION,
model_provider: str = GEN_AI_MODEL_PROVIDER,
) -> int:
"""Best effort attempt to get the max tokens for the LLM"""
if not model_name:
if GEN_AI_MAX_TOKENS:
# This is an override, so always return this
return GEN_AI_MAX_TOKENS
if not model_name:
return 4096
try:
return get_max_tokens(model_name)
if model_provider == "openai":
return get_max_tokens(model_name)
return get_max_tokens("/".join([model_provider, model_name]))
except Exception:
return GEN_AI_MAX_TOKENS
return 4096
def get_max_input_tokens(
model_name: str | None = GEN_AI_MODEL_VERSION,
model_provider: str = GEN_AI_MODEL_PROVIDER,
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
) -> int:
input_toks = get_llm_max_tokens(model_name) - output_tokens
input_toks = (
get_llm_max_tokens(model_name=model_name, model_provider=model_provider)
- output_tokens
)
if input_toks <= 0:
raise RuntimeError("No tokens for input for the LLM given settings")