mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 19:23:02 +01:00
526 lines
18 KiB
Python
526 lines
18 KiB
Python
import copy
|
|
import json
|
|
from collections.abc import Callable
|
|
from collections.abc import Iterator
|
|
from typing import Any
|
|
from typing import cast
|
|
|
|
import litellm # type: ignore
|
|
import tiktoken
|
|
from langchain.prompts.base import StringPromptValue
|
|
from langchain.prompts.chat import ChatPromptValue
|
|
from langchain.schema import PromptValue
|
|
from langchain.schema.language_model import LanguageModelInput
|
|
from langchain.schema.messages import AIMessage
|
|
from langchain.schema.messages import BaseMessage
|
|
from langchain.schema.messages import HumanMessage
|
|
from langchain.schema.messages import SystemMessage
|
|
from litellm.exceptions import APIConnectionError # type: ignore
|
|
from litellm.exceptions import APIError # type: ignore
|
|
from litellm.exceptions import AuthenticationError # type: ignore
|
|
from litellm.exceptions import BadRequestError # type: ignore
|
|
from litellm.exceptions import BudgetExceededError # type: ignore
|
|
from litellm.exceptions import ContentPolicyViolationError # type: ignore
|
|
from litellm.exceptions import ContextWindowExceededError # type: ignore
|
|
from litellm.exceptions import NotFoundError # type: ignore
|
|
from litellm.exceptions import PermissionDeniedError # type: ignore
|
|
from litellm.exceptions import RateLimitError # type: ignore
|
|
from litellm.exceptions import Timeout # type: ignore
|
|
from litellm.exceptions import UnprocessableEntityError # type: ignore
|
|
|
|
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
|
|
from onyx.configs.constants import MessageType
|
|
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
|
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
|
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
|
from onyx.file_store.models import ChatFileType
|
|
from onyx.file_store.models import InMemoryChatFile
|
|
from onyx.llm.interfaces import LLM
|
|
from onyx.prompts.constants import CODE_BLOCK_PAT
|
|
from onyx.utils.b64 import get_image_type
|
|
from onyx.utils.b64 import get_image_type_from_bytes
|
|
from onyx.utils.logger import setup_logger
|
|
from shared_configs.configs import LOG_LEVEL
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def litellm_exception_to_error_msg(
|
|
e: Exception,
|
|
llm: LLM,
|
|
fallback_to_error_msg: bool = False,
|
|
custom_error_msg_mappings: dict[str, str]
|
|
| None = LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS,
|
|
) -> str:
|
|
error_msg = str(e)
|
|
|
|
if custom_error_msg_mappings:
|
|
for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items():
|
|
if error_msg_pattern in error_msg:
|
|
return custom_error_msg
|
|
|
|
if isinstance(e, BadRequestError):
|
|
error_msg = "Bad request: The server couldn't process your request. Please check your input."
|
|
elif isinstance(e, AuthenticationError):
|
|
error_msg = "Authentication failed: Please check your API key and credentials."
|
|
elif isinstance(e, PermissionDeniedError):
|
|
error_msg = (
|
|
"Permission denied: You don't have the necessary permissions for this operation."
|
|
"Ensure you have access to this model."
|
|
)
|
|
elif isinstance(e, NotFoundError):
|
|
error_msg = "Resource not found: The requested resource doesn't exist."
|
|
elif isinstance(e, UnprocessableEntityError):
|
|
error_msg = "Unprocessable entity: The server couldn't process your request due to semantic errors."
|
|
elif isinstance(e, RateLimitError):
|
|
error_msg = (
|
|
"Rate limit exceeded: Please slow down your requests and try again later."
|
|
)
|
|
elif isinstance(e, ContextWindowExceededError):
|
|
error_msg = (
|
|
"Context window exceeded: Your input is too long for the model to process."
|
|
)
|
|
if llm is not None:
|
|
try:
|
|
max_context = get_max_input_tokens(
|
|
model_name=llm.config.model_name,
|
|
model_provider=llm.config.model_provider,
|
|
)
|
|
error_msg += f"Your invoked model ({llm.config.model_name}) has a maximum context size of {max_context}"
|
|
except Exception:
|
|
logger.warning(
|
|
"Unable to get maximum input token for LiteLLM excpetion handling"
|
|
)
|
|
elif isinstance(e, ContentPolicyViolationError):
|
|
error_msg = "Content policy violation: Your request violates the content policy. Please revise your input."
|
|
elif isinstance(e, APIConnectionError):
|
|
error_msg = "API connection error: Failed to connect to the API. Please check your internet connection."
|
|
elif isinstance(e, BudgetExceededError):
|
|
error_msg = (
|
|
"Budget exceeded: You've exceeded your allocated budget for API usage."
|
|
)
|
|
elif isinstance(e, Timeout):
|
|
error_msg = "Request timed out: The operation took too long to complete. Please try again."
|
|
elif isinstance(e, APIError):
|
|
error_msg = f"API error: An error occurred while communicating with the API. Details: {str(e)}"
|
|
elif not fallback_to_error_msg:
|
|
error_msg = "An unexpected error occurred while processing your request. Please try again later."
|
|
return error_msg
|
|
|
|
|
|
def _build_content(
|
|
message: str,
|
|
files: list[InMemoryChatFile] | None = None,
|
|
) -> str:
|
|
"""Applies all non-image files."""
|
|
if not files:
|
|
return message
|
|
|
|
text_files = [
|
|
file
|
|
for file in files
|
|
if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV)
|
|
]
|
|
|
|
if not text_files:
|
|
return message
|
|
|
|
final_message_with_files = "FILES:\n\n"
|
|
for file in text_files:
|
|
file_content = file.content.decode("utf-8")
|
|
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
|
|
final_message_with_files += (
|
|
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
|
|
)
|
|
|
|
return final_message_with_files + message
|
|
|
|
|
|
def build_content_with_imgs(
|
|
message: str,
|
|
files: list[InMemoryChatFile] | None = None,
|
|
img_urls: list[str] | None = None,
|
|
b64_imgs: list[str] | None = None,
|
|
message_type: MessageType = MessageType.USER,
|
|
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
|
files = files or []
|
|
|
|
# Only include image files for user messages
|
|
img_files = (
|
|
[file for file in files if file.file_type == ChatFileType.IMAGE]
|
|
if message_type == MessageType.USER
|
|
else []
|
|
)
|
|
|
|
img_urls = img_urls or []
|
|
b64_imgs = b64_imgs or []
|
|
|
|
message_main_content = _build_content(message, files)
|
|
|
|
if not img_files and not img_urls:
|
|
return message_main_content
|
|
|
|
return cast(
|
|
list[str | dict[str, Any]],
|
|
[
|
|
{
|
|
"type": "text",
|
|
"text": message_main_content,
|
|
},
|
|
]
|
|
+ [
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": (
|
|
f"data:{get_image_type_from_bytes(file.content)};"
|
|
f"base64,{file.to_base64()}"
|
|
),
|
|
},
|
|
}
|
|
for file in img_files
|
|
]
|
|
+ [
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
|
|
},
|
|
}
|
|
for b64_img in b64_imgs
|
|
]
|
|
+ [
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": url,
|
|
},
|
|
}
|
|
for url in img_urls
|
|
],
|
|
)
|
|
|
|
|
|
def message_to_prompt_and_imgs(message: BaseMessage) -> tuple[str, list[str]]:
|
|
if isinstance(message.content, str):
|
|
return message.content, []
|
|
|
|
imgs = []
|
|
texts = []
|
|
for part in message.content:
|
|
if isinstance(part, dict):
|
|
if part.get("type") == "image_url":
|
|
img_url = part.get("image_url", {}).get("url")
|
|
if img_url:
|
|
imgs.append(img_url)
|
|
elif part.get("type") == "text":
|
|
text = part.get("text")
|
|
if text:
|
|
texts.append(text)
|
|
else:
|
|
texts.append(part)
|
|
|
|
return "".join(texts), imgs
|
|
|
|
|
|
def dict_based_prompt_to_langchain_prompt(
|
|
messages: list[dict[str, str]]
|
|
) -> list[BaseMessage]:
|
|
prompt: list[BaseMessage] = []
|
|
for message in messages:
|
|
role = message.get("role")
|
|
content = message.get("content")
|
|
if not role:
|
|
raise ValueError(f"Message missing `role`: {message}")
|
|
if not content:
|
|
raise ValueError(f"Message missing `content`: {message}")
|
|
elif role == "user":
|
|
prompt.append(HumanMessage(content=content))
|
|
elif role == "system":
|
|
prompt.append(SystemMessage(content=content))
|
|
elif role == "assistant":
|
|
prompt.append(AIMessage(content=content))
|
|
else:
|
|
raise ValueError(f"Unknown role: {role}")
|
|
return prompt
|
|
|
|
|
|
def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
|
|
return [HumanMessage(content=message)]
|
|
|
|
|
|
def convert_lm_input_to_basic_string(lm_input: LanguageModelInput) -> str:
|
|
"""Heavily inspired by:
|
|
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
|
|
"""
|
|
prompt_value = None
|
|
if isinstance(lm_input, PromptValue):
|
|
prompt_value = lm_input
|
|
elif isinstance(lm_input, str):
|
|
prompt_value = StringPromptValue(text=lm_input)
|
|
elif isinstance(lm_input, list):
|
|
prompt_value = ChatPromptValue(messages=lm_input)
|
|
|
|
if prompt_value is None:
|
|
raise ValueError(
|
|
f"Invalid input type {type(lm_input)}. "
|
|
"Must be a PromptValue, str, or list of BaseMessages."
|
|
)
|
|
|
|
return prompt_value.to_string()
|
|
|
|
|
|
def message_to_string(message: BaseMessage) -> str:
|
|
if not isinstance(message.content, str):
|
|
raise RuntimeError("LLM message not in expected format.")
|
|
|
|
return message.content
|
|
|
|
|
|
def message_generator_to_string_generator(
|
|
messages: Iterator[BaseMessage],
|
|
) -> Iterator[str]:
|
|
for message in messages:
|
|
yield message_to_string(message)
|
|
|
|
|
|
def should_be_verbose() -> bool:
|
|
return LOG_LEVEL == "debug"
|
|
|
|
|
|
# estimate of the number of tokens in an image url
|
|
# is correct when downsampling is used. Is very wrong when OpenAI does not downsample
|
|
# TODO: improve this
|
|
_IMG_TOKENS = 85
|
|
|
|
|
|
def check_message_tokens(
|
|
message: BaseMessage, encode_fn: Callable[[str], list] | None = None
|
|
) -> int:
|
|
if isinstance(message.content, str):
|
|
return check_number_of_tokens(message.content, encode_fn)
|
|
|
|
total_tokens = 0
|
|
for part in message.content:
|
|
if isinstance(part, str):
|
|
total_tokens += check_number_of_tokens(part, encode_fn)
|
|
continue
|
|
|
|
if part["type"] == "text":
|
|
total_tokens += check_number_of_tokens(part["text"], encode_fn)
|
|
elif part["type"] == "image_url":
|
|
total_tokens += _IMG_TOKENS
|
|
|
|
if isinstance(message, AIMessage) and message.tool_calls:
|
|
for tool_call in message.tool_calls:
|
|
total_tokens += check_number_of_tokens(
|
|
json.dumps(tool_call["args"]), encode_fn
|
|
)
|
|
total_tokens += check_number_of_tokens(tool_call["name"], encode_fn)
|
|
|
|
return total_tokens
|
|
|
|
|
|
def check_number_of_tokens(
|
|
text: str, encode_fn: Callable[[str], list] | None = None
|
|
) -> int:
|
|
"""Gets the number of tokens in the provided text, using the provided encoding
|
|
function. If none is provided, default to the tiktoken encoder used by GPT-3.5
|
|
and GPT-4.
|
|
"""
|
|
|
|
if encode_fn is None:
|
|
encode_fn = tiktoken.get_encoding("cl100k_base").encode
|
|
|
|
return len(encode_fn(text))
|
|
|
|
|
|
def test_llm(llm: LLM) -> str | None:
|
|
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
|
error_msg = None
|
|
for _ in range(2):
|
|
try:
|
|
llm.invoke("Do not respond")
|
|
return None
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.warning(f"Failed to call LLM with the following error: {error_msg}")
|
|
|
|
return error_msg
|
|
|
|
|
|
def get_model_map() -> dict:
|
|
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
|
|
|
|
# NOTE: we could add additional models here in the future,
|
|
# but for now there is no point. Ollama allows the user to
|
|
# to specify their desired max context window, and it's
|
|
# unlikely to be standard across users even for the same model
|
|
# (it heavily depends on their hardware). For now, we'll just
|
|
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
|
|
# for model_name in [
|
|
# "llama3.2",
|
|
# "llama3.2:1b",
|
|
# "llama3.2:3b",
|
|
# "llama3.2:11b",
|
|
# "llama3.2:90b",
|
|
# ]:
|
|
# starting_map[f"ollama/{model_name}"] = {
|
|
# "max_tokens": 128000,
|
|
# "max_input_tokens": 128000,
|
|
# "max_output_tokens": 128000,
|
|
# }
|
|
|
|
return starting_map
|
|
|
|
|
|
def _strip_extra_provider_from_model_name(model_name: str) -> str:
|
|
return model_name.split("/")[1] if "/" in model_name else model_name
|
|
|
|
|
|
def _strip_colon_from_model_name(model_name: str) -> str:
|
|
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
|
|
|
|
|
def _find_model_obj(
|
|
model_map: dict, provider: str, model_names: list[str | None]
|
|
) -> dict | None:
|
|
# Filter out None values and deduplicate model names
|
|
filtered_model_names = [name for name in model_names if name]
|
|
|
|
# First try all model names with provider prefix
|
|
for model_name in filtered_model_names:
|
|
model_obj = model_map.get(f"{provider}/{model_name}")
|
|
if model_obj:
|
|
logger.debug(f"Using model object for {provider}/{model_name}")
|
|
return model_obj
|
|
|
|
# Then try all model names without provider prefix
|
|
for model_name in filtered_model_names:
|
|
model_obj = model_map.get(model_name)
|
|
if model_obj:
|
|
logger.debug(f"Using model object for {model_name}")
|
|
return model_obj
|
|
|
|
return None
|
|
|
|
|
|
def get_llm_max_tokens(
|
|
model_map: dict,
|
|
model_name: str,
|
|
model_provider: str,
|
|
) -> int:
|
|
"""Best effort attempt to get the max tokens for the LLM"""
|
|
if GEN_AI_MAX_TOKENS:
|
|
# This is an override, so always return this
|
|
logger.info(f"Using override GEN_AI_MAX_TOKENS: {GEN_AI_MAX_TOKENS}")
|
|
return GEN_AI_MAX_TOKENS
|
|
|
|
try:
|
|
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
|
|
model_name
|
|
)
|
|
model_obj = _find_model_obj(
|
|
model_map,
|
|
model_provider,
|
|
[
|
|
model_name,
|
|
# Remove leading extra provider. Usually for cases where user has a
|
|
# customer model proxy which appends another prefix
|
|
extra_provider_stripped_model_name,
|
|
# remove :XXXX from the end, if present. Needed for ollama.
|
|
_strip_colon_from_model_name(model_name),
|
|
_strip_colon_from_model_name(extra_provider_stripped_model_name),
|
|
],
|
|
)
|
|
if not model_obj:
|
|
raise RuntimeError(
|
|
f"No litellm entry found for {model_provider}/{model_name}"
|
|
)
|
|
|
|
if "max_input_tokens" in model_obj:
|
|
max_tokens = model_obj["max_input_tokens"]
|
|
logger.info(
|
|
f"Max tokens for {model_name}: {max_tokens} (from max_input_tokens)"
|
|
)
|
|
return max_tokens
|
|
|
|
if "max_tokens" in model_obj:
|
|
max_tokens = model_obj["max_tokens"]
|
|
logger.info(f"Max tokens for {model_name}: {max_tokens} (from max_tokens)")
|
|
return max_tokens
|
|
|
|
logger.error(f"No max tokens found for LLM: {model_name}")
|
|
raise RuntimeError("No max tokens found for LLM")
|
|
except Exception:
|
|
logger.exception(
|
|
f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS}."
|
|
)
|
|
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
|
|
|
|
|
def get_llm_max_output_tokens(
|
|
model_map: dict,
|
|
model_name: str,
|
|
model_provider: str,
|
|
) -> int:
|
|
"""Best effort attempt to get the max output tokens for the LLM"""
|
|
try:
|
|
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
|
if not model_obj:
|
|
model_obj = model_map[model_name]
|
|
logger.debug(f"Using model object for {model_name}")
|
|
else:
|
|
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
|
|
|
if "max_output_tokens" in model_obj:
|
|
max_output_tokens = model_obj["max_output_tokens"]
|
|
logger.info(f"Max output tokens for {model_name}: {max_output_tokens}")
|
|
return max_output_tokens
|
|
|
|
# Fallback to a fraction of max_tokens if max_output_tokens is not specified
|
|
if "max_tokens" in model_obj:
|
|
max_output_tokens = int(model_obj["max_tokens"] * 0.1)
|
|
logger.info(
|
|
f"Fallback max output tokens for {model_name}: {max_output_tokens} (10% of max_tokens)"
|
|
)
|
|
return max_output_tokens
|
|
|
|
logger.error(f"No max output tokens found for LLM: {model_name}")
|
|
raise RuntimeError("No max output tokens found for LLM")
|
|
except Exception:
|
|
default_output_tokens = int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS)
|
|
logger.exception(
|
|
f"Failed to get max output tokens for LLM with name {model_name}. "
|
|
f"Defaulting to {default_output_tokens} (fallback max tokens)."
|
|
)
|
|
return default_output_tokens
|
|
|
|
|
|
def get_max_input_tokens(
|
|
model_name: str,
|
|
model_provider: str,
|
|
output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
|
|
) -> int:
|
|
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
|
|
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,
|
|
# and there is no other interface to get what we want. This should be okay though, since the
|
|
# `model_cost` dict is a named public interface:
|
|
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
|
|
# model_map is litellm.model_cost
|
|
litellm_model_map = get_model_map()
|
|
|
|
input_toks = (
|
|
get_llm_max_tokens(
|
|
model_name=model_name,
|
|
model_provider=model_provider,
|
|
model_map=litellm_model_map,
|
|
)
|
|
- output_tokens
|
|
)
|
|
|
|
if input_toks <= 0:
|
|
raise RuntimeError("No tokens for input for the LLM given settings")
|
|
|
|
return input_toks
|