diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index a26d20d70a..40e978c3ef 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -830,6 +830,7 @@ def stream_chat_message_objects( db_session=db_session, persona=persona, actual_user_input=message_text, + user_id=user_id, ) if not search_tool_override_kwargs_for_user_files: latest_query_files.extend(in_memory_user_files) diff --git a/backend/onyx/chat/user_files/parse_user_files.py b/backend/onyx/chat/user_files/parse_user_files.py index ee3f93e7c7..37ef93345a 100644 --- a/backend/onyx/chat/user_files/parse_user_files.py +++ b/backend/onyx/chat/user_files/parse_user_files.py @@ -1,9 +1,11 @@ +from uuid import UUID + from sqlalchemy.orm import Session from onyx.db.models import Persona from onyx.db.models import UserFile from onyx.file_store.models import InMemoryChatFile -from onyx.file_store.utils import get_user_files +from onyx.file_store.utils import get_user_files_as_user from onyx.file_store.utils import load_in_memory_chat_files from onyx.tools.models import SearchToolOverrideKwargs from onyx.utils.logger import setup_logger @@ -18,6 +20,8 @@ def parse_user_files( db_session: Session, persona: Persona, actual_user_input: str, + # should only be None if auth is disabled + user_id: UUID | None, ) -> tuple[list[InMemoryChatFile], list[UserFile], SearchToolOverrideKwargs | None]: """ Parse user files and folders into in-memory chat files and create search tool override kwargs. @@ -29,6 +33,7 @@ def parse_user_files( db_session: Database session persona: Persona to calculate available tokens actual_user_input: User's input message for token calculation + user_id: User ID to validate file ownership Returns: Tuple of ( @@ -49,9 +54,10 @@ def parse_user_files( db_session, ) - user_file_models = get_user_files( + user_file_models = get_user_files_as_user( user_file_ids or [], user_folder_ids or [], + user_id, db_session, ) diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index 433e16d7dd..877ab12c16 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -42,6 +42,8 @@ from onyx.utils.threadpool_concurrency import FunctionCall from onyx.utils.threadpool_concurrency import run_functions_in_parallel from onyx.utils.timing import log_function_time from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop +from shared_configs.configs import MULTI_TENANT +from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @@ -180,7 +182,11 @@ class SearchPipeline: filters = IndexFilters( user_file_ids=user_file_ids or [], user_folder_ids=user_folder_ids or [], + # NOTE: this can be None, since it's assumed that the user_file_ids / user_folder_ids + # have already been verified as owned by the user running this query + # TODO: make this more robust access_control_list=None, + tenant_id=get_current_tenant_id() if MULTI_TENANT else None, ) # Use a simplified query that skips all unnecessary processing diff --git a/backend/onyx/file_store/utils.py b/backend/onyx/file_store/utils.py index 0fc3d7fb70..9f3a56da6c 100644 --- a/backend/onyx/file_store/utils.py +++ b/backend/onyx/file_store/utils.py @@ -2,6 +2,7 @@ import base64 from collections.abc import Callable from io import BytesIO from typing import cast +from uuid import UUID from uuid import uuid4 import requests @@ -245,6 +246,26 @@ def get_user_files( return user_files +def get_user_files_as_user( + user_file_ids: list[int], + user_folder_ids: list[int], + user_id: UUID | None, + db_session: Session, +) -> list[UserFile]: + """ + Fetches all UserFile database records for a given user. + """ + user_files = get_user_files(user_file_ids, user_folder_ids, db_session) + for user_file in user_files: + # Note: if user_id is None, then all files should be None as well + # (since auth must be disabled in this case) + if user_file.user_id != user_id: + raise ValueError( + f"User {user_id} does not have access to file {user_file.id}" + ) + return user_files + + def save_file_from_url(url: str) -> str: """NOTE: using multiple sessions here, since this is often called using multithreading. In practice, sharing a session has resulted in