diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index ee2f582c9..d75209557 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -55,7 +55,7 @@ def create_chat_chain( id_to_msg = {msg.id: msg for msg in all_chat_messages} if not all_chat_messages: - raise ValueError("No messages in Chat Session") + raise RuntimeError("No messages in Chat Session") root_message = all_chat_messages[0] if root_message.parent_message is not None: diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index c07b708bb..50eb3cef8 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -241,6 +241,7 @@ def test_llm(llm: LLM) -> str | None: def get_llm_max_tokens( + model_map: dict, model_name: str | None = GEN_AI_MODEL_VERSION, model_provider: str = GEN_AI_MODEL_PROVIDER, ) -> int: @@ -250,18 +251,12 @@ def get_llm_max_tokens( return GEN_AI_MAX_TOKENS model_name = model_name or get_default_llm_version()[0] - # NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually - # returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict, - # and there is no other interface to get what we want. This should be okay though, since the - # `model_cost` dict is a named public interface: - # https://litellm.vercel.app/docs/completion/token_usage#7-model_cost - litellm_model_map = litellm.model_cost try: if model_provider == "openai": - model_obj = litellm_model_map[model_name] + model_obj = model_map[model_name] else: - model_obj = litellm_model_map[f"{model_provider}/{model_name}"] + model_obj = model_map[f"{model_provider}/{model_name}"] if "max_tokens" in model_obj: return model_obj["max_tokens"] elif "max_input_tokens" in model_obj and "max_output_tokens" in model_obj: @@ -275,17 +270,73 @@ def get_llm_max_tokens( return 4096 +def get_llm_max_input_tokens( + output_tokens: int, + model_map: dict, + model_name: str | None = GEN_AI_MODEL_VERSION, + model_provider: str = GEN_AI_MODEL_PROVIDER, +) -> int | None: + try: + if model_provider == "openai": + model_obj = model_map[model_name] + else: + model_obj = model_map[f"{model_provider}/{model_name}"] + + max_in = model_obj.get("max_input_tokens") + max_out = model_obj.get("max_output_tokens") + if max_in is None or max_out is None: + # Can't calculate precisely, use the fallback method + return None + + # Some APIs may not actually work like this, but it's a safer approach generally speaking + # since worst case we remove some extra tokens from the input space + output_token_debt = 0 + if output_tokens > max_out: + logger.warning( + "More output tokens requested than model is likely able to handle" + ) + output_token_debt = output_tokens - max_out + + remaining_max_input_tokens = max_in - output_token_debt + return remaining_max_input_tokens + + except Exception: + # We can try the less accurate approach if this fails + return None + + def get_max_input_tokens( model_name: str | None = GEN_AI_MODEL_VERSION, model_provider: str = GEN_AI_MODEL_PROVIDER, output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, ) -> int: + # NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually + # returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict, + # and there is no other interface to get what we want. This should be okay though, since the + # `model_cost` dict is a named public interface: + # https://litellm.vercel.app/docs/completion/token_usage#7-model_cost + # model_map is litellm.model_cost + litellm_model_map = litellm.model_cost + model_name = model_name or get_default_llm_version()[0] - input_toks = ( - get_llm_max_tokens(model_name=model_name, model_provider=model_provider) - - output_tokens + + input_toks = get_llm_max_input_tokens( + output_tokens=output_tokens, + model_map=litellm_model_map, + model_name=model_name, + model_provider=model_provider, ) + if input_toks is None: + input_toks = ( + get_llm_max_tokens( + model_name=model_name, + model_provider=model_provider, + model_map=litellm_model_map, + ) + - output_tokens + ) + if input_toks <= 0: raise RuntimeError("No tokens for input for the LLM given settings") diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 008ee8480..9ca893876 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -24,7 +24,7 @@ httpx-oauth==0.11.2 huggingface-hub==0.20.1 jira==3.5.1 langchain==0.1.9 -litellm==1.27.10 +litellm==1.34.8 llama-index==0.9.45 Mako==1.2.4 msal==1.26.0