mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-05 04:01:31 +02:00
* Add multi text array field * Add multiple values to model configuration for a custom LLM provider * Fix reference to old field name * Add migration * Update all instances of model_names / display_model_names to use new schema migration * Update background task * Update endpoints to not throw errors * Add test * Update backend/alembic/versions/7a70b7664e37_add_models_configuration_table.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Update backend/onyx/background/celery/tasks/llm_model_update/tasks.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Fix list comprehension nits * Update web/src/components/admin/connectors/Field.tsx Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Update web/src/app/admin/configuration/llm/interfaces.ts Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Implement greptile recommendations * Update backend/onyx/db/llm.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Update backend/onyx/server/manage/llm/api.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Update backend/onyx/background/celery/tasks/llm_model_update/tasks.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Update backend/onyx/db/llm.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Fix more greptile suggestions * Run formatter again * Update backend/onyx/db/models.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Add relationship to `LLMProvider` and `ModelConfigurations` classes * Use sqlalchemy ORM relationships instead of manually populating fields * Upgrade migration * Update interface * Remove all instances of model_names and display_model_names from backend * Add more tests and fix bugs * Run prettier * Add types * Update migration to perform data transformation * Ensure native llm providers don't have custom max input tokens * Start updating frontend logic to support custom max input tokens * Pass max input tokens to LLM class (to be passed into `litellm.completion` call later) * Add ModelConfigurationField component for custom llm providers * Edit spacing and styling of model configuration matrix * Fix error message displaying bug * Edit opacity of `FiX` field for first index * Change opacity back * Change roundness * Address comments on PR * Perform fetching of `max_input_tokens` at the beginning of the callgraph and rope it throughout the entire callstack * Change `add` to `execute` * Move `max_input_tokens` into `LLMConfig` * Fix bug with error messages not being cleared * Change field used to fetch LLMProvider * Fix model-configuration UI * Address comments * Remove circular import * Fix failing tests in GH * Fix failing tests * Use `isSubset` instead of equality to determine native vs custom LLM Provider * Remove unused import * Make responses always display max_input_tokens * Fix api endpoint to hit * Update types in web application * Update object field * Fix more type errors * Fix failing llm provider tests --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
163 lines
5.2 KiB
Python
163 lines
5.2 KiB
Python
import abc
|
|
from collections.abc import Iterator
|
|
from typing import Literal
|
|
|
|
from langchain.schema.language_model import LanguageModelInput
|
|
from langchain_core.messages import AIMessageChunk
|
|
from langchain_core.messages import BaseMessage
|
|
from pydantic import BaseModel
|
|
|
|
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
|
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
|
from onyx.configs.app_configs import LOG_INDIVIDUAL_MODEL_TOKENS
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
ToolChoiceOptions = Literal["required"] | Literal["auto"] | Literal["none"]
|
|
|
|
|
|
class LLMConfig(BaseModel):
|
|
model_provider: str
|
|
model_name: str
|
|
temperature: float
|
|
api_key: str | None = None
|
|
api_base: str | None = None
|
|
api_version: str | None = None
|
|
deployment_name: str | None = None
|
|
credentials_file: str | None = None
|
|
max_input_tokens: int
|
|
# This disables the "model_" protected namespace for pydantic
|
|
model_config = {"protected_namespaces": ()}
|
|
|
|
|
|
def log_prompt(prompt: LanguageModelInput) -> None:
|
|
if isinstance(prompt, list):
|
|
for ind, msg in enumerate(prompt):
|
|
if isinstance(msg, AIMessageChunk):
|
|
if msg.content:
|
|
log_msg = msg.content
|
|
elif msg.tool_call_chunks:
|
|
log_msg = "Tool Calls: " + str(
|
|
[
|
|
{
|
|
key: value
|
|
for key, value in tool_call.items()
|
|
if key != "index"
|
|
}
|
|
for tool_call in msg.tool_call_chunks
|
|
]
|
|
)
|
|
else:
|
|
log_msg = ""
|
|
logger.debug(f"Message {ind}:\n{log_msg}")
|
|
else:
|
|
logger.debug(f"Message {ind}:\n{msg.content}")
|
|
if isinstance(prompt, str):
|
|
logger.debug(f"Prompt:\n{prompt}")
|
|
|
|
|
|
class LLM(abc.ABC):
|
|
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
|
to use these implementations to connect to a variety of LLM providers."""
|
|
|
|
@property
|
|
def requires_warm_up(self) -> bool:
|
|
"""Is this model running in memory and needs an initial call to warm it up?"""
|
|
return False
|
|
|
|
@property
|
|
def requires_api_key(self) -> bool:
|
|
return True
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def config(self) -> LLMConfig:
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
def log_model_configs(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
def _precall(self, prompt: LanguageModelInput) -> None:
|
|
if DISABLE_GENERATIVE_AI:
|
|
raise Exception("Generative AI is disabled")
|
|
if LOG_DANSWER_MODEL_INTERACTIONS:
|
|
log_prompt(prompt)
|
|
|
|
def invoke(
|
|
self,
|
|
prompt: LanguageModelInput,
|
|
tools: list[dict] | None = None,
|
|
tool_choice: ToolChoiceOptions | None = None,
|
|
structured_response_format: dict | None = None,
|
|
timeout_override: int | None = None,
|
|
max_tokens: int | None = None,
|
|
) -> BaseMessage:
|
|
self._precall(prompt)
|
|
# TODO add a postcall to log model outputs independent of concrete class
|
|
# implementation
|
|
return self._invoke_implementation(
|
|
prompt,
|
|
tools,
|
|
tool_choice,
|
|
structured_response_format,
|
|
timeout_override,
|
|
max_tokens,
|
|
)
|
|
|
|
@abc.abstractmethod
|
|
def _invoke_implementation(
|
|
self,
|
|
prompt: LanguageModelInput,
|
|
tools: list[dict] | None = None,
|
|
tool_choice: ToolChoiceOptions | None = None,
|
|
structured_response_format: dict | None = None,
|
|
timeout_override: int | None = None,
|
|
max_tokens: int | None = None,
|
|
) -> BaseMessage:
|
|
raise NotImplementedError
|
|
|
|
def stream(
|
|
self,
|
|
prompt: LanguageModelInput,
|
|
tools: list[dict] | None = None,
|
|
tool_choice: ToolChoiceOptions | None = None,
|
|
structured_response_format: dict | None = None,
|
|
timeout_override: int | None = None,
|
|
max_tokens: int | None = None,
|
|
) -> Iterator[BaseMessage]:
|
|
self._precall(prompt)
|
|
# TODO add a postcall to log model outputs independent of concrete class
|
|
# implementation
|
|
messages = self._stream_implementation(
|
|
prompt,
|
|
tools,
|
|
tool_choice,
|
|
structured_response_format,
|
|
timeout_override,
|
|
max_tokens,
|
|
)
|
|
|
|
tokens = []
|
|
for message in messages:
|
|
if LOG_INDIVIDUAL_MODEL_TOKENS:
|
|
tokens.append(message.content)
|
|
yield message
|
|
|
|
if LOG_INDIVIDUAL_MODEL_TOKENS and tokens:
|
|
logger.debug(f"Model Tokens: {tokens}")
|
|
|
|
@abc.abstractmethod
|
|
def _stream_implementation(
|
|
self,
|
|
prompt: LanguageModelInput,
|
|
tools: list[dict] | None = None,
|
|
tool_choice: ToolChoiceOptions | None = None,
|
|
structured_response_format: dict | None = None,
|
|
timeout_override: int | None = None,
|
|
max_tokens: int | None = None,
|
|
) -> Iterator[BaseMessage]:
|
|
raise NotImplementedError
|