mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-29 13:25:50 +02:00
Get accurate model output max (#2260)
* get accurate model output max * squash * udpated max default tokens * rename + use fallbacks * functional * remove max tokens * update naming * comment out function to prevent mypy issues
This commit is contained in:
@@ -80,11 +80,16 @@ GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
||||
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
|
||||
# Override the auto-detection of LLM max context length
|
||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
||||
|
||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
||||
# This is the minimum token context we will leave for the LLM to generate an answer
|
||||
GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||
os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024
|
||||
)
|
||||
|
||||
# Typically, GenAI models nowadays are at least 4K tokens
|
||||
GEN_AI_MODEL_DEFAULT_MAX_TOKENS = 4096
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
|
||||
|
||||
# Number of tokens from chat history to include at maximum
|
||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||
# as this drives up the cost unnecessarily
|
||||
|
@@ -28,7 +28,6 @@ from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
@@ -193,10 +192,10 @@ class DefaultMultiLLM(LLM):
|
||||
timeout: int,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
max_output_tokens: int | None = None,
|
||||
api_base: str | None = GEN_AI_API_ENDPOINT,
|
||||
api_version: str | None = GEN_AI_API_VERSION,
|
||||
custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
@@ -209,7 +208,17 @@ class DefaultMultiLLM(LLM):
|
||||
self._api_base = api_base
|
||||
self._api_version = api_version
|
||||
self._custom_llm_provider = custom_llm_provider
|
||||
self._max_output_tokens = max_output_tokens
|
||||
|
||||
# This can be used to store the maximum output tkoens for this model.
|
||||
# self._max_output_tokens = (
|
||||
# max_output_tokens
|
||||
# if max_output_tokens is not None
|
||||
# else get_llm_max_output_tokens(
|
||||
# model_map=litellm.model_cost,
|
||||
# model_name=model_name,
|
||||
# model_provider=model_provider,
|
||||
# )
|
||||
# )
|
||||
self._custom_config = custom_config
|
||||
|
||||
# NOTE: have to set these as environment variables for Litellm since
|
||||
@@ -228,6 +237,30 @@ class DefaultMultiLLM(LLM):
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Config: {self.config}")
|
||||
|
||||
# def _calculate_max_output_tokens(self, prompt: LanguageModelInput) -> int:
|
||||
# # NOTE: This method can be used for calculating the maximum tokens for the stream,
|
||||
# # but it isn't used in practice due to the computational cost of counting tokens
|
||||
# # and because LLM providers automatically cut off at the maximum output.
|
||||
# # The implementation is kept for potential future use or debugging purposes.
|
||||
|
||||
# # Get max input tokens for the model
|
||||
# max_context_tokens = get_max_input_tokens(
|
||||
# model_name=self.config.model_name, model_provider=self.config.model_provider
|
||||
# )
|
||||
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# model_name=self.config.model_name,
|
||||
# provider_type=self.config.model_provider,
|
||||
# )
|
||||
# # Calculate tokens in the input prompt
|
||||
# input_tokens = sum(len(llm_tokenizer.encode(str(m))) for m in prompt)
|
||||
|
||||
# # Calculate available tokens for output
|
||||
# available_output_tokens = max_context_tokens - input_tokens
|
||||
|
||||
# # Return the lesser of available tokens or configured max
|
||||
# return min(self._max_output_tokens, available_output_tokens)
|
||||
|
||||
def _completion(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
@@ -259,9 +292,6 @@ class DefaultMultiLLM(LLM):
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=self._temperature,
|
||||
max_tokens=self._max_output_tokens
|
||||
if self._max_output_tokens > 0
|
||||
else None,
|
||||
timeout=self._timeout,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
|
@@ -8,7 +8,7 @@ from langchain_core.messages import BaseMessage
|
||||
from requests import Timeout
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
from danswer.llm.utils import convert_lm_input_to_basic_string
|
||||
@@ -38,7 +38,7 @@ class CustomModelServer(LLM):
|
||||
api_key: str | None,
|
||||
timeout: int,
|
||||
endpoint: str | None = GEN_AI_API_ENDPOINT,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
max_output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
|
||||
):
|
||||
if not endpoint:
|
||||
raise ValueError(
|
||||
|
@@ -30,10 +30,10 @@ from litellm.exceptions import Timeout # type: ignore
|
||||
from litellm.exceptions import UnprocessableEntityError # type: ignore
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_DEFAULT_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
@@ -336,31 +336,80 @@ def get_llm_max_tokens(
|
||||
"""Best effort attempt to get the max tokens for the LLM"""
|
||||
if GEN_AI_MAX_TOKENS:
|
||||
# This is an override, so always return this
|
||||
logger.info(f"Using override GEN_AI_MAX_TOKENS: {GEN_AI_MAX_TOKENS}")
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
try:
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if not model_obj:
|
||||
model_obj = model_map[model_name]
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
else:
|
||||
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||
|
||||
if "max_input_tokens" in model_obj:
|
||||
return model_obj["max_input_tokens"]
|
||||
max_tokens = model_obj["max_input_tokens"]
|
||||
logger.info(
|
||||
f"Max tokens for {model_name}: {max_tokens} (from max_input_tokens)"
|
||||
)
|
||||
return max_tokens
|
||||
|
||||
if "max_tokens" in model_obj:
|
||||
return model_obj["max_tokens"]
|
||||
max_tokens = model_obj["max_tokens"]
|
||||
logger.info(f"Max tokens for {model_name}: {max_tokens} (from max_tokens)")
|
||||
return max_tokens
|
||||
|
||||
logger.error(f"No max tokens found for LLM: {model_name}")
|
||||
raise RuntimeError("No max tokens found for LLM")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_DEFAULT_MAX_TOKENS}."
|
||||
f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS}."
|
||||
)
|
||||
return GEN_AI_MODEL_DEFAULT_MAX_TOKENS
|
||||
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
|
||||
|
||||
def get_llm_max_output_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
) -> int:
|
||||
"""Best effort attempt to get the max output tokens for the LLM"""
|
||||
try:
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if not model_obj:
|
||||
model_obj = model_map[model_name]
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
else:
|
||||
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||
|
||||
if "max_output_tokens" in model_obj:
|
||||
max_output_tokens = model_obj["max_output_tokens"]
|
||||
logger.info(f"Max output tokens for {model_name}: {max_output_tokens}")
|
||||
return max_output_tokens
|
||||
|
||||
# Fallback to a fraction of max_tokens if max_output_tokens is not specified
|
||||
if "max_tokens" in model_obj:
|
||||
max_output_tokens = int(model_obj["max_tokens"] * 0.1)
|
||||
logger.info(
|
||||
f"Fallback max output tokens for {model_name}: {max_output_tokens} (10% of max_tokens)"
|
||||
)
|
||||
return max_output_tokens
|
||||
|
||||
logger.error(f"No max output tokens found for LLM: {model_name}")
|
||||
raise RuntimeError("No max output tokens found for LLM")
|
||||
except Exception:
|
||||
default_output_tokens = int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS)
|
||||
logger.exception(
|
||||
f"Failed to get max output tokens for LLM with name {model_name}. "
|
||||
f"Defaulting to {default_output_tokens} (fallback max tokens)."
|
||||
)
|
||||
return default_output_tokens
|
||||
|
||||
|
||||
def get_max_input_tokens(
|
||||
model_name: str,
|
||||
model_provider: str,
|
||||
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
|
||||
) -> int:
|
||||
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
|
||||
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,
|
||||
|
@@ -13,7 +13,7 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_DEFAULT_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
@@ -130,7 +130,7 @@ class SearchTool(Tool):
|
||||
# For small context models, don't include additional surrounding context
|
||||
# The 3 here for at least minimum 1 above, 1 below and 1 for the middle chunk
|
||||
max_llm_tokens = compute_max_llm_input_tokens(self.llm.config)
|
||||
if max_llm_tokens < 3 * GEN_AI_MODEL_DEFAULT_MAX_TOKENS:
|
||||
if max_llm_tokens < 3 * GEN_AI_MODEL_FALLBACK_MAX_TOKENS:
|
||||
self.chunks_above = 0
|
||||
self.chunks_below = 0
|
||||
|
||||
|
Reference in New Issue
Block a user