mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-09 12:47:13 +02:00
Get accurate model output max (#2260)
* get accurate model output max * squash * udpated max default tokens * rename + use fallbacks * functional * remove max tokens * update naming * comment out function to prevent mypy issues
This commit is contained in:
@@ -80,11 +80,16 @@ GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
|
|||||||
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
|
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
|
||||||
# Override the auto-detection of LLM max context length
|
# Override the auto-detection of LLM max context length
|
||||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
||||||
|
|
||||||
# Set this to be enough for an answer + quotes. Also used for Chat
|
# Set this to be enough for an answer + quotes. Also used for Chat
|
||||||
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
|
# This is the minimum token context we will leave for the LLM to generate an answer
|
||||||
|
GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||||
|
os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024
|
||||||
|
)
|
||||||
|
|
||||||
# Typically, GenAI models nowadays are at least 4K tokens
|
# Typically, GenAI models nowadays are at least 4K tokens
|
||||||
GEN_AI_MODEL_DEFAULT_MAX_TOKENS = 4096
|
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
|
||||||
|
|
||||||
# Number of tokens from chat history to include at maximum
|
# Number of tokens from chat history to include at maximum
|
||||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||||
# as this drives up the cost unnecessarily
|
# as this drives up the cost unnecessarily
|
||||||
|
@@ -28,7 +28,6 @@ from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
|
|||||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||||
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||||
from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE
|
from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
|
||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.interfaces import LLMConfig
|
from danswer.llm.interfaces import LLMConfig
|
||||||
@@ -193,10 +192,10 @@ class DefaultMultiLLM(LLM):
|
|||||||
timeout: int,
|
timeout: int,
|
||||||
model_provider: str,
|
model_provider: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
max_output_tokens: int | None = None,
|
||||||
api_base: str | None = GEN_AI_API_ENDPOINT,
|
api_base: str | None = GEN_AI_API_ENDPOINT,
|
||||||
api_version: str | None = GEN_AI_API_VERSION,
|
api_version: str | None = GEN_AI_API_VERSION,
|
||||||
custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE,
|
custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE,
|
||||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
|
||||||
temperature: float = GEN_AI_TEMPERATURE,
|
temperature: float = GEN_AI_TEMPERATURE,
|
||||||
custom_config: dict[str, str] | None = None,
|
custom_config: dict[str, str] | None = None,
|
||||||
extra_headers: dict[str, str] | None = None,
|
extra_headers: dict[str, str] | None = None,
|
||||||
@@ -209,7 +208,17 @@ class DefaultMultiLLM(LLM):
|
|||||||
self._api_base = api_base
|
self._api_base = api_base
|
||||||
self._api_version = api_version
|
self._api_version = api_version
|
||||||
self._custom_llm_provider = custom_llm_provider
|
self._custom_llm_provider = custom_llm_provider
|
||||||
self._max_output_tokens = max_output_tokens
|
|
||||||
|
# This can be used to store the maximum output tkoens for this model.
|
||||||
|
# self._max_output_tokens = (
|
||||||
|
# max_output_tokens
|
||||||
|
# if max_output_tokens is not None
|
||||||
|
# else get_llm_max_output_tokens(
|
||||||
|
# model_map=litellm.model_cost,
|
||||||
|
# model_name=model_name,
|
||||||
|
# model_provider=model_provider,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
self._custom_config = custom_config
|
self._custom_config = custom_config
|
||||||
|
|
||||||
# NOTE: have to set these as environment variables for Litellm since
|
# NOTE: have to set these as environment variables for Litellm since
|
||||||
@@ -228,6 +237,30 @@ class DefaultMultiLLM(LLM):
|
|||||||
def log_model_configs(self) -> None:
|
def log_model_configs(self) -> None:
|
||||||
logger.debug(f"Config: {self.config}")
|
logger.debug(f"Config: {self.config}")
|
||||||
|
|
||||||
|
# def _calculate_max_output_tokens(self, prompt: LanguageModelInput) -> int:
|
||||||
|
# # NOTE: This method can be used for calculating the maximum tokens for the stream,
|
||||||
|
# # but it isn't used in practice due to the computational cost of counting tokens
|
||||||
|
# # and because LLM providers automatically cut off at the maximum output.
|
||||||
|
# # The implementation is kept for potential future use or debugging purposes.
|
||||||
|
|
||||||
|
# # Get max input tokens for the model
|
||||||
|
# max_context_tokens = get_max_input_tokens(
|
||||||
|
# model_name=self.config.model_name, model_provider=self.config.model_provider
|
||||||
|
# )
|
||||||
|
|
||||||
|
# llm_tokenizer = get_tokenizer(
|
||||||
|
# model_name=self.config.model_name,
|
||||||
|
# provider_type=self.config.model_provider,
|
||||||
|
# )
|
||||||
|
# # Calculate tokens in the input prompt
|
||||||
|
# input_tokens = sum(len(llm_tokenizer.encode(str(m))) for m in prompt)
|
||||||
|
|
||||||
|
# # Calculate available tokens for output
|
||||||
|
# available_output_tokens = max_context_tokens - input_tokens
|
||||||
|
|
||||||
|
# # Return the lesser of available tokens or configured max
|
||||||
|
# return min(self._max_output_tokens, available_output_tokens)
|
||||||
|
|
||||||
def _completion(
|
def _completion(
|
||||||
self,
|
self,
|
||||||
prompt: LanguageModelInput,
|
prompt: LanguageModelInput,
|
||||||
@@ -259,9 +292,6 @@ class DefaultMultiLLM(LLM):
|
|||||||
stream=stream,
|
stream=stream,
|
||||||
# model params
|
# model params
|
||||||
temperature=self._temperature,
|
temperature=self._temperature,
|
||||||
max_tokens=self._max_output_tokens
|
|
||||||
if self._max_output_tokens > 0
|
|
||||||
else None,
|
|
||||||
timeout=self._timeout,
|
timeout=self._timeout,
|
||||||
# For now, we don't support parallel tool calls
|
# For now, we don't support parallel tool calls
|
||||||
# NOTE: we can't pass this in if tools are not specified
|
# NOTE: we can't pass this in if tools are not specified
|
||||||
|
@@ -8,7 +8,7 @@ from langchain_core.messages import BaseMessage
|
|||||||
from requests import Timeout
|
from requests import Timeout
|
||||||
|
|
||||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.interfaces import ToolChoiceOptions
|
from danswer.llm.interfaces import ToolChoiceOptions
|
||||||
from danswer.llm.utils import convert_lm_input_to_basic_string
|
from danswer.llm.utils import convert_lm_input_to_basic_string
|
||||||
@@ -38,7 +38,7 @@ class CustomModelServer(LLM):
|
|||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
timeout: int,
|
timeout: int,
|
||||||
endpoint: str | None = GEN_AI_API_ENDPOINT,
|
endpoint: str | None = GEN_AI_API_ENDPOINT,
|
||||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
max_output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
|
||||||
):
|
):
|
||||||
if not endpoint:
|
if not endpoint:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@@ -30,10 +30,10 @@ from litellm.exceptions import Timeout # type: ignore
|
|||||||
from litellm.exceptions import UnprocessableEntityError # type: ignore
|
from litellm.exceptions import UnprocessableEntityError # type: ignore
|
||||||
|
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_DEFAULT_MAX_TOKENS
|
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||||
|
from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
from danswer.file_store.models import ChatFileType
|
from danswer.file_store.models import ChatFileType
|
||||||
from danswer.file_store.models import InMemoryChatFile
|
from danswer.file_store.models import InMemoryChatFile
|
||||||
@@ -336,31 +336,80 @@ def get_llm_max_tokens(
|
|||||||
"""Best effort attempt to get the max tokens for the LLM"""
|
"""Best effort attempt to get the max tokens for the LLM"""
|
||||||
if GEN_AI_MAX_TOKENS:
|
if GEN_AI_MAX_TOKENS:
|
||||||
# This is an override, so always return this
|
# This is an override, so always return this
|
||||||
|
logger.info(f"Using override GEN_AI_MAX_TOKENS: {GEN_AI_MAX_TOKENS}")
|
||||||
return GEN_AI_MAX_TOKENS
|
return GEN_AI_MAX_TOKENS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||||
if not model_obj:
|
if not model_obj:
|
||||||
model_obj = model_map[model_name]
|
model_obj = model_map[model_name]
|
||||||
|
logger.debug(f"Using model object for {model_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||||
|
|
||||||
if "max_input_tokens" in model_obj:
|
if "max_input_tokens" in model_obj:
|
||||||
return model_obj["max_input_tokens"]
|
max_tokens = model_obj["max_input_tokens"]
|
||||||
|
logger.info(
|
||||||
|
f"Max tokens for {model_name}: {max_tokens} (from max_input_tokens)"
|
||||||
|
)
|
||||||
|
return max_tokens
|
||||||
|
|
||||||
if "max_tokens" in model_obj:
|
if "max_tokens" in model_obj:
|
||||||
return model_obj["max_tokens"]
|
max_tokens = model_obj["max_tokens"]
|
||||||
|
logger.info(f"Max tokens for {model_name}: {max_tokens} (from max_tokens)")
|
||||||
|
return max_tokens
|
||||||
|
|
||||||
|
logger.error(f"No max tokens found for LLM: {model_name}")
|
||||||
raise RuntimeError("No max tokens found for LLM")
|
raise RuntimeError("No max tokens found for LLM")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_DEFAULT_MAX_TOKENS}."
|
f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS}."
|
||||||
)
|
)
|
||||||
return GEN_AI_MODEL_DEFAULT_MAX_TOKENS
|
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_max_output_tokens(
|
||||||
|
model_map: dict,
|
||||||
|
model_name: str,
|
||||||
|
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||||
|
) -> int:
|
||||||
|
"""Best effort attempt to get the max output tokens for the LLM"""
|
||||||
|
try:
|
||||||
|
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||||
|
if not model_obj:
|
||||||
|
model_obj = model_map[model_name]
|
||||||
|
logger.debug(f"Using model object for {model_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||||
|
|
||||||
|
if "max_output_tokens" in model_obj:
|
||||||
|
max_output_tokens = model_obj["max_output_tokens"]
|
||||||
|
logger.info(f"Max output tokens for {model_name}: {max_output_tokens}")
|
||||||
|
return max_output_tokens
|
||||||
|
|
||||||
|
# Fallback to a fraction of max_tokens if max_output_tokens is not specified
|
||||||
|
if "max_tokens" in model_obj:
|
||||||
|
max_output_tokens = int(model_obj["max_tokens"] * 0.1)
|
||||||
|
logger.info(
|
||||||
|
f"Fallback max output tokens for {model_name}: {max_output_tokens} (10% of max_tokens)"
|
||||||
|
)
|
||||||
|
return max_output_tokens
|
||||||
|
|
||||||
|
logger.error(f"No max output tokens found for LLM: {model_name}")
|
||||||
|
raise RuntimeError("No max output tokens found for LLM")
|
||||||
|
except Exception:
|
||||||
|
default_output_tokens = int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS)
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to get max output tokens for LLM with name {model_name}. "
|
||||||
|
f"Defaulting to {default_output_tokens} (fallback max tokens)."
|
||||||
|
)
|
||||||
|
return default_output_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_max_input_tokens(
|
def get_max_input_tokens(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_provider: str,
|
model_provider: str,
|
||||||
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
|
||||||
) -> int:
|
) -> int:
|
||||||
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
|
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
|
||||||
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,
|
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,
|
||||||
|
@@ -13,7 +13,7 @@ from danswer.chat.models import LlmDoc
|
|||||||
from danswer.chat.models import SectionRelevancePiece
|
from danswer.chat.models import SectionRelevancePiece
|
||||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_DEFAULT_MAX_TOKENS
|
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||||
from danswer.db.models import Persona
|
from danswer.db.models import Persona
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.dynamic_configs.interface import JSON_ro
|
from danswer.dynamic_configs.interface import JSON_ro
|
||||||
@@ -130,7 +130,7 @@ class SearchTool(Tool):
|
|||||||
# For small context models, don't include additional surrounding context
|
# For small context models, don't include additional surrounding context
|
||||||
# The 3 here for at least minimum 1 above, 1 below and 1 for the middle chunk
|
# The 3 here for at least minimum 1 above, 1 below and 1 for the middle chunk
|
||||||
max_llm_tokens = compute_max_llm_input_tokens(self.llm.config)
|
max_llm_tokens = compute_max_llm_input_tokens(self.llm.config)
|
||||||
if max_llm_tokens < 3 * GEN_AI_MODEL_DEFAULT_MAX_TOKENS:
|
if max_llm_tokens < 3 * GEN_AI_MODEL_FALLBACK_MAX_TOKENS:
|
||||||
self.chunks_above = 0
|
self.chunks_above = 0
|
||||||
self.chunks_below = 0
|
self.chunks_below = 0
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user