diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index 46f07e176..dcc6e6fc5 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -236,7 +236,8 @@ class QABlock(QAModel): def warm_up_model(self) -> None: """This is called during server start up to load the models into memory in case the chosen LLM is not accessed via API""" - self._llm.invoke("Ignore this!") + if self._llm.requires_warm_up: + self._llm.invoke("Ignore this!") def answer_question( self, diff --git a/backend/danswer/llm/google_colab_demo.py b/backend/danswer/llm/google_colab_demo.py index 9d409002f..d1bdcf390 100644 --- a/backend/danswer/llm/google_colab_demo.py +++ b/backend/danswer/llm/google_colab_demo.py @@ -45,8 +45,8 @@ class GoogleColabDemo(LLM): response.raise_for_status() return json.loads(response.content).get("generated_text", "") - def invoke(self, input: LanguageModelInput) -> str: - return self._execute(input) + def invoke(self, prompt: LanguageModelInput) -> str: + return self._execute(prompt) - def stream(self, input: LanguageModelInput) -> Iterator[str]: - yield self._execute(input) + def stream(self, prompt: LanguageModelInput) -> Iterator[str]: + yield self._execute(prompt) diff --git a/backend/danswer/llm/llm.py b/backend/danswer/llm/llm.py index 985994769..c008fab43 100644 --- a/backend/danswer/llm/llm.py +++ b/backend/danswer/llm/llm.py @@ -15,12 +15,17 @@ class LLM(abc.ABC): """Mimics the LangChain LLM / BaseChatModel interfaces to make it easy to use these implementations to connect to a variety of LLM providers.""" + @property + def requires_warm_up(self) -> bool: + """Is this model running in memory and needs an initial call to warm it up?""" + return False + @abc.abstractmethod - def invoke(self, input: LanguageModelInput) -> str: + def invoke(self, prompt: LanguageModelInput) -> str: raise NotImplementedError @abc.abstractmethod - def stream(self, input: LanguageModelInput) -> Iterator[str]: + def stream(self, prompt: LanguageModelInput) -> Iterator[str]: raise NotImplementedError @@ -35,10 +40,10 @@ class LangChainChatLLM(LLM, abc.ABC): f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}" ) - def invoke(self, input: LanguageModelInput) -> str: + def invoke(self, prompt: LanguageModelInput) -> str: self._log_model_config() - return self.llm.invoke(input).content + return self.llm.invoke(prompt).content - def stream(self, input: LanguageModelInput) -> Iterator[str]: + def stream(self, prompt: LanguageModelInput) -> Iterator[str]: self._log_model_config() - yield from message_generator_to_string_generator(self.llm.stream(input)) + yield from message_generator_to_string_generator(self.llm.stream(prompt))