danswer/backend/onyx/utils/threadpool_concurrency.py
2025-03-25 17:40:12 +00:00

357 lines
11 KiB
Python

import collections.abc
import contextvars
import copy
import threading
import uuid
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import MutableMapping
from collections.abc import Sequence
from concurrent.futures import as_completed
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
from typing import Any
from typing import cast
from typing import Generic
from typing import overload
from typing import Protocol
from typing import TypeVar
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from onyx.utils.logger import setup_logger
logger = setup_logger()
R = TypeVar("R")
KT = TypeVar("KT") # Key type
VT = TypeVar("VT") # Value type
_T = TypeVar("_T") # Default type
class ThreadSafeDict(MutableMapping[KT, VT]):
"""
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
Implements the MutableMapping interface to provide a complete dictionary-like interface.
Example usage:
# Create a thread-safe dictionary
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
# Basic operations (atomic)
safe_dict["key"] = 1
value = safe_dict["key"]
del safe_dict["key"]
# Bulk operations (atomic)
safe_dict.update({"key1": 1, "key2": 2})
"""
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
self._dict: dict[KT, VT] = input_dict or {}
self.lock = threading.Lock()
def __getitem__(self, key: KT) -> VT:
with self.lock:
return self._dict[key]
def __setitem__(self, key: KT, value: VT) -> None:
with self.lock:
self._dict[key] = value
def __delitem__(self, key: KT) -> None:
with self.lock:
del self._dict[key]
def __iter__(self) -> Iterator[KT]:
# Return a snapshot of keys to avoid potential modification during iteration
with self.lock:
return iter(list(self._dict.keys()))
def __len__(self) -> int:
with self.lock:
return len(self._dict)
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
cls.validate, handler(dict[KT, VT])
)
@classmethod
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
if isinstance(v, dict):
return ThreadSafeDict(v)
return v
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
return ThreadSafeDict(copy.deepcopy(self._dict))
def clear(self) -> None:
"""Remove all items from the dictionary atomically."""
with self.lock:
self._dict.clear()
def copy(self) -> dict[KT, VT]:
"""Return a shallow copy of the dictionary atomically."""
with self.lock:
return self._dict.copy()
@overload
def get(self, key: KT) -> VT | None:
...
@overload
def get(self, key: KT, default: VT | _T) -> VT | _T:
...
def get(self, key: KT, default: Any = None) -> Any:
"""Get a value with a default, atomically."""
with self.lock:
return self._dict.get(key, default)
def pop(self, key: KT, default: Any = None) -> Any:
"""Remove and return a value with optional default, atomically."""
with self.lock:
if default is None:
return self._dict.pop(key)
return self._dict.pop(key, default)
def setdefault(self, key: KT, default: VT) -> VT:
"""Set a default value if key is missing, atomically."""
with self.lock:
return self._dict.setdefault(key, default)
def update(self, *args: Any, **kwargs: VT) -> None:
"""Update the dictionary atomically from another mapping or from kwargs."""
with self.lock:
self._dict.update(*args, **kwargs)
def items(self) -> collections.abc.ItemsView[KT, VT]:
"""Return a view of (key, value) pairs atomically."""
with self.lock:
return collections.abc.ItemsView(self)
def keys(self) -> collections.abc.KeysView[KT]:
"""Return a view of keys atomically."""
with self.lock:
return collections.abc.KeysView(self)
def values(self) -> collections.abc.ValuesView[VT]:
"""Return a view of values atomically."""
with self.lock:
return collections.abc.ValuesView(self)
class CallableProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def run_functions_tuples_in_parallel(
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
allow_failures: bool = False,
max_workers: int | None = None,
) -> list[Any]:
"""
Executes multiple functions in parallel and returns a list of the results for each function.
This function preserves contextvars across threads, which is important for maintaining
context like tenant IDs in database sessions.
Args:
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
allow_failures: if set to True, then the function result will just be None
max_workers: Max number of worker threads
Returns:
list: A list of results from each function, in the same order as the input functions.
"""
workers = (
min(max_workers, len(functions_with_args))
if max_workers is not None
else len(functions_with_args)
)
if workers <= 0:
return []
results = []
with ThreadPoolExecutor(max_workers=workers) as executor:
# The primary reason for propagating contextvars is to allow acquiring a db session
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
# find that it is increasing latency we can make using it optional.
future_to_index = {
executor.submit(contextvars.copy_context().run, func, *args): i
for i, (func, args) in enumerate(functions_with_args)
}
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
results.append((index, future.result()))
except Exception as e:
logger.exception(f"Function at index {index} failed due to {e}")
results.append((index, None)) # type: ignore
if not allow_failures:
raise
results.sort(key=lambda x: x[0])
return [result for index, result in results]
class FunctionCall(Generic[R]):
"""
Container for run_functions_in_parallel, fetch the results from the output of
run_functions_in_parallel via the FunctionCall.result_id.
"""
def __init__(
self, func: Callable[..., R], args: tuple = (), kwargs: dict | None = None
):
self.func = func
self.args = args
self.kwargs = kwargs if kwargs is not None else {}
self.result_id = str(uuid.uuid4())
def execute(self) -> R:
return self.func(*self.args, **self.kwargs)
def run_functions_in_parallel(
function_calls: list[FunctionCall],
allow_failures: bool = False,
) -> dict[str, Any]:
"""
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
are the result_id of the FunctionCall and the values are the results of the call.
"""
results: dict[str, Any] = {}
if len(function_calls) == 0:
return results
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
future_to_id = {
executor.submit(
contextvars.copy_context().run, func_call.execute
): func_call.result_id
for func_call in function_calls
}
for future in as_completed(future_to_id):
result_id = future_to_id[future]
try:
results[result_id] = future.result()
except Exception as e:
logger.exception(f"Function with ID {result_id} failed due to {e}")
results[result_id] = None
if not allow_failures:
raise
return results
class TimeoutThread(threading.Thread, Generic[R]):
def __init__(
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
):
super().__init__()
self.timeout = timeout
self.func = func
self.args = args
self.kwargs = kwargs
self.exception: Exception | None = None
def run(self) -> None:
try:
self.result = self.func(*self.args, **self.kwargs)
except Exception as e:
self.exception = e
def end(self) -> None:
raise TimeoutError(
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
)
def run_with_timeout(
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
) -> R:
"""
Executes a function with a timeout. If the function doesn't complete within the specified
timeout, raises TimeoutError.
"""
context = contextvars.copy_context()
task = TimeoutThread(timeout, context.run, func, *args, **kwargs)
task.start()
task.join(timeout)
if task.exception is not None:
raise task.exception
if task.is_alive():
task.end()
return task.result # type: ignore
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
# difficult to use. It's up to the programmer to call wait_on_background on the thread after
# the code you want to run in parallel is finished. As with all python thread parallelism,
# this is only useful for I/O bound tasks.
def run_in_background(
func: Callable[..., R], *args: Any, **kwargs: Any
) -> TimeoutThread[R]:
"""
Runs a function in a background thread. Returns a TimeoutThread object that can be used
to wait for the function to finish with wait_on_background.
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
task.start()
return cast(TimeoutThread[R], task)
def wait_on_background(task: TimeoutThread[R]) -> R:
"""
Used in conjunction with run_in_background. blocks until the task is finished,
then returns the result of the task.
"""
task.join()
if task.exception is not None:
raise task.exception
return task.result
def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]:
return ind, next(g, None)
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index: dict[Future[tuple[int, R | None]], int] = {
executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens)
}
next_ind = len(gens)
while future_to_index:
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
for future in done:
ind, result = future.result()
if result is not None:
yield result
future_to_index[
executor.submit(_next_or_none, ind, gens[ind])
] = next_ind
next_ind += 1
del future_to_index[future]