import json import os import traceback from collections.abc import Iterator from collections.abc import Sequence from typing import Any from typing import cast import litellm # type: ignore from httpx import RemoteProtocolError from langchain.schema.language_model import LanguageModelInput from langchain_core.messages import AIMessage from langchain_core.messages import AIMessageChunk from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessageChunk from langchain_core.messages import ChatMessage from langchain_core.messages import ChatMessageChunk from langchain_core.messages import FunctionMessage from langchain_core.messages import FunctionMessageChunk from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessageChunk from langchain_core.messages import SystemMessage from langchain_core.messages import SystemMessageChunk from langchain_core.messages.tool import ToolCallChunk from langchain_core.messages.tool import ToolMessage from langchain_core.prompt_values import PromptValue from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS from onyx.configs.app_configs import MOCK_LLM_RESPONSE from onyx.configs.chat_configs import QA_TIMEOUT from onyx.configs.model_configs import ( DISABLE_LITELLM_STREAMING, ) from onyx.configs.model_configs import GEN_AI_TEMPERATURE from onyx.configs.model_configs import LITELLM_EXTRA_BODY from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMConfig from onyx.llm.interfaces import ToolChoiceOptions from onyx.llm.utils import model_is_reasoning_model from onyx.server.utils import mask_string from onyx.utils.logger import setup_logger from onyx.utils.long_term_log import LongTermLogger logger = setup_logger() # If a user configures a different model and it doesn't support all the same # parameters like frequency and presence, just ignore them litellm.drop_params = True litellm.telemetry = False _LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt" class LLMTimeoutError(Exception): """ Exception raised when an LLM call times out. """ class LLMRateLimitError(Exception): """ Exception raised when an LLM call is rate limited. """ def _base_msg_to_role(msg: BaseMessage) -> str: if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk): return "user" if isinstance(msg, AIMessage) or isinstance(msg, AIMessageChunk): return "assistant" if isinstance(msg, SystemMessage) or isinstance(msg, SystemMessageChunk): return "system" if isinstance(msg, FunctionMessage) or isinstance(msg, FunctionMessageChunk): return "function" return "unknown" def _convert_litellm_message_to_langchain_message( litellm_message: litellm.Message, ) -> BaseMessage: # Extracting the basic attributes from the litellm message content = litellm_message.content or "" role = litellm_message.role # Handling function calls and tool calls if present tool_calls = ( cast( list[litellm.ChatCompletionMessageToolCall], litellm_message.tool_calls, ) if hasattr(litellm_message, "tool_calls") else [] ) # Create the appropriate langchain message based on the role if role == "user": return HumanMessage(content=content) elif role == "assistant": return AIMessage( content=content, tool_calls=[ { "name": tool_call.function.name or "", "args": json.loads(tool_call.function.arguments), "id": tool_call.id, } for tool_call in tool_calls ] if tool_calls else [], ) elif role == "system": return SystemMessage(content=content) else: raise ValueError(f"Unknown role type received: {role}") def _convert_message_to_dict(message: BaseMessage) -> dict: """Adapted from langchain_community.chat_models.litellm._convert_message_to_dict""" if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} elif isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: message_dict["tool_calls"] = [ { "id": tool_call.get("id"), "function": { "name": tool_call["name"], "arguments": json.dumps(tool_call["args"]), }, "type": "function", "index": tool_call.get("index", 0), } for tool_call in message.tool_calls ] if "function_call" in message.additional_kwargs: message_dict["function_call"] = message.additional_kwargs["function_call"] elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} elif isinstance(message, FunctionMessage): message_dict = { "role": "function", "content": message.content, "name": message.name, } elif isinstance(message, ToolMessage): message_dict = { "tool_call_id": message.tool_call_id, "role": "tool", "name": message.name or "", "content": message.content, } else: raise ValueError(f"Got unknown type {message}") if "name" in message.additional_kwargs: message_dict["name"] = message.additional_kwargs["name"] return message_dict def _convert_delta_to_message_chunk( _dict: dict[str, Any], curr_msg: BaseMessage | None, stop_reason: str | None = None, ) -> BaseMessageChunk: """Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk""" role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None) content = _dict.get("content") or "" additional_kwargs = {} if _dict.get("function_call"): additional_kwargs.update({"function_call": dict(_dict["function_call"])}) tool_calls = cast( list[litellm.utils.ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls") ) if role == "user": return HumanMessageChunk(content=content) # NOTE: if tool calls are present, then it's an assistant. # In Ollama, the role will be None for tool-calls elif role == "assistant" or tool_calls: if tool_calls: tool_call = tool_calls[0] tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or "" idx = tool_call.index tool_call_chunk = ToolCallChunk( name=tool_name, id=tool_call.id, args=tool_call.function.arguments, index=idx, ) return AIMessageChunk( content=content, tool_call_chunks=[tool_call_chunk], additional_kwargs={ "usage_metadata": {"stop": stop_reason}, **additional_kwargs, }, ) return AIMessageChunk( content=content, additional_kwargs={ "usage_metadata": {"stop": stop_reason}, **additional_kwargs, }, ) elif role == "system": return SystemMessageChunk(content=content) elif role == "function": return FunctionMessageChunk(content=content, name=_dict["name"]) elif role: return ChatMessageChunk(content=content, role=role) raise ValueError(f"Unknown role: {role}") def _prompt_to_dict( prompt: LanguageModelInput, ) -> Sequence[str | list[str] | dict[str, Any] | tuple[str, str]]: # NOTE: this must go first, since it is also a Sequence if isinstance(prompt, str): return [_convert_message_to_dict(HumanMessage(content=prompt))] if isinstance(prompt, (list, Sequence)): return [ _convert_message_to_dict(msg) if isinstance(msg, BaseMessage) else msg for msg in prompt ] if isinstance(prompt, PromptValue): return [_convert_message_to_dict(message) for message in prompt.to_messages()] class DefaultMultiLLM(LLM): """Uses Litellm library to allow easy configuration to use a multitude of LLMs See https://python.langchain.com/docs/integrations/chat/litellm""" def __init__( self, api_key: str | None, model_provider: str, model_name: str, timeout: int | None = None, api_base: str | None = None, api_version: str | None = None, deployment_name: str | None = None, max_output_tokens: int | None = None, custom_llm_provider: str | None = None, temperature: float | None = None, custom_config: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None, extra_body: dict | None = LITELLM_EXTRA_BODY, model_kwargs: dict[str, Any] | None = None, long_term_logger: LongTermLogger | None = None, ): self._timeout = timeout if timeout is None: if model_is_reasoning_model(model_name): self._timeout = QA_TIMEOUT * 10 # Reasoning models are slow else: self._timeout = QA_TIMEOUT self._temperature = GEN_AI_TEMPERATURE if temperature is None else temperature self._model_provider = model_provider self._model_version = model_name self._api_key = api_key self._deployment_name = deployment_name self._api_base = api_base self._api_version = api_version self._custom_llm_provider = custom_llm_provider self._long_term_logger = long_term_logger # This can be used to store the maximum output tokens for this model. # self._max_output_tokens = ( # max_output_tokens # if max_output_tokens is not None # else get_llm_max_output_tokens( # model_map=litellm.model_cost, # model_name=model_name, # model_provider=model_provider, # ) # ) self._custom_config = custom_config # Create a dictionary for model-specific arguments if it's None model_kwargs = model_kwargs or {} # NOTE: have to set these as environment variables for Litellm since # not all are able to passed in but they always support them set as env # variables. We'll also try passing them in, since litellm just ignores # addtional kwargs (and some kwargs MUST be passed in rather than set as # env variables) if custom_config: # Specifically pass in "vertex_credentials" / "vertex_location" as a # model_kwarg to the completion call for vertex AI. More details here: # https://docs.litellm.ai/docs/providers/vertex vertex_credentials_key = "vertex_credentials" vertex_location_key = "vertex_location" for k, v in custom_config.items(): if model_provider == "vertex_ai": if k == vertex_credentials_key: model_kwargs[k] = v continue elif k == vertex_location_key: model_kwargs[k] = v continue # for all values, set them as env variables os.environ[k] = v if extra_headers: model_kwargs.update({"extra_headers": extra_headers}) if extra_body: model_kwargs.update({"extra_body": extra_body}) self._model_kwargs = model_kwargs def log_model_configs(self) -> None: logger.debug(f"Config: {self.config}") def _safe_model_config(self) -> dict: dump = self.config.model_dump() dump["api_key"] = mask_string(dump.get("api_key", "")) return dump def _record_call(self, prompt: LanguageModelInput) -> None: if self._long_term_logger: self._long_term_logger.record( {"prompt": _prompt_to_dict(prompt), "model": self._safe_model_config()}, category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY, ) def _record_result( self, prompt: LanguageModelInput, model_output: BaseMessage ) -> None: if self._long_term_logger: self._long_term_logger.record( { "prompt": _prompt_to_dict(prompt), "content": model_output.content, "tool_calls": ( model_output.tool_calls if hasattr(model_output, "tool_calls") else [] ), "model": self._safe_model_config(), }, category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY, ) def _record_error(self, prompt: LanguageModelInput, error: Exception) -> None: if self._long_term_logger: self._long_term_logger.record( { "prompt": _prompt_to_dict(prompt), "error": str(error), "traceback": "".join( traceback.format_exception( type(error), error, error.__traceback__ ) ), "model": self._safe_model_config(), }, category=_LLM_PROMPT_LONG_TERM_LOG_CATEGORY, ) # def _calculate_max_output_tokens(self, prompt: LanguageModelInput) -> int: # # NOTE: This method can be used for calculating the maximum tokens for the stream, # # but it isn't used in practice due to the computational cost of counting tokens # # and because LLM providers automatically cut off at the maximum output. # # The implementation is kept for potential future use or debugging purposes. # # Get max input tokens for the model # max_context_tokens = get_max_input_tokens( # model_name=self.config.model_name, model_provider=self.config.model_provider # ) # llm_tokenizer = get_tokenizer( # model_name=self.config.model_name, # provider_type=self.config.model_provider, # ) # # Calculate tokens in the input prompt # input_tokens = sum(len(llm_tokenizer.encode(str(m))) for m in prompt) # # Calculate available tokens for output # available_output_tokens = max_context_tokens - input_tokens # # Return the lesser of available tokens or configured max # return min(self._max_output_tokens, available_output_tokens) def _completion( self, prompt: LanguageModelInput, tools: list[dict] | None, tool_choice: ToolChoiceOptions | None, stream: bool, structured_response_format: dict | None = None, timeout_override: int | None = None, ) -> litellm.ModelResponse | litellm.CustomStreamWrapper: # litellm doesn't accept LangChain BaseMessage objects, so we need to convert them # to a dict representation processed_prompt = _prompt_to_dict(prompt) self._record_call(processed_prompt) try: return litellm.completion( mock_response=MOCK_LLM_RESPONSE, # model choice # model="openai/gpt-4", model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}", # NOTE: have to pass in None instead of empty string for these # otherwise litellm can have some issues with bedrock api_key=self._api_key or None, base_url=self._api_base or None, api_version=self._api_version or None, custom_llm_provider=self._custom_llm_provider or None, # actual input messages=processed_prompt, tools=tools, tool_choice=tool_choice if tools else None, # streaming choice stream=stream, # model params temperature=0, timeout=timeout_override or self._timeout, # For now, we don't support parallel tool calls # NOTE: we can't pass this in if tools are not specified # or else OpenAI throws an error **( {"parallel_tool_calls": False} if tools and self.config.model_name not in [ "o3-mini", "o3-preview", "o1", "o1-preview", "o1-mini", "o1-mini-2024-09-12", "o3-mini-2025-01-31", ] else {} ), # TODO: remove once LITELLM has patched **( {"response_format": structured_response_format} if structured_response_format else {} ), **self._model_kwargs, ) except Exception as e: self._record_error(processed_prompt, e) # for break pointing if isinstance(e, litellm.Timeout): raise LLMTimeoutError(e) elif isinstance(e, litellm.RateLimitError): raise LLMRateLimitError(e) raise e @property def config(self) -> LLMConfig: return LLMConfig( model_provider=self._model_provider, model_name=self._model_version, temperature=self._temperature, api_key=self._api_key, api_base=self._api_base, api_version=self._api_version, deployment_name=self._deployment_name, ) def _invoke_implementation( self, prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, ) -> BaseMessage: if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() response = cast( litellm.ModelResponse, self._completion( prompt=prompt, tools=tools, tool_choice=tool_choice, stream=False, structured_response_format=structured_response_format, timeout_override=timeout_override, ), ) choice = response.choices[0] if hasattr(choice, "message"): output = _convert_litellm_message_to_langchain_message(choice.message) if output: self._record_result(prompt, output) return output else: raise ValueError("Unexpected response choice type") def _stream_implementation( self, prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, ) -> Iterator[BaseMessage]: if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() if DISABLE_LITELLM_STREAMING: yield self.invoke( prompt, tools, tool_choice, structured_response_format, timeout_override, ) return output = None response = cast( litellm.CustomStreamWrapper, self._completion( prompt=prompt, tools=tools, tool_choice=tool_choice, stream=True, structured_response_format=structured_response_format, timeout_override=timeout_override, ), ) try: for part in response: if not part["choices"]: continue choice = part["choices"][0] message_chunk = _convert_delta_to_message_chunk( choice["delta"], output, stop_reason=choice["finish_reason"], ) if output is None: output = message_chunk else: output += message_chunk yield message_chunk except RemoteProtocolError: raise RuntimeError( "The AI model failed partway through generation, please try again." ) if output: self._record_result(prompt, output) if LOG_DANSWER_MODEL_INTERACTIONS and output: content = output.content or "" if isinstance(output, AIMessage): if content: log_msg = content elif output.tool_calls: log_msg = "Tool Calls: " + str( [ { key: value for key, value in tool_call.items() if key != "index" } for tool_call in output.tool_calls ] ) else: log_msg = "" logger.debug(f"Raw Model Output:\n{log_msg}") else: logger.debug(f"Raw Model Output:\n{content}")