mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-12 09:00:53 +02:00
Fix LLM warm up (#433)
This commit is contained in:
parent
9316b78f47
commit
101ff2f392
@ -236,6 +236,7 @@ class QABlock(QAModel):
|
|||||||
def warm_up_model(self) -> None:
|
def warm_up_model(self) -> None:
|
||||||
"""This is called during server start up to load the models into memory
|
"""This is called during server start up to load the models into memory
|
||||||
in case the chosen LLM is not accessed via API"""
|
in case the chosen LLM is not accessed via API"""
|
||||||
|
if self._llm.requires_warm_up:
|
||||||
self._llm.invoke("Ignore this!")
|
self._llm.invoke("Ignore this!")
|
||||||
|
|
||||||
def answer_question(
|
def answer_question(
|
||||||
|
@ -45,8 +45,8 @@ class GoogleColabDemo(LLM):
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return json.loads(response.content).get("generated_text", "")
|
return json.loads(response.content).get("generated_text", "")
|
||||||
|
|
||||||
def invoke(self, input: LanguageModelInput) -> str:
|
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||||
return self._execute(input)
|
return self._execute(prompt)
|
||||||
|
|
||||||
def stream(self, input: LanguageModelInput) -> Iterator[str]:
|
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||||
yield self._execute(input)
|
yield self._execute(prompt)
|
||||||
|
@ -15,12 +15,17 @@ class LLM(abc.ABC):
|
|||||||
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
||||||
to use these implementations to connect to a variety of LLM providers."""
|
to use these implementations to connect to a variety of LLM providers."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_warm_up(self) -> bool:
|
||||||
|
"""Is this model running in memory and needs an initial call to warm it up?"""
|
||||||
|
return False
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def invoke(self, input: LanguageModelInput) -> str:
|
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def stream(self, input: LanguageModelInput) -> Iterator[str]:
|
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -35,10 +40,10 @@ class LangChainChatLLM(LLM, abc.ABC):
|
|||||||
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, input: LanguageModelInput) -> str:
|
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||||
self._log_model_config()
|
self._log_model_config()
|
||||||
return self.llm.invoke(input).content
|
return self.llm.invoke(prompt).content
|
||||||
|
|
||||||
def stream(self, input: LanguageModelInput) -> Iterator[str]:
|
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||||
self._log_model_config()
|
self._log_model_config()
|
||||||
yield from message_generator_to_string_generator(self.llm.stream(input))
|
yield from message_generator_to_string_generator(self.llm.stream(prompt))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user