mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-02 11:09:20 +02:00
Improve model token limit detection (#3292)
* Properly find context window for ollama llama * Better ollama support + upgrade litellm * Ugprade OpenAI as well * Fix mypy
This commit is contained in:
parent
63d1eefee5
commit
16863de0aa
@ -70,7 +70,9 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Typically, GenAI models nowadays are at least 4K tokens
|
# Typically, GenAI models nowadays are at least 4K tokens
|
||||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
|
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
|
||||||
|
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
|
||||||
|
)
|
||||||
|
|
||||||
# Number of tokens from chat history to include at maximum
|
# Number of tokens from chat history to include at maximum
|
||||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||||
|
@ -26,7 +26,9 @@ from langchain_core.messages.tool import ToolMessage
|
|||||||
from langchain_core.prompt_values import PromptValue
|
from langchain_core.prompt_values import PromptValue
|
||||||
|
|
||||||
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||||
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
|
from danswer.configs.model_configs import (
|
||||||
|
DISABLE_LITELLM_STREAMING,
|
||||||
|
)
|
||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
|
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
@ -161,7 +163,9 @@ def _convert_delta_to_message_chunk(
|
|||||||
|
|
||||||
if role == "user":
|
if role == "user":
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant":
|
# NOTE: if tool calls are present, then it's an assistant.
|
||||||
|
# In Ollama, the role will be None for tool-calls
|
||||||
|
elif role == "assistant" or tool_calls:
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
tool_call = tool_calls[0]
|
tool_call = tool_calls[0]
|
||||||
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
|
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
|
||||||
@ -236,6 +240,7 @@ class DefaultMultiLLM(LLM):
|
|||||||
custom_config: dict[str, str] | None = None,
|
custom_config: dict[str, str] | None = None,
|
||||||
extra_headers: dict[str, str] | None = None,
|
extra_headers: dict[str, str] | None = None,
|
||||||
extra_body: dict | None = LITELLM_EXTRA_BODY,
|
extra_body: dict | None = LITELLM_EXTRA_BODY,
|
||||||
|
model_kwargs: dict[str, Any] | None = None,
|
||||||
long_term_logger: LongTermLogger | None = None,
|
long_term_logger: LongTermLogger | None = None,
|
||||||
):
|
):
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
@ -268,7 +273,7 @@ class DefaultMultiLLM(LLM):
|
|||||||
for k, v in custom_config.items():
|
for k, v in custom_config.items():
|
||||||
os.environ[k] = v
|
os.environ[k] = v
|
||||||
|
|
||||||
model_kwargs: dict[str, Any] = {}
|
model_kwargs = model_kwargs or {}
|
||||||
if extra_headers:
|
if extra_headers:
|
||||||
model_kwargs.update({"extra_headers": extra_headers})
|
model_kwargs.update({"extra_headers": extra_headers})
|
||||||
if extra_body:
|
if extra_body:
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
from danswer.db.engine import get_session_context_manager
|
from danswer.db.engine import get_session_context_manager
|
||||||
from danswer.db.llm import fetch_default_provider
|
from danswer.db.llm import fetch_default_provider
|
||||||
@ -13,6 +16,15 @@ from danswer.utils.headers import build_llm_extra_headers
|
|||||||
from danswer.utils.long_term_log import LongTermLogger
|
from danswer.utils.long_term_log import LongTermLogger
|
||||||
|
|
||||||
|
|
||||||
|
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
|
||||||
|
"""Ollama requires us to specify the max context window.
|
||||||
|
|
||||||
|
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
|
||||||
|
TODO: allow model-specific values to be configured via the UI.
|
||||||
|
"""
|
||||||
|
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
|
||||||
|
|
||||||
|
|
||||||
def get_main_llm_from_tuple(
|
def get_main_llm_from_tuple(
|
||||||
llms: tuple[LLM, LLM],
|
llms: tuple[LLM, LLM],
|
||||||
) -> LLM:
|
) -> LLM:
|
||||||
@ -132,5 +144,6 @@ def get_llm(
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
custom_config=custom_config,
|
custom_config=custom_config,
|
||||||
extra_headers=build_llm_extra_headers(additional_headers),
|
extra_headers=build_llm_extra_headers(additional_headers),
|
||||||
|
model_kwargs=_build_extra_model_kwargs(provider),
|
||||||
long_term_logger=long_term_logger,
|
long_term_logger=long_term_logger,
|
||||||
)
|
)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
@ -385,6 +386,62 @@ def test_llm(llm: LLM) -> str | None:
|
|||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_map() -> dict:
|
||||||
|
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
|
||||||
|
|
||||||
|
# NOTE: we could add additional models here in the future,
|
||||||
|
# but for now there is no point. Ollama allows the user to
|
||||||
|
# to specify their desired max context window, and it's
|
||||||
|
# unlikely to be standard across users even for the same model
|
||||||
|
# (it heavily depends on their hardware). For now, we'll just
|
||||||
|
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
|
||||||
|
# for model_name in [
|
||||||
|
# "llama3.2",
|
||||||
|
# "llama3.2:1b",
|
||||||
|
# "llama3.2:3b",
|
||||||
|
# "llama3.2:11b",
|
||||||
|
# "llama3.2:90b",
|
||||||
|
# ]:
|
||||||
|
# starting_map[f"ollama/{model_name}"] = {
|
||||||
|
# "max_tokens": 128000,
|
||||||
|
# "max_input_tokens": 128000,
|
||||||
|
# "max_output_tokens": 128000,
|
||||||
|
# }
|
||||||
|
|
||||||
|
return starting_map
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_extra_provider_from_model_name(model_name: str) -> str:
|
||||||
|
return model_name.split("/")[1] if "/" in model_name else model_name
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_colon_from_model_name(model_name: str) -> str:
|
||||||
|
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
||||||
|
|
||||||
|
|
||||||
|
def _find_model_obj(
|
||||||
|
model_map: dict, provider: str, model_names: list[str | None]
|
||||||
|
) -> dict | None:
|
||||||
|
# Filter out None values and deduplicate model names
|
||||||
|
filtered_model_names = [name for name in model_names if name]
|
||||||
|
|
||||||
|
# First try all model names with provider prefix
|
||||||
|
for model_name in filtered_model_names:
|
||||||
|
model_obj = model_map.get(f"{provider}/{model_name}")
|
||||||
|
if model_obj:
|
||||||
|
logger.debug(f"Using model object for {provider}/{model_name}")
|
||||||
|
return model_obj
|
||||||
|
|
||||||
|
# Then try all model names without provider prefix
|
||||||
|
for model_name in filtered_model_names:
|
||||||
|
model_obj = model_map.get(model_name)
|
||||||
|
if model_obj:
|
||||||
|
logger.debug(f"Using model object for {model_name}")
|
||||||
|
return model_obj
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_llm_max_tokens(
|
def get_llm_max_tokens(
|
||||||
model_map: dict,
|
model_map: dict,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -397,22 +454,22 @@ def get_llm_max_tokens(
|
|||||||
return GEN_AI_MAX_TOKENS
|
return GEN_AI_MAX_TOKENS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
|
||||||
if model_obj:
|
model_name
|
||||||
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
)
|
||||||
|
model_obj = _find_model_obj(
|
||||||
if not model_obj:
|
model_map,
|
||||||
model_obj = model_map.get(model_name)
|
model_provider,
|
||||||
if model_obj:
|
[
|
||||||
logger.debug(f"Using model object for {model_name}")
|
model_name,
|
||||||
|
# Remove leading extra provider. Usually for cases where user has a
|
||||||
if not model_obj:
|
# customer model proxy which appends another prefix
|
||||||
model_name_split = model_name.split("/")
|
extra_provider_stripped_model_name,
|
||||||
if len(model_name_split) > 1:
|
# remove :XXXX from the end, if present. Needed for ollama.
|
||||||
model_obj = model_map.get(model_name_split[1])
|
_strip_colon_from_model_name(model_name),
|
||||||
if model_obj:
|
_strip_colon_from_model_name(extra_provider_stripped_model_name),
|
||||||
logger.debug(f"Using model object for {model_name_split[1]}")
|
],
|
||||||
|
)
|
||||||
if not model_obj:
|
if not model_obj:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"No litellm entry found for {model_provider}/{model_name}"
|
f"No litellm entry found for {model_provider}/{model_name}"
|
||||||
@ -488,7 +545,7 @@ def get_max_input_tokens(
|
|||||||
# `model_cost` dict is a named public interface:
|
# `model_cost` dict is a named public interface:
|
||||||
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
|
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
|
||||||
# model_map is litellm.model_cost
|
# model_map is litellm.model_cost
|
||||||
litellm_model_map = litellm.model_cost
|
litellm_model_map = get_model_map()
|
||||||
|
|
||||||
input_toks = (
|
input_toks = (
|
||||||
get_llm_max_tokens(
|
get_llm_max_tokens(
|
||||||
|
@ -29,7 +29,7 @@ trafilatura==1.12.2
|
|||||||
langchain==0.1.17
|
langchain==0.1.17
|
||||||
langchain-core==0.1.50
|
langchain-core==0.1.50
|
||||||
langchain-text-splitters==0.0.1
|
langchain-text-splitters==0.0.1
|
||||||
litellm==1.50.2
|
litellm==1.53.1
|
||||||
lxml==5.3.0
|
lxml==5.3.0
|
||||||
lxml_html_clean==0.2.2
|
lxml_html_clean==0.2.2
|
||||||
llama-index==0.9.45
|
llama-index==0.9.45
|
||||||
@ -38,7 +38,7 @@ msal==1.28.0
|
|||||||
nltk==3.8.1
|
nltk==3.8.1
|
||||||
Office365-REST-Python-Client==2.5.9
|
Office365-REST-Python-Client==2.5.9
|
||||||
oauthlib==3.2.2
|
oauthlib==3.2.2
|
||||||
openai==1.52.2
|
openai==1.55.3
|
||||||
openpyxl==3.1.2
|
openpyxl==3.1.2
|
||||||
playwright==1.41.2
|
playwright==1.41.2
|
||||||
psutil==5.9.5
|
psutil==5.9.5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user