mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 20:24:32 +02:00
welcome to onyx
This commit is contained in:
0
backend/onyx/llm/__init__.py
Normal file
0
backend/onyx/llm/__init__.py
Normal file
514
backend/onyx/llm/chat_llm.py
Normal file
514
backend/onyx/llm/chat_llm.py
Normal file
@@ -0,0 +1,514 @@
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import litellm # type: ignore
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import BaseMessageChunk
|
||||
from langchain_core.messages import ChatMessage
|
||||
from langchain_core.messages import ChatMessageChunk
|
||||
from langchain_core.messages import FunctionMessage
|
||||
from langchain_core.messages import FunctionMessageChunk
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import HumanMessageChunk
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.messages import SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from onyx.configs.model_configs import (
|
||||
DISABLE_LITELLM_STREAMING,
|
||||
)
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.server.utils import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
|
||||
|
||||
def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
|
||||
return "user"
|
||||
if isinstance(msg, AIMessage) or isinstance(msg, AIMessageChunk):
|
||||
return "assistant"
|
||||
if isinstance(msg, SystemMessage) or isinstance(msg, SystemMessageChunk):
|
||||
return "system"
|
||||
if isinstance(msg, FunctionMessage) or isinstance(msg, FunctionMessageChunk):
|
||||
return "function"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _convert_litellm_message_to_langchain_message(
|
||||
litellm_message: litellm.Message,
|
||||
) -> BaseMessage:
|
||||
# Extracting the basic attributes from the litellm message
|
||||
content = litellm_message.content or ""
|
||||
role = litellm_message.role
|
||||
|
||||
# Handling function calls and tool calls if present
|
||||
tool_calls = (
|
||||
cast(
|
||||
list[litellm.ChatCompletionMessageToolCall],
|
||||
litellm_message.tool_calls,
|
||||
)
|
||||
if hasattr(litellm_message, "tool_calls")
|
||||
else []
|
||||
)
|
||||
|
||||
# Create the appropriate langchain message based on the role
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_call.function.name or "",
|
||||
"args": json.loads(tool_call.function.arguments),
|
||||
"id": tool_call.id,
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role type received: {role}")
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Adapted from langchain_community.chat_models.litellm._convert_message_to_dict"""
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
{
|
||||
"id": tool_call.get("id"),
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_call.get("index", 0),
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"role": "tool",
|
||||
"name": message.name or "",
|
||||
"content": message.content,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: dict[str, Any],
|
||||
curr_msg: BaseMessage | None,
|
||||
stop_reason: str | None = None,
|
||||
) -> BaseMessageChunk:
|
||||
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
|
||||
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None)
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs = {}
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs.update({"function_call": dict(_dict["function_call"])})
|
||||
tool_calls = cast(
|
||||
list[litellm.utils.ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls")
|
||||
)
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
# NOTE: if tool calls are present, then it's an assistant.
|
||||
# In Ollama, the role will be None for tool-calls
|
||||
elif role == "assistant" or tool_calls:
|
||||
if tool_calls:
|
||||
tool_call = tool_calls[0]
|
||||
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
|
||||
idx = tool_call.index
|
||||
|
||||
tool_call_chunk = ToolCallChunk(
|
||||
name=tool_name,
|
||||
id=tool_call.id,
|
||||
args=tool_call.function.arguments,
|
||||
index=idx,
|
||||
)
|
||||
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
tool_call_chunks=[tool_call_chunk],
|
||||
additional_kwargs={
|
||||
"usage_metadata": {"stop": stop_reason},
|
||||
**additional_kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs={
|
||||
"usage_metadata": {"stop": stop_reason},
|
||||
**additional_kwargs,
|
||||
},
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function":
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
|
||||
|
||||
def _prompt_to_dict(
|
||||
prompt: LanguageModelInput,
|
||||
) -> Sequence[str | list[str] | dict[str, Any] | tuple[str, str]]:
|
||||
# NOTE: this must go first, since it is also a Sequence
|
||||
if isinstance(prompt, str):
|
||||
return [_convert_message_to_dict(HumanMessage(content=prompt))]
|
||||
|
||||
if isinstance(prompt, (list, Sequence)):
|
||||
return [
|
||||
_convert_message_to_dict(msg) if isinstance(msg, BaseMessage) else msg
|
||||
for msg in prompt
|
||||
]
|
||||
|
||||
if isinstance(prompt, PromptValue):
|
||||
return [_convert_message_to_dict(message) for message in prompt.to_messages()]
|
||||
|
||||
|
||||
class DefaultMultiLLM(LLM):
|
||||
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
timeout: int,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
custom_llm_provider: str | None = None,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict | None = LITELLM_EXTRA_BODY,
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
):
|
||||
self._timeout = timeout
|
||||
self._model_provider = model_provider
|
||||
self._model_version = model_name
|
||||
self._temperature = temperature
|
||||
self._api_key = api_key
|
||||
self._deployment_name = deployment_name
|
||||
self._api_base = api_base
|
||||
self._api_version = api_version
|
||||
self._custom_llm_provider = custom_llm_provider
|
||||
self._long_term_logger = long_term_logger
|
||||
|
||||
# This can be used to store the maximum output tokens for this model.
|
||||
# self._max_output_tokens = (
|
||||
# max_output_tokens
|
||||
# if max_output_tokens is not None
|
||||
# else get_llm_max_output_tokens(
|
||||
# model_map=litellm.model_cost,
|
||||
# model_name=model_name,
|
||||
# model_provider=model_provider,
|
||||
# )
|
||||
# )
|
||||
self._custom_config = custom_config
|
||||
|
||||
# NOTE: have to set these as environment variables for Litellm since
|
||||
# not all are able to passed in but they always support them set as env
|
||||
# variables. We'll also try passing them in, since litellm just ignores
|
||||
# addtional kwargs (and some kwargs MUST be passed in rather than set as
|
||||
# env variables)
|
||||
if custom_config:
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
if custom_config:
|
||||
model_kwargs.update(custom_config)
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
model_kwargs.update({"extra_body": extra_body})
|
||||
|
||||
self._model_kwargs = model_kwargs
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Config: {self.config}")
|
||||
|
||||
def _safe_model_config(self) -> dict:
|
||||
dump = self.config.model_dump()
|
||||
dump["api_key"] = mask_string(dump.get("api_key", ""))
|
||||
return dump
|
||||
|
||||
def _record_call(self, prompt: LanguageModelInput) -> None:
|
||||
if self._long_term_logger:
|
||||
self._long_term_logger.record(
|
||||
{"prompt": _prompt_to_dict(prompt), "model": self._safe_model_config()},
|
||||
category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY,
|
||||
)
|
||||
|
||||
def _record_result(
|
||||
self, prompt: LanguageModelInput, model_output: BaseMessage
|
||||
) -> None:
|
||||
if self._long_term_logger:
|
||||
self._long_term_logger.record(
|
||||
{
|
||||
"prompt": _prompt_to_dict(prompt),
|
||||
"content": model_output.content,
|
||||
"tool_calls": (
|
||||
model_output.tool_calls
|
||||
if hasattr(model_output, "tool_calls")
|
||||
else []
|
||||
),
|
||||
"model": self._safe_model_config(),
|
||||
},
|
||||
category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY,
|
||||
)
|
||||
|
||||
def _record_error(self, prompt: LanguageModelInput, error: Exception) -> None:
|
||||
if self._long_term_logger:
|
||||
self._long_term_logger.record(
|
||||
{
|
||||
"prompt": _prompt_to_dict(prompt),
|
||||
"error": str(error),
|
||||
"traceback": "".join(
|
||||
traceback.format_exception(
|
||||
type(error), error, error.__traceback__
|
||||
)
|
||||
),
|
||||
"model": self._safe_model_config(),
|
||||
},
|
||||
category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY,
|
||||
)
|
||||
|
||||
# def _calculate_max_output_tokens(self, prompt: LanguageModelInput) -> int:
|
||||
# # NOTE: This method can be used for calculating the maximum tokens for the stream,
|
||||
# # but it isn't used in practice due to the computational cost of counting tokens
|
||||
# # and because LLM providers automatically cut off at the maximum output.
|
||||
# # The implementation is kept for potential future use or debugging purposes.
|
||||
|
||||
# # Get max input tokens for the model
|
||||
# max_context_tokens = get_max_input_tokens(
|
||||
# model_name=self.config.model_name, model_provider=self.config.model_provider
|
||||
# )
|
||||
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# model_name=self.config.model_name,
|
||||
# provider_type=self.config.model_provider,
|
||||
# )
|
||||
# # Calculate tokens in the input prompt
|
||||
# input_tokens = sum(len(llm_tokenizer.encode(str(m))) for m in prompt)
|
||||
|
||||
# # Calculate available tokens for output
|
||||
# available_output_tokens = max_context_tokens - input_tokens
|
||||
|
||||
# # Return the lesser of available tokens or configured max
|
||||
# return min(self._max_output_tokens, available_output_tokens)
|
||||
|
||||
def _completion(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None,
|
||||
tool_choice: ToolChoiceOptions | None,
|
||||
stream: bool,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
|
||||
# to a dict representation
|
||||
processed_prompt = _prompt_to_dict(prompt)
|
||||
self._record_call(processed_prompt)
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
# model choice
|
||||
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
|
||||
# NOTE: have to pass in None instead of empty string for these
|
||||
# otherwise litellm can have some issues with bedrock
|
||||
api_key=self._api_key or None,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
# actual input
|
||||
messages=processed_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice if tools else None,
|
||||
# streaming choice
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=self._temperature,
|
||||
timeout=self._timeout,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
**({"parallel_tool_calls": False} if tools else {}),
|
||||
**(
|
||||
{"response_format": structured_response_format}
|
||||
if structured_response_format
|
||||
else {}
|
||||
),
|
||||
**self._model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
self._record_error(processed_prompt, e)
|
||||
# for break pointing
|
||||
raise e
|
||||
|
||||
@property
|
||||
def config(self) -> LLMConfig:
|
||||
return LLMConfig(
|
||||
model_provider=self._model_provider,
|
||||
model_name=self._model_version,
|
||||
temperature=self._temperature,
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
api_version=self._api_version,
|
||||
deployment_name=self._deployment_name,
|
||||
)
|
||||
|
||||
def _invoke_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
response = cast(
|
||||
litellm.ModelResponse,
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, False, structured_response_format
|
||||
),
|
||||
)
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message"):
|
||||
output = _convert_litellm_message_to_langchain_message(choice.message)
|
||||
if output:
|
||||
self._record_result(prompt, output)
|
||||
return output
|
||||
else:
|
||||
raise ValueError("Unexpected response choice type")
|
||||
|
||||
def _stream_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
return
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, True, structured_response_format
|
||||
),
|
||||
)
|
||||
try:
|
||||
for part in response:
|
||||
if not part["choices"]:
|
||||
continue
|
||||
|
||||
choice = part["choices"][0]
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"],
|
||||
output,
|
||||
stop_reason=choice["finish_reason"],
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = message_chunk
|
||||
else:
|
||||
output += message_chunk
|
||||
|
||||
yield message_chunk
|
||||
|
||||
except RemoteProtocolError:
|
||||
raise RuntimeError(
|
||||
"The AI model failed partway through generation, please try again."
|
||||
)
|
||||
|
||||
if output:
|
||||
self._record_result(prompt, output)
|
||||
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS and output:
|
||||
content = output.content or ""
|
||||
if isinstance(output, AIMessage):
|
||||
if content:
|
||||
log_msg = content
|
||||
elif output.tool_calls:
|
||||
log_msg = "Tool Calls: " + str(
|
||||
[
|
||||
{
|
||||
key: value
|
||||
for key, value in tool_call.items()
|
||||
if key != "index"
|
||||
}
|
||||
for tool_call in output.tool_calls
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Raw Model Output:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Raw Model Output:\n{content}")
|
94
backend/onyx/llm/custom_llm.py
Normal file
94
backend/onyx/llm/custom_llm.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
|
||||
import requests
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from requests import Timeout
|
||||
|
||||
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.utils import convert_lm_input_to_basic_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CustomModelServer(LLM):
|
||||
"""This class is to provide an example for how to use Onyx
|
||||
with any LLM, even servers with custom API definitions.
|
||||
To use with your own model server, simply implement the functions
|
||||
below to fit your model server expectation
|
||||
|
||||
The implementation below works against the custom FastAPI server from the blog:
|
||||
https://medium.com/@yuhongsun96/how-to-augment-llms-with-private-data-29349bd8ae9f
|
||||
"""
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Not used here but you probably want a model server that isn't completely open
|
||||
api_key: str | None,
|
||||
timeout: int,
|
||||
endpoint: str,
|
||||
max_output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
|
||||
):
|
||||
if not endpoint:
|
||||
raise ValueError(
|
||||
"Cannot point Onyx to a custom LLM server without providing the "
|
||||
"endpoint for the model server."
|
||||
)
|
||||
|
||||
self._endpoint = endpoint
|
||||
self._max_output_tokens = max_output_tokens
|
||||
self._timeout = timeout
|
||||
|
||||
def _execute(self, input: LanguageModelInput) -> AIMessage:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs": convert_lm_input_to_basic_string(input),
|
||||
"parameters": {
|
||||
"temperature": 0.0,
|
||||
"max_tokens": self._max_output_tokens,
|
||||
},
|
||||
}
|
||||
try:
|
||||
response = requests.post(
|
||||
self._endpoint, headers=headers, json=data, timeout=self._timeout
|
||||
)
|
||||
except Timeout as error:
|
||||
raise Timeout(f"Model inference to {self._endpoint} timed out") from error
|
||||
|
||||
response.raise_for_status()
|
||||
response_content = json.loads(response.content).get("generated_text", "")
|
||||
return AIMessage(content=response_content)
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Custom model at: {self._endpoint}")
|
||||
|
||||
def _invoke_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
return self._execute(prompt)
|
||||
|
||||
def _stream_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
yield self._execute(prompt)
|
4
backend/onyx/llm/exceptions.py
Normal file
4
backend/onyx/llm/exceptions.py
Normal file
@@ -0,0 +1,4 @@
|
||||
class GenAIDisabledException(Exception):
|
||||
def __init__(self, message: str = "Generative AI has been turned off") -> None:
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
160
backend/onyx/llm/factory.py
Normal file
160
backend/onyx/llm/factory.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.chat_llm import DefaultMultiLLM
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
|
||||
"""Ollama requires us to specify the max context window.
|
||||
|
||||
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
|
||||
TODO: allow model-specific values to be configured via the UI.
|
||||
"""
|
||||
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
|
||||
|
||||
|
||||
def get_main_llm_from_tuple(
|
||||
llms: tuple[LLM, LLM],
|
||||
) -> LLM:
|
||||
return llms[0]
|
||||
|
||||
|
||||
def get_llms_for_persona(
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
if persona is None:
|
||||
logger.warning("No persona provided, using default LLMs")
|
||||
return get_default_llms()
|
||||
|
||||
model_provider_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
|
||||
provider_name = model_provider_override or persona.llm_model_provider_override
|
||||
if not provider_name:
|
||||
return get_default_llms(
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_provider(db_session, provider_name)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No LLM provider found")
|
||||
|
||||
model = model_version_override or persona.llm_model_version_override
|
||||
fast_model = llm_provider.fast_default_model_name or llm_provider.default_model_name
|
||||
if not model:
|
||||
raise ValueError("No model name found")
|
||||
if not fast_model:
|
||||
raise ValueError("No fast model name found")
|
||||
|
||||
def _create_llm(model: str) -> LLM:
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
temperature=temperature_override,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
return _create_llm(model), _create_llm(fast_model)
|
||||
|
||||
|
||||
def get_default_llms(
|
||||
timeout: int = QA_TIMEOUT,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No default LLM provider found")
|
||||
|
||||
model_name = llm_provider.default_model_name
|
||||
fast_model_name = (
|
||||
llm_provider.fast_default_model_name or llm_provider.default_model_name
|
||||
)
|
||||
if not model_name:
|
||||
raise ValueError("No default model name found")
|
||||
if not fast_model_name:
|
||||
raise ValueError("No fast default model name found")
|
||||
|
||||
def _create_llm(model: str) -> LLM:
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
return _create_llm(model_name), _create_llm(fast_model_name)
|
||||
|
||||
|
||||
def get_llm(
|
||||
provider: str,
|
||||
model: str,
|
||||
deployment_name: str | None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
if temperature is None:
|
||||
temperature = GEN_AI_TEMPERATURE
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
deployment_name=deployment_name,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
custom_config=custom_config,
|
||||
extra_headers=build_llm_extra_headers(additional_headers),
|
||||
model_kwargs=_build_extra_model_kwargs(provider),
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
142
backend/onyx/llm/interfaces.py
Normal file
142
backend/onyx/llm/interfaces.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Literal
|
||||
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import LOG_INDIVIDUAL_MODEL_TOKENS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ToolChoiceOptions = Literal["required"] | Literal["auto"] | Literal["none"]
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
model_provider: str
|
||||
model_name: str
|
||||
temperature: float
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
deployment_name: str | None = None
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
||||
def log_prompt(prompt: LanguageModelInput) -> None:
|
||||
if isinstance(prompt, list):
|
||||
for ind, msg in enumerate(prompt):
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
if msg.content:
|
||||
log_msg = msg.content
|
||||
elif msg.tool_call_chunks:
|
||||
log_msg = "Tool Calls: " + str(
|
||||
[
|
||||
{
|
||||
key: value
|
||||
for key, value in tool_call.items()
|
||||
if key != "index"
|
||||
}
|
||||
for tool_call in msg.tool_call_chunks
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
|
||||
|
||||
class LLM(abc.ABC):
|
||||
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
||||
to use these implementations to connect to a variety of LLM providers."""
|
||||
|
||||
@property
|
||||
def requires_warm_up(self) -> bool:
|
||||
"""Is this model running in memory and needs an initial call to warm it up?"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def config(self) -> LLMConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def log_model_configs(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _precall(self, prompt: LanguageModelInput) -> None:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise Exception("Generative AI is disabled")
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
log_prompt(prompt)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _invoke_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
def stream(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
messages = self._stream_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
|
||||
tokens = []
|
||||
for message in messages:
|
||||
if LOG_INDIVIDUAL_MODEL_TOKENS:
|
||||
tokens.append(message.content)
|
||||
yield message
|
||||
|
||||
if LOG_INDIVIDUAL_MODEL_TOKENS and tokens:
|
||||
logger.debug(f"Model Tokens: {tokens}")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _stream_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
146
backend/onyx/llm/llm_provider_options.py
Normal file
146
backend/onyx/llm/llm_provider_options.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import litellm # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CustomConfigKey(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
is_required: bool = True
|
||||
is_secret: bool = False
|
||||
|
||||
|
||||
class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
api_key_required: bool
|
||||
api_base_required: bool
|
||||
api_version_required: bool
|
||||
custom_config_keys: list[CustomConfigKey] | None = None
|
||||
llm_names: list[str]
|
||||
default_model: str | None = None
|
||||
default_fast_model: str | None = None
|
||||
# set for providers like Azure, which require a deployment name.
|
||||
deployment_name_required: bool = False
|
||||
# set for providers like Azure, which support a single model per deployment.
|
||||
single_model_supported: bool = False
|
||||
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"gpt-4",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-0613",
|
||||
"gpt-4o-2024-08-06",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
]
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
|
||||
# models
|
||||
BEDROCK_MODEL_NAMES = [
|
||||
model
|
||||
for model in litellm.bedrock_models
|
||||
if "/" not in model and "embed" not in model
|
||||
][::-1]
|
||||
|
||||
IGNORABLE_ANTHROPIC_MODELS = [
|
||||
"claude-2",
|
||||
"claude-instant-1",
|
||||
"anthropic/claude-3-5-sonnet-20241022",
|
||||
]
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
ANTHROPIC_MODEL_NAMES = [
|
||||
model
|
||||
for model in litellm.anthropic_models
|
||||
if model not in IGNORABLE_ANTHROPIC_MODELS
|
||||
][::-1]
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
_PROVIDER_TO_MODELS_MAP = {
|
||||
OPENAI_PROVIDER_NAME: OPEN_AI_MODEL_NAMES,
|
||||
BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES,
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES,
|
||||
}
|
||||
|
||||
|
||||
def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
return [
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name="openai",
|
||||
display_name="OpenAI",
|
||||
api_key_required=True,
|
||||
api_base_required=False,
|
||||
api_version_required=False,
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider(OPENAI_PROVIDER_NAME),
|
||||
default_model="gpt-4",
|
||||
default_fast_model="gpt-4o-mini",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=ANTHROPIC_PROVIDER_NAME,
|
||||
display_name="Anthropic",
|
||||
api_key_required=True,
|
||||
api_base_required=False,
|
||||
api_version_required=False,
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
|
||||
default_model="claude-3-5-sonnet-20241022",
|
||||
default_fast_model="claude-3-5-sonnet-20241022",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=AZURE_PROVIDER_NAME,
|
||||
display_name="Azure OpenAI",
|
||||
api_key_required=True,
|
||||
api_base_required=True,
|
||||
api_version_required=True,
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider(AZURE_PROVIDER_NAME),
|
||||
deployment_name_required=True,
|
||||
single_model_supported=True,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=BEDROCK_PROVIDER_NAME,
|
||||
display_name="AWS Bedrock",
|
||||
api_key_required=False,
|
||||
api_base_required=False,
|
||||
api_version_required=False,
|
||||
custom_config_keys=[
|
||||
CustomConfigKey(name="AWS_REGION_NAME"),
|
||||
CustomConfigKey(
|
||||
name="AWS_ACCESS_KEY_ID",
|
||||
is_required=False,
|
||||
description="If using AWS IAM roles, AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY can be left blank.",
|
||||
),
|
||||
CustomConfigKey(
|
||||
name="AWS_SECRET_ACCESS_KEY",
|
||||
is_required=False,
|
||||
is_secret=True,
|
||||
description="If using AWS IAM roles, AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY can be left blank.",
|
||||
),
|
||||
],
|
||||
llm_names=fetch_models_for_provider(BEDROCK_PROVIDER_NAME),
|
||||
default_model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
default_fast_model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def fetch_models_for_provider(provider_name: str) -> list[str]:
|
||||
return _PROVIDER_TO_MODELS_MAP.get(provider_name, [])
|
59
backend/onyx/llm/models.py
Normal file
59
backend/onyx/llm/models.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.db.models import ChatMessage
|
||||
|
||||
|
||||
class PreviousMessage(BaseModel):
|
||||
"""Simplified version of `ChatMessage`"""
|
||||
|
||||
message: str
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
|
||||
) -> "PreviousMessage":
|
||||
message_file_ids = (
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
)
|
||||
return cls(
|
||||
message=chat_message.message,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
files=[
|
||||
file
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
content = build_content_with_imgs(self.message, self.files)
|
||||
if self.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
elif self.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
else:
|
||||
return SystemMessage(content=content)
|
20
backend/onyx/llm/override_models.py
Normal file
20
backend/onyx/llm/override_models.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Overrides sent over the wire / stored in the DB
|
||||
|
||||
NOTE: these models are used in many places, so have to be
|
||||
kepy in a separate file to avoid circular imports.
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
||||
class PromptOverride(BaseModel):
|
||||
system_prompt: str | None = None
|
||||
task_prompt: str | None = None
|
515
backend/onyx/llm/utils.py
Normal file
515
backend/onyx/llm/utils.py
Normal file
@@ -0,0 +1,515 @@
|
||||
import copy
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import litellm # type: ignore
|
||||
import tiktoken
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.schema import PromptValue
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from litellm.exceptions import APIConnectionError # type: ignore
|
||||
from litellm.exceptions import APIError # type: ignore
|
||||
from litellm.exceptions import AuthenticationError # type: ignore
|
||||
from litellm.exceptions import BadRequestError # type: ignore
|
||||
from litellm.exceptions import BudgetExceededError # type: ignore
|
||||
from litellm.exceptions import ContentPolicyViolationError # type: ignore
|
||||
from litellm.exceptions import ContextWindowExceededError # type: ignore
|
||||
from litellm.exceptions import NotFoundError # type: ignore
|
||||
from litellm.exceptions import PermissionDeniedError # type: ignore
|
||||
from litellm.exceptions import RateLimitError # type: ignore
|
||||
from litellm.exceptions import Timeout # type: ignore
|
||||
from litellm.exceptions import UnprocessableEntityError # type: ignore
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.constants import CODE_BLOCK_PAT
|
||||
from onyx.utils.b64 import get_image_type
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def litellm_exception_to_error_msg(
|
||||
e: Exception, llm: LLM, fallback_to_error_msg: bool = False
|
||||
) -> str:
|
||||
error_msg = str(e)
|
||||
|
||||
if isinstance(e, BadRequestError):
|
||||
error_msg = "Bad request: The server couldn't process your request. Please check your input."
|
||||
elif isinstance(e, AuthenticationError):
|
||||
error_msg = "Authentication failed: Please check your API key and credentials."
|
||||
elif isinstance(e, PermissionDeniedError):
|
||||
error_msg = (
|
||||
"Permission denied: You don't have the necessary permissions for this operation."
|
||||
"Ensure you have access to this model."
|
||||
)
|
||||
elif isinstance(e, NotFoundError):
|
||||
error_msg = "Resource not found: The requested resource doesn't exist."
|
||||
elif isinstance(e, UnprocessableEntityError):
|
||||
error_msg = "Unprocessable entity: The server couldn't process your request due to semantic errors."
|
||||
elif isinstance(e, RateLimitError):
|
||||
error_msg = (
|
||||
"Rate limit exceeded: Please slow down your requests and try again later."
|
||||
)
|
||||
elif isinstance(e, ContextWindowExceededError):
|
||||
error_msg = (
|
||||
"Context window exceeded: Your input is too long for the model to process."
|
||||
)
|
||||
if llm is not None:
|
||||
try:
|
||||
max_context = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
error_msg += f"Your invoked model ({llm.config.model_name}) has a maximum context size of {max_context}"
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Unable to get maximum input token for LiteLLM excpetion handling"
|
||||
)
|
||||
elif isinstance(e, ContentPolicyViolationError):
|
||||
error_msg = "Content policy violation: Your request violates the content policy. Please revise your input."
|
||||
elif isinstance(e, APIConnectionError):
|
||||
error_msg = "API connection error: Failed to connect to the API. Please check your internet connection."
|
||||
elif isinstance(e, BudgetExceededError):
|
||||
error_msg = (
|
||||
"Budget exceeded: You've exceeded your allocated budget for API usage."
|
||||
)
|
||||
elif isinstance(e, Timeout):
|
||||
error_msg = "Request timed out: The operation took too long to complete. Please try again."
|
||||
elif isinstance(e, APIError):
|
||||
error_msg = f"API error: An error occurred while communicating with the API. Details: {str(e)}"
|
||||
elif not fallback_to_error_msg:
|
||||
error_msg = "An unexpected error occurred while processing your request. Please try again later."
|
||||
return error_msg
|
||||
|
||||
|
||||
def _build_content(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
) -> str:
|
||||
"""Applies all non-image files."""
|
||||
if not files:
|
||||
return message
|
||||
|
||||
text_files = [
|
||||
file
|
||||
for file in files
|
||||
if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV)
|
||||
]
|
||||
|
||||
if not text_files:
|
||||
return message
|
||||
|
||||
final_message_with_files = "FILES:\n\n"
|
||||
for file in text_files:
|
||||
file_content = file.content.decode("utf-8")
|
||||
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
|
||||
final_message_with_files += (
|
||||
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
|
||||
)
|
||||
|
||||
return final_message_with_files + message
|
||||
|
||||
|
||||
def build_content_with_imgs(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
message_type: MessageType = MessageType.USER,
|
||||
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
||||
files = files or []
|
||||
|
||||
# Only include image files for user messages
|
||||
img_files = (
|
||||
[file for file in files if file.file_type == ChatFileType.IMAGE]
|
||||
if message_type == MessageType.USER
|
||||
else []
|
||||
)
|
||||
|
||||
img_urls = img_urls or []
|
||||
b64_imgs = b64_imgs or []
|
||||
|
||||
message_main_content = _build_content(message, files)
|
||||
|
||||
if not img_files and not img_urls:
|
||||
return message_main_content
|
||||
|
||||
return cast(
|
||||
list[str | dict[str, Any]],
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": message_main_content,
|
||||
},
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": (
|
||||
f"data:{get_image_type_from_bytes(file.content)};"
|
||||
f"base64,{file.to_base64()}"
|
||||
),
|
||||
},
|
||||
}
|
||||
for file in img_files
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
|
||||
},
|
||||
}
|
||||
for b64_img in b64_imgs
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
for url in img_urls
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def message_to_prompt_and_imgs(message: BaseMessage) -> tuple[str, list[str]]:
|
||||
if isinstance(message.content, str):
|
||||
return message.content, []
|
||||
|
||||
imgs = []
|
||||
texts = []
|
||||
for part in message.content:
|
||||
if isinstance(part, dict):
|
||||
if part.get("type") == "image_url":
|
||||
img_url = part.get("image_url", {}).get("url")
|
||||
if img_url:
|
||||
imgs.append(img_url)
|
||||
elif part.get("type") == "text":
|
||||
text = part.get("text")
|
||||
if text:
|
||||
texts.append(text)
|
||||
else:
|
||||
texts.append(part)
|
||||
|
||||
return "".join(texts), imgs
|
||||
|
||||
|
||||
def dict_based_prompt_to_langchain_prompt(
|
||||
messages: list[dict[str, str]]
|
||||
) -> list[BaseMessage]:
|
||||
prompt: list[BaseMessage] = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if not role:
|
||||
raise ValueError(f"Message missing `role`: {message}")
|
||||
if not content:
|
||||
raise ValueError(f"Message missing `content`: {message}")
|
||||
elif role == "user":
|
||||
prompt.append(HumanMessage(content=content))
|
||||
elif role == "system":
|
||||
prompt.append(SystemMessage(content=content))
|
||||
elif role == "assistant":
|
||||
prompt.append(AIMessage(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
return prompt
|
||||
|
||||
|
||||
def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
|
||||
return [HumanMessage(content=message)]
|
||||
|
||||
|
||||
def convert_lm_input_to_basic_string(lm_input: LanguageModelInput) -> str:
|
||||
"""Heavily inspired by:
|
||||
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
|
||||
"""
|
||||
prompt_value = None
|
||||
if isinstance(lm_input, PromptValue):
|
||||
prompt_value = lm_input
|
||||
elif isinstance(lm_input, str):
|
||||
prompt_value = StringPromptValue(text=lm_input)
|
||||
elif isinstance(lm_input, list):
|
||||
prompt_value = ChatPromptValue(messages=lm_input)
|
||||
|
||||
if prompt_value is None:
|
||||
raise ValueError(
|
||||
f"Invalid input type {type(lm_input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
|
||||
return prompt_value.to_string()
|
||||
|
||||
|
||||
def message_to_string(message: BaseMessage) -> str:
|
||||
if not isinstance(message.content, str):
|
||||
raise RuntimeError("LLM message not in expected format.")
|
||||
|
||||
return message.content
|
||||
|
||||
|
||||
def message_generator_to_string_generator(
|
||||
messages: Iterator[BaseMessage],
|
||||
) -> Iterator[str]:
|
||||
for message in messages:
|
||||
yield message_to_string(message)
|
||||
|
||||
|
||||
def should_be_verbose() -> bool:
|
||||
return LOG_LEVEL == "debug"
|
||||
|
||||
|
||||
# estimate of the number of tokens in an image url
|
||||
# is correct when downsampling is used. Is very wrong when OpenAI does not downsample
|
||||
# TODO: improve this
|
||||
_IMG_TOKENS = 85
|
||||
|
||||
|
||||
def check_message_tokens(
|
||||
message: BaseMessage, encode_fn: Callable[[str], list] | None = None
|
||||
) -> int:
|
||||
if isinstance(message.content, str):
|
||||
return check_number_of_tokens(message.content, encode_fn)
|
||||
|
||||
total_tokens = 0
|
||||
for part in message.content:
|
||||
if isinstance(part, str):
|
||||
total_tokens += check_number_of_tokens(part, encode_fn)
|
||||
continue
|
||||
|
||||
if part["type"] == "text":
|
||||
total_tokens += check_number_of_tokens(part["text"], encode_fn)
|
||||
elif part["type"] == "image_url":
|
||||
total_tokens += _IMG_TOKENS
|
||||
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
total_tokens += check_number_of_tokens(
|
||||
json.dumps(tool_call["args"]), encode_fn
|
||||
)
|
||||
total_tokens += check_number_of_tokens(tool_call["name"], encode_fn)
|
||||
|
||||
return total_tokens
|
||||
|
||||
|
||||
def check_number_of_tokens(
|
||||
text: str, encode_fn: Callable[[str], list] | None = None
|
||||
) -> int:
|
||||
"""Gets the number of tokens in the provided text, using the provided encoding
|
||||
function. If none is provided, default to the tiktoken encoder used by GPT-3.5
|
||||
and GPT-4.
|
||||
"""
|
||||
|
||||
if encode_fn is None:
|
||||
encode_fn = tiktoken.get_encoding("cl100k_base").encode
|
||||
|
||||
return len(encode_fn(text))
|
||||
|
||||
|
||||
def test_llm(llm: LLM) -> str | None:
|
||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||
error_msg = None
|
||||
for _ in range(2):
|
||||
try:
|
||||
llm.invoke("Do not respond")
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.warning(f"Failed to call LLM with the following error: {error_msg}")
|
||||
|
||||
return error_msg
|
||||
|
||||
|
||||
def get_model_map() -> dict:
|
||||
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
|
||||
|
||||
# NOTE: we could add additional models here in the future,
|
||||
# but for now there is no point. Ollama allows the user to
|
||||
# to specify their desired max context window, and it's
|
||||
# unlikely to be standard across users even for the same model
|
||||
# (it heavily depends on their hardware). For now, we'll just
|
||||
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
|
||||
# for model_name in [
|
||||
# "llama3.2",
|
||||
# "llama3.2:1b",
|
||||
# "llama3.2:3b",
|
||||
# "llama3.2:11b",
|
||||
# "llama3.2:90b",
|
||||
# ]:
|
||||
# starting_map[f"ollama/{model_name}"] = {
|
||||
# "max_tokens": 128000,
|
||||
# "max_input_tokens": 128000,
|
||||
# "max_output_tokens": 128000,
|
||||
# }
|
||||
|
||||
return starting_map
|
||||
|
||||
|
||||
def _strip_extra_provider_from_model_name(model_name: str) -> str:
|
||||
return model_name.split("/")[1] if "/" in model_name else model_name
|
||||
|
||||
|
||||
def _strip_colon_from_model_name(model_name: str) -> str:
|
||||
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
||||
|
||||
|
||||
def _find_model_obj(
|
||||
model_map: dict, provider: str, model_names: list[str | None]
|
||||
) -> dict | None:
|
||||
# Filter out None values and deduplicate model names
|
||||
filtered_model_names = [name for name in model_names if name]
|
||||
|
||||
# First try all model names with provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(f"{provider}/{model_name}")
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {provider}/{model_name}")
|
||||
return model_obj
|
||||
|
||||
# Then try all model names without provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
return model_obj
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_max_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
model_provider: str,
|
||||
) -> int:
|
||||
"""Best effort attempt to get the max tokens for the LLM"""
|
||||
if GEN_AI_MAX_TOKENS:
|
||||
# This is an override, so always return this
|
||||
logger.info(f"Using override GEN_AI_MAX_TOKENS: {GEN_AI_MAX_TOKENS}")
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
try:
|
||||
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
|
||||
model_name
|
||||
)
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
[
|
||||
model_name,
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
extra_provider_stripped_model_name,
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(extra_provider_stripped_model_name),
|
||||
],
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
)
|
||||
|
||||
if "max_input_tokens" in model_obj:
|
||||
max_tokens = model_obj["max_input_tokens"]
|
||||
logger.info(
|
||||
f"Max tokens for {model_name}: {max_tokens} (from max_input_tokens)"
|
||||
)
|
||||
return max_tokens
|
||||
|
||||
if "max_tokens" in model_obj:
|
||||
max_tokens = model_obj["max_tokens"]
|
||||
logger.info(f"Max tokens for {model_name}: {max_tokens} (from max_tokens)")
|
||||
return max_tokens
|
||||
|
||||
logger.error(f"No max tokens found for LLM: {model_name}")
|
||||
raise RuntimeError("No max tokens found for LLM")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS}."
|
||||
)
|
||||
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
|
||||
|
||||
def get_llm_max_output_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
model_provider: str,
|
||||
) -> int:
|
||||
"""Best effort attempt to get the max output tokens for the LLM"""
|
||||
try:
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if not model_obj:
|
||||
model_obj = model_map[model_name]
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
else:
|
||||
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||
|
||||
if "max_output_tokens" in model_obj:
|
||||
max_output_tokens = model_obj["max_output_tokens"]
|
||||
logger.info(f"Max output tokens for {model_name}: {max_output_tokens}")
|
||||
return max_output_tokens
|
||||
|
||||
# Fallback to a fraction of max_tokens if max_output_tokens is not specified
|
||||
if "max_tokens" in model_obj:
|
||||
max_output_tokens = int(model_obj["max_tokens"] * 0.1)
|
||||
logger.info(
|
||||
f"Fallback max output tokens for {model_name}: {max_output_tokens} (10% of max_tokens)"
|
||||
)
|
||||
return max_output_tokens
|
||||
|
||||
logger.error(f"No max output tokens found for LLM: {model_name}")
|
||||
raise RuntimeError("No max output tokens found for LLM")
|
||||
except Exception:
|
||||
default_output_tokens = int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS)
|
||||
logger.exception(
|
||||
f"Failed to get max output tokens for LLM with name {model_name}. "
|
||||
f"Defaulting to {default_output_tokens} (fallback max tokens)."
|
||||
)
|
||||
return default_output_tokens
|
||||
|
||||
|
||||
def get_max_input_tokens(
|
||||
model_name: str,
|
||||
model_provider: str,
|
||||
output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
|
||||
) -> int:
|
||||
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
|
||||
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,
|
||||
# and there is no other interface to get what we want. This should be okay though, since the
|
||||
# `model_cost` dict is a named public interface:
|
||||
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
|
||||
# model_map is litellm.model_cost
|
||||
litellm_model_map = get_model_map()
|
||||
|
||||
input_toks = (
|
||||
get_llm_max_tokens(
|
||||
model_name=model_name,
|
||||
model_provider=model_provider,
|
||||
model_map=litellm_model_map,
|
||||
)
|
||||
- output_tokens
|
||||
)
|
||||
|
||||
if input_toks <= 0:
|
||||
raise RuntimeError("No tokens for input for the LLM given settings")
|
||||
|
||||
return input_toks
|
Reference in New Issue
Block a user