From 00fa36d591506b2dfe7ad961313bf74e10daf853 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 29 Aug 2024 11:01:56 -0700 Subject: [PATCH] Get accurate model output max (#2260) * get accurate model output max * squash * udpated max default tokens * rename + use fallbacks * functional * remove max tokens * update naming * comment out function to prevent mypy issues --- backend/danswer/configs/model_configs.py | 9 ++- backend/danswer/llm/chat_llm.py | 42 ++++++++++++-- backend/danswer/llm/custom_llm.py | 4 +- backend/danswer/llm/utils.py | 63 ++++++++++++++++++--- backend/danswer/tools/search/search_tool.py | 4 +- 5 files changed, 103 insertions(+), 19 deletions(-) diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index e7dce6f12f5a..9e323c2b5394 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -80,11 +80,16 @@ GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None # Override the auto-detection of LLM max context length GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None + # Set this to be enough for an answer + quotes. Also used for Chat -GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024) +# This is the minimum token context we will leave for the LLM to generate an answer +GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int( + os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024 +) # Typically, GenAI models nowadays are at least 4K tokens -GEN_AI_MODEL_DEFAULT_MAX_TOKENS = 4096 +GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096 + # Number of tokens from chat history to include at maximum # 3000 should be enough context regardless of use, no need to include as much as possible # as this drives up the cost unnecessarily diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index b0e7d8034e8c..33b1cc24c811 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -28,7 +28,6 @@ from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_API_VERSION from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLMConfig @@ -193,10 +192,10 @@ class DefaultMultiLLM(LLM): timeout: int, model_provider: str, model_name: str, + max_output_tokens: int | None = None, api_base: str | None = GEN_AI_API_ENDPOINT, api_version: str | None = GEN_AI_API_VERSION, custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, temperature: float = GEN_AI_TEMPERATURE, custom_config: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None, @@ -209,7 +208,17 @@ class DefaultMultiLLM(LLM): self._api_base = api_base self._api_version = api_version self._custom_llm_provider = custom_llm_provider - self._max_output_tokens = max_output_tokens + + # This can be used to store the maximum output tkoens for this model. + # self._max_output_tokens = ( + # max_output_tokens + # if max_output_tokens is not None + # else get_llm_max_output_tokens( + # model_map=litellm.model_cost, + # model_name=model_name, + # model_provider=model_provider, + # ) + # ) self._custom_config = custom_config # NOTE: have to set these as environment variables for Litellm since @@ -228,6 +237,30 @@ class DefaultMultiLLM(LLM): def log_model_configs(self) -> None: logger.debug(f"Config: {self.config}") + # def _calculate_max_output_tokens(self, prompt: LanguageModelInput) -> int: + # # NOTE: This method can be used for calculating the maximum tokens for the stream, + # # but it isn't used in practice due to the computational cost of counting tokens + # # and because LLM providers automatically cut off at the maximum output. + # # The implementation is kept for potential future use or debugging purposes. + + # # Get max input tokens for the model + # max_context_tokens = get_max_input_tokens( + # model_name=self.config.model_name, model_provider=self.config.model_provider + # ) + + # llm_tokenizer = get_tokenizer( + # model_name=self.config.model_name, + # provider_type=self.config.model_provider, + # ) + # # Calculate tokens in the input prompt + # input_tokens = sum(len(llm_tokenizer.encode(str(m))) for m in prompt) + + # # Calculate available tokens for output + # available_output_tokens = max_context_tokens - input_tokens + + # # Return the lesser of available tokens or configured max + # return min(self._max_output_tokens, available_output_tokens) + def _completion( self, prompt: LanguageModelInput, @@ -259,9 +292,6 @@ class DefaultMultiLLM(LLM): stream=stream, # model params temperature=self._temperature, - max_tokens=self._max_output_tokens - if self._max_output_tokens > 0 - else None, timeout=self._timeout, # For now, we don't support parallel tool calls # NOTE: we can't pass this in if tools are not specified diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py index da71e0e5b651..967e014a903a 100644 --- a/backend/danswer/llm/custom_llm.py +++ b/backend/danswer/llm/custom_llm.py @@ -8,7 +8,7 @@ from langchain_core.messages import BaseMessage from requests import Timeout from danswer.configs.model_configs import GEN_AI_API_ENDPOINT -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS from danswer.llm.interfaces import LLM from danswer.llm.interfaces import ToolChoiceOptions from danswer.llm.utils import convert_lm_input_to_basic_string @@ -38,7 +38,7 @@ class CustomModelServer(LLM): api_key: str | None, timeout: int, endpoint: str | None = GEN_AI_API_ENDPOINT, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + max_output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS, ): if not endpoint: raise ValueError( diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index db777c8e27bd..82617f3f05b6 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -30,10 +30,10 @@ from litellm.exceptions import Timeout # type: ignore from litellm.exceptions import UnprocessableEntityError # type: ignore from danswer.configs.constants import MessageType -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_DEFAULT_MAX_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS from danswer.db.models import ChatMessage from danswer.file_store.models import ChatFileType from danswer.file_store.models import InMemoryChatFile @@ -336,31 +336,80 @@ def get_llm_max_tokens( """Best effort attempt to get the max tokens for the LLM""" if GEN_AI_MAX_TOKENS: # This is an override, so always return this + logger.info(f"Using override GEN_AI_MAX_TOKENS: {GEN_AI_MAX_TOKENS}") return GEN_AI_MAX_TOKENS try: model_obj = model_map.get(f"{model_provider}/{model_name}") if not model_obj: model_obj = model_map[model_name] + logger.debug(f"Using model object for {model_name}") + else: + logger.debug(f"Using model object for {model_provider}/{model_name}") if "max_input_tokens" in model_obj: - return model_obj["max_input_tokens"] + max_tokens = model_obj["max_input_tokens"] + logger.info( + f"Max tokens for {model_name}: {max_tokens} (from max_input_tokens)" + ) + return max_tokens if "max_tokens" in model_obj: - return model_obj["max_tokens"] + max_tokens = model_obj["max_tokens"] + logger.info(f"Max tokens for {model_name}: {max_tokens} (from max_tokens)") + return max_tokens + logger.error(f"No max tokens found for LLM: {model_name}") raise RuntimeError("No max tokens found for LLM") except Exception: logger.exception( - f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_DEFAULT_MAX_TOKENS}." + f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS}." ) - return GEN_AI_MODEL_DEFAULT_MAX_TOKENS + return GEN_AI_MODEL_FALLBACK_MAX_TOKENS + + +def get_llm_max_output_tokens( + model_map: dict, + model_name: str, + model_provider: str = GEN_AI_MODEL_PROVIDER, +) -> int: + """Best effort attempt to get the max output tokens for the LLM""" + try: + model_obj = model_map.get(f"{model_provider}/{model_name}") + if not model_obj: + model_obj = model_map[model_name] + logger.debug(f"Using model object for {model_name}") + else: + logger.debug(f"Using model object for {model_provider}/{model_name}") + + if "max_output_tokens" in model_obj: + max_output_tokens = model_obj["max_output_tokens"] + logger.info(f"Max output tokens for {model_name}: {max_output_tokens}") + return max_output_tokens + + # Fallback to a fraction of max_tokens if max_output_tokens is not specified + if "max_tokens" in model_obj: + max_output_tokens = int(model_obj["max_tokens"] * 0.1) + logger.info( + f"Fallback max output tokens for {model_name}: {max_output_tokens} (10% of max_tokens)" + ) + return max_output_tokens + + logger.error(f"No max output tokens found for LLM: {model_name}") + raise RuntimeError("No max output tokens found for LLM") + except Exception: + default_output_tokens = int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS) + logger.exception( + f"Failed to get max output tokens for LLM with name {model_name}. " + f"Defaulting to {default_output_tokens} (fallback max tokens)." + ) + return default_output_tokens def get_max_input_tokens( model_name: str, model_provider: str, - output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS, ) -> int: # NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually # returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict, diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index 4ec6ac050e3a..13d3a304b06f 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -13,7 +13,7 @@ from danswer.chat.models import LlmDoc from danswer.chat.models import SectionRelevancePiece from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW -from danswer.configs.model_configs import GEN_AI_MODEL_DEFAULT_MAX_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from danswer.db.models import Persona from danswer.db.models import User from danswer.dynamic_configs.interface import JSON_ro @@ -130,7 +130,7 @@ class SearchTool(Tool): # For small context models, don't include additional surrounding context # The 3 here for at least minimum 1 above, 1 below and 1 for the middle chunk max_llm_tokens = compute_max_llm_input_tokens(self.llm.config) - if max_llm_tokens < 3 * GEN_AI_MODEL_DEFAULT_MAX_TOKENS: + if max_llm_tokens < 3 * GEN_AI_MODEL_FALLBACK_MAX_TOKENS: self.chunks_above = 0 self.chunks_below = 0