diff --git a/backend/ee/onyx/server/query_and_chat/token_limit.py b/backend/ee/onyx/server/query_and_chat/token_limit.py index 5ee53b8f3..c6cd8486e 100644 --- a/backend/ee/onyx/server/query_and_chat/token_limit.py +++ b/backend/ee/onyx/server/query_and_chat/token_limit.py @@ -13,7 +13,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.api_key import is_api_key_email_address -from onyx.db.engine import get_session_with_tenant +from onyx.db.engine import get_session_with_current_tenant from onyx.db.models import ChatMessage from onyx.db.models import ChatSession from onyx.db.models import TokenRateLimit @@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel -def _check_token_rate_limits(user: User | None, tenant_id: str) -> None: +def _check_token_rate_limits(user: User | None) -> None: if user is None: # Unauthenticated users are only rate limited by global settings - _user_is_rate_limited_by_global(tenant_id) + _user_is_rate_limited_by_global() elif is_api_key_email_address(user.email): # API keys are only rate limited by global settings - _user_is_rate_limited_by_global(tenant_id) + _user_is_rate_limited_by_global() else: run_functions_tuples_in_parallel( [ - (_user_is_rate_limited, (user.id, tenant_id)), - (_user_is_rate_limited_by_group, (user.id, tenant_id)), - (_user_is_rate_limited_by_global, (tenant_id,)), + (_user_is_rate_limited, (user.id,)), + (_user_is_rate_limited_by_group, (user.id,)), + (_user_is_rate_limited_by_global, ()), ] ) @@ -52,8 +52,8 @@ User rate limits """ -def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None: - with get_session_with_tenant(tenant_id=tenant_id) as db_session: +def _user_is_rate_limited(user_id: UUID) -> None: + with get_session_with_current_tenant() as db_session: user_rate_limits = fetch_all_user_token_rate_limits( db_session=db_session, enabled_only=True, ordered=False ) @@ -93,8 +93,8 @@ User Group rate limits """ -def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None: - with get_session_with_tenant(tenant_id=tenant_id) as db_session: +def _user_is_rate_limited_by_group(user_id: UUID) -> None: + with get_session_with_current_tenant() as db_session: group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session) if group_rate_limits: diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index a3cce6b7b..54c2c4375 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -870,7 +870,6 @@ def stream_chat_message_objects( for img in img_generation_response if img.image_data ], - tenant_id=tenant_id, ) info.ai_message_files.extend( [ diff --git a/backend/onyx/file_store/utils.py b/backend/onyx/file_store/utils.py index 384990718..91198790a 100644 --- a/backend/onyx/file_store/utils.py +++ b/backend/onyx/file_store/utils.py @@ -8,7 +8,7 @@ import requests from sqlalchemy.orm import Session from onyx.configs.constants import FileOrigin -from onyx.db.engine import get_session_with_tenant +from onyx.db.engine import get_session_with_current_tenant from onyx.db.models import ChatMessage from onyx.file_store.file_store import get_default_file_store from onyx.file_store.models import FileDescriptor @@ -53,11 +53,11 @@ def load_all_chat_files( return files -def save_file_from_url(url: str, tenant_id: str) -> str: +def save_file_from_url(url: str) -> str: """NOTE: using multiple sessions here, since this is often called using multithreading. In practice, sharing a session has resulted in weird errors.""" - with get_session_with_tenant(tenant_id=tenant_id) as db_session: + with get_session_with_current_tenant() as db_session: response = requests.get(url) response.raise_for_status() @@ -75,8 +75,8 @@ def save_file_from_url(url: str, tenant_id: str) -> str: return unique_id -def save_file_from_base64(base64_string: str, tenant_id: str) -> str: - with get_session_with_tenant(tenant_id=tenant_id) as db_session: +def save_file_from_base64(base64_string: str) -> str: + with get_session_with_current_tenant() as db_session: unique_id = str(uuid4()) file_store = get_default_file_store(db_session) file_store.save_file( @@ -90,14 +90,12 @@ def save_file_from_base64(base64_string: str, tenant_id: str) -> str: def save_file( - tenant_id: str, url: str | None = None, base64_data: str | None = None, ) -> str: """Save a file from either a URL or base64 encoded string. Args: - tenant_id: The tenant ID to save the file under url: URL to download file from base64_data: Base64 encoded file data @@ -111,22 +109,22 @@ def save_file( raise ValueError("Cannot specify both url and base64_data") if url is not None: - return save_file_from_url(url, tenant_id) + return save_file_from_url(url) elif base64_data is not None: - return save_file_from_base64(base64_data, tenant_id) + return save_file_from_base64(base64_data) else: raise ValueError("Must specify either url or base64_data") -def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]: +def save_files(urls: list[str], base64_files: list[str]) -> list[str]: # NOTE: be explicit about typing so that if we change things, we get notified funcs: list[ tuple[ - Callable[[str, str | None, str | None], str], - tuple[str, str | None, str | None], + Callable[[str | None, str | None], str], + tuple[str | None, str | None], ] - ] = [(save_file, (tenant_id, url, None)) for url in urls] + [ - (save_file, (tenant_id, None, base64_file)) for base64_file in base64_files + ] = [(save_file, (url, None)) for url in urls] + [ + (save_file, (None, base64_file)) for base64_file in base64_files ] return run_functions_tuples_in_parallel(funcs) diff --git a/backend/onyx/server/query_and_chat/token_limit.py b/backend/onyx/server/query_and_chat/token_limit.py index b94903a28..fc0bc629d 100644 --- a/backend/onyx/server/query_and_chat/token_limit.py +++ b/backend/onyx/server/query_and_chat/token_limit.py @@ -13,7 +13,6 @@ from sqlalchemy.orm import Session from onyx.auth.users import current_chat_accesssible_user from onyx.db.engine import get_session_context_manager -from onyx.db.engine import get_session_with_tenant from onyx.db.models import ChatMessage from onyx.db.models import ChatSession from onyx.db.models import TokenRateLimit @@ -21,7 +20,6 @@ from onyx.db.models import User from onyx.db.token_limit import fetch_all_global_token_rate_limits from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation -from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @@ -39,13 +37,13 @@ def check_token_rate_limits( return versioned_rate_limit_strategy = fetch_versioned_implementation( - "onyx.server.query_and_chat.token_limit", "_check_token_rate_limits" + "onyx.server.query_and_chat.token_limit", _check_token_rate_limits.__name__ ) - return versioned_rate_limit_strategy(user, get_current_tenant_id()) + return versioned_rate_limit_strategy(user) -def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None: - _user_is_rate_limited_by_global(tenant_id) +def _check_token_rate_limits(_: User | None) -> None: + _user_is_rate_limited_by_global() """ @@ -53,8 +51,8 @@ Global rate limits """ -def _user_is_rate_limited_by_global(tenant_id: str | None) -> None: - with get_session_with_tenant(tenant_id=tenant_id) as db_session: +def _user_is_rate_limited_by_global() -> None: + with get_session_context_manager() as db_session: global_rate_limits = fetch_all_global_token_rate_limits( db_session=db_session, enabled_only=True, ordered=False ) diff --git a/backend/onyx/utils/threadpool_concurrency.py b/backend/onyx/utils/threadpool_concurrency.py index f6a2b3fbe..4ef87348f 100644 --- a/backend/onyx/utils/threadpool_concurrency.py +++ b/backend/onyx/utils/threadpool_concurrency.py @@ -1,3 +1,4 @@ +import contextvars import threading import uuid from collections.abc import Callable @@ -14,10 +15,6 @@ logger = setup_logger() R = TypeVar("R") -# WARNING: it is not currently well understood whether we lose access to contextvars when functions are -# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless -# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or -# is not safe, update this comment. def run_functions_tuples_in_parallel( functions_with_args: list[tuple[Callable, tuple]], allow_failures: bool = False, @@ -45,8 +42,11 @@ def run_functions_tuples_in_parallel( results = [] with ThreadPoolExecutor(max_workers=workers) as executor: + # The primary reason for propagating contextvars is to allow acquiring a db session + # that respects tenant id. Context.run is expected to be low-overhead, but if we later + # find that it is increasing latency we can make using it optional. future_to_index = { - executor.submit(func, *args): i + executor.submit(contextvars.copy_context().run, func, *args): i for i, (func, args) in enumerate(functions_with_args) } @@ -83,10 +83,6 @@ class FunctionCall(Generic[R]): return self.func(*self.args, **self.kwargs) -# WARNING: it is not currently well understood whether we lose access to contextvars when functions are -# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless -# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or -# is not safe, update this comment. def run_functions_in_parallel( function_calls: list[FunctionCall], allow_failures: bool = False, @@ -102,7 +98,9 @@ def run_functions_in_parallel( with ThreadPoolExecutor(max_workers=len(function_calls)) as executor: future_to_id = { - executor.submit(func_call.execute): func_call.result_id + executor.submit( + contextvars.copy_context().run, func_call.execute + ): func_call.result_id for func_call in function_calls } @@ -143,10 +141,6 @@ class TimeoutThread(threading.Thread): ) -# WARNING: it is not currently well understood whether we lose access to contextvars when functions are -# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless -# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or -# is not safe, update this comment. def run_with_timeout( timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any ) -> R: @@ -154,7 +148,8 @@ def run_with_timeout( Executes a function with a timeout. If the function doesn't complete within the specified timeout, raises TimeoutError. """ - task = TimeoutThread(timeout, func, *args, **kwargs) + context = contextvars.copy_context() + task = TimeoutThread(timeout, context.run, func, *args, **kwargs) task.start() task.join(timeout) diff --git a/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py b/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py new file mode 100644 index 000000000..4d6d9a6a3 --- /dev/null +++ b/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py @@ -0,0 +1,131 @@ +import contextvars +import time + +from onyx.utils.threadpool_concurrency import FunctionCall +from onyx.utils.threadpool_concurrency import run_functions_in_parallel +from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel +from onyx.utils.threadpool_concurrency import run_with_timeout + +# Create a test contextvar +test_var = contextvars.ContextVar("test_var", default="default") + + +def get_contextvar_value() -> str: + """Helper function that runs in a thread and returns the contextvar value""" + # Add a small sleep to ensure we're actually running in a different thread + time.sleep(0.1) + return test_var.get() + + +def test_run_with_timeout_preserves_contextvar() -> None: + """Test that run_with_timeout preserves contextvar values""" + # Set a value in the main thread + test_var.set("test_value") + + # Run function with timeout and verify the value is preserved + result = run_with_timeout(1.0, get_contextvar_value) + assert result == "test_value" + + +def test_run_functions_in_parallel_preserves_contextvar() -> None: + """Test that run_functions_in_parallel preserves contextvar values""" + # Set a value in the main thread + test_var.set("parallel_test") + + # Create multiple function calls + function_calls = [ + FunctionCall(get_contextvar_value), + FunctionCall(get_contextvar_value), + ] + + # Run in parallel and verify all results have the correct value + results = run_functions_in_parallel(function_calls) + + for result_id, value in results.items(): + assert value == "parallel_test" + + +def test_run_functions_tuples_preserves_contextvar() -> None: + """Test that run_functions_tuples_in_parallel preserves contextvar values""" + # Set a value in the main thread + test_var.set("tuple_test") + + # Create list of function tuples + functions_with_args = [ + (get_contextvar_value, ()), + (get_contextvar_value, ()), + ] + + # Run in parallel and verify all results have the correct value + results = run_functions_tuples_in_parallel(functions_with_args) + + for result in results: + assert result == "tuple_test" + + +def test_nested_contextvar_modifications() -> None: + """Test that modifications to contextvars in threads don't affect other threads""" + + def modify_and_return_contextvar(new_value: str) -> tuple[str, str]: + """Helper that modifies the contextvar and returns both values""" + original = test_var.get() + test_var.set(new_value) + time.sleep(0.1) # Ensure threads overlap + return original, test_var.get() + + # Set initial value + test_var.set("initial") + + # Run multiple functions that modify the contextvar + functions_with_args = [ + (modify_and_return_contextvar, ("thread1",)), + (modify_and_return_contextvar, ("thread2",)), + ] + + results = run_functions_tuples_in_parallel(functions_with_args) + + # Verify each thread saw the initial value and its own modification + for original, modified in results: + assert original == "initial" # Each thread should see the initial value + assert modified in [ + "thread1", + "thread2", + ] # Each thread should see its own modification + + # Verify the main thread's value wasn't affected + assert test_var.get() == "initial" + + +def test_contextvar_isolation_between_runs() -> None: + """Test that contextvar changes don't leak between separate parallel runs""" + + def set_and_return_contextvar(value: str) -> str: + test_var.set(value) + return test_var.get() + + # First run + test_var.set("first_run") + first_results = run_functions_tuples_in_parallel( + [ + (set_and_return_contextvar, ("thread1",)), + (set_and_return_contextvar, ("thread2",)), + ] + ) + + # Verify first run results + assert all(result in ["thread1", "thread2"] for result in first_results) + + # Second run should still see the main thread's value + assert test_var.get() == "first_run" + + # Second run with different value + test_var.set("second_run") + second_results = run_functions_tuples_in_parallel( + [ + (set_and_return_contextvar, ("thread3",)), + (set_and_return_contextvar, ("thread4",)), + ] + ) + + # Verify second run results + assert all(result in ["thread3", "thread4"] for result in second_results)