diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index 90ad481d453..632e09c5118 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -232,32 +232,6 @@ class DefaultMultiLLM(LLM): self._model_kwargs = model_kwargs - @staticmethod - def _log_prompt(prompt: LanguageModelInput) -> None: - if isinstance(prompt, list): - for ind, msg in enumerate(prompt): - if isinstance(msg, AIMessageChunk): - if msg.content: - log_msg = msg.content - elif msg.tool_call_chunks: - log_msg = "Tool Calls: " + str( - [ - { - key: value - for key, value in tool_call.items() - if key != "index" - } - for tool_call in msg.tool_call_chunks - ] - ) - else: - log_msg = "" - logger.debug(f"Message {ind}:\n{log_msg}") - else: - logger.debug(f"Message {ind}:\n{msg.content}") - if isinstance(prompt, str): - logger.debug(f"Prompt:\n{prompt}") - def log_model_configs(self) -> None: logger.info(f"Config: {self.config}") @@ -311,7 +285,7 @@ class DefaultMultiLLM(LLM): api_version=self._api_version, ) - def invoke( + def _invoke_implementation( self, prompt: LanguageModelInput, tools: list[dict] | None = None, @@ -319,7 +293,6 @@ class DefaultMultiLLM(LLM): ) -> BaseMessage: if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() - self._log_prompt(prompt) response = cast( litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False) @@ -328,7 +301,7 @@ class DefaultMultiLLM(LLM): response.choices[0].message ) - def stream( + def _stream_implementation( self, prompt: LanguageModelInput, tools: list[dict] | None = None, @@ -336,7 +309,6 @@ class DefaultMultiLLM(LLM): ) -> Iterator[BaseMessage]: if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() - self._log_prompt(prompt) if DISABLE_LITELLM_STREAMING: yield self.invoke(prompt) diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py index 2c4c029aa2d..da71e0e5b65 100644 --- a/backend/danswer/llm/custom_llm.py +++ b/backend/danswer/llm/custom_llm.py @@ -76,7 +76,7 @@ class CustomModelServer(LLM): def log_model_configs(self) -> None: logger.debug(f"Custom model at: {self._endpoint}") - def invoke( + def _invoke_implementation( self, prompt: LanguageModelInput, tools: list[dict] | None = None, @@ -84,7 +84,7 @@ class CustomModelServer(LLM): ) -> BaseMessage: return self._execute(prompt) - def stream( + def _stream_implementation( self, prompt: LanguageModelInput, tools: list[dict] | None = None, diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index e876403c421..63bd45ba76b 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -3,9 +3,12 @@ from collections.abc import Iterator from typing import Literal from langchain.schema.language_model import LanguageModelInput +from langchain_core.messages import AIMessageChunk from langchain_core.messages import BaseMessage from pydantic import BaseModel +from danswer.configs.app_configs import DISABLE_GENERATIVE_AI +from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS from danswer.utils.logger import setup_logger @@ -23,6 +26,32 @@ class LLMConfig(BaseModel): api_version: str | None +def log_prompt(prompt: LanguageModelInput) -> None: + if isinstance(prompt, list): + for ind, msg in enumerate(prompt): + if isinstance(msg, AIMessageChunk): + if msg.content: + log_msg = msg.content + elif msg.tool_call_chunks: + log_msg = "Tool Calls: " + str( + [ + { + key: value + for key, value in tool_call.items() + if key != "index" + } + for tool_call in msg.tool_call_chunks + ] + ) + else: + log_msg = "" + logger.debug(f"Message {ind}:\n{log_msg}") + else: + logger.debug(f"Message {ind}:\n{msg.content}") + if isinstance(prompt, str): + logger.debug(f"Prompt:\n{prompt}") + + class LLM(abc.ABC): """Mimics the LangChain LLM / BaseChatModel interfaces to make it easy to use these implementations to connect to a variety of LLM providers.""" @@ -45,20 +74,48 @@ class LLM(abc.ABC): def log_model_configs(self) -> None: raise NotImplementedError - @abc.abstractmethod + def _precall(self, prompt: LanguageModelInput) -> None: + if DISABLE_GENERATIVE_AI: + raise Exception("Generative AI is disabled") + if LOG_DANSWER_MODEL_INTERACTIONS: + log_prompt(prompt) + def invoke( self, prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + ) -> BaseMessage: + self._precall(prompt) + # TODO add a postcall to log model outputs independent of concrete class + # implementation + return self._invoke_implementation(prompt, tools, tool_choice) + + @abc.abstractmethod + def _invoke_implementation( + self, + prompt: LanguageModelInput, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, ) -> BaseMessage: raise NotImplementedError - @abc.abstractmethod def stream( self, prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + ) -> Iterator[BaseMessage]: + self._precall(prompt) + # TODO add a postcall to log model outputs independent of concrete class + # implementation + return self._stream_implementation(prompt, tools, tool_choice) + + @abc.abstractmethod + def _stream_implementation( + self, + prompt: LanguageModelInput, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, ) -> Iterator[BaseMessage]: raise NotImplementedError