mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-02 21:22:51 +02:00
major Agent Search Updates (#3994)
This commit is contained in:
@@ -52,6 +52,18 @@ litellm.telemetry = False
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
"""
|
||||
Exception raised when an LLM call times out.
|
||||
"""
|
||||
|
||||
|
||||
class LLMRateLimitError(Exception):
|
||||
"""
|
||||
Exception raised when an LLM call is rate limited.
|
||||
"""
|
||||
|
||||
|
||||
def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
|
||||
return "user"
|
||||
@@ -389,6 +401,7 @@ class DefaultMultiLLM(LLM):
|
||||
tool_choice: ToolChoiceOptions | None,
|
||||
stream: bool,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
|
||||
# to a dict representation
|
||||
@@ -419,7 +432,7 @@ class DefaultMultiLLM(LLM):
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=0,
|
||||
timeout=self._timeout,
|
||||
timeout=timeout_override or self._timeout,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
@@ -438,6 +451,12 @@ class DefaultMultiLLM(LLM):
|
||||
except Exception as e:
|
||||
self._record_error(processed_prompt, e)
|
||||
# for break pointing
|
||||
if isinstance(e, litellm.Timeout):
|
||||
raise LLMTimeoutError(e)
|
||||
|
||||
elif isinstance(e, litellm.RateLimitError):
|
||||
raise LLMRateLimitError(e)
|
||||
|
||||
raise e
|
||||
|
||||
@property
|
||||
@@ -458,6 +477,7 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
@@ -465,7 +485,12 @@ class DefaultMultiLLM(LLM):
|
||||
response = cast(
|
||||
litellm.ModelResponse,
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, False, structured_response_format
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
),
|
||||
)
|
||||
choice = response.choices[0]
|
||||
@@ -483,19 +508,31 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
yield self.invoke(
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
)
|
||||
return
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, True, structured_response_format
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
),
|
||||
)
|
||||
try:
|
||||
|
Reference in New Issue
Block a user