Enforce Disable GenAI if set (#1860)

This commit is contained in:
Yuhong Sun
2024-07-18 13:25:55 -07:00
committed by GitHub
parent a595d43ae3
commit 5230f7e22f
3 changed files with 63 additions and 34 deletions

View File

@ -232,32 +232,6 @@ class DefaultMultiLLM(LLM):
self._model_kwargs = model_kwargs
@staticmethod
def _log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
if isinstance(msg, AIMessageChunk):
if msg.content:
log_msg = msg.content
elif msg.tool_call_chunks:
log_msg = "Tool Calls: " + str(
[
{
key: value
for key, value in tool_call.items()
if key != "index"
}
for tool_call in msg.tool_call_chunks
]
)
else:
log_msg = ""
logger.debug(f"Message {ind}:\n{log_msg}")
else:
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
def log_model_configs(self) -> None:
logger.info(f"Config: {self.config}")
@ -311,7 +285,7 @@ class DefaultMultiLLM(LLM):
api_version=self._api_version,
)
def invoke(
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
@ -319,7 +293,6 @@ class DefaultMultiLLM(LLM):
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
self._log_prompt(prompt)
response = cast(
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
@ -328,7 +301,7 @@ class DefaultMultiLLM(LLM):
response.choices[0].message
)
def stream(
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
@ -336,7 +309,6 @@ class DefaultMultiLLM(LLM):
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
self._log_prompt(prompt)
if DISABLE_LITELLM_STREAMING:
yield self.invoke(prompt)

View File

@ -76,7 +76,7 @@ class CustomModelServer(LLM):
def log_model_configs(self) -> None:
logger.debug(f"Custom model at: {self._endpoint}")
def invoke(
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
@ -84,7 +84,7 @@ class CustomModelServer(LLM):
) -> BaseMessage:
return self._execute(prompt)
def stream(
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,

View File

@ -3,9 +3,12 @@ from collections.abc import Iterator
from typing import Literal
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.utils.logger import setup_logger
@ -23,6 +26,32 @@ class LLMConfig(BaseModel):
api_version: str | None
def log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
if isinstance(msg, AIMessageChunk):
if msg.content:
log_msg = msg.content
elif msg.tool_call_chunks:
log_msg = "Tool Calls: " + str(
[
{
key: value
for key, value in tool_call.items()
if key != "index"
}
for tool_call in msg.tool_call_chunks
]
)
else:
log_msg = ""
logger.debug(f"Message {ind}:\n{log_msg}")
else:
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
class LLM(abc.ABC):
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
to use these implementations to connect to a variety of LLM providers."""
@ -45,20 +74,48 @@ class LLM(abc.ABC):
def log_model_configs(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def _precall(self, prompt: LanguageModelInput) -> None:
if DISABLE_GENERATIVE_AI:
raise Exception("Generative AI is disabled")
if LOG_DANSWER_MODEL_INTERACTIONS:
log_prompt(prompt)
def invoke(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(prompt, tools, tool_choice)
@abc.abstractmethod
def _invoke_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> BaseMessage:
raise NotImplementedError
@abc.abstractmethod
def stream(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._stream_implementation(prompt, tools, tool_choice)
@abc.abstractmethod
def _stream_implementation(
self,
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError