Files
danswer/backend/onyx/llm/interfaces.py
Raunak Bhagat b97628070e feat: Add ability to specify max input token limit for custom LLM providers (#4510)
* Add multi text array field

* Add multiple values to model configuration for a custom LLM provider

* Fix reference to old field name

* Add migration

* Update all instances of model_names / display_model_names to use new schema migration

* Update background task

* Update endpoints to not throw errors

* Add test

* Update backend/alembic/versions/7a70b7664e37_add_models_configuration_table.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update backend/onyx/background/celery/tasks/llm_model_update/tasks.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Fix list comprehension nits

* Update web/src/components/admin/connectors/Field.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update web/src/app/admin/configuration/llm/interfaces.ts

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Implement greptile recommendations

* Update backend/onyx/db/llm.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update backend/onyx/server/manage/llm/api.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update backend/onyx/background/celery/tasks/llm_model_update/tasks.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update backend/onyx/db/llm.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Fix more greptile suggestions

* Run formatter again

* Update backend/onyx/db/models.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Add relationship to `LLMProvider` and `ModelConfigurations` classes

* Use sqlalchemy ORM relationships instead of manually populating fields

* Upgrade migration

* Update interface

* Remove all instances of model_names and display_model_names from backend

* Add more tests and fix bugs

* Run prettier

* Add types

* Update migration to perform data transformation

* Ensure native llm providers don't have custom max input tokens

* Start updating frontend logic to support custom max input tokens

* Pass max input tokens to LLM class (to be passed into `litellm.completion` call later)

* Add ModelConfigurationField component for custom llm providers

* Edit spacing and styling of model configuration matrix

* Fix error message displaying bug

* Edit opacity of `FiX` field for first index

* Change opacity back

* Change roundness

* Address comments on PR

* Perform fetching of `max_input_tokens` at the beginning of the callgraph and rope it throughout the entire callstack

* Change `add` to `execute`

* Move `max_input_tokens` into `LLMConfig`

* Fix bug with error messages not being cleared

* Change field used to fetch LLMProvider

* Fix model-configuration UI

* Address comments

* Remove circular import

* Fix failing tests in GH

* Fix failing tests

* Use `isSubset` instead of equality to determine native vs custom LLM Provider

* Remove unused import

* Make responses always display max_input_tokens

* Fix api endpoint to hit

* Update types in web application

* Update object field

* Fix more type errors

* Fix failing llm provider tests

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-04-21 04:30:21 -07:00

163 lines
5.2 KiB
Python

import abc
from collections.abc import Iterator
from typing import Literal
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from onyx.configs.app_configs import LOG_INDIVIDUAL_MODEL_TOKENS
from onyx.utils.logger import setup_logger
logger = setup_logger()
ToolChoiceOptions = Literal["required"] | Literal["auto"] | Literal["none"]
class LLMConfig(BaseModel):
model_provider: str
model_name: str
temperature: float
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
deployment_name: str | None = None
credentials_file: str | None = None
max_input_tokens: int
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
def log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
if isinstance(msg, AIMessageChunk):
if msg.content:
log_msg = msg.content
elif msg.tool_call_chunks:
log_msg = "Tool Calls: " + str(
[
{
key: value
for key, value in tool_call.items()
if key != "index"
}
for tool_call in msg.tool_call_chunks
]
)
else:
log_msg = ""
logger.debug(f"Message {ind}:\n{log_msg}")
else:
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
class LLM(abc.ABC):
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
to use these implementations to connect to a variety of LLM providers."""
@property
def requires_warm_up(self) -> bool:
"""Is this model running in memory and needs an initial call to warm it up?"""
return False
@property
def requires_api_key(self) -> bool:
return True
@property
@abc.abstractmethod
def config(self) -> LLMConfig:
raise NotImplementedError
@abc.abstractmethod
def log_model_configs(self) -> None:
raise NotImplementedError
def _precall(self, prompt: LanguageModelInput) -> None:
if DISABLE_GENERATIVE_AI:
raise Exception("Generative AI is disabled")
if LOG_DANSWER_MODEL_INTERACTIONS:
log_prompt(prompt)
def invoke(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
@abc.abstractmethod
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
raise NotImplementedError
def stream(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
messages = self._stream_implementation(
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
tokens = []
for message in messages:
if LOG_INDIVIDUAL_MODEL_TOKENS:
tokens.append(message.content)
yield message
if LOG_INDIVIDUAL_MODEL_TOKENS and tokens:
logger.debug(f"Model Tokens: {tokens}")
@abc.abstractmethod
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError