mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-02 11:09:20 +02:00
* added timeouts for agent llm calls * timing suggestions in agent config * improved timeout that actually exits early * added new global timeout and connection timeout distinction * fixed error raising bug and made entity extraction recoverable * warnings and refactor * mypy --------- Co-authored-by: joachim-danswer <joachim@danswer.ai>
167 lines
5.5 KiB
Python
167 lines
5.5 KiB
Python
import threading
|
|
import uuid
|
|
from collections.abc import Callable
|
|
from concurrent.futures import as_completed
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Any
|
|
from typing import Generic
|
|
from typing import TypeVar
|
|
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
R = TypeVar("R")
|
|
|
|
|
|
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
|
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
|
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
|
# is not safe, update this comment.
|
|
def run_functions_tuples_in_parallel(
|
|
functions_with_args: list[tuple[Callable, tuple]],
|
|
allow_failures: bool = False,
|
|
max_workers: int | None = None,
|
|
) -> list[Any]:
|
|
"""
|
|
Executes multiple functions in parallel and returns a list of the results for each function.
|
|
|
|
Args:
|
|
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
|
|
allow_failures: if set to True, then the function result will just be None
|
|
max_workers: Max number of worker threads
|
|
|
|
Returns:
|
|
dict: A dictionary mapping function names to their results or error messages.
|
|
"""
|
|
workers = (
|
|
min(max_workers, len(functions_with_args))
|
|
if max_workers is not None
|
|
else len(functions_with_args)
|
|
)
|
|
|
|
if workers <= 0:
|
|
return []
|
|
|
|
results = []
|
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
future_to_index = {
|
|
executor.submit(func, *args): i
|
|
for i, (func, args) in enumerate(functions_with_args)
|
|
}
|
|
|
|
for future in as_completed(future_to_index):
|
|
index = future_to_index[future]
|
|
try:
|
|
results.append((index, future.result()))
|
|
except Exception as e:
|
|
logger.exception(f"Function at index {index} failed due to {e}")
|
|
results.append((index, None))
|
|
|
|
if not allow_failures:
|
|
raise
|
|
|
|
results.sort(key=lambda x: x[0])
|
|
return [result for index, result in results]
|
|
|
|
|
|
class FunctionCall(Generic[R]):
|
|
"""
|
|
Container for run_functions_in_parallel, fetch the results from the output of
|
|
run_functions_in_parallel via the FunctionCall.result_id.
|
|
"""
|
|
|
|
def __init__(
|
|
self, func: Callable[..., R], args: tuple = (), kwargs: dict | None = None
|
|
):
|
|
self.func = func
|
|
self.args = args
|
|
self.kwargs = kwargs if kwargs is not None else {}
|
|
self.result_id = str(uuid.uuid4())
|
|
|
|
def execute(self) -> R:
|
|
return self.func(*self.args, **self.kwargs)
|
|
|
|
|
|
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
|
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
|
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
|
# is not safe, update this comment.
|
|
def run_functions_in_parallel(
|
|
function_calls: list[FunctionCall],
|
|
allow_failures: bool = False,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
|
|
are the result_id of the FunctionCall and the values are the results of the call.
|
|
"""
|
|
results: dict[str, Any] = {}
|
|
|
|
if len(function_calls) == 0:
|
|
return results
|
|
|
|
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
|
|
future_to_id = {
|
|
executor.submit(func_call.execute): func_call.result_id
|
|
for func_call in function_calls
|
|
}
|
|
|
|
for future in as_completed(future_to_id):
|
|
result_id = future_to_id[future]
|
|
try:
|
|
results[result_id] = future.result()
|
|
except Exception as e:
|
|
logger.exception(f"Function with ID {result_id} failed due to {e}")
|
|
results[result_id] = None
|
|
|
|
if not allow_failures:
|
|
raise
|
|
|
|
return results
|
|
|
|
|
|
class TimeoutThread(threading.Thread):
|
|
def __init__(
|
|
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
|
):
|
|
super().__init__()
|
|
self.timeout = timeout
|
|
self.func = func
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
self.exception: Exception | None = None
|
|
|
|
def run(self) -> None:
|
|
try:
|
|
self.result = self.func(*self.args, **self.kwargs)
|
|
except Exception as e:
|
|
self.exception = e
|
|
|
|
def end(self) -> None:
|
|
raise TimeoutError(
|
|
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
|
|
)
|
|
|
|
|
|
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
|
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
|
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
|
# is not safe, update this comment.
|
|
def run_with_timeout(
|
|
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
|
) -> R:
|
|
"""
|
|
Executes a function with a timeout. If the function doesn't complete within the specified
|
|
timeout, raises TimeoutError.
|
|
"""
|
|
task = TimeoutThread(timeout, func, *args, **kwargs)
|
|
task.start()
|
|
task.join(timeout)
|
|
|
|
if task.exception is not None:
|
|
raise task.exception
|
|
if task.is_alive():
|
|
task.end()
|
|
|
|
return task.result
|