Improve model token limit detection (#3292)

* Properly find context window for ollama llama

* Better ollama support + upgrade litellm

* Ugprade OpenAI as well

* Fix mypy
This commit is contained in:
Chris Weaver 2024-11-29 20:42:56 -08:00 committed by GitHub
parent 63d1eefee5
commit 16863de0aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 100 additions and 23 deletions

View File

@ -70,7 +70,9 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
) )
# Typically, GenAI models nowadays are at least 4K tokens # Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096 GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
)
# Number of tokens from chat history to include at maximum # Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible # 3000 should be enough context regardless of use, no need to include as much as possible

View File

@ -26,7 +26,9 @@ from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue from langchain_core.prompt_values import PromptValue
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING from danswer.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import LITELLM_EXTRA_BODY from danswer.configs.model_configs import LITELLM_EXTRA_BODY
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
@ -161,7 +163,9 @@ def _convert_delta_to_message_chunk(
if role == "user": if role == "user":
return HumanMessageChunk(content=content) return HumanMessageChunk(content=content)
elif role == "assistant": # NOTE: if tool calls are present, then it's an assistant.
# In Ollama, the role will be None for tool-calls
elif role == "assistant" or tool_calls:
if tool_calls: if tool_calls:
tool_call = tool_calls[0] tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or "" tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
@ -236,6 +240,7 @@ class DefaultMultiLLM(LLM):
custom_config: dict[str, str] | None = None, custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY, extra_body: dict | None = LITELLM_EXTRA_BODY,
model_kwargs: dict[str, Any] | None = None,
long_term_logger: LongTermLogger | None = None, long_term_logger: LongTermLogger | None = None,
): ):
self._timeout = timeout self._timeout = timeout
@ -268,7 +273,7 @@ class DefaultMultiLLM(LLM):
for k, v in custom_config.items(): for k, v in custom_config.items():
os.environ[k] = v os.environ[k] = v
model_kwargs: dict[str, Any] = {} model_kwargs = model_kwargs or {}
if extra_headers: if extra_headers:
model_kwargs.update({"extra_headers": extra_headers}) model_kwargs.update({"extra_headers": extra_headers})
if extra_body: if extra_body:

View File

@ -1,5 +1,8 @@
from typing import Any
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.engine import get_session_context_manager from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider from danswer.db.llm import fetch_default_provider
@ -13,6 +16,15 @@ from danswer.utils.headers import build_llm_extra_headers
from danswer.utils.long_term_log import LongTermLogger from danswer.utils.long_term_log import LongTermLogger
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
"""Ollama requires us to specify the max context window.
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
TODO: allow model-specific values to be configured via the UI.
"""
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
def get_main_llm_from_tuple( def get_main_llm_from_tuple(
llms: tuple[LLM, LLM], llms: tuple[LLM, LLM],
) -> LLM: ) -> LLM:
@ -132,5 +144,6 @@ def get_llm(
temperature=temperature, temperature=temperature,
custom_config=custom_config, custom_config=custom_config,
extra_headers=build_llm_extra_headers(additional_headers), extra_headers=build_llm_extra_headers(additional_headers),
model_kwargs=_build_extra_model_kwargs(provider),
long_term_logger=long_term_logger, long_term_logger=long_term_logger,
) )

View File

@ -1,3 +1,4 @@
import copy
import io import io
import json import json
from collections.abc import Callable from collections.abc import Callable
@ -385,6 +386,62 @@ def test_llm(llm: LLM) -> str | None:
return error_msg return error_msg
def get_model_map() -> dict:
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
# NOTE: we could add additional models here in the future,
# but for now there is no point. Ollama allows the user to
# to specify their desired max context window, and it's
# unlikely to be standard across users even for the same model
# (it heavily depends on their hardware). For now, we'll just
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
# for model_name in [
# "llama3.2",
# "llama3.2:1b",
# "llama3.2:3b",
# "llama3.2:11b",
# "llama3.2:90b",
# ]:
# starting_map[f"ollama/{model_name}"] = {
# "max_tokens": 128000,
# "max_input_tokens": 128000,
# "max_output_tokens": 128000,
# }
return starting_map
def _strip_extra_provider_from_model_name(model_name: str) -> str:
return model_name.split("/")[1] if "/" in model_name else model_name
def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
def _find_model_obj(
model_map: dict, provider: str, model_names: list[str | None]
) -> dict | None:
# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]
# First try all model names with provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(f"{provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {provider}/{model_name}")
return model_obj
# Then try all model names without provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
return model_obj
return None
def get_llm_max_tokens( def get_llm_max_tokens(
model_map: dict, model_map: dict,
model_name: str, model_name: str,
@ -397,22 +454,22 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS return GEN_AI_MAX_TOKENS
try: try:
model_obj = model_map.get(f"{model_provider}/{model_name}") extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
if model_obj: model_name
logger.debug(f"Using model object for {model_provider}/{model_name}") )
model_obj = _find_model_obj(
if not model_obj: model_map,
model_obj = model_map.get(model_name) model_provider,
if model_obj: [
logger.debug(f"Using model object for {model_name}") model_name,
# Remove leading extra provider. Usually for cases where user has a
if not model_obj: # customer model proxy which appends another prefix
model_name_split = model_name.split("/") extra_provider_stripped_model_name,
if len(model_name_split) > 1: # remove :XXXX from the end, if present. Needed for ollama.
model_obj = model_map.get(model_name_split[1]) _strip_colon_from_model_name(model_name),
if model_obj: _strip_colon_from_model_name(extra_provider_stripped_model_name),
logger.debug(f"Using model object for {model_name_split[1]}") ],
)
if not model_obj: if not model_obj:
raise RuntimeError( raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}" f"No litellm entry found for {model_provider}/{model_name}"
@ -488,7 +545,7 @@ def get_max_input_tokens(
# `model_cost` dict is a named public interface: # `model_cost` dict is a named public interface:
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost # https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
# model_map is litellm.model_cost # model_map is litellm.model_cost
litellm_model_map = litellm.model_cost litellm_model_map = get_model_map()
input_toks = ( input_toks = (
get_llm_max_tokens( get_llm_max_tokens(

View File

@ -29,7 +29,7 @@ trafilatura==1.12.2
langchain==0.1.17 langchain==0.1.17
langchain-core==0.1.50 langchain-core==0.1.50
langchain-text-splitters==0.0.1 langchain-text-splitters==0.0.1
litellm==1.50.2 litellm==1.53.1
lxml==5.3.0 lxml==5.3.0
lxml_html_clean==0.2.2 lxml_html_clean==0.2.2
llama-index==0.9.45 llama-index==0.9.45
@ -38,7 +38,7 @@ msal==1.28.0
nltk==3.8.1 nltk==3.8.1
Office365-REST-Python-Client==2.5.9 Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2 oauthlib==3.2.2
openai==1.52.2 openai==1.55.3
openpyxl==3.1.2 openpyxl==3.1.2
playwright==1.41.2 playwright==1.41.2
psutil==5.9.5 psutil==5.9.5