mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 20:39:29 +02:00
thread utils respect contextvars (#4074)
* thread utils respect contextvars now * address pablo comments * removed tenant id from places it was already being passed * fix rate limit check and pablo comment
This commit is contained in:
parent
1f2af373e1
commit
4a4e4a6c50
@ -13,7 +13,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import TokenRateLimit
|
||||
@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
|
||||
def _check_token_rate_limits(user: User | None) -> None:
|
||||
if user is None:
|
||||
# Unauthenticated users are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
elif is_api_key_email_address(user.email):
|
||||
# API keys are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
else:
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(_user_is_rate_limited, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_global, (tenant_id,)),
|
||||
(_user_is_rate_limited, (user.id,)),
|
||||
(_user_is_rate_limited_by_group, (user.id,)),
|
||||
(_user_is_rate_limited_by_global, ()),
|
||||
]
|
||||
)
|
||||
|
||||
@ -52,8 +52,8 @@ User rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def _user_is_rate_limited(user_id: UUID) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
@ -93,8 +93,8 @@ User Group rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||
|
||||
if group_rate_limits:
|
||||
|
@ -870,7 +870,6 @@ def stream_chat_message_objects(
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
|
@ -8,7 +8,7 @@ import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@ -53,11 +53,11 @@ def load_all_chat_files(
|
||||
return files
|
||||
|
||||
|
||||
def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
def save_file_from_url(url: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
weird errors."""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
@ -75,8 +75,8 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
return unique_id
|
||||
|
||||
|
||||
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def save_file_from_base64(base64_string: str) -> str:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
unique_id = str(uuid4())
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
@ -90,14 +90,12 @@ def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
|
||||
|
||||
|
||||
def save_file(
|
||||
tenant_id: str,
|
||||
url: str | None = None,
|
||||
base64_data: str | None = None,
|
||||
) -> str:
|
||||
"""Save a file from either a URL or base64 encoded string.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to save the file under
|
||||
url: URL to download file from
|
||||
base64_data: Base64 encoded file data
|
||||
|
||||
@ -111,22 +109,22 @@ def save_file(
|
||||
raise ValueError("Cannot specify both url and base64_data")
|
||||
|
||||
if url is not None:
|
||||
return save_file_from_url(url, tenant_id)
|
||||
return save_file_from_url(url)
|
||||
elif base64_data is not None:
|
||||
return save_file_from_base64(base64_data, tenant_id)
|
||||
return save_file_from_base64(base64_data)
|
||||
else:
|
||||
raise ValueError("Must specify either url or base64_data")
|
||||
|
||||
|
||||
def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]:
|
||||
def save_files(urls: list[str], base64_files: list[str]) -> list[str]:
|
||||
# NOTE: be explicit about typing so that if we change things, we get notified
|
||||
funcs: list[
|
||||
tuple[
|
||||
Callable[[str, str | None, str | None], str],
|
||||
tuple[str, str | None, str | None],
|
||||
Callable[[str | None, str | None], str],
|
||||
tuple[str | None, str | None],
|
||||
]
|
||||
] = [(save_file, (tenant_id, url, None)) for url in urls] + [
|
||||
(save_file, (tenant_id, None, base64_file)) for base64_file in base64_files
|
||||
] = [(save_file, (url, None)) for url in urls] + [
|
||||
(save_file, (None, base64_file)) for base64_file in base64_files
|
||||
]
|
||||
|
||||
return run_functions_tuples_in_parallel(funcs)
|
||||
|
@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import TokenRateLimit
|
||||
@ -21,7 +20,6 @@ from onyx.db.models import User
|
||||
from onyx.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -39,13 +37,13 @@ def check_token_rate_limits(
|
||||
return
|
||||
|
||||
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
||||
"onyx.server.query_and_chat.token_limit", "_check_token_rate_limits"
|
||||
"onyx.server.query_and_chat.token_limit", _check_token_rate_limits.__name__
|
||||
)
|
||||
return versioned_rate_limit_strategy(user, get_current_tenant_id())
|
||||
return versioned_rate_limit_strategy(user)
|
||||
|
||||
|
||||
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
def _check_token_rate_limits(_: User | None) -> None:
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
|
||||
"""
|
||||
@ -53,8 +51,8 @@ Global rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def _user_is_rate_limited_by_global() -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
global_rate_limits = fetch_all_global_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
@ -14,10 +15,6 @@ logger = setup_logger()
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
allow_failures: bool = False,
|
||||
@ -45,8 +42,11 @@ def run_functions_tuples_in_parallel(
|
||||
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# The primary reason for propagating contextvars is to allow acquiring a db session
|
||||
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
|
||||
# find that it is increasing latency we can make using it optional.
|
||||
future_to_index = {
|
||||
executor.submit(func, *args): i
|
||||
executor.submit(contextvars.copy_context().run, func, *args): i
|
||||
for i, (func, args) in enumerate(functions_with_args)
|
||||
}
|
||||
|
||||
@ -83,10 +83,6 @@ class FunctionCall(Generic[R]):
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_functions_in_parallel(
|
||||
function_calls: list[FunctionCall],
|
||||
allow_failures: bool = False,
|
||||
@ -102,7 +98,9 @@ def run_functions_in_parallel(
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
|
||||
future_to_id = {
|
||||
executor.submit(func_call.execute): func_call.result_id
|
||||
executor.submit(
|
||||
contextvars.copy_context().run, func_call.execute
|
||||
): func_call.result_id
|
||||
for func_call in function_calls
|
||||
}
|
||||
|
||||
@ -143,10 +141,6 @@ class TimeoutThread(threading.Thread):
|
||||
)
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_with_timeout(
|
||||
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
) -> R:
|
||||
@ -154,7 +148,8 @@ def run_with_timeout(
|
||||
Executes a function with a timeout. If the function doesn't complete within the specified
|
||||
timeout, raises TimeoutError.
|
||||
"""
|
||||
task = TimeoutThread(timeout, func, *args, **kwargs)
|
||||
context = contextvars.copy_context()
|
||||
task = TimeoutThread(timeout, context.run, func, *args, **kwargs)
|
||||
task.start()
|
||||
task.join(timeout)
|
||||
|
||||
|
131
backend/tests/unit/onyx/utils/test_threadpool_contextvars.py
Normal file
131
backend/tests/unit/onyx/utils/test_threadpool_contextvars.py
Normal file
@ -0,0 +1,131 @@
|
||||
import contextvars
|
||||
import time
|
||||
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
# Create a test contextvar
|
||||
test_var = contextvars.ContextVar("test_var", default="default")
|
||||
|
||||
|
||||
def get_contextvar_value() -> str:
|
||||
"""Helper function that runs in a thread and returns the contextvar value"""
|
||||
# Add a small sleep to ensure we're actually running in a different thread
|
||||
time.sleep(0.1)
|
||||
return test_var.get()
|
||||
|
||||
|
||||
def test_run_with_timeout_preserves_contextvar() -> None:
|
||||
"""Test that run_with_timeout preserves contextvar values"""
|
||||
# Set a value in the main thread
|
||||
test_var.set("test_value")
|
||||
|
||||
# Run function with timeout and verify the value is preserved
|
||||
result = run_with_timeout(1.0, get_contextvar_value)
|
||||
assert result == "test_value"
|
||||
|
||||
|
||||
def test_run_functions_in_parallel_preserves_contextvar() -> None:
|
||||
"""Test that run_functions_in_parallel preserves contextvar values"""
|
||||
# Set a value in the main thread
|
||||
test_var.set("parallel_test")
|
||||
|
||||
# Create multiple function calls
|
||||
function_calls = [
|
||||
FunctionCall(get_contextvar_value),
|
||||
FunctionCall(get_contextvar_value),
|
||||
]
|
||||
|
||||
# Run in parallel and verify all results have the correct value
|
||||
results = run_functions_in_parallel(function_calls)
|
||||
|
||||
for result_id, value in results.items():
|
||||
assert value == "parallel_test"
|
||||
|
||||
|
||||
def test_run_functions_tuples_preserves_contextvar() -> None:
|
||||
"""Test that run_functions_tuples_in_parallel preserves contextvar values"""
|
||||
# Set a value in the main thread
|
||||
test_var.set("tuple_test")
|
||||
|
||||
# Create list of function tuples
|
||||
functions_with_args = [
|
||||
(get_contextvar_value, ()),
|
||||
(get_contextvar_value, ()),
|
||||
]
|
||||
|
||||
# Run in parallel and verify all results have the correct value
|
||||
results = run_functions_tuples_in_parallel(functions_with_args)
|
||||
|
||||
for result in results:
|
||||
assert result == "tuple_test"
|
||||
|
||||
|
||||
def test_nested_contextvar_modifications() -> None:
|
||||
"""Test that modifications to contextvars in threads don't affect other threads"""
|
||||
|
||||
def modify_and_return_contextvar(new_value: str) -> tuple[str, str]:
|
||||
"""Helper that modifies the contextvar and returns both values"""
|
||||
original = test_var.get()
|
||||
test_var.set(new_value)
|
||||
time.sleep(0.1) # Ensure threads overlap
|
||||
return original, test_var.get()
|
||||
|
||||
# Set initial value
|
||||
test_var.set("initial")
|
||||
|
||||
# Run multiple functions that modify the contextvar
|
||||
functions_with_args = [
|
||||
(modify_and_return_contextvar, ("thread1",)),
|
||||
(modify_and_return_contextvar, ("thread2",)),
|
||||
]
|
||||
|
||||
results = run_functions_tuples_in_parallel(functions_with_args)
|
||||
|
||||
# Verify each thread saw the initial value and its own modification
|
||||
for original, modified in results:
|
||||
assert original == "initial" # Each thread should see the initial value
|
||||
assert modified in [
|
||||
"thread1",
|
||||
"thread2",
|
||||
] # Each thread should see its own modification
|
||||
|
||||
# Verify the main thread's value wasn't affected
|
||||
assert test_var.get() == "initial"
|
||||
|
||||
|
||||
def test_contextvar_isolation_between_runs() -> None:
|
||||
"""Test that contextvar changes don't leak between separate parallel runs"""
|
||||
|
||||
def set_and_return_contextvar(value: str) -> str:
|
||||
test_var.set(value)
|
||||
return test_var.get()
|
||||
|
||||
# First run
|
||||
test_var.set("first_run")
|
||||
first_results = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(set_and_return_contextvar, ("thread1",)),
|
||||
(set_and_return_contextvar, ("thread2",)),
|
||||
]
|
||||
)
|
||||
|
||||
# Verify first run results
|
||||
assert all(result in ["thread1", "thread2"] for result in first_results)
|
||||
|
||||
# Second run should still see the main thread's value
|
||||
assert test_var.get() == "first_run"
|
||||
|
||||
# Second run with different value
|
||||
test_var.set("second_run")
|
||||
second_results = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(set_and_return_contextvar, ("thread3",)),
|
||||
(set_and_return_contextvar, ("thread4",)),
|
||||
]
|
||||
)
|
||||
|
||||
# Verify second run results
|
||||
assert all(result in ["thread3", "thread4"] for result in second_results)
|
Loading…
x
Reference in New Issue
Block a user