welcome to onyx

This commit is contained in:
pablodanswer
2024-12-13 09:48:43 -08:00
parent 54dcbfa288
commit 21ec5ed795
813 changed files with 7021 additions and 6824 deletions

View File

View File

@@ -0,0 +1,514 @@
import json
import os
import traceback
from collections.abc import Iterator
from collections.abc import Sequence
from typing import Any
from typing import cast
import litellm # type: ignore
from httpx import RemoteProtocolError
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langchain_core.messages import BaseMessageChunk
from langchain_core.messages import ChatMessage
from langchain_core.messages import ChatMessageChunk
from langchain_core.messages import FunctionMessage
from langchain_core.messages import FunctionMessageChunk
from langchain_core.messages import HumanMessage
from langchain_core.messages import HumanMessageChunk
from langchain_core.messages import SystemMessage
from langchain_core.messages import SystemMessageChunk
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from onyx.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.configs.model_configs import LITELLM_EXTRA_BODY
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.server.utils import mask_string
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
logger = setup_logger()
# If a user configures a different model and it doesn't support all the same
# parameters like frequency and presence, just ignore them
litellm.drop_params = True
litellm.telemetry = False
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
def _base_msg_to_role(msg: BaseMessage) -> str:
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
return "user"
if isinstance(msg, AIMessage) or isinstance(msg, AIMessageChunk):
return "assistant"
if isinstance(msg, SystemMessage) or isinstance(msg, SystemMessageChunk):
return "system"
if isinstance(msg, FunctionMessage) or isinstance(msg, FunctionMessageChunk):
return "function"
return "unknown"
def _convert_litellm_message_to_langchain_message(
litellm_message: litellm.Message,
) -> BaseMessage:
# Extracting the basic attributes from the litellm message
content = litellm_message.content or ""
role = litellm_message.role
# Handling function calls and tool calls if present
tool_calls = (
cast(
list[litellm.ChatCompletionMessageToolCall],
litellm_message.tool_calls,
)
if hasattr(litellm_message, "tool_calls")
else []
)
# Create the appropriate langchain message based on the role
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(
content=content,
tool_calls=[
{
"name": tool_call.function.name or "",
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in tool_calls
]
if tool_calls
else [],
)
elif role == "system":
return SystemMessage(content=content)
else:
raise ValueError(f"Unknown role type received: {role}")
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Adapted from langchain_community.chat_models.litellm._convert_message_to_dict"""
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [
{
"id": tool_call.get("id"),
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
"type": "function",
"index": tool_call.get("index", 0),
}
for tool_call in message.tool_calls
]
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
elif isinstance(message, ToolMessage):
message_dict = {
"tool_call_id": message.tool_call_id,
"role": "tool",
"name": message.name or "",
"content": message.content,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def _convert_delta_to_message_chunk(
_dict: dict[str, Any],
curr_msg: BaseMessage | None,
stop_reason: str | None = None,
) -> BaseMessageChunk:
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None)
content = _dict.get("content") or ""
additional_kwargs = {}
if _dict.get("function_call"):
additional_kwargs.update({"function_call": dict(_dict["function_call"])})
tool_calls = cast(
list[litellm.utils.ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls")
)
if role == "user":
return HumanMessageChunk(content=content)
# NOTE: if tool calls are present, then it's an assistant.
# In Ollama, the role will be None for tool-calls
elif role == "assistant" or tool_calls:
if tool_calls:
tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
idx = tool_call.index
tool_call_chunk = ToolCallChunk(
name=tool_name,
id=tool_call.id,
args=tool_call.function.arguments,
index=idx,
)
return AIMessageChunk(
content=content,
tool_call_chunks=[tool_call_chunk],
additional_kwargs={
"usage_metadata": {"stop": stop_reason},
**additional_kwargs,
},
)
return AIMessageChunk(
content=content,
additional_kwargs={
"usage_metadata": {"stop": stop_reason},
**additional_kwargs,
},
)
elif role == "system":
return SystemMessageChunk(content=content)
elif role == "function":
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role:
return ChatMessageChunk(content=content, role=role)
raise ValueError(f"Unknown role: {role}")
def _prompt_to_dict(
prompt: LanguageModelInput,
) -> Sequence[str | list[str] | dict[str, Any] | tuple[str, str]]:
# NOTE: this must go first, since it is also a Sequence
if isinstance(prompt, str):
return [_convert_message_to_dict(HumanMessage(content=prompt))]
if isinstance(prompt, (list, Sequence)):
return [
_convert_message_to_dict(msg) if isinstance(msg, BaseMessage) else msg
for msg in prompt
]
if isinstance(prompt, PromptValue):
return [_convert_message_to_dict(message) for message in prompt.to_messages()]
class DefaultMultiLLM(LLM):
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
See https://python.langchain.com/docs/integrations/chat/litellm"""
def __init__(
self,
api_key: str | None,
timeout: int,
model_provider: str,
model_name: str,
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
max_output_tokens: int | None = None,
custom_llm_provider: str | None = None,
temperature: float = GEN_AI_TEMPERATURE,
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
model_kwargs: dict[str, Any] | None = None,
long_term_logger: LongTermLogger | None = None,
):
self._timeout = timeout
self._model_provider = model_provider
self._model_version = model_name
self._temperature = temperature
self._api_key = api_key
self._deployment_name = deployment_name
self._api_base = api_base
self._api_version = api_version
self._custom_llm_provider = custom_llm_provider
self._long_term_logger = long_term_logger
# This can be used to store the maximum output tokens for this model.
# self._max_output_tokens = (
# max_output_tokens
# if max_output_tokens is not None
# else get_llm_max_output_tokens(
# model_map=litellm.model_cost,
# model_name=model_name,
# model_provider=model_provider,
# )
# )
self._custom_config = custom_config
# NOTE: have to set these as environment variables for Litellm since
# not all are able to passed in but they always support them set as env
# variables. We'll also try passing them in, since litellm just ignores
# addtional kwargs (and some kwargs MUST be passed in rather than set as
# env variables)
if custom_config:
for k, v in custom_config.items():
os.environ[k] = v
model_kwargs = model_kwargs or {}
if custom_config:
model_kwargs.update(custom_config)
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:
model_kwargs.update({"extra_body": extra_body})
self._model_kwargs = model_kwargs
def log_model_configs(self) -> None:
logger.debug(f"Config: {self.config}")
def _safe_model_config(self) -> dict:
dump = self.config.model_dump()
dump["api_key"] = mask_string(dump.get("api_key", ""))
return dump
def _record_call(self, prompt: LanguageModelInput) -> None:
if self._long_term_logger:
self._long_term_logger.record(
{"prompt": _prompt_to_dict(prompt), "model": self._safe_model_config()},
category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY,
)
def _record_result(
self, prompt: LanguageModelInput, model_output: BaseMessage
) -> None:
if self._long_term_logger:
self._long_term_logger.record(
{
"prompt": _prompt_to_dict(prompt),
"content": model_output.content,
"tool_calls": (
model_output.tool_calls
if hasattr(model_output, "tool_calls")
else []
),
"model": self._safe_model_config(),
},
category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY,
)
def _record_error(self, prompt: LanguageModelInput, error: Exception) -> None:
if self._long_term_logger:
self._long_term_logger.record(
{
"prompt": _prompt_to_dict(prompt),
"error": str(error),
"traceback": "".join(
traceback.format_exception(
type(error), error, error.__traceback__
)
),
"model": self._safe_model_config(),
},
category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY,
)
# def _calculate_max_output_tokens(self, prompt: LanguageModelInput) -> int:
# # NOTE: This method can be used for calculating the maximum tokens for the stream,
# # but it isn't used in practice due to the computational cost of counting tokens
# # and because LLM providers automatically cut off at the maximum output.
# # The implementation is kept for potential future use or debugging purposes.
# # Get max input tokens for the model
# max_context_tokens = get_max_input_tokens(
# model_name=self.config.model_name, model_provider=self.config.model_provider
# )
# llm_tokenizer = get_tokenizer(
# model_name=self.config.model_name,
# provider_type=self.config.model_provider,
# )
# # Calculate tokens in the input prompt
# input_tokens = sum(len(llm_tokenizer.encode(str(m))) for m in prompt)
# # Calculate available tokens for output
# available_output_tokens = max_context_tokens - input_tokens
# # Return the lesser of available tokens or configured max
# return min(self._max_output_tokens, available_output_tokens)
def _completion(
self,
prompt: LanguageModelInput,
tools: list[dict] | None,
tool_choice: ToolChoiceOptions | None,
stream: bool,
structured_response_format: dict | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
# to a dict representation
processed_prompt = _prompt_to_dict(prompt)
self._record_call(processed_prompt)
try:
return litellm.completion(
# model choice
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
# NOTE: have to pass in None instead of empty string for these
# otherwise litellm can have some issues with bedrock
api_key=self._api_key or None,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
# actual input
messages=processed_prompt,
tools=tools,
tool_choice=tool_choice if tools else None,
# streaming choice
stream=stream,
# model params
temperature=self._temperature,
timeout=self._timeout,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
**({"parallel_tool_calls": False} if tools else {}),
**(
{"response_format": structured_response_format}
if structured_response_format
else {}
),
**self._model_kwargs,
)
except Exception as e:
self._record_error(processed_prompt, e)
# for break pointing
raise e
@property
def config(self) -> LLMConfig:
return LLMConfig(
model_provider=self._model_provider,
model_name=self._model_version,
temperature=self._temperature,
api_key=self._api_key,
api_base=self._api_base,
api_version=self._api_version,
deployment_name=self._deployment_name,
)
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
response = cast(
litellm.ModelResponse,
self._completion(
prompt, tools, tool_choice, False, structured_response_format
),
)
choice = response.choices[0]
if hasattr(choice, "message"):
output = _convert_litellm_message_to_langchain_message(choice.message)
if output:
self._record_result(prompt, output)
return output
else:
raise ValueError("Unexpected response choice type")
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
return
output = None
response = cast(
litellm.CustomStreamWrapper,
self._completion(
prompt, tools, tool_choice, True, structured_response_format
),
)
try:
for part in response:
if not part["choices"]:
continue
choice = part["choices"][0]
message_chunk = _convert_delta_to_message_chunk(
choice["delta"],
output,
stop_reason=choice["finish_reason"],
)
if output is None:
output = message_chunk
else:
output += message_chunk
yield message_chunk
except RemoteProtocolError:
raise RuntimeError(
"The AI model failed partway through generation, please try again."
)
if output:
self._record_result(prompt, output)
if LOG_DANSWER_MODEL_INTERACTIONS and output:
content = output.content or ""
if isinstance(output, AIMessage):
if content:
log_msg = content
elif output.tool_calls:
log_msg = "Tool Calls: " + str(
[
{
key: value
for key, value in tool_call.items()
if key != "index"
}
for tool_call in output.tool_calls
]
)
else:
log_msg = ""
logger.debug(f"Raw Model Output:\n{log_msg}")
else:
logger.debug(f"Raw Model Output:\n{content}")

View File

@@ -0,0 +1,94 @@
import json
from collections.abc import Iterator
import requests
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from requests import Timeout
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.utils import convert_lm_input_to_basic_string
from onyx.utils.logger import setup_logger
logger = setup_logger()
class CustomModelServer(LLM):
"""This class is to provide an example for how to use Onyx
with any LLM, even servers with custom API definitions.
To use with your own model server, simply implement the functions
below to fit your model server expectation
The implementation below works against the custom FastAPI server from the blog:
https://medium.com/@yuhongsun96/how-to-augment-llms-with-private-data-29349bd8ae9f
"""
@property
def requires_api_key(self) -> bool:
return False
def __init__(
self,
# Not used here but you probably want a model server that isn't completely open
api_key: str | None,
timeout: int,
endpoint: str,
max_output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
):
if not endpoint:
raise ValueError(
"Cannot point Onyx to a custom LLM server without providing the "
"endpoint for the model server."
)
self._endpoint = endpoint
self._max_output_tokens = max_output_tokens
self._timeout = timeout
def _execute(self, input: LanguageModelInput) -> AIMessage:
headers = {
"Content-Type": "application/json",
}
data = {
"inputs": convert_lm_input_to_basic_string(input),
"parameters": {
"temperature": 0.0,
"max_tokens": self._max_output_tokens,
},
}
try:
response = requests.post(
self._endpoint, headers=headers, json=data, timeout=self._timeout
)
except Timeout as error:
raise Timeout(f"Model inference to {self._endpoint} timed out") from error
response.raise_for_status()
response_content = json.loads(response.content).get("generated_text", "")
return AIMessage(content=response_content)
def log_model_configs(self) -> None:
logger.debug(f"Custom model at: {self._endpoint}")
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
return self._execute(prompt)
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)

View File

@@ -0,0 +1,4 @@
class GenAIDisabledException(Exception):
def __init__(self, message: str = "Generative AI has been turned off") -> None:
self.message = message
super().__init__(self.message)

160
backend/onyx/llm/factory.py Normal file
View File

@@ -0,0 +1,160 @@
from typing import Any
from onyx.chat.models import PersonaOverrideConfig
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.chat_configs import QA_TIMEOUT
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine import get_session_context_manager
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_provider
from onyx.db.models import Persona
from onyx.llm.chat_llm import DefaultMultiLLM
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
logger = setup_logger()
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
"""Ollama requires us to specify the max context window.
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
TODO: allow model-specific values to be configured via the UI.
"""
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
def get_main_llm_from_tuple(
llms: tuple[LLM, LLM],
) -> LLM:
return llms[0]
def get_llms_for_persona(
persona: Persona | PersonaOverrideConfig | None,
llm_override: LLMOverride | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> tuple[LLM, LLM]:
if persona is None:
logger.warning("No persona provided, using default LLMs")
return get_default_llms()
model_provider_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
temperature_override = llm_override.temperature if llm_override else None
provider_name = model_provider_override or persona.llm_model_provider_override
if not provider_name:
return get_default_llms(
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
with get_session_context_manager() as db_session:
llm_provider = fetch_provider(db_session, provider_name)
if not llm_provider:
raise ValueError("No LLM provider found")
model = model_version_override or persona.llm_model_version_override
fast_model = llm_provider.fast_default_model_name or llm_provider.default_model_name
if not model:
raise ValueError("No model name found")
if not fast_model:
raise ValueError("No fast model name found")
def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
temperature=temperature_override,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
return _create_llm(model), _create_llm(fast_model)
def get_default_llms(
timeout: int = QA_TIMEOUT,
temperature: float = GEN_AI_TEMPERATURE,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> tuple[LLM, LLM]:
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
with get_session_context_manager() as db_session:
llm_provider = fetch_default_provider(db_session)
if not llm_provider:
raise ValueError("No default LLM provider found")
model_name = llm_provider.default_model_name
fast_model_name = (
llm_provider.fast_default_model_name or llm_provider.default_model_name
)
if not model_name:
raise ValueError("No default model name found")
if not fast_model_name:
raise ValueError("No fast default model name found")
def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
return _create_llm(model_name), _create_llm(fast_model_name)
def get_llm(
provider: str,
model: str,
deployment_name: str | None,
api_key: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
custom_config: dict[str, str] | None = None,
temperature: float | None = None,
timeout: int = QA_TIMEOUT,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:
if temperature is None:
temperature = GEN_AI_TEMPERATURE
return DefaultMultiLLM(
model_provider=provider,
model_name=model,
deployment_name=deployment_name,
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
temperature=temperature,
custom_config=custom_config,
extra_headers=build_llm_extra_headers(additional_headers),
model_kwargs=_build_extra_model_kwargs(provider),
long_term_logger=long_term_logger,
)

View File

@@ -0,0 +1,142 @@
import abc
from collections.abc import Iterator
from typing import Literal
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from onyx.configs.app_configs import LOG_INDIVIDUAL_MODEL_TOKENS
from onyx.utils.logger import setup_logger
logger = setup_logger()
ToolChoiceOptions = Literal["required"] | Literal["auto"] | Literal["none"]
class LLMConfig(BaseModel):
model_provider: str
model_name: str
temperature: float
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
deployment_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
def log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
if isinstance(msg, AIMessageChunk):
if msg.content:
log_msg = msg.content
elif msg.tool_call_chunks:
log_msg = "Tool Calls: " + str(
[
{
key: value
for key, value in tool_call.items()
if key != "index"
}
for tool_call in msg.tool_call_chunks
]
)
else:
log_msg = ""
logger.debug(f"Message {ind}:\n{log_msg}")
else:
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
class LLM(abc.ABC):
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
to use these implementations to connect to a variety of LLM providers."""
@property
def requires_warm_up(self) -> bool:
"""Is this model running in memory and needs an initial call to warm it up?"""
return False
@property
def requires_api_key(self) -> bool:
return True
@property
@abc.abstractmethod
def config(self) -> LLMConfig:
raise NotImplementedError
@abc.abstractmethod
def log_model_configs(self) -> None:
raise NotImplementedError
def _precall(self, prompt: LanguageModelInput) -> None:
if DISABLE_GENERATIVE_AI:
raise Exception("Generative AI is disabled")
if LOG_DANSWER_MODEL_INTERACTIONS:
log_prompt(prompt)
def invoke(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(
prompt, tools, tool_choice, structured_response_format
)
@abc.abstractmethod
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
raise NotImplementedError
def stream(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
messages = self._stream_implementation(
prompt, tools, tool_choice, structured_response_format
)
tokens = []
for message in messages:
if LOG_INDIVIDUAL_MODEL_TOKENS:
tokens.append(message.content)
yield message
if LOG_INDIVIDUAL_MODEL_TOKENS and tokens:
logger.debug(f"Model Tokens: {tokens}")
@abc.abstractmethod
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@@ -0,0 +1,146 @@
import litellm # type: ignore
from pydantic import BaseModel
class CustomConfigKey(BaseModel):
name: str
description: str | None = None
is_required: bool = True
is_secret: bool = False
class WellKnownLLMProviderDescriptor(BaseModel):
name: str
display_name: str
api_key_required: bool
api_base_required: bool
api_version_required: bool
custom_config_keys: list[CustomConfigKey] | None = None
llm_names: list[str]
default_model: str | None = None
default_fast_model: str | None = None
# set for providers like Azure, which require a deployment name.
deployment_name_required: bool = False
# set for providers like Azure, which support a single model per deployment.
single_model_supported: bool = False
OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"o1-mini",
"o1-preview",
"gpt-4",
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-0613",
"gpt-4o-2024-08-06",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
]
BEDROCK_PROVIDER_NAME = "bedrock"
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
# models
BEDROCK_MODEL_NAMES = [
model
for model in litellm.bedrock_models
if "/" not in model and "embed" not in model
][::-1]
IGNORABLE_ANTHROPIC_MODELS = [
"claude-2",
"claude-instant-1",
"anthropic/claude-3-5-sonnet-20241022",
]
ANTHROPIC_PROVIDER_NAME = "anthropic"
ANTHROPIC_MODEL_NAMES = [
model
for model in litellm.anthropic_models
if model not in IGNORABLE_ANTHROPIC_MODELS
][::-1]
AZURE_PROVIDER_NAME = "azure"
_PROVIDER_TO_MODELS_MAP = {
OPENAI_PROVIDER_NAME: OPEN_AI_MODEL_NAMES,
BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES,
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES,
}
def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
return [
WellKnownLLMProviderDescriptor(
name="openai",
display_name="OpenAI",
api_key_required=True,
api_base_required=False,
api_version_required=False,
custom_config_keys=[],
llm_names=fetch_models_for_provider(OPENAI_PROVIDER_NAME),
default_model="gpt-4",
default_fast_model="gpt-4o-mini",
),
WellKnownLLMProviderDescriptor(
name=ANTHROPIC_PROVIDER_NAME,
display_name="Anthropic",
api_key_required=True,
api_base_required=False,
api_version_required=False,
custom_config_keys=[],
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
default_model="claude-3-5-sonnet-20241022",
default_fast_model="claude-3-5-sonnet-20241022",
),
WellKnownLLMProviderDescriptor(
name=AZURE_PROVIDER_NAME,
display_name="Azure OpenAI",
api_key_required=True,
api_base_required=True,
api_version_required=True,
custom_config_keys=[],
llm_names=fetch_models_for_provider(AZURE_PROVIDER_NAME),
deployment_name_required=True,
single_model_supported=True,
),
WellKnownLLMProviderDescriptor(
name=BEDROCK_PROVIDER_NAME,
display_name="AWS Bedrock",
api_key_required=False,
api_base_required=False,
api_version_required=False,
custom_config_keys=[
CustomConfigKey(name="AWS_REGION_NAME"),
CustomConfigKey(
name="AWS_ACCESS_KEY_ID",
is_required=False,
description="If using AWS IAM roles, AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY can be left blank.",
),
CustomConfigKey(
name="AWS_SECRET_ACCESS_KEY",
is_required=False,
is_secret=True,
description="If using AWS IAM roles, AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY can be left blank.",
),
],
llm_names=fetch_models_for_provider(BEDROCK_PROVIDER_NAME),
default_model="anthropic.claude-3-5-sonnet-20241022-v2:0",
default_fast_model="anthropic.claude-3-5-sonnet-20241022-v2:0",
),
]
def fetch_models_for_provider(provider_name: str) -> list[str]:
return _PROVIDER_TO_MODELS_MAP.get(provider_name, [])

View File

@@ -0,0 +1,59 @@
from typing import TYPE_CHECKING
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from pydantic import BaseModel
from onyx.configs.constants import MessageType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.utils import build_content_with_imgs
from onyx.tools.models import ToolCallFinalResult
if TYPE_CHECKING:
from onyx.db.models import ChatMessage
class PreviousMessage(BaseModel):
"""Simplified version of `ChatMessage`"""
message: str
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_call: ToolCallFinalResult | None
@classmethod
def from_chat_message(
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
) -> "PreviousMessage":
message_file_ids = (
[file["id"] for file in chat_message.files] if chat_message.files else []
)
return cls(
message=chat_message.message,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
files=[
file
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
)
def to_langchain_msg(self) -> BaseMessage:
content = build_content_with_imgs(self.message, self.files)
if self.message_type == MessageType.USER:
return HumanMessage(content=content)
elif self.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
else:
return SystemMessage(content=content)

View File

@@ -0,0 +1,20 @@
"""Overrides sent over the wire / stored in the DB
NOTE: these models are used in many places, so have to be
kepy in a separate file to avoid circular imports.
"""
from pydantic import BaseModel
class LLMOverride(BaseModel):
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
class PromptOverride(BaseModel):
system_prompt: str | None = None
task_prompt: str | None = None

515
backend/onyx/llm/utils.py Normal file
View File

@@ -0,0 +1,515 @@
import copy
import json
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
import litellm # type: ignore
import tiktoken
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import PromptValue
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from litellm.exceptions import APIConnectionError # type: ignore
from litellm.exceptions import APIError # type: ignore
from litellm.exceptions import AuthenticationError # type: ignore
from litellm.exceptions import BadRequestError # type: ignore
from litellm.exceptions import BudgetExceededError # type: ignore
from litellm.exceptions import ContentPolicyViolationError # type: ignore
from litellm.exceptions import ContextWindowExceededError # type: ignore
from litellm.exceptions import NotFoundError # type: ignore
from litellm.exceptions import PermissionDeniedError # type: ignore
from litellm.exceptions import RateLimitError # type: ignore
from litellm.exceptions import Timeout # type: ignore
from litellm.exceptions import UnprocessableEntityError # type: ignore
from onyx.configs.constants import MessageType
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.prompts.constants import CODE_BLOCK_PAT
from onyx.utils.b64 import get_image_type
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
from shared_configs.configs import LOG_LEVEL
logger = setup_logger()
def litellm_exception_to_error_msg(
e: Exception, llm: LLM, fallback_to_error_msg: bool = False
) -> str:
error_msg = str(e)
if isinstance(e, BadRequestError):
error_msg = "Bad request: The server couldn't process your request. Please check your input."
elif isinstance(e, AuthenticationError):
error_msg = "Authentication failed: Please check your API key and credentials."
elif isinstance(e, PermissionDeniedError):
error_msg = (
"Permission denied: You don't have the necessary permissions for this operation."
"Ensure you have access to this model."
)
elif isinstance(e, NotFoundError):
error_msg = "Resource not found: The requested resource doesn't exist."
elif isinstance(e, UnprocessableEntityError):
error_msg = "Unprocessable entity: The server couldn't process your request due to semantic errors."
elif isinstance(e, RateLimitError):
error_msg = (
"Rate limit exceeded: Please slow down your requests and try again later."
)
elif isinstance(e, ContextWindowExceededError):
error_msg = (
"Context window exceeded: Your input is too long for the model to process."
)
if llm is not None:
try:
max_context = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
error_msg += f"Your invoked model ({llm.config.model_name}) has a maximum context size of {max_context}"
except Exception:
logger.warning(
"Unable to get maximum input token for LiteLLM excpetion handling"
)
elif isinstance(e, ContentPolicyViolationError):
error_msg = "Content policy violation: Your request violates the content policy. Please revise your input."
elif isinstance(e, APIConnectionError):
error_msg = "API connection error: Failed to connect to the API. Please check your internet connection."
elif isinstance(e, BudgetExceededError):
error_msg = (
"Budget exceeded: You've exceeded your allocated budget for API usage."
)
elif isinstance(e, Timeout):
error_msg = "Request timed out: The operation took too long to complete. Please try again."
elif isinstance(e, APIError):
error_msg = f"API error: An error occurred while communicating with the API. Details: {str(e)}"
elif not fallback_to_error_msg:
error_msg = "An unexpected error occurred while processing your request. Please try again later."
return error_msg
def _build_content(
message: str,
files: list[InMemoryChatFile] | None = None,
) -> str:
"""Applies all non-image files."""
if not files:
return message
text_files = [
file
for file in files
if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV)
]
if not text_files:
return message
final_message_with_files = "FILES:\n\n"
for file in text_files:
file_content = file.content.decode("utf-8")
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
final_message_with_files += (
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
)
return final_message_with_files + message
def build_content_with_imgs(
message: str,
files: list[InMemoryChatFile] | None = None,
img_urls: list[str] | None = None,
b64_imgs: list[str] | None = None,
message_type: MessageType = MessageType.USER,
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
files = files or []
# Only include image files for user messages
img_files = (
[file for file in files if file.file_type == ChatFileType.IMAGE]
if message_type == MessageType.USER
else []
)
img_urls = img_urls or []
b64_imgs = b64_imgs or []
message_main_content = _build_content(message, files)
if not img_files and not img_urls:
return message_main_content
return cast(
list[str | dict[str, Any]],
[
{
"type": "text",
"text": message_main_content,
},
]
+ [
{
"type": "image_url",
"image_url": {
"url": (
f"data:{get_image_type_from_bytes(file.content)};"
f"base64,{file.to_base64()}"
),
},
}
for file in img_files
]
+ [
{
"type": "image_url",
"image_url": {
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
},
}
for b64_img in b64_imgs
]
+ [
{
"type": "image_url",
"image_url": {
"url": url,
},
}
for url in img_urls
],
)
def message_to_prompt_and_imgs(message: BaseMessage) -> tuple[str, list[str]]:
if isinstance(message.content, str):
return message.content, []
imgs = []
texts = []
for part in message.content:
if isinstance(part, dict):
if part.get("type") == "image_url":
img_url = part.get("image_url", {}).get("url")
if img_url:
imgs.append(img_url)
elif part.get("type") == "text":
text = part.get("text")
if text:
texts.append(text)
else:
texts.append(part)
return "".join(texts), imgs
def dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]]
) -> list[BaseMessage]:
prompt: list[BaseMessage] = []
for message in messages:
role = message.get("role")
content = message.get("content")
if not role:
raise ValueError(f"Message missing `role`: {message}")
if not content:
raise ValueError(f"Message missing `content`: {message}")
elif role == "user":
prompt.append(HumanMessage(content=content))
elif role == "system":
prompt.append(SystemMessage(content=content))
elif role == "assistant":
prompt.append(AIMessage(content=content))
else:
raise ValueError(f"Unknown role: {role}")
return prompt
def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
return [HumanMessage(content=message)]
def convert_lm_input_to_basic_string(lm_input: LanguageModelInput) -> str:
"""Heavily inspired by:
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
"""
prompt_value = None
if isinstance(lm_input, PromptValue):
prompt_value = lm_input
elif isinstance(lm_input, str):
prompt_value = StringPromptValue(text=lm_input)
elif isinstance(lm_input, list):
prompt_value = ChatPromptValue(messages=lm_input)
if prompt_value is None:
raise ValueError(
f"Invalid input type {type(lm_input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
return prompt_value.to_string()
def message_to_string(message: BaseMessage) -> str:
if not isinstance(message.content, str):
raise RuntimeError("LLM message not in expected format.")
return message.content
def message_generator_to_string_generator(
messages: Iterator[BaseMessage],
) -> Iterator[str]:
for message in messages:
yield message_to_string(message)
def should_be_verbose() -> bool:
return LOG_LEVEL == "debug"
# estimate of the number of tokens in an image url
# is correct when downsampling is used. Is very wrong when OpenAI does not downsample
# TODO: improve this
_IMG_TOKENS = 85
def check_message_tokens(
message: BaseMessage, encode_fn: Callable[[str], list] | None = None
) -> int:
if isinstance(message.content, str):
return check_number_of_tokens(message.content, encode_fn)
total_tokens = 0
for part in message.content:
if isinstance(part, str):
total_tokens += check_number_of_tokens(part, encode_fn)
continue
if part["type"] == "text":
total_tokens += check_number_of_tokens(part["text"], encode_fn)
elif part["type"] == "image_url":
total_tokens += _IMG_TOKENS
if isinstance(message, AIMessage) and message.tool_calls:
for tool_call in message.tool_calls:
total_tokens += check_number_of_tokens(
json.dumps(tool_call["args"]), encode_fn
)
total_tokens += check_number_of_tokens(tool_call["name"], encode_fn)
return total_tokens
def check_number_of_tokens(
text: str, encode_fn: Callable[[str], list] | None = None
) -> int:
"""Gets the number of tokens in the provided text, using the provided encoding
function. If none is provided, default to the tiktoken encoder used by GPT-3.5
and GPT-4.
"""
if encode_fn is None:
encode_fn = tiktoken.get_encoding("cl100k_base").encode
return len(encode_fn(text))
def test_llm(llm: LLM) -> str | None:
# try for up to 2 timeouts (e.g. 10 seconds in total)
error_msg = None
for _ in range(2):
try:
llm.invoke("Do not respond")
return None
except Exception as e:
error_msg = str(e)
logger.warning(f"Failed to call LLM with the following error: {error_msg}")
return error_msg
def get_model_map() -> dict:
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
# NOTE: we could add additional models here in the future,
# but for now there is no point. Ollama allows the user to
# to specify their desired max context window, and it's
# unlikely to be standard across users even for the same model
# (it heavily depends on their hardware). For now, we'll just
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
# for model_name in [
# "llama3.2",
# "llama3.2:1b",
# "llama3.2:3b",
# "llama3.2:11b",
# "llama3.2:90b",
# ]:
# starting_map[f"ollama/{model_name}"] = {
# "max_tokens": 128000,
# "max_input_tokens": 128000,
# "max_output_tokens": 128000,
# }
return starting_map
def _strip_extra_provider_from_model_name(model_name: str) -> str:
return model_name.split("/")[1] if "/" in model_name else model_name
def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
def _find_model_obj(
model_map: dict, provider: str, model_names: list[str | None]
) -> dict | None:
# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]
# First try all model names with provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(f"{provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {provider}/{model_name}")
return model_obj
# Then try all model names without provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
return model_obj
return None
def get_llm_max_tokens(
model_map: dict,
model_name: str,
model_provider: str,
) -> int:
"""Best effort attempt to get the max tokens for the LLM"""
if GEN_AI_MAX_TOKENS:
# This is an override, so always return this
logger.info(f"Using override GEN_AI_MAX_TOKENS: {GEN_AI_MAX_TOKENS}")
return GEN_AI_MAX_TOKENS
try:
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
model_name
)
model_obj = _find_model_obj(
model_map,
model_provider,
[
model_name,
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
extra_provider_stripped_model_name,
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(extra_provider_stripped_model_name),
],
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
)
if "max_input_tokens" in model_obj:
max_tokens = model_obj["max_input_tokens"]
logger.info(
f"Max tokens for {model_name}: {max_tokens} (from max_input_tokens)"
)
return max_tokens
if "max_tokens" in model_obj:
max_tokens = model_obj["max_tokens"]
logger.info(f"Max tokens for {model_name}: {max_tokens} (from max_tokens)")
return max_tokens
logger.error(f"No max tokens found for LLM: {model_name}")
raise RuntimeError("No max tokens found for LLM")
except Exception:
logger.exception(
f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS}."
)
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
def get_llm_max_output_tokens(
model_map: dict,
model_name: str,
model_provider: str,
) -> int:
"""Best effort attempt to get the max output tokens for the LLM"""
try:
model_obj = model_map.get(f"{model_provider}/{model_name}")
if not model_obj:
model_obj = model_map[model_name]
logger.debug(f"Using model object for {model_name}")
else:
logger.debug(f"Using model object for {model_provider}/{model_name}")
if "max_output_tokens" in model_obj:
max_output_tokens = model_obj["max_output_tokens"]
logger.info(f"Max output tokens for {model_name}: {max_output_tokens}")
return max_output_tokens
# Fallback to a fraction of max_tokens if max_output_tokens is not specified
if "max_tokens" in model_obj:
max_output_tokens = int(model_obj["max_tokens"] * 0.1)
logger.info(
f"Fallback max output tokens for {model_name}: {max_output_tokens} (10% of max_tokens)"
)
return max_output_tokens
logger.error(f"No max output tokens found for LLM: {model_name}")
raise RuntimeError("No max output tokens found for LLM")
except Exception:
default_output_tokens = int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS)
logger.exception(
f"Failed to get max output tokens for LLM with name {model_name}. "
f"Defaulting to {default_output_tokens} (fallback max tokens)."
)
return default_output_tokens
def get_max_input_tokens(
model_name: str,
model_provider: str,
output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS,
) -> int:
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
# returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict,
# and there is no other interface to get what we want. This should be okay though, since the
# `model_cost` dict is a named public interface:
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
# model_map is litellm.model_cost
litellm_model_map = get_model_map()
input_toks = (
get_llm_max_tokens(
model_name=model_name,
model_provider=model_provider,
model_map=litellm_model_map,
)
- output_tokens
)
if input_toks <= 0:
raise RuntimeError("No tokens for input for the LLM given settings")
return input_toks