mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-25 04:13:25 +02:00
thread utils respect contextvars (#4074)
* thread utils respect contextvars now * address pablo comments * removed tenant id from places it was already being passed * fix rate limit check and pablo comment
This commit is contained in:
@@ -13,7 +13,7 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from onyx.db.api_key import is_api_key_email_address
|
from onyx.db.api_key import is_api_key_email_address
|
||||||
from onyx.db.engine import get_session_with_tenant
|
from onyx.db.engine import get_session_with_current_tenant
|
||||||
from onyx.db.models import ChatMessage
|
from onyx.db.models import ChatMessage
|
||||||
from onyx.db.models import ChatSession
|
from onyx.db.models import ChatSession
|
||||||
from onyx.db.models import TokenRateLimit
|
from onyx.db.models import TokenRateLimit
|
||||||
@@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
|
|||||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||||
|
|
||||||
|
|
||||||
def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
|
def _check_token_rate_limits(user: User | None) -> None:
|
||||||
if user is None:
|
if user is None:
|
||||||
# Unauthenticated users are only rate limited by global settings
|
# Unauthenticated users are only rate limited by global settings
|
||||||
_user_is_rate_limited_by_global(tenant_id)
|
_user_is_rate_limited_by_global()
|
||||||
|
|
||||||
elif is_api_key_email_address(user.email):
|
elif is_api_key_email_address(user.email):
|
||||||
# API keys are only rate limited by global settings
|
# API keys are only rate limited by global settings
|
||||||
_user_is_rate_limited_by_global(tenant_id)
|
_user_is_rate_limited_by_global()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
run_functions_tuples_in_parallel(
|
run_functions_tuples_in_parallel(
|
||||||
[
|
[
|
||||||
(_user_is_rate_limited, (user.id, tenant_id)),
|
(_user_is_rate_limited, (user.id,)),
|
||||||
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
|
(_user_is_rate_limited_by_group, (user.id,)),
|
||||||
(_user_is_rate_limited_by_global, (tenant_id,)),
|
(_user_is_rate_limited_by_global, ()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,8 +52,8 @@ User rate limits
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
|
def _user_is_rate_limited(user_id: UUID) -> None:
|
||||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
with get_session_with_current_tenant() as db_session:
|
||||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||||
db_session=db_session, enabled_only=True, ordered=False
|
db_session=db_session, enabled_only=True, ordered=False
|
||||||
)
|
)
|
||||||
@@ -93,8 +93,8 @@ User Group rate limits
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
|
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
|
||||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
with get_session_with_current_tenant() as db_session:
|
||||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||||
|
|
||||||
if group_rate_limits:
|
if group_rate_limits:
|
||||||
|
@@ -870,7 +870,6 @@ def stream_chat_message_objects(
|
|||||||
for img in img_generation_response
|
for img in img_generation_response
|
||||||
if img.image_data
|
if img.image_data
|
||||||
],
|
],
|
||||||
tenant_id=tenant_id,
|
|
||||||
)
|
)
|
||||||
info.ai_message_files.extend(
|
info.ai_message_files.extend(
|
||||||
[
|
[
|
||||||
|
@@ -8,7 +8,7 @@ import requests
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from onyx.configs.constants import FileOrigin
|
from onyx.configs.constants import FileOrigin
|
||||||
from onyx.db.engine import get_session_with_tenant
|
from onyx.db.engine import get_session_with_current_tenant
|
||||||
from onyx.db.models import ChatMessage
|
from onyx.db.models import ChatMessage
|
||||||
from onyx.file_store.file_store import get_default_file_store
|
from onyx.file_store.file_store import get_default_file_store
|
||||||
from onyx.file_store.models import FileDescriptor
|
from onyx.file_store.models import FileDescriptor
|
||||||
@@ -53,11 +53,11 @@ def load_all_chat_files(
|
|||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
def save_file_from_url(url: str, tenant_id: str) -> str:
|
def save_file_from_url(url: str) -> str:
|
||||||
"""NOTE: using multiple sessions here, since this is often called
|
"""NOTE: using multiple sessions here, since this is often called
|
||||||
using multithreading. In practice, sharing a session has resulted in
|
using multithreading. In practice, sharing a session has resulted in
|
||||||
weird errors."""
|
weird errors."""
|
||||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
with get_session_with_current_tenant() as db_session:
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -75,8 +75,8 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
|
|||||||
return unique_id
|
return unique_id
|
||||||
|
|
||||||
|
|
||||||
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
|
def save_file_from_base64(base64_string: str) -> str:
|
||||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
with get_session_with_current_tenant() as db_session:
|
||||||
unique_id = str(uuid4())
|
unique_id = str(uuid4())
|
||||||
file_store = get_default_file_store(db_session)
|
file_store = get_default_file_store(db_session)
|
||||||
file_store.save_file(
|
file_store.save_file(
|
||||||
@@ -90,14 +90,12 @@ def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def save_file(
|
def save_file(
|
||||||
tenant_id: str,
|
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
base64_data: str | None = None,
|
base64_data: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Save a file from either a URL or base64 encoded string.
|
"""Save a file from either a URL or base64 encoded string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tenant_id: The tenant ID to save the file under
|
|
||||||
url: URL to download file from
|
url: URL to download file from
|
||||||
base64_data: Base64 encoded file data
|
base64_data: Base64 encoded file data
|
||||||
|
|
||||||
@@ -111,22 +109,22 @@ def save_file(
|
|||||||
raise ValueError("Cannot specify both url and base64_data")
|
raise ValueError("Cannot specify both url and base64_data")
|
||||||
|
|
||||||
if url is not None:
|
if url is not None:
|
||||||
return save_file_from_url(url, tenant_id)
|
return save_file_from_url(url)
|
||||||
elif base64_data is not None:
|
elif base64_data is not None:
|
||||||
return save_file_from_base64(base64_data, tenant_id)
|
return save_file_from_base64(base64_data)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Must specify either url or base64_data")
|
raise ValueError("Must specify either url or base64_data")
|
||||||
|
|
||||||
|
|
||||||
def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]:
|
def save_files(urls: list[str], base64_files: list[str]) -> list[str]:
|
||||||
# NOTE: be explicit about typing so that if we change things, we get notified
|
# NOTE: be explicit about typing so that if we change things, we get notified
|
||||||
funcs: list[
|
funcs: list[
|
||||||
tuple[
|
tuple[
|
||||||
Callable[[str, str | None, str | None], str],
|
Callable[[str | None, str | None], str],
|
||||||
tuple[str, str | None, str | None],
|
tuple[str | None, str | None],
|
||||||
]
|
]
|
||||||
] = [(save_file, (tenant_id, url, None)) for url in urls] + [
|
] = [(save_file, (url, None)) for url in urls] + [
|
||||||
(save_file, (tenant_id, None, base64_file)) for base64_file in base64_files
|
(save_file, (None, base64_file)) for base64_file in base64_files
|
||||||
]
|
]
|
||||||
|
|
||||||
return run_functions_tuples_in_parallel(funcs)
|
return run_functions_tuples_in_parallel(funcs)
|
||||||
|
@@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from onyx.auth.users import current_chat_accesssible_user
|
from onyx.auth.users import current_chat_accesssible_user
|
||||||
from onyx.db.engine import get_session_context_manager
|
from onyx.db.engine import get_session_context_manager
|
||||||
from onyx.db.engine import get_session_with_tenant
|
|
||||||
from onyx.db.models import ChatMessage
|
from onyx.db.models import ChatMessage
|
||||||
from onyx.db.models import ChatSession
|
from onyx.db.models import ChatSession
|
||||||
from onyx.db.models import TokenRateLimit
|
from onyx.db.models import TokenRateLimit
|
||||||
@@ -21,7 +20,6 @@ from onyx.db.models import User
|
|||||||
from onyx.db.token_limit import fetch_all_global_token_rate_limits
|
from onyx.db.token_limit import fetch_all_global_token_rate_limits
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||||
from shared_configs.contextvars import get_current_tenant_id
|
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -39,13 +37,13 @@ def check_token_rate_limits(
|
|||||||
return
|
return
|
||||||
|
|
||||||
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
||||||
"onyx.server.query_and_chat.token_limit", "_check_token_rate_limits"
|
"onyx.server.query_and_chat.token_limit", _check_token_rate_limits.__name__
|
||||||
)
|
)
|
||||||
return versioned_rate_limit_strategy(user, get_current_tenant_id())
|
return versioned_rate_limit_strategy(user)
|
||||||
|
|
||||||
|
|
||||||
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
|
def _check_token_rate_limits(_: User | None) -> None:
|
||||||
_user_is_rate_limited_by_global(tenant_id)
|
_user_is_rate_limited_by_global()
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -53,8 +51,8 @@ Global rate limits
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
|
def _user_is_rate_limited_by_global() -> None:
|
||||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
global_rate_limits = fetch_all_global_token_rate_limits(
|
global_rate_limits = fetch_all_global_token_rate_limits(
|
||||||
db_session=db_session, enabled_only=True, ordered=False
|
db_session=db_session, enabled_only=True, ordered=False
|
||||||
)
|
)
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import contextvars
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
@@ -14,10 +15,6 @@ logger = setup_logger()
|
|||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
|
||||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
|
||||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
|
||||||
# is not safe, update this comment.
|
|
||||||
def run_functions_tuples_in_parallel(
|
def run_functions_tuples_in_parallel(
|
||||||
functions_with_args: list[tuple[Callable, tuple]],
|
functions_with_args: list[tuple[Callable, tuple]],
|
||||||
allow_failures: bool = False,
|
allow_failures: bool = False,
|
||||||
@@ -45,8 +42,11 @@ def run_functions_tuples_in_parallel(
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||||
|
# The primary reason for propagating contextvars is to allow acquiring a db session
|
||||||
|
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
|
||||||
|
# find that it is increasing latency we can make using it optional.
|
||||||
future_to_index = {
|
future_to_index = {
|
||||||
executor.submit(func, *args): i
|
executor.submit(contextvars.copy_context().run, func, *args): i
|
||||||
for i, (func, args) in enumerate(functions_with_args)
|
for i, (func, args) in enumerate(functions_with_args)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,10 +83,6 @@ class FunctionCall(Generic[R]):
|
|||||||
return self.func(*self.args, **self.kwargs)
|
return self.func(*self.args, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
|
||||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
|
||||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
|
||||||
# is not safe, update this comment.
|
|
||||||
def run_functions_in_parallel(
|
def run_functions_in_parallel(
|
||||||
function_calls: list[FunctionCall],
|
function_calls: list[FunctionCall],
|
||||||
allow_failures: bool = False,
|
allow_failures: bool = False,
|
||||||
@@ -102,7 +98,9 @@ def run_functions_in_parallel(
|
|||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
|
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
|
||||||
future_to_id = {
|
future_to_id = {
|
||||||
executor.submit(func_call.execute): func_call.result_id
|
executor.submit(
|
||||||
|
contextvars.copy_context().run, func_call.execute
|
||||||
|
): func_call.result_id
|
||||||
for func_call in function_calls
|
for func_call in function_calls
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,10 +141,6 @@ class TimeoutThread(threading.Thread):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
|
||||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
|
||||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
|
||||||
# is not safe, update this comment.
|
|
||||||
def run_with_timeout(
|
def run_with_timeout(
|
||||||
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||||
) -> R:
|
) -> R:
|
||||||
@@ -154,7 +148,8 @@ def run_with_timeout(
|
|||||||
Executes a function with a timeout. If the function doesn't complete within the specified
|
Executes a function with a timeout. If the function doesn't complete within the specified
|
||||||
timeout, raises TimeoutError.
|
timeout, raises TimeoutError.
|
||||||
"""
|
"""
|
||||||
task = TimeoutThread(timeout, func, *args, **kwargs)
|
context = contextvars.copy_context()
|
||||||
|
task = TimeoutThread(timeout, context.run, func, *args, **kwargs)
|
||||||
task.start()
|
task.start()
|
||||||
task.join(timeout)
|
task.join(timeout)
|
||||||
|
|
||||||
|
131
backend/tests/unit/onyx/utils/test_threadpool_contextvars.py
Normal file
131
backend/tests/unit/onyx/utils/test_threadpool_contextvars.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
import contextvars
|
||||||
|
import time
|
||||||
|
|
||||||
|
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||||
|
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||||
|
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||||
|
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||||
|
|
||||||
|
# Create a test contextvar
|
||||||
|
test_var = contextvars.ContextVar("test_var", default="default")
|
||||||
|
|
||||||
|
|
||||||
|
def get_contextvar_value() -> str:
|
||||||
|
"""Helper function that runs in a thread and returns the contextvar value"""
|
||||||
|
# Add a small sleep to ensure we're actually running in a different thread
|
||||||
|
time.sleep(0.1)
|
||||||
|
return test_var.get()
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_timeout_preserves_contextvar() -> None:
|
||||||
|
"""Test that run_with_timeout preserves contextvar values"""
|
||||||
|
# Set a value in the main thread
|
||||||
|
test_var.set("test_value")
|
||||||
|
|
||||||
|
# Run function with timeout and verify the value is preserved
|
||||||
|
result = run_with_timeout(1.0, get_contextvar_value)
|
||||||
|
assert result == "test_value"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_functions_in_parallel_preserves_contextvar() -> None:
|
||||||
|
"""Test that run_functions_in_parallel preserves contextvar values"""
|
||||||
|
# Set a value in the main thread
|
||||||
|
test_var.set("parallel_test")
|
||||||
|
|
||||||
|
# Create multiple function calls
|
||||||
|
function_calls = [
|
||||||
|
FunctionCall(get_contextvar_value),
|
||||||
|
FunctionCall(get_contextvar_value),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run in parallel and verify all results have the correct value
|
||||||
|
results = run_functions_in_parallel(function_calls)
|
||||||
|
|
||||||
|
for result_id, value in results.items():
|
||||||
|
assert value == "parallel_test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_functions_tuples_preserves_contextvar() -> None:
|
||||||
|
"""Test that run_functions_tuples_in_parallel preserves contextvar values"""
|
||||||
|
# Set a value in the main thread
|
||||||
|
test_var.set("tuple_test")
|
||||||
|
|
||||||
|
# Create list of function tuples
|
||||||
|
functions_with_args = [
|
||||||
|
(get_contextvar_value, ()),
|
||||||
|
(get_contextvar_value, ()),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run in parallel and verify all results have the correct value
|
||||||
|
results = run_functions_tuples_in_parallel(functions_with_args)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
assert result == "tuple_test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_contextvar_modifications() -> None:
|
||||||
|
"""Test that modifications to contextvars in threads don't affect other threads"""
|
||||||
|
|
||||||
|
def modify_and_return_contextvar(new_value: str) -> tuple[str, str]:
|
||||||
|
"""Helper that modifies the contextvar and returns both values"""
|
||||||
|
original = test_var.get()
|
||||||
|
test_var.set(new_value)
|
||||||
|
time.sleep(0.1) # Ensure threads overlap
|
||||||
|
return original, test_var.get()
|
||||||
|
|
||||||
|
# Set initial value
|
||||||
|
test_var.set("initial")
|
||||||
|
|
||||||
|
# Run multiple functions that modify the contextvar
|
||||||
|
functions_with_args = [
|
||||||
|
(modify_and_return_contextvar, ("thread1",)),
|
||||||
|
(modify_and_return_contextvar, ("thread2",)),
|
||||||
|
]
|
||||||
|
|
||||||
|
results = run_functions_tuples_in_parallel(functions_with_args)
|
||||||
|
|
||||||
|
# Verify each thread saw the initial value and its own modification
|
||||||
|
for original, modified in results:
|
||||||
|
assert original == "initial" # Each thread should see the initial value
|
||||||
|
assert modified in [
|
||||||
|
"thread1",
|
||||||
|
"thread2",
|
||||||
|
] # Each thread should see its own modification
|
||||||
|
|
||||||
|
# Verify the main thread's value wasn't affected
|
||||||
|
assert test_var.get() == "initial"
|
||||||
|
|
||||||
|
|
||||||
|
def test_contextvar_isolation_between_runs() -> None:
|
||||||
|
"""Test that contextvar changes don't leak between separate parallel runs"""
|
||||||
|
|
||||||
|
def set_and_return_contextvar(value: str) -> str:
|
||||||
|
test_var.set(value)
|
||||||
|
return test_var.get()
|
||||||
|
|
||||||
|
# First run
|
||||||
|
test_var.set("first_run")
|
||||||
|
first_results = run_functions_tuples_in_parallel(
|
||||||
|
[
|
||||||
|
(set_and_return_contextvar, ("thread1",)),
|
||||||
|
(set_and_return_contextvar, ("thread2",)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify first run results
|
||||||
|
assert all(result in ["thread1", "thread2"] for result in first_results)
|
||||||
|
|
||||||
|
# Second run should still see the main thread's value
|
||||||
|
assert test_var.get() == "first_run"
|
||||||
|
|
||||||
|
# Second run with different value
|
||||||
|
test_var.set("second_run")
|
||||||
|
second_results = run_functions_tuples_in_parallel(
|
||||||
|
[
|
||||||
|
(set_and_return_contextvar, ("thread3",)),
|
||||||
|
(set_and_return_contextvar, ("thread4",)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify second run results
|
||||||
|
assert all(result in ["thread3", "thread4"] for result in second_results)
|
Reference in New Issue
Block a user