mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-17 08:33:06 +02:00
Enforce Disable GenAI if set (#1860)
This commit is contained in:
@ -232,32 +232,6 @@ class DefaultMultiLLM(LLM):
|
|||||||
|
|
||||||
self._model_kwargs = model_kwargs
|
self._model_kwargs = model_kwargs
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _log_prompt(prompt: LanguageModelInput) -> None:
|
|
||||||
if isinstance(prompt, list):
|
|
||||||
for ind, msg in enumerate(prompt):
|
|
||||||
if isinstance(msg, AIMessageChunk):
|
|
||||||
if msg.content:
|
|
||||||
log_msg = msg.content
|
|
||||||
elif msg.tool_call_chunks:
|
|
||||||
log_msg = "Tool Calls: " + str(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
key: value
|
|
||||||
for key, value in tool_call.items()
|
|
||||||
if key != "index"
|
|
||||||
}
|
|
||||||
for tool_call in msg.tool_call_chunks
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
log_msg = ""
|
|
||||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
|
||||||
if isinstance(prompt, str):
|
|
||||||
logger.debug(f"Prompt:\n{prompt}")
|
|
||||||
|
|
||||||
def log_model_configs(self) -> None:
|
def log_model_configs(self) -> None:
|
||||||
logger.info(f"Config: {self.config}")
|
logger.info(f"Config: {self.config}")
|
||||||
|
|
||||||
@ -311,7 +285,7 @@ class DefaultMultiLLM(LLM):
|
|||||||
api_version=self._api_version,
|
api_version=self._api_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(
|
def _invoke_implementation(
|
||||||
self,
|
self,
|
||||||
prompt: LanguageModelInput,
|
prompt: LanguageModelInput,
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict] | None = None,
|
||||||
@ -319,7 +293,6 @@ class DefaultMultiLLM(LLM):
|
|||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||||
self.log_model_configs()
|
self.log_model_configs()
|
||||||
self._log_prompt(prompt)
|
|
||||||
|
|
||||||
response = cast(
|
response = cast(
|
||||||
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
|
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
|
||||||
@ -328,7 +301,7 @@ class DefaultMultiLLM(LLM):
|
|||||||
response.choices[0].message
|
response.choices[0].message
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream(
|
def _stream_implementation(
|
||||||
self,
|
self,
|
||||||
prompt: LanguageModelInput,
|
prompt: LanguageModelInput,
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict] | None = None,
|
||||||
@ -336,7 +309,6 @@ class DefaultMultiLLM(LLM):
|
|||||||
) -> Iterator[BaseMessage]:
|
) -> Iterator[BaseMessage]:
|
||||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||||
self.log_model_configs()
|
self.log_model_configs()
|
||||||
self._log_prompt(prompt)
|
|
||||||
|
|
||||||
if DISABLE_LITELLM_STREAMING:
|
if DISABLE_LITELLM_STREAMING:
|
||||||
yield self.invoke(prompt)
|
yield self.invoke(prompt)
|
||||||
|
@ -76,7 +76,7 @@ class CustomModelServer(LLM):
|
|||||||
def log_model_configs(self) -> None:
|
def log_model_configs(self) -> None:
|
||||||
logger.debug(f"Custom model at: {self._endpoint}")
|
logger.debug(f"Custom model at: {self._endpoint}")
|
||||||
|
|
||||||
def invoke(
|
def _invoke_implementation(
|
||||||
self,
|
self,
|
||||||
prompt: LanguageModelInput,
|
prompt: LanguageModelInput,
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict] | None = None,
|
||||||
@ -84,7 +84,7 @@ class CustomModelServer(LLM):
|
|||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
return self._execute(prompt)
|
return self._execute(prompt)
|
||||||
|
|
||||||
def stream(
|
def _stream_implementation(
|
||||||
self,
|
self,
|
||||||
prompt: LanguageModelInput,
|
prompt: LanguageModelInput,
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict] | None = None,
|
||||||
|
@ -3,9 +3,12 @@ from collections.abc import Iterator
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from langchain.schema.language_model import LanguageModelInput
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
|
from langchain_core.messages import AIMessageChunk
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
|
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
@ -23,6 +26,32 @@ class LLMConfig(BaseModel):
|
|||||||
api_version: str | None
|
api_version: str | None
|
||||||
|
|
||||||
|
|
||||||
|
def log_prompt(prompt: LanguageModelInput) -> None:
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
for ind, msg in enumerate(prompt):
|
||||||
|
if isinstance(msg, AIMessageChunk):
|
||||||
|
if msg.content:
|
||||||
|
log_msg = msg.content
|
||||||
|
elif msg.tool_call_chunks:
|
||||||
|
log_msg = "Tool Calls: " + str(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
key: value
|
||||||
|
for key, value in tool_call.items()
|
||||||
|
if key != "index"
|
||||||
|
}
|
||||||
|
for tool_call in msg.tool_call_chunks
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_msg = ""
|
||||||
|
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
logger.debug(f"Prompt:\n{prompt}")
|
||||||
|
|
||||||
|
|
||||||
class LLM(abc.ABC):
|
class LLM(abc.ABC):
|
||||||
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
||||||
to use these implementations to connect to a variety of LLM providers."""
|
to use these implementations to connect to a variety of LLM providers."""
|
||||||
@ -45,20 +74,48 @@ class LLM(abc.ABC):
|
|||||||
def log_model_configs(self) -> None:
|
def log_model_configs(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
def _precall(self, prompt: LanguageModelInput) -> None:
|
||||||
|
if DISABLE_GENERATIVE_AI:
|
||||||
|
raise Exception("Generative AI is disabled")
|
||||||
|
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||||
|
log_prompt(prompt)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
prompt: LanguageModelInput,
|
prompt: LanguageModelInput,
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict] | None = None,
|
||||||
tool_choice: ToolChoiceOptions | None = None,
|
tool_choice: ToolChoiceOptions | None = None,
|
||||||
|
) -> BaseMessage:
|
||||||
|
self._precall(prompt)
|
||||||
|
# TODO add a postcall to log model outputs independent of concrete class
|
||||||
|
# implementation
|
||||||
|
return self._invoke_implementation(prompt, tools, tool_choice)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _invoke_implementation(
|
||||||
|
self,
|
||||||
|
prompt: LanguageModelInput,
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
tool_choice: ToolChoiceOptions | None = None,
|
||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
prompt: LanguageModelInput,
|
prompt: LanguageModelInput,
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict] | None = None,
|
||||||
tool_choice: ToolChoiceOptions | None = None,
|
tool_choice: ToolChoiceOptions | None = None,
|
||||||
|
) -> Iterator[BaseMessage]:
|
||||||
|
self._precall(prompt)
|
||||||
|
# TODO add a postcall to log model outputs independent of concrete class
|
||||||
|
# implementation
|
||||||
|
return self._stream_implementation(prompt, tools, tool_choice)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _stream_implementation(
|
||||||
|
self,
|
||||||
|
prompt: LanguageModelInput,
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
tool_choice: ToolChoiceOptions | None = None,
|
||||||
) -> Iterator[BaseMessage]:
|
) -> Iterator[BaseMessage]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
Reference in New Issue
Block a user