mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-12 22:23:01 +02:00
Enforce Disable GenAI if set (#1860)
This commit is contained in:
@ -232,32 +232,6 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
self._model_kwargs = model_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _log_prompt(prompt: LanguageModelInput) -> None:
|
||||
if isinstance(prompt, list):
|
||||
for ind, msg in enumerate(prompt):
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
if msg.content:
|
||||
log_msg = msg.content
|
||||
elif msg.tool_call_chunks:
|
||||
log_msg = "Tool Calls: " + str(
|
||||
[
|
||||
{
|
||||
key: value
|
||||
for key, value in tool_call.items()
|
||||
if key != "index"
|
||||
}
|
||||
for tool_call in msg.tool_call_chunks
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.info(f"Config: {self.config}")
|
||||
|
||||
@ -311,7 +285,7 @@ class DefaultMultiLLM(LLM):
|
||||
api_version=self._api_version,
|
||||
)
|
||||
|
||||
def invoke(
|
||||
def _invoke_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
@ -319,7 +293,6 @@ class DefaultMultiLLM(LLM):
|
||||
) -> BaseMessage:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
self._log_prompt(prompt)
|
||||
|
||||
response = cast(
|
||||
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
|
||||
@ -328,7 +301,7 @@ class DefaultMultiLLM(LLM):
|
||||
response.choices[0].message
|
||||
)
|
||||
|
||||
def stream(
|
||||
def _stream_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
@ -336,7 +309,6 @@ class DefaultMultiLLM(LLM):
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
self._log_prompt(prompt)
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(prompt)
|
||||
|
@ -76,7 +76,7 @@ class CustomModelServer(LLM):
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Custom model at: {self._endpoint}")
|
||||
|
||||
def invoke(
|
||||
def _invoke_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
@ -84,7 +84,7 @@ class CustomModelServer(LLM):
|
||||
) -> BaseMessage:
|
||||
return self._execute(prompt)
|
||||
|
||||
def stream(
|
||||
def _stream_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
|
@ -3,9 +3,12 @@ from collections.abc import Iterator
|
||||
from typing import Literal
|
||||
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@ -23,6 +26,32 @@ class LLMConfig(BaseModel):
|
||||
api_version: str | None
|
||||
|
||||
|
||||
def log_prompt(prompt: LanguageModelInput) -> None:
|
||||
if isinstance(prompt, list):
|
||||
for ind, msg in enumerate(prompt):
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
if msg.content:
|
||||
log_msg = msg.content
|
||||
elif msg.tool_call_chunks:
|
||||
log_msg = "Tool Calls: " + str(
|
||||
[
|
||||
{
|
||||
key: value
|
||||
for key, value in tool_call.items()
|
||||
if key != "index"
|
||||
}
|
||||
for tool_call in msg.tool_call_chunks
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
|
||||
|
||||
class LLM(abc.ABC):
|
||||
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
||||
to use these implementations to connect to a variety of LLM providers."""
|
||||
@ -45,20 +74,48 @@ class LLM(abc.ABC):
|
||||
def log_model_configs(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _precall(self, prompt: LanguageModelInput) -> None:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise Exception("Generative AI is disabled")
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
log_prompt(prompt)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation(prompt, tools, tool_choice)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _invoke_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def stream(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._stream_implementation(prompt, tools, tool_choice)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _stream_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user