Get accurate model output max (#2260)

* get accurate model output max

* squash

* udpated max default tokens

* rename + use fallbacks

* functional

* remove max tokens

* update naming

* comment out function to prevent mypy issues
This commit is contained in:
pablodanswer
2024-08-29 11:01:56 -07:00
committed by GitHub
parent 3b596fd6a8
commit 00fa36d591
5 changed files with 103 additions and 19 deletions

View File

@@ -80,11 +80,16 @@ GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
# Override the auto-detection of LLM max context length
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
# This is the minimum token context we will leave for the LLM to generate an answer
GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024
)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_DEFAULT_MAX_TOKENS = 4096
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible
# as this drives up the cost unnecessarily

View File

@@ -28,7 +28,6 @@ from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_API_VERSION
from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
@@ -193,10 +192,10 @@ class DefaultMultiLLM(LLM):
timeout: int,
model_provider: str,
model_name: str,
max_output_tokens: int | None = None,
api_base: str | None = GEN_AI_API_ENDPOINT,
api_version: str | None = GEN_AI_API_VERSION,
custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
temperature: float = GEN_AI_TEMPERATURE,
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
@@ -209,7 +208,17 @@ class DefaultMultiLLM(LLM):
self._api_base = api_base
self._api_version = api_version
self._custom_llm_provider = custom_llm_provider
self._max_output_tokens = max_output_tokens
# This can be used to store the maximum output tkoens for this model.
# self._max_output_tokens = (
# max_output_tokens
# if max_output_tokens is not None
# else get_llm_max_output_tokens(
# model_map=litellm.model_cost,
# model_name=model_name,
# model_provider=model_provider,
# )
# )
self._custom_config = custom_config
# NOTE: have to set these as environment variables for Litellm since
@@ -228,6 +237,30 @@ class DefaultMultiLLM(LLM):
def log_model_configs(self) -> None:
logger.debug(f"Config: {self.config}")
# def _calculate_max_output_tokens(self, prompt: LanguageModelInput) -> int:
# # NOTE: This method can be used for calculating the maximum tokens for the stream,
# # but it isn't used in practice due to the computational cost of counting tokens
# # and because LLM providers automatically cut off at the maximum output.
# # The implementation is kept for potential future use or debugging purposes.
# # Get max input tokens for the model
# max_context_tokens = get_max_input_tokens(
# model_name=self.config.model_name, model_provider=self.config.model_provider
# )
# llm_tokenizer = get_tokenizer(
# model_name=self.config.model_name,
# provider_type=self.config.model_provider,
# )
# # Calculate tokens in the input prompt
# input_tokens = sum(len(llm_tokenizer.encode(str(m))) for m in prompt)
# # Calculate available tokens for output
# available_output_tokens = max_context_tokens - input_tokens
# # Return the lesser of available tokens or configured max
# return min(self._max_output_tokens, available_output_tokens)
def _completion(
self,
prompt: LanguageModelInput,
@@ -259,9 +292,6 @@ class DefaultMultiLLM(LLM):
stream=stream,
# model params
temperature=self._temperature,
max_tokens=self._max_output_tokens
if self._max_output_tokens > 0
else None,
timeout=self._timeout,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified

View File

@@ -8,7 +8,7 @@ from langchain_core.messages import BaseMessage
from requests import Timeout
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.llm.utils import convert_lm_input_to_basic_string
@@ -38,7 +38,7 @@ class CustomModelServer(LLM):
api_key: str | None,
timeout: int,
endpoint: str | None = GEN_AI_API_ENDPOINT,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
max_output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
):
if not endpoint:
raise ValueError(

View File

@@ -30,10 +30,10 @@ from litellm.exceptions import Timeout # type: ignore
from litellm.exceptions import UnprocessableEntityError # type: ignore
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_DEFAULT_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from danswer.db.models import ChatMessage
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
@@ -336,31 +336,80 @@ def get_llm_max_tokens(
"""Best effort attempt to get the max tokens for the LLM"""
if GEN_AI_MAX_TOKENS:
# This is an override, so always return this
logger.info(f"Using override GEN_AI_MAX_TOKENS: {GEN_AI_MAX_TOKENS}")
return GEN_AI_MAX_TOKENS
try:
model_obj = model_map.get(f"{model_provider}/{model_name}")
if not model_obj:
model_obj = model_map[model_name]
logger.debug(f"Using model object for {model_name}")
else:
logger.debug(f"Using model object for {model_provider}/{model_name}")
if "max_input_tokens" in model_obj:
return model_obj["max_input_tokens"]
max_tokens = model_obj["max_input_tokens"]
logger.info(
f"Max tokens for {model_name}: {max_tokens} (from max_input_tokens)"
)
return max_tokens
if "max_tokens" in model_obj:
return model_obj["max_tokens"]
max_tokens = model_obj["max_tokens"]
logger.info(f"Max tokens for {model_name}: {max_tokens} (from max_tokens)")
return max_tokens
logger.error(f"No max tokens found for LLM: {model_name}")
raise RuntimeError("No max tokens found for LLM")
except Exception:
logger.exception(
f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_DEFAULT_MAX_TOKENS}."
f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS}."
)
return GEN_AI_MODEL_DEFAULT_MAX_TOKENS
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
def get_llm_max_output_tokens(
model_map: dict,
model_name: str,
model_provider: str = GEN_AI_MODEL_PROVIDER,
) -> int:
"""Best effort attempt to get the max output tokens for the LLM"""
try:
model_obj = model_map.get(f"{model_provider}/{model_name}")
if not model_obj:
model_obj = model_map[model_name]
logger.debug(f"Using model object for {model_name}")
else:
logger.debug(f"Using model object for {model_provider}/{model_name}")
if "max_output_tokens" in model_obj:
max_output_tokens = model_obj["max_output_tokens"]
logger.info(f"Max output tokens for {model_name}: {max_output_tokens}")
return max_output_tokens
# Fallback to a fraction of max_tokens if max_output_tokens is not specified
if "max_tokens" in model_obj:
max_output_tokens = int(model_obj["max_tokens"] * 0.1)
logger.info(
f"Fallback max output tokens for {model_name}: {max_output_tokens} (10% of max_tokens)"
)
return max_output_tokens
logger.error(f"No max output tokens found for LLM: {model_name}")
raise RuntimeError("No max output tokens found for LLM")
except Exception:
default_output_tokens = int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS)
logger.exception(
f"Failed to get max output tokens for LLM with name {model_name}. "
f"Defaulting to {default_output_tokens} (fallback max tokens)."
)
return default_output_tokens
def get_max_input_tokens(
model_name: str,
model_provider: str,
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
) -> int:
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,

View File

@@ -13,7 +13,7 @@ from danswer.chat.models import LlmDoc
from danswer.chat.models import SectionRelevancePiece
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.configs.model_configs import GEN_AI_MODEL_DEFAULT_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.dynamic_configs.interface import JSON_ro
@@ -130,7 +130,7 @@ class SearchTool(Tool):
# For small context models, don't include additional surrounding context
# The 3 here for at least minimum 1 above, 1 below and 1 for the middle chunk
max_llm_tokens = compute_max_llm_input_tokens(self.llm.config)
if max_llm_tokens < 3 * GEN_AI_MODEL_DEFAULT_MAX_TOKENS:
if max_llm_tokens < 3 * GEN_AI_MODEL_FALLBACK_MAX_TOKENS:
self.chunks_above = 0
self.chunks_below = 0