major Agent Search Updates (#3994)

This commit is contained in:
joachim-danswer
2025-02-14 11:40:21 -08:00
committed by GitHub
parent ec78f78f3c
commit 6687d5d499
36 changed files with 2115 additions and 431 deletions

View File

@@ -52,6 +52,18 @@ litellm.telemetry = False
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
class LLMTimeoutError(Exception):
"""
Exception raised when an LLM call times out.
"""
class LLMRateLimitError(Exception):
"""
Exception raised when an LLM call is rate limited.
"""
def _base_msg_to_role(msg: BaseMessage) -> str:
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
return "user"
@@ -389,6 +401,7 @@ class DefaultMultiLLM(LLM):
tool_choice: ToolChoiceOptions | None,
stream: bool,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
# to a dict representation
@@ -419,7 +432,7 @@ class DefaultMultiLLM(LLM):
stream=stream,
# model params
temperature=0,
timeout=self._timeout,
timeout=timeout_override or self._timeout,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -438,6 +451,12 @@ class DefaultMultiLLM(LLM):
except Exception as e:
self._record_error(processed_prompt, e)
# for break pointing
if isinstance(e, litellm.Timeout):
raise LLMTimeoutError(e)
elif isinstance(e, litellm.RateLimitError):
raise LLMRateLimitError(e)
raise e
@property
@@ -458,6 +477,7 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -465,7 +485,12 @@ class DefaultMultiLLM(LLM):
response = cast(
litellm.ModelResponse,
self._completion(
prompt, tools, tool_choice, False, structured_response_format
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
),
)
choice = response.choices[0]
@@ -483,19 +508,31 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
yield self.invoke(
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
)
return
output = None
response = cast(
litellm.CustomStreamWrapper,
self._completion(
prompt, tools, tool_choice, True, structured_response_format
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
),
)
try: