danswer/backend/onyx/llm/interfaces.py
evan-danswer 0c29743538
use max_tokens to do better rate limit handling (#4224)
* use max_tokens to do better rate limit handling

* fix unti tests

* address greptile comment, thanks greptile
2025-03-06 18:12:05 -08:00

161 lines
5.1 KiB
Python

import abc
from collections.abc import Iterator
from typing import Literal
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from onyx.configs.app_configs import LOG_INDIVIDUAL_MODEL_TOKENS
from onyx.utils.logger import setup_logger
logger = setup_logger()
ToolChoiceOptions = Literal["required"] | Literal["auto"] | Literal["none"]
class LLMConfig(BaseModel):
model_provider: str
model_name: str
temperature: float
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
deployment_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
def log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
if isinstance(msg, AIMessageChunk):
if msg.content:
log_msg = msg.content
elif msg.tool_call_chunks:
log_msg = "Tool Calls: " + str(
[
{
key: value
for key, value in tool_call.items()
if key != "index"
}
for tool_call in msg.tool_call_chunks
]
)
else:
log_msg = ""
logger.debug(f"Message {ind}:\n{log_msg}")
else:
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
class LLM(abc.ABC):
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
to use these implementations to connect to a variety of LLM providers."""
@property
def requires_warm_up(self) -> bool:
"""Is this model running in memory and needs an initial call to warm it up?"""
return False
@property
def requires_api_key(self) -> bool:
return True
@property
@abc.abstractmethod
def config(self) -> LLMConfig:
raise NotImplementedError
@abc.abstractmethod
def log_model_configs(self) -> None:
raise NotImplementedError
def _precall(self, prompt: LanguageModelInput) -> None:
if DISABLE_GENERATIVE_AI:
raise Exception("Generative AI is disabled")
if LOG_DANSWER_MODEL_INTERACTIONS:
log_prompt(prompt)
def invoke(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
@abc.abstractmethod
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
raise NotImplementedError
def stream(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
messages = self._stream_implementation(
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
tokens = []
for message in messages:
if LOG_INDIVIDUAL_MODEL_TOKENS:
tokens.append(message.content)
yield message
if LOG_INDIVIDUAL_MODEL_TOKENS and tokens:
logger.debug(f"Model Tokens: {tokens}")
@abc.abstractmethod
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError