Fix LLM warm up (#433)

This commit is contained in:
Yuhong Sun 2023-09-11 14:47:36 -07:00 committed by GitHub
parent 9316b78f47
commit 101ff2f392
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 11 deletions

View File

@ -236,6 +236,7 @@ class QABlock(QAModel):
def warm_up_model(self) -> None: def warm_up_model(self) -> None:
"""This is called during server start up to load the models into memory """This is called during server start up to load the models into memory
in case the chosen LLM is not accessed via API""" in case the chosen LLM is not accessed via API"""
if self._llm.requires_warm_up:
self._llm.invoke("Ignore this!") self._llm.invoke("Ignore this!")
def answer_question( def answer_question(

View File

@ -45,8 +45,8 @@ class GoogleColabDemo(LLM):
response.raise_for_status() response.raise_for_status()
return json.loads(response.content).get("generated_text", "") return json.loads(response.content).get("generated_text", "")
def invoke(self, input: LanguageModelInput) -> str: def invoke(self, prompt: LanguageModelInput) -> str:
return self._execute(input) return self._execute(prompt)
def stream(self, input: LanguageModelInput) -> Iterator[str]: def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
yield self._execute(input) yield self._execute(prompt)

View File

@ -15,12 +15,17 @@ class LLM(abc.ABC):
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy """Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
to use these implementations to connect to a variety of LLM providers.""" to use these implementations to connect to a variety of LLM providers."""
@property
def requires_warm_up(self) -> bool:
"""Is this model running in memory and needs an initial call to warm it up?"""
return False
@abc.abstractmethod @abc.abstractmethod
def invoke(self, input: LanguageModelInput) -> str: def invoke(self, prompt: LanguageModelInput) -> str:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def stream(self, input: LanguageModelInput) -> Iterator[str]: def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
raise NotImplementedError raise NotImplementedError
@ -35,10 +40,10 @@ class LangChainChatLLM(LLM, abc.ABC):
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}" f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
) )
def invoke(self, input: LanguageModelInput) -> str: def invoke(self, prompt: LanguageModelInput) -> str:
self._log_model_config() self._log_model_config()
return self.llm.invoke(input).content return self.llm.invoke(prompt).content
def stream(self, input: LanguageModelInput) -> Iterator[str]: def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
self._log_model_config() self._log_model_config()
yield from message_generator_to_string_generator(self.llm.stream(input)) yield from message_generator_to_string_generator(self.llm.stream(prompt))