mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-12 13:59:35 +02:00
Support Detection of LLM Max Context for Non OpenAI Models (#1060)
This commit is contained in:
parent
cd8d8def1e
commit
1a1c91a7d9
@ -574,7 +574,7 @@ def compute_max_document_tokens(
|
||||
max_input_tokens = (
|
||||
max_llm_token_override
|
||||
if max_llm_token_override
|
||||
else get_max_input_tokens(llm_name)
|
||||
else get_max_input_tokens(model_name=llm_name)
|
||||
)
|
||||
if persona.prompts:
|
||||
# TODO this may not always be the first prompt
|
||||
|
@ -443,7 +443,7 @@ def stream_chat_message(
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
llm_max_input_tokens = get_max_input_tokens(llm_name)
|
||||
llm_max_input_tokens = get_max_input_tokens(model_name=llm_name)
|
||||
|
||||
llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens
|
||||
|
||||
|
@ -96,9 +96,8 @@ GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
|
||||
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
||||
# LiteLLM custom_llm_provider
|
||||
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
|
||||
# If the max tokens can't be found from the name, use this as the backup
|
||||
# This happens if user is configuring a different LLM to use
|
||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 4096)
|
||||
# Override the auto-detection of LLM max context length
|
||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
||||
# Number of tokens from chat history to include at maximum
|
||||
|
@ -224,8 +224,12 @@ def handle_message(
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
if len(new_message_request.messages) > 1:
|
||||
llm_name = GEN_AI_MODEL_VERSION
|
||||
if persona and persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(GEN_AI_MODEL_VERSION)
|
||||
input_tokens = get_max_input_tokens(model_name=llm_name)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
@ -24,6 +24,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
@ -203,22 +204,35 @@ def test_llm(llm: LLM) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_llm_max_tokens(model_name: str | None = GEN_AI_MODEL_VERSION) -> int:
|
||||
def get_llm_max_tokens(
|
||||
model_name: str | None = GEN_AI_MODEL_VERSION,
|
||||
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
) -> int:
|
||||
"""Best effort attempt to get the max tokens for the LLM"""
|
||||
if not model_name:
|
||||
if GEN_AI_MAX_TOKENS:
|
||||
# This is an override, so always return this
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
if not model_name:
|
||||
return 4096
|
||||
|
||||
try:
|
||||
return get_max_tokens(model_name)
|
||||
if model_provider == "openai":
|
||||
return get_max_tokens(model_name)
|
||||
return get_max_tokens("/".join([model_provider, model_name]))
|
||||
except Exception:
|
||||
return GEN_AI_MAX_TOKENS
|
||||
return 4096
|
||||
|
||||
|
||||
def get_max_input_tokens(
|
||||
model_name: str | None = GEN_AI_MODEL_VERSION,
|
||||
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
) -> int:
|
||||
input_toks = get_llm_max_tokens(model_name) - output_tokens
|
||||
input_toks = (
|
||||
get_llm_max_tokens(model_name=model_name, model_provider=model_provider)
|
||||
- output_tokens
|
||||
)
|
||||
|
||||
if input_toks <= 0:
|
||||
raise RuntimeError("No tokens for input for the LLM given settings")
|
||||
|
Loading…
x
Reference in New Issue
Block a user