Fix LLM warm up ()

This commit is contained in:
Yuhong Sun 2023-09-11 14:47:36 -07:00 committed by GitHub
parent 9316b78f47
commit 101ff2f392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 11 deletions
backend/danswer

@ -236,7 +236,8 @@ class QABlock(QAModel):
def warm_up_model(self) -> None:
"""This is called during server start up to load the models into memory
in case the chosen LLM is not accessed via API"""
self._llm.invoke("Ignore this!")
if self._llm.requires_warm_up:
self._llm.invoke("Ignore this!")
def answer_question(
self,

@ -45,8 +45,8 @@ class GoogleColabDemo(LLM):
response.raise_for_status()
return json.loads(response.content).get("generated_text", "")
def invoke(self, input: LanguageModelInput) -> str:
return self._execute(input)
def invoke(self, prompt: LanguageModelInput) -> str:
return self._execute(prompt)
def stream(self, input: LanguageModelInput) -> Iterator[str]:
yield self._execute(input)
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
yield self._execute(prompt)

@ -15,12 +15,17 @@ class LLM(abc.ABC):
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
to use these implementations to connect to a variety of LLM providers."""
@property
def requires_warm_up(self) -> bool:
"""Is this model running in memory and needs an initial call to warm it up?"""
return False
@abc.abstractmethod
def invoke(self, input: LanguageModelInput) -> str:
def invoke(self, prompt: LanguageModelInput) -> str:
raise NotImplementedError
@abc.abstractmethod
def stream(self, input: LanguageModelInput) -> Iterator[str]:
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
raise NotImplementedError
@ -35,10 +40,10 @@ class LangChainChatLLM(LLM, abc.ABC):
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
)
def invoke(self, input: LanguageModelInput) -> str:
def invoke(self, prompt: LanguageModelInput) -> str:
self._log_model_config()
return self.llm.invoke(input).content
return self.llm.invoke(prompt).content
def stream(self, input: LanguageModelInput) -> Iterator[str]:
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
self._log_model_config()
yield from message_generator_to_string_generator(self.llm.stream(input))
yield from message_generator_to_string_generator(self.llm.stream(prompt))