From a98dcbc7dee4eb8e50ee4d85c79cb5539cad2430 Mon Sep 17 00:00:00 2001 From: pablonyx Date: Tue, 25 Feb 2025 19:53:46 -0800 Subject: [PATCH] Update tenant logic (#4122) * k * k * k * quick nit * nit --- .../ee/onyx/background/celery/apps/primary.py | 23 ++++++------------- .../background/celery/tasks/vespa/tasks.py | 2 +- .../ee/onyx/server/tenants/user_mapping.py | 8 +++---- backend/onyx/auth/api_key.py | 4 ++-- backend/onyx/auth/email_utils.py | 5 ++-- backend/onyx/auth/users.py | 4 ++-- .../onyx/background/celery/apps/app_base.py | 5 ++-- .../onyx/background/celery/celery_utils.py | 4 ++-- .../celery/tasks/connector_deletion/tasks.py | 12 ++++------ .../tasks/doc_permission_syncing/tasks.py | 13 +++++------ .../tasks/external_group_syncing/tasks.py | 11 ++++----- .../background/celery/tasks/indexing/tasks.py | 12 +++++----- .../background/celery/tasks/indexing/utils.py | 6 ++--- .../celery/tasks/llm_model_update/tasks.py | 2 +- .../celery/tasks/monitoring/tasks.py | 6 ++--- .../background/celery/tasks/periodic/tasks.py | 2 +- .../background/celery/tasks/pruning/tasks.py | 12 +++++----- .../celery/tasks/shared/RetryDocumentIndex.py | 4 ++-- .../background/celery/tasks/shared/tasks.py | 4 ++-- .../background/celery/tasks/vespa/tasks.py | 14 +++++------ .../onyx/background/indexing/run_indexing.py | 10 ++++---- backend/onyx/connectors/factory.py | 7 ------ backend/onyx/connectors/file/connector.py | 11 ++------- backend/onyx/db/api_key.py | 3 +-- backend/onyx/db/engine.py | 6 ++--- .../document_index/document_index_utils.py | 6 ++--- backend/onyx/document_index/interfaces.py | 10 ++++---- backend/onyx/document_index/vespa/index.py | 8 +++---- backend/onyx/indexing/indexing_pipeline.py | 6 ++--- backend/onyx/indexing/models.py | 4 ++-- backend/onyx/main.py | 2 +- backend/onyx/onyxbot/slack/blocks.py | 4 ++-- .../onyxbot/slack/handlers/handle_buttons.py | 2 +- .../onyxbot/slack/handlers/handle_message.py | 2 +- .../slack/handlers/handle_regular_answer.py | 2 +- backend/onyx/onyxbot/slack/listener.py | 16 ++++++------- backend/onyx/onyxbot/slack/utils.py | 6 ++--- backend/onyx/redis/redis_connector.py | 4 ++-- .../redis/redis_connector_credential_pair.py | 4 ++-- backend/onyx/redis/redis_connector_delete.py | 4 ++-- .../redis/redis_connector_doc_perm_sync.py | 4 ++-- .../redis/redis_connector_ext_group_sync.py | 4 ++-- backend/onyx/redis/redis_connector_index.py | 4 ++-- backend/onyx/redis/redis_connector_prune.py | 4 ++-- backend/onyx/redis/redis_connector_stop.py | 4 ++-- backend/onyx/redis/redis_document_set.py | 4 ++-- backend/onyx/redis/redis_object_helper.py | 6 ++--- backend/onyx/redis/redis_usergroup.py | 4 ++-- backend/onyx/seeding/load_docs.py | 8 ++++--- backend/onyx/server/documents/cc_pair.py | 2 +- backend/onyx/server/documents/connector.py | 1 - backend/onyx/server/documents/credential.py | 3 --- backend/onyx/server/settings/models.py | 3 ++- backend/onyx/setup.py | 2 +- backend/scripts/debugging/onyx_vespa.py | 8 +++---- .../scripts/force_delete_connector_by_id.py | 3 ++- backend/scripts/orphan_doc_cleanup_script.py | 5 +++- backend/shared_configs/contextvars.py | 2 ++ 58 files changed, 155 insertions(+), 186 deletions(-) diff --git a/backend/ee/onyx/background/celery/apps/primary.py b/backend/ee/onyx/background/celery/apps/primary.py index a12ad0d42..d984ab720 100644 --- a/backend/ee/onyx/background/celery/apps/primary.py +++ b/backend/ee/onyx/background/celery/apps/primary.py @@ -5,11 +5,9 @@ from onyx.background.celery.apps.primary import celery_app from onyx.background.task_utils import build_celery_task_wrapper from onyx.configs.app_configs import JOB_TIMEOUT from onyx.db.chat import delete_chat_sessions_older_than -from onyx.db.engine import get_session_with_tenant +from onyx.db.engine import get_session_with_current_tenant from onyx.server.settings.store import load_settings from onyx.utils.logger import setup_logger -from shared_configs.configs import MULTI_TENANT -from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() @@ -18,10 +16,8 @@ logger = setup_logger() @build_celery_task_wrapper(name_chat_ttl_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def perform_ttl_management_task( - retention_limit_days: int, *, tenant_id: str | None -) -> None: - with get_session_with_tenant(tenant_id=tenant_id) as db_session: +def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None: + with get_session_with_current_tenant() as db_session: delete_chat_sessions_older_than(retention_limit_days, db_session) @@ -35,24 +31,19 @@ def perform_ttl_management_task( ignore_result=True, soft_time_limit=JOB_TIMEOUT, ) -def check_ttl_management_task(*, tenant_id: str | None) -> None: +def check_ttl_management_task(*, tenant_id: str) -> None: """Runs periodically to check if any ttl tasks should be run and adds them to the queue""" - token = None - if MULTI_TENANT and tenant_id is not None: - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) settings = load_settings() retention_limit_days = settings.maximum_chat_retention_days - with get_session_with_tenant(tenant_id=tenant_id) as db_session: + with get_session_with_current_tenant() as db_session: if should_perform_chat_ttl_check(retention_limit_days, db_session): perform_ttl_management_task.apply_async( kwargs=dict( retention_limit_days=retention_limit_days, tenant_id=tenant_id ), ) - if token is not None: - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) @celery_app.task( @@ -60,9 +51,9 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None: ignore_result=True, soft_time_limit=JOB_TIMEOUT, ) -def autogenerate_usage_report_task(*, tenant_id: str | None) -> None: +def autogenerate_usage_report_task(*, tenant_id: str) -> None: """This generates usage report under the /admin/generate-usage/report endpoint""" - with get_session_with_tenant(tenant_id=tenant_id) as db_session: + with get_session_with_current_tenant() as db_session: create_new_usage_report( db_session=db_session, user_id=None, diff --git a/backend/ee/onyx/background/celery/tasks/vespa/tasks.py b/backend/ee/onyx/background/celery/tasks/vespa/tasks.py index 45c65b73c..ac0a7b882 100644 --- a/backend/ee/onyx/background/celery/tasks/vespa/tasks.py +++ b/backend/ee/onyx/background/celery/tasks/vespa/tasks.py @@ -18,7 +18,7 @@ logger = setup_logger() def monitor_usergroup_taskset( - tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session + tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session ) -> None: """This function is likely to move in the worker refactor happening next.""" fence_key = key_bytes.decode("utf-8") diff --git a/backend/ee/onyx/server/tenants/user_mapping.py b/backend/ee/onyx/server/tenants/user_mapping.py index e7dfc2d3c..b5b0fe196 100644 --- a/backend/ee/onyx/server/tenants/user_mapping.py +++ b/backend/ee/onyx/server/tenants/user_mapping.py @@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str: def user_owns_a_tenant(email: str) -> bool: - with get_session_with_tenant(tenant_id=None) as db_session: + with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session: result = ( db_session.query(UserTenantMapping) .filter(UserTenantMapping.email == email) @@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool: def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant(tenant_id=None) as db_session: + with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session: try: for email in emails: db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) @@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant(tenant_id=None) as db_session: + with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session: try: mappings_to_delete = ( db_session.query(UserTenantMapping) @@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: def remove_all_users_from_tenant(tenant_id: str) -> None: - with get_session_with_tenant(tenant_id=None) as db_session: + with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session: db_session.query(UserTenantMapping).filter( UserTenantMapping.tenant_id == tenant_id ).delete() diff --git a/backend/onyx/auth/api_key.py b/backend/onyx/auth/api_key.py index ebb974f8b..e6c8c0c58 100644 --- a/backend/onyx/auth/api_key.py +++ b/backend/onyx/auth/api_key.py @@ -10,6 +10,7 @@ from pydantic import BaseModel from onyx.auth.schemas import UserRole from onyx.configs.app_configs import API_KEY_HASH_ROUNDS +from shared_configs.configs import MULTI_TENANT _API_KEY_HEADER_NAME = "Authorization" @@ -35,8 +36,7 @@ class ApiKeyDescriptor(BaseModel): def generate_api_key(tenant_id: str | None = None) -> str: - # For backwards compatibility, if no tenant_id, generate old style key - if not tenant_id: + if not MULTI_TENANT or not tenant_id: return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN) encoded_tenant = quote(tenant_id) # URL encode the tenant ID diff --git a/backend/onyx/auth/email_utils.py b/backend/onyx/auth/email_utils.py index c4e72bacd..6f677b12e 100644 --- a/backend/onyx/auth/email_utils.py +++ b/backend/onyx/auth/email_utils.py @@ -15,6 +15,7 @@ from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import AuthType from onyx.configs.constants import TENANT_ID_COOKIE_NAME from onyx.db.models import User +from shared_configs.configs import MULTI_TENANT HTML_EMAIL_TEMPLATE = """\ @@ -242,13 +243,13 @@ def send_user_email_invite( def send_forgot_password_email( user_email: str, token: str, + tenant_id: str, mail_from: str = EMAIL_FROM, - tenant_id: str | None = None, ) -> None: # Builds a forgot password email with or without fancy HTML subject = "Onyx Forgot Password" link = f"{WEB_DOMAIN}/auth/reset-password?token={token}" - if tenant_id: + if MULTI_TENANT: link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}" message = f"

Click the following link to reset your password:

{link}

" html_content = build_html_email("Reset Your Password", message) diff --git a/backend/onyx/auth/users.py b/backend/onyx/auth/users.py index 2b0a671e6..f28825f1f 100644 --- a/backend/onyx/auth/users.py +++ b/backend/onyx/auth/users.py @@ -214,7 +214,7 @@ def verify_email_is_invited(email: str) -> None: raise PermissionError("User not on allowed user whitelist") -def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None: +def verify_email_in_whitelist(email: str, tenant_id: str) -> None: with get_session_with_tenant(tenant_id=tenant_id) as db_session: if not get_user_by_email(email, db_session): verify_email_is_invited(email) @@ -553,7 +553,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): async_return_default_schema, )(email=user.email) - send_forgot_password_email(user.email, token, tenant_id=tenant_id) + send_forgot_password_email(user.email, tenant_id=tenant_id, token=token) async def on_after_request_verify( self, user: User, token: str, request: Optional[Request] = None diff --git a/backend/onyx/background/celery/apps/app_base.py b/backend/onyx/background/celery/apps/app_base.py index 4dc7ca17a..bce48fb4f 100644 --- a/backend/onyx/background/celery/apps/app_base.py +++ b/backend/onyx/background/celery/apps/app_base.py @@ -2,6 +2,7 @@ import logging import multiprocessing import time from typing import Any +from typing import cast import sentry_sdk from celery import Task @@ -131,9 +132,9 @@ def on_task_postrun( # Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg if not kwargs: logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs") - tenant_id = None + tenant_id = POSTGRES_DEFAULT_SCHEMA else: - tenant_id = kwargs.get("tenant_id") + tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)) task_logger.debug( f"Task {task.name} (ID: {task_id}) completed with state: {state} " diff --git a/backend/onyx/background/celery/celery_utils.py b/backend/onyx/background/celery/celery_utils.py index 60b3dfacc..ffb201dca 100644 --- a/backend/onyx/background/celery/celery_utils.py +++ b/backend/onyx/background/celery/celery_utils.py @@ -34,7 +34,7 @@ def _get_deletion_status( connector_id: int, credential_id: int, db_session: Session, - tenant_id: str | None = None, + tenant_id: str, ) -> TaskQueueState | None: """We no longer store TaskQueueState in the DB for a deletion attempt. This function populates TaskQueueState by just checking redis. @@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot( connector_id: int, credential_id: int, db_session: Session, - tenant_id: str | None = None, + tenant_id: str, ) -> DeletionAttemptSnapshot | None: deletion_task = _get_deletion_status( connector_id, credential_id, db_session, tenant_id diff --git a/backend/onyx/background/celery/tasks/connector_deletion/tasks.py b/backend/onyx/background/celery/tasks/connector_deletion/tasks.py index baaa60135..c1d458c1a 100644 --- a/backend/onyx/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/onyx/background/celery/tasks/connector_deletion/tasks.py @@ -109,9 +109,7 @@ def revoke_tasks_blocking_deletion( trail=False, bind=True, ) -def check_for_connector_deletion_task( - self: Task, *, tenant_id: str | None -) -> bool | None: +def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None: r = get_redis_client() r_replica = get_redis_replica_client() r_celery: Redis = self.app.broker_connection().channel().client # type: ignore @@ -224,7 +222,7 @@ def try_generate_document_cc_pair_cleanup_tasks( cc_pair_id: int, db_session: Session, lock_beat: RedisLock, - tenant_id: str | None, + tenant_id: str, ) -> int | None: """Returns an int if syncing is needed. The int represents the number of sync tasks generated. Note that syncing can still be required even if the number of sync tasks generated is zero. @@ -345,7 +343,7 @@ def try_generate_document_cc_pair_cleanup_tasks( def monitor_connector_deletion_taskset( - tenant_id: str | None, key_bytes: bytes, r: Redis + tenant_id: str, key_bytes: bytes, r: Redis ) -> None: fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) @@ -500,7 +498,7 @@ def monitor_connector_deletion_taskset( def validate_connector_deletion_fences( - tenant_id: str | None, + tenant_id: str, r: Redis, r_replica: Redis, r_celery: Redis, @@ -540,7 +538,7 @@ def validate_connector_deletion_fences( def validate_connector_deletion_fence( - tenant_id: str | None, + tenant_id: str, key_bytes: bytes, queued_tasks: set[str], r: Redis, diff --git a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py index 563e87179..b308e5a18 100644 --- a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py @@ -221,7 +221,7 @@ def try_creating_permissions_sync_task( app: Celery, cc_pair_id: int, r: Redis, - tenant_id: str | None, + tenant_id: str, ) -> str | None: """Returns a randomized payload id on success. Returns None if no syncing is required.""" @@ -320,7 +320,7 @@ def try_creating_permissions_sync_task( def connector_permission_sync_generator_task( self: Task, cc_pair_id: int, - tenant_id: str | None, + tenant_id: str, ) -> None: """ Permission sync task that handles document permission syncing for a given connector credential pair @@ -410,7 +410,6 @@ def connector_permission_sync_generator_task( cc_pair.connector.id, cc_pair.credential.id, db_session, - tenant_id, enforce_creation=False, ) if not created: @@ -510,7 +509,7 @@ def connector_permission_sync_generator_task( ) def update_external_document_permissions_task( self: Task, - tenant_id: str | None, + tenant_id: str, serialized_doc_external_access: dict, source_string: str, connector_id: int, @@ -585,7 +584,7 @@ def update_external_document_permissions_task( def validate_permission_sync_fences( - tenant_id: str | None, + tenant_id: str, r: Redis, r_replica: Redis, r_celery: Redis, @@ -632,7 +631,7 @@ def validate_permission_sync_fences( def validate_permission_sync_fence( - tenant_id: str | None, + tenant_id: str, key_bytes: bytes, queued_tasks: set[str], reserved_tasks: set[str], @@ -842,7 +841,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface): def monitor_ccpair_permissions_taskset( - tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session + tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session ) -> None: fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) diff --git a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py index 73ca958b0..6aa257305 100644 --- a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py @@ -123,7 +123,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool: soft_time_limit=JOB_TIMEOUT, bind=True, ) -def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None: +def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None: # we need to use celery's redis client to access its redis data # (which lives on a different db number) r = get_redis_client() @@ -220,7 +220,7 @@ def try_creating_external_group_sync_task( app: Celery, cc_pair_id: int, r: Redis, - tenant_id: str | None, + tenant_id: str, ) -> str | None: """Returns an int if syncing is needed. The int represents the number of sync tasks generated. Returns None if no syncing is required.""" @@ -306,7 +306,7 @@ def try_creating_external_group_sync_task( def connector_external_group_sync_generator_task( self: Task, cc_pair_id: int, - tenant_id: str | None, + tenant_id: str, ) -> None: """ External group sync task for a given connector credential pair @@ -392,7 +392,6 @@ def connector_external_group_sync_generator_task( cc_pair.connector.id, cc_pair.credential.id, db_session, - tenant_id, enforce_creation=False, ) if not created: @@ -494,7 +493,7 @@ def connector_external_group_sync_generator_task( def validate_external_group_sync_fences( - tenant_id: str | None, + tenant_id: str, celery_app: Celery, r: Redis, r_replica: Redis, @@ -526,7 +525,7 @@ def validate_external_group_sync_fences( def validate_external_group_sync_fence( - tenant_id: str | None, + tenant_id: str, key_bytes: bytes, reserved_tasks: set[str], r_celery: Redis, diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index bbc0ed47c..e8063e575 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -182,7 +182,7 @@ class SimpleJobResult: class ConnectorIndexingContext(BaseModel): - tenant_id: str | None + tenant_id: str cc_pair_id: int search_settings_id: int index_attempt_id: int @@ -210,7 +210,7 @@ class ConnectorIndexingLogBuilder: def monitor_ccpair_indexing_taskset( - tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session + tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session ) -> None: # if the fence doesn't exist, there's nothing to do fence_key = key_bytes.decode("utf-8") @@ -358,7 +358,7 @@ def monitor_ccpair_indexing_taskset( soft_time_limit=300, bind=True, ) -def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: +def check_for_indexing(self: Task, *, tenant_id: str) -> int | None: """a lightweight task used to kick off indexing tasks. Occcasionally does some validation of existing state to clear up error conditions""" @@ -598,7 +598,7 @@ def connector_indexing_task( cc_pair_id: int, search_settings_id: int, is_ee: bool, - tenant_id: str | None, + tenant_id: str, ) -> int | None: """Indexing task. For a cc pair, this task pulls all document IDs from the source and compares those IDs to locally stored documents and deletes all locally stored IDs missing @@ -890,7 +890,7 @@ def connector_indexing_proxy_task( index_attempt_id: int, cc_pair_id: int, search_settings_id: int, - tenant_id: str | None, + tenant_id: str, ) -> None: """celery out of process task execution strategy is pool=prefork, but it uses fork, and forking is inherently unstable. @@ -1170,7 +1170,7 @@ def connector_indexing_proxy_task( name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP, soft_time_limit=300, ) -def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None: +def check_for_checkpoint_cleanup(*, tenant_id: str) -> None: """Clean up old checkpoints that are older than 7 days.""" locked = False redis_client = get_redis_client(tenant_id=tenant_id) diff --git a/backend/onyx/background/celery/tasks/indexing/utils.py b/backend/onyx/background/celery/tasks/indexing/utils.py index 48acf28fe..cfb528799 100644 --- a/backend/onyx/background/celery/tasks/indexing/utils.py +++ b/backend/onyx/background/celery/tasks/indexing/utils.py @@ -187,7 +187,7 @@ class IndexingCallback(IndexingCallbackBase): def validate_indexing_fence( - tenant_id: str | None, + tenant_id: str, key_bytes: bytes, reserved_tasks: set[str], r_celery: Redis, @@ -311,7 +311,7 @@ def validate_indexing_fence( def validate_indexing_fences( - tenant_id: str | None, + tenant_id: str, r_replica: Redis, r_celery: Redis, lock_beat: RedisLock, @@ -442,7 +442,7 @@ def try_creating_indexing_task( reindex: bool, db_session: Session, r: Redis, - tenant_id: str | None, + tenant_id: str, ) -> int | None: """Checks for any conditions that should block the indexing task from being created, then creates the task. diff --git a/backend/onyx/background/celery/tasks/llm_model_update/tasks.py b/backend/onyx/background/celery/tasks/llm_model_update/tasks.py index 34e42cc0d..ae0eafe81 100644 --- a/backend/onyx/background/celery/tasks/llm_model_update/tasks.py +++ b/backend/onyx/background/celery/tasks/llm_model_update/tasks.py @@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]: trail=False, bind=True, ) -def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None: +def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None: if not LLM_MODEL_UPDATE_API_URL: raise ValueError("LLM model update API URL not configured") diff --git a/backend/onyx/background/celery/tasks/monitoring/tasks.py b/backend/onyx/background/celery/tasks/monitoring/tasks.py index 9f93298d1..2fa67788f 100644 --- a/backend/onyx/background/celery/tasks/monitoring/tasks.py +++ b/backend/onyx/background/celery/tasks/monitoring/tasks.py @@ -91,7 +91,7 @@ class Metric(BaseModel): } task_logger.info(json.dumps(data)) - def emit(self, tenant_id: str | None) -> None: + def emit(self, tenant_id: str) -> None: # Convert value to appropriate type based on the input value bool_value = None float_value = None @@ -656,7 +656,7 @@ def build_job_id( queue=OnyxCeleryQueues.MONITORING, bind=True, ) -def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None: +def monitor_background_processes(self: Task, *, tenant_id: str) -> None: """Collect and emit metrics about background processes. This task runs periodically to gather metrics about: - Queue lengths for different Celery queues @@ -864,7 +864,7 @@ def cloud_monitor_celery_queues( @shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True) -def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None: +def monitor_celery_queues(self: Task, *, tenant_id: str) -> None: return monitor_celery_queues_helper(self) diff --git a/backend/onyx/background/celery/tasks/periodic/tasks.py b/backend/onyx/background/celery/tasks/periodic/tasks.py index 7f3ff3caf..de0672cc8 100644 --- a/backend/onyx/background/celery/tasks/periodic/tasks.py +++ b/backend/onyx/background/celery/tasks/periodic/tasks.py @@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant bind=True, base=AbortableTask, ) -def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int: +def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int: """Runs periodically to clean up the kombu_message table""" # we will select messages older than this amount to clean up diff --git a/backend/onyx/background/celery/tasks/pruning/tasks.py b/backend/onyx/background/celery/tasks/pruning/tasks.py index 2de5338f2..739eba3b0 100644 --- a/backend/onyx/background/celery/tasks/pruning/tasks.py +++ b/backend/onyx/background/celery/tasks/pruning/tasks.py @@ -114,7 +114,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool: soft_time_limit=JOB_TIMEOUT, bind=True, ) -def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None: +def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None: r = get_redis_client() r_replica = get_redis_replica_client() r_celery: Redis = self.app.broker_connection().channel().client # type: ignore @@ -211,7 +211,7 @@ def try_creating_prune_generator_task( cc_pair: ConnectorCredentialPair, db_session: Session, r: Redis, - tenant_id: str | None, + tenant_id: str, ) -> str | None: """Checks for any conditions that should block the pruning generator task from being created, then creates the task. @@ -333,7 +333,7 @@ def connector_pruning_generator_task( cc_pair_id: int, connector_id: int, credential_id: int, - tenant_id: str | None, + tenant_id: str, ) -> None: """connector pruning task. For a cc pair, this task pulls all document IDs from the source and compares those IDs to locally stored documents and deletes all locally stored IDs missing @@ -521,7 +521,7 @@ def connector_pruning_generator_task( def monitor_ccpair_pruning_taskset( - tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session + tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session ) -> None: fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) @@ -567,7 +567,7 @@ def monitor_ccpair_pruning_taskset( def validate_pruning_fences( - tenant_id: str | None, + tenant_id: str, r: Redis, r_replica: Redis, r_celery: Redis, @@ -615,7 +615,7 @@ def validate_pruning_fences( def validate_pruning_fence( - tenant_id: str | None, + tenant_id: str, key_bytes: bytes, reserved_tasks: set[str], queued_tasks: set[str], diff --git a/backend/onyx/background/celery/tasks/shared/RetryDocumentIndex.py b/backend/onyx/background/celery/tasks/shared/RetryDocumentIndex.py index 34a3e0a88..fc94807c1 100644 --- a/backend/onyx/background/celery/tasks/shared/RetryDocumentIndex.py +++ b/backend/onyx/background/celery/tasks/shared/RetryDocumentIndex.py @@ -32,7 +32,7 @@ class RetryDocumentIndex: self, doc_id: str, *, - tenant_id: str | None, + tenant_id: str, chunk_count: int | None, ) -> int: return self.index.delete_single( @@ -50,7 +50,7 @@ class RetryDocumentIndex: self, doc_id: str, *, - tenant_id: str | None, + tenant_id: str, chunk_count: int | None, fields: VespaDocumentFields, ) -> int: diff --git a/backend/onyx/background/celery/tasks/shared/tasks.py b/backend/onyx/background/celery/tasks/shared/tasks.py index 474f579f9..4b5a96ca2 100644 --- a/backend/onyx/background/celery/tasks/shared/tasks.py +++ b/backend/onyx/background/celery/tasks/shared/tasks.py @@ -76,7 +76,7 @@ def document_by_cc_pair_cleanup_task( document_id: str, connector_id: int, credential_id: int, - tenant_id: str | None, + tenant_id: str, ) -> bool: """A lightweight subtask used to clean up document to cc pair relationships. Created by connection deletion and connector pruning parent tasks.""" @@ -297,7 +297,7 @@ def cloud_beat_task_generator( return None last_lock_time = time.monotonic() - tenant_ids: list[str] | list[None] = [] + tenant_ids: list[str] = [] try: tenant_ids = get_all_tenant_ids() diff --git a/backend/onyx/background/celery/tasks/vespa/tasks.py b/backend/onyx/background/celery/tasks/vespa/tasks.py index d58120edf..a3c84acec 100644 --- a/backend/onyx/background/celery/tasks/vespa/tasks.py +++ b/backend/onyx/background/celery/tasks/vespa/tasks.py @@ -76,7 +76,7 @@ logger = setup_logger() trail=False, bind=True, ) -def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None: +def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None: """Runs periodically to check if any document needs syncing. Generates sets of tasks for Celery if syncing is needed.""" @@ -208,7 +208,7 @@ def try_generate_stale_document_sync_tasks( db_session: Session, r: Redis, lock_beat: RedisLock, - tenant_id: str | None, + tenant_id: str, ) -> int | None: # the fence is up, do nothing @@ -284,7 +284,7 @@ def try_generate_document_set_sync_tasks( db_session: Session, r: Redis, lock_beat: RedisLock, - tenant_id: str | None, + tenant_id: str, ) -> int | None: lock_beat.reacquire() @@ -361,7 +361,7 @@ def try_generate_user_group_sync_tasks( db_session: Session, r: Redis, lock_beat: RedisLock, - tenant_id: str | None, + tenant_id: str, ) -> int | None: lock_beat.reacquire() @@ -448,7 +448,7 @@ def monitor_connector_taskset(r: Redis) -> None: def monitor_document_set_taskset( - tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session + tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session ) -> None: fence_key = key_bytes.decode("utf-8") document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key) @@ -523,9 +523,7 @@ def monitor_document_set_taskset( time_limit=LIGHT_TIME_LIMIT, max_retries=3, ) -def vespa_metadata_sync_task( - self: Task, document_id: str, *, tenant_id: str | None -) -> bool: +def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool: start = time.monotonic() completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index a3be32807..d569d6dfd 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -55,6 +55,7 @@ from onyx.utils.logger import setup_logger from onyx.utils.logger import TaskAttemptSingleton from onyx.utils.telemetry import create_milestone_and_report from onyx.utils.variable_functionality import global_version +from shared_configs.configs import MULTI_TENANT logger = setup_logger() @@ -67,7 +68,6 @@ def _get_connector_runner( batch_size: int, start_time: datetime, end_time: datetime, - tenant_id: str | None, leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE, ) -> ConnectorRunner: """ @@ -86,7 +86,6 @@ def _get_connector_runner( input_type=task, connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config, credential=attempt.connector_credential_pair.credential, - tenant_id=tenant_id, ) # validate the connector settings @@ -241,7 +240,7 @@ def _check_failure_threshold( def _run_indexing( db_session: Session, index_attempt_id: int, - tenant_id: str | None, + tenant_id: str, callback: IndexingHeartbeatInterface | None = None, ) -> None: """ @@ -388,7 +387,6 @@ def _run_indexing( batch_size=INDEX_BATCH_SIZE, start_time=window_start, end_time=window_end, - tenant_id=tenant_id, ) # don't use a checkpoint if we're explicitly indexing from @@ -681,7 +679,7 @@ def _run_indexing( def run_indexing_entrypoint( index_attempt_id: int, - tenant_id: str | None, + tenant_id: str, connector_credential_pair_id: int, is_ee: bool = False, callback: IndexingHeartbeatInterface | None = None, @@ -701,7 +699,7 @@ def run_indexing_entrypoint( attempt = transition_attempt_to_in_progress(index_attempt_id, db_session) tenant_str = "" - if tenant_id is not None: + if MULTI_TENANT: tenant_str = f" for tenant {tenant_id}" connector_name = attempt.connector_credential_pair.connector.name diff --git a/backend/onyx/connectors/factory.py b/backend/onyx/connectors/factory.py index 0fd0da49c..14221d2e3 100644 --- a/backend/onyx/connectors/factory.py +++ b/backend/onyx/connectors/factory.py @@ -5,7 +5,6 @@ from sqlalchemy.orm import Session from onyx.configs.app_configs import INTEGRATION_TESTS_MODE from onyx.configs.constants import DocumentSource -from onyx.configs.constants import DocumentSourceRequiringTenantContext from onyx.connectors.airtable.airtable_connector import AirtableConnector from onyx.connectors.asana.connector import AsanaConnector from onyx.connectors.axero.connector import AxeroConnector @@ -164,13 +163,9 @@ def instantiate_connector( input_type: InputType, connector_specific_config: dict[str, Any], credential: Credential, - tenant_id: str | None = None, ) -> BaseConnector: connector_class = identify_connector_class(source, input_type) - if source in DocumentSourceRequiringTenantContext: - connector_specific_config["tenant_id"] = tenant_id - connector = connector_class(**connector_specific_config) new_credentials = connector.load_credentials(credential.credential_json) @@ -184,7 +179,6 @@ def validate_ccpair_for_user( connector_id: int, credential_id: int, db_session: Session, - tenant_id: str | None, enforce_creation: bool = True, ) -> bool: if INTEGRATION_TESTS_MODE: @@ -216,7 +210,6 @@ def validate_ccpair_for_user( input_type=connector.input_type, connector_specific_config=connector.connector_specific_config, credential=credential, - tenant_id=tenant_id, ) except ConnectorValidationError as e: raise e diff --git a/backend/onyx/connectors/file/connector.py b/backend/onyx/connectors/file/connector.py index 4c99d457e..751495056 100644 --- a/backend/onyx/connectors/file/connector.py +++ b/backend/onyx/connectors/file/connector.py @@ -16,7 +16,7 @@ from onyx.connectors.interfaces import LoadConnector from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import Document from onyx.connectors.models import Section -from onyx.db.engine import get_session_with_tenant +from onyx.db.engine import get_session_with_current_tenant from onyx.file_processing.extract_file_text import detect_encoding from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_processing.extract_file_text import get_file_ext @@ -27,8 +27,6 @@ from onyx.file_processing.extract_file_text import read_pdf_file from onyx.file_processing.extract_file_text import read_text_file from onyx.file_store.file_store import get_default_file_store from onyx.utils.logger import setup_logger -from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA -from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() @@ -165,12 +163,10 @@ class LocalFileConnector(LoadConnector): def __init__( self, file_locations: list[Path | str], - tenant_id: str = POSTGRES_DEFAULT_SCHEMA, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.file_locations = [Path(file_location) for file_location in file_locations] self.batch_size = batch_size - self.tenant_id = tenant_id self.pdf_pass: str | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: @@ -179,9 +175,8 @@ class LocalFileConnector(LoadConnector): def load_from_state(self) -> GenerateDocumentsOutput: documents: list[Document] = [] - token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id) - with get_session_with_tenant(tenant_id=self.tenant_id) as db_session: + with get_session_with_current_tenant() as db_session: for file_path in self.file_locations: current_datetime = datetime.now(timezone.utc) files = _read_files_and_metadata( @@ -203,8 +198,6 @@ class LocalFileConnector(LoadConnector): if documents: yield documents - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) - if __name__ == "__main__": connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]]) diff --git a/backend/onyx/db/api_key.py b/backend/onyx/db/api_key.py index e7cd5e10a..1a992fd51 100644 --- a/backend/onyx/db/api_key.py +++ b/backend/onyx/db/api_key.py @@ -16,7 +16,6 @@ from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER from onyx.db.models import ApiKey from onyx.db.models import User from onyx.server.api_key.models import APIKeyArgs -from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id @@ -73,7 +72,7 @@ def insert_api_key( # Get tenant_id from context var (will be default schema for single tenant) tenant_id = get_current_tenant_id() - api_key = generate_api_key(tenant_id if MULTI_TENANT else None) + api_key = generate_api_key(tenant_id) api_key_user_id = uuid.uuid4() display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER diff --git a/backend/onyx/db/engine.py b/backend/onyx/db/engine.py index 668ef68e3..840b5d93d 100644 --- a/backend/onyx/db/engine.py +++ b/backend/onyx/db/engine.py @@ -258,11 +258,11 @@ class SqlEngine: cls._engine = None -def get_all_tenant_ids() -> list[str] | list[None]: +def get_all_tenant_ids() -> list[str]: """Returning [None] means the only tenant is the 'public' or self hosted tenant.""" if not MULTI_TENANT: - return [None] + return [POSTGRES_DEFAULT_SCHEMA] with get_session_with_shared_schema() as session: result = session.execute( @@ -417,7 +417,7 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]: @contextmanager -def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]: +def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]: """ Generate a database session for a specific tenant. """ diff --git a/backend/onyx/document_index/document_index_utils.py b/backend/onyx/document_index/document_index_utils.py index 28529d21c..dc7892141 100644 --- a/backend/onyx/document_index/document_index_utils.py +++ b/backend/onyx/document_index/document_index_utils.py @@ -81,7 +81,7 @@ def translate_boost_count_to_multiplier(boost: int) -> float: # Vespa's Document API. def get_document_chunk_ids( enriched_document_info_list: list[EnrichedDocumentIndexingInfo], - tenant_id: str | None, + tenant_id: str, large_chunks_enabled: bool, ) -> list[UUID]: doc_chunk_ids = [] @@ -139,7 +139,7 @@ def get_uuid_from_chunk_info( *, document_id: str, chunk_id: int, - tenant_id: str | None, + tenant_id: str, large_chunk_id: int | None = None, ) -> UUID: """NOTE: be VERY carefuly about changing this function. If changed without a migration, @@ -154,7 +154,7 @@ def get_uuid_from_chunk_info( "large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id) ) unique_identifier_string = "_".join([doc_str, chunk_index]) - if tenant_id and MULTI_TENANT: + if MULTI_TENANT: unique_identifier_string += "_" + tenant_id uuid_value = uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string) diff --git a/backend/onyx/document_index/interfaces.py b/backend/onyx/document_index/interfaces.py index 08dbfdc9e..663e5feee 100644 --- a/backend/onyx/document_index/interfaces.py +++ b/backend/onyx/document_index/interfaces.py @@ -43,7 +43,7 @@ class IndexBatchParams: doc_id_to_previous_chunk_cnt: dict[str, int | None] doc_id_to_new_chunk_cnt: dict[str, int] - tenant_id: str | None + tenant_id: str large_chunks_enabled: bool @@ -222,7 +222,7 @@ class Deletable(abc.ABC): self, doc_id: str, *, - tenant_id: str | None, + tenant_id: str, chunk_count: int | None, ) -> int: """ @@ -249,7 +249,7 @@ class Updatable(abc.ABC): self, doc_id: str, *, - tenant_id: str | None, + tenant_id: str, chunk_count: int | None, fields: VespaDocumentFields, ) -> int: @@ -270,9 +270,7 @@ class Updatable(abc.ABC): raise NotImplementedError @abc.abstractmethod - def update( - self, update_requests: list[UpdateRequest], *, tenant_id: str | None - ) -> None: + def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None: """ Updates some set of chunks. The document and fields to update are specified in the update requests. Each update request in the list applies its changes to a list of document ids. diff --git a/backend/onyx/document_index/vespa/index.py b/backend/onyx/document_index/vespa/index.py index 2d61aad01..c2e631f6c 100644 --- a/backend/onyx/document_index/vespa/index.py +++ b/backend/onyx/document_index/vespa/index.py @@ -468,9 +468,7 @@ class VespaIndex(DocumentIndex): failure_msg = f"Failed to update document: {future_to_document_id[future]}" raise requests.HTTPError(failure_msg) from e - def update( - self, update_requests: list[UpdateRequest], *, tenant_id: str | None - ) -> None: + def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None: logger.debug(f"Updating {len(update_requests)} documents in Vespa") # Handle Vespa character limitations @@ -618,7 +616,7 @@ class VespaIndex(DocumentIndex): doc_id: str, *, chunk_count: int | None, - tenant_id: str | None, + tenant_id: str, fields: VespaDocumentFields, ) -> int: """Note: if the document id does not exist, the update will be a no-op and the @@ -661,7 +659,7 @@ class VespaIndex(DocumentIndex): self, doc_id: str, *, - tenant_id: str | None, + tenant_id: str, chunk_count: int | None, ) -> int: total_chunks_deleted = 0 diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index 76dc20ecb..fe95f2a9b 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -158,8 +158,8 @@ def index_doc_batch_with_handler( document_batch: list[Document], index_attempt_metadata: IndexAttemptMetadata, db_session: Session, + tenant_id: str, ignore_time_skip: bool = False, - tenant_id: str | None = None, ) -> IndexingPipelineResult: try: index_pipeline_result = index_doc_batch( @@ -317,8 +317,8 @@ def index_doc_batch( document_index: DocumentIndex, index_attempt_metadata: IndexAttemptMetadata, db_session: Session, + tenant_id: str, ignore_time_skip: bool = False, - tenant_id: str | None = None, filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents, ) -> IndexingPipelineResult: """Takes different pieces of the indexing pipeline and applies it to a batch of documents @@ -525,9 +525,9 @@ def build_indexing_pipeline( embedder: IndexingEmbedder, document_index: DocumentIndex, db_session: Session, + tenant_id: str, chunker: Chunker | None = None, ignore_time_skip: bool = False, - tenant_id: str | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> IndexingPipelineProtocol: """Builds a pipeline which takes in a list (batch) of docs and indexes them.""" diff --git a/backend/onyx/indexing/models.py b/backend/onyx/indexing/models.py index f62a29f13..0c4451cc7 100644 --- a/backend/onyx/indexing/models.py +++ b/backend/onyx/indexing/models.py @@ -84,7 +84,7 @@ class DocMetadataAwareIndexChunk(IndexChunk): negative -> ranked lower. """ - tenant_id: str | None = None + tenant_id: str access: "DocumentAccess" document_sets: set[str] boost: int @@ -96,7 +96,7 @@ class DocMetadataAwareIndexChunk(IndexChunk): access: "DocumentAccess", document_sets: set[str], boost: int, - tenant_id: str | None, + tenant_id: str, ) -> "DocMetadataAwareIndexChunk": index_chunk_data = index_chunk.model_dump() return cls( diff --git a/backend/onyx/main.py b/backend/onyx/main.py index c852fd0fe..2444e6f19 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -219,7 +219,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # If we are multi-tenant, we need to only set up initial public tables with Session(engine) as db_session: - setup_onyx(db_session, None) + setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA) else: setup_multitenant_onyx() diff --git a/backend/onyx/onyxbot/slack/blocks.py b/backend/onyx/onyxbot/slack/blocks.py index 66415cea9..2c2138253 100644 --- a/backend/onyx/onyxbot/slack/blocks.py +++ b/backend/onyx/onyxbot/slack/blocks.py @@ -410,7 +410,7 @@ def _build_qa_response_blocks( def _build_continue_in_web_ui_block( - tenant_id: str | None, + tenant_id: str, message_id: int | None, ) -> Block: if message_id is None: @@ -482,7 +482,7 @@ def build_follow_up_resolved_blocks( def build_slack_response_blocks( answer: ChatOnyxBotResponse, - tenant_id: str | None, + tenant_id: str, message_info: SlackMessageInfo, channel_conf: ChannelConfig | None, use_citations: bool, diff --git a/backend/onyx/onyxbot/slack/handlers/handle_buttons.py b/backend/onyx/onyxbot/slack/handlers/handle_buttons.py index 928b281bd..42428f231 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_buttons.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_buttons.py @@ -151,7 +151,7 @@ def handle_slack_feedback( user_id_to_post_confirmation: str, channel_id_to_post_confirmation: str, thread_ts_to_post_confirmation: str, - tenant_id: str | None, + tenant_id: str, ) -> None: message_id, doc_id, doc_rank = decompose_action_id(feedback_id) diff --git a/backend/onyx/onyxbot/slack/handlers/handle_message.py b/backend/onyx/onyxbot/slack/handlers/handle_message.py index 751c212e3..3d38417e2 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_message.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_message.py @@ -109,7 +109,7 @@ def handle_message( slack_channel_config: SlackChannelConfig, client: WebClient, feedback_reminder_id: str | None, - tenant_id: str | None, + tenant_id: str, ) -> bool: """Potentially respond to the user message depending on filters and if an answer was generated diff --git a/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py b/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py index b9e711465..f7c7d8f1a 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py @@ -72,7 +72,7 @@ def handle_regular_answer( channel: str, logger: OnyxLoggingAdapter, feedback_reminder_id: str | None, - tenant_id: str | None, + tenant_id: str, num_retries: int = DANSWER_BOT_NUM_RETRIES, thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE, should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS, diff --git a/backend/onyx/onyxbot/slack/listener.py b/backend/onyx/onyxbot/slack/listener.py index 325d26dd9..6fb590622 100644 --- a/backend/onyx/onyxbot/slack/listener.py +++ b/backend/onyx/onyxbot/slack/listener.py @@ -123,13 +123,13 @@ _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT" class SlackbotHandler: def __init__(self) -> None: logger.info("Initializing SlackbotHandler") - self.tenant_ids: Set[str | None] = set() + self.tenant_ids: Set[str] = set() # The keys for these dictionaries are tuples of (tenant_id, slack_bot_id) - self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {} - self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {} + self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {} + self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {} # Store Redis lock objects here so we can release them properly - self.redis_locks: Dict[str | None, Lock] = {} + self.redis_locks: Dict[str, Lock] = {} self.running = True self.pod_id = self.get_pod_id() @@ -193,7 +193,7 @@ class SlackbotHandler: self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL) def _manage_clients_per_tenant( - self, db_session: Session, tenant_id: str | None, bot: SlackBot + self, db_session: Session, tenant_id: str, bot: SlackBot ) -> None: """ - If the tokens are missing or empty, close the socket client and remove them. @@ -385,7 +385,7 @@ class SlackbotHandler: finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) - def _remove_tenant(self, tenant_id: str | None) -> None: + def _remove_tenant(self, tenant_id: str) -> None: """ Helper to remove a tenant from `self.tenant_ids` and close any socket clients. (Lock release now happens in `acquire_tenants()`, not here.) @@ -415,7 +415,7 @@ class SlackbotHandler: ) def start_socket_client( - self, slack_bot_id: int, tenant_id: str | None, slack_bot_tokens: SlackBotTokens + self, slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens ) -> None: socket_client: TenantSocketModeClient = _get_socket_client( slack_bot_tokens, tenant_id, slack_bot_id @@ -912,7 +912,7 @@ def create_process_slack_event() -> ( def _get_socket_client( - slack_bot_tokens: SlackBotTokens, tenant_id: str | None, slack_bot_id: int + slack_bot_tokens: SlackBotTokens, tenant_id: str, slack_bot_id: int ) -> TenantSocketModeClient: # For more info on how to set this up, checkout the docs: # https://docs.onyx.app/slack_bot_setup diff --git a/backend/onyx/onyxbot/slack/utils.py b/backend/onyx/onyxbot/slack/utils.py index bc08dfdf5..f5d2209b0 100644 --- a/backend/onyx/onyxbot/slack/utils.py +++ b/backend/onyx/onyxbot/slack/utils.py @@ -570,7 +570,7 @@ def read_slack_thread( def slack_usage_report( - action: str, sender_id: str | None, client: WebClient, tenant_id: str | None + action: str, sender_id: str | None, client: WebClient, tenant_id: str ) -> None: if DISABLE_TELEMETRY: return @@ -663,9 +663,7 @@ def get_feedback_visibility() -> FeedbackVisibility: class TenantSocketModeClient(SocketModeClient): - def __init__( - self, tenant_id: str | None, slack_bot_id: int, *args: Any, **kwargs: Any - ): + def __init__(self, tenant_id: str, slack_bot_id: int, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.tenant_id = tenant_id self.slack_bot_id = slack_bot_id diff --git a/backend/onyx/redis/redis_connector.py b/backend/onyx/redis/redis_connector.py index 196f2306c..d8e4854c6 100644 --- a/backend/onyx/redis/redis_connector.py +++ b/backend/onyx/redis/redis_connector.py @@ -16,10 +16,10 @@ class RedisConnector: """Composes several classes to simplify interacting with a connector and its associated background tasks / associated redis interactions.""" - def __init__(self, tenant_id: str | None, id: int) -> None: + def __init__(self, tenant_id: str, id: int) -> None: """id: a connector credential pair id""" - self.tenant_id: str | None = tenant_id + self.tenant_id: str = tenant_id self.id: int = id self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id) diff --git a/backend/onyx/redis/redis_connector_credential_pair.py b/backend/onyx/redis/redis_connector_credential_pair.py index 463a0c1a2..5bbbd2e08 100644 --- a/backend/onyx/redis/redis_connector_credential_pair.py +++ b/backend/onyx/redis/redis_connector_credential_pair.py @@ -31,7 +31,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper): PREFIX = "connectorsync" TASKSET_PREFIX = PREFIX + "_taskset" - def __init__(self, tenant_id: str | None, id: int) -> None: + def __init__(self, tenant_id: str, id: int) -> None: super().__init__(tenant_id, str(id)) # documents that should be skipped @@ -60,7 +60,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper): db_session: Session, redis_client: Redis, lock: RedisLock, - tenant_id: str | None, + tenant_id: str, ) -> tuple[int, int] | None: """We can limit the number of tasks generated here, which is useful to prevent one tenant from overwhelming the sync queue. diff --git a/backend/onyx/redis/redis_connector_delete.py b/backend/onyx/redis/redis_connector_delete.py index d475c2545..98a36fe78 100644 --- a/backend/onyx/redis/redis_connector_delete.py +++ b/backend/onyx/redis/redis_connector_delete.py @@ -39,8 +39,8 @@ class RedisConnectorDelete: ACTIVE_PREFIX = PREFIX + "_active" ACTIVE_TTL = 3600 - def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: - self.tenant_id: str | None = tenant_id + def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None: + self.tenant_id: str = tenant_id self.id = id self.redis = redis diff --git a/backend/onyx/redis/redis_connector_doc_perm_sync.py b/backend/onyx/redis/redis_connector_doc_perm_sync.py index 1dc8dfff5..5c420d59e 100644 --- a/backend/onyx/redis/redis_connector_doc_perm_sync.py +++ b/backend/onyx/redis/redis_connector_doc_perm_sync.py @@ -52,8 +52,8 @@ class RedisConnectorPermissionSync: ACTIVE_PREFIX = PREFIX + "_active" ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2 - def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: - self.tenant_id: str | None = tenant_id + def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None: + self.tenant_id: str = tenant_id self.id = id self.redis = redis diff --git a/backend/onyx/redis/redis_connector_ext_group_sync.py b/backend/onyx/redis/redis_connector_ext_group_sync.py index a63463df8..7cc0f2d20 100644 --- a/backend/onyx/redis/redis_connector_ext_group_sync.py +++ b/backend/onyx/redis/redis_connector_ext_group_sync.py @@ -44,8 +44,8 @@ class RedisConnectorExternalGroupSync: ACTIVE_PREFIX = PREFIX + "_active" ACTIVE_TTL = 3600 - def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: - self.tenant_id: str | None = tenant_id + def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None: + self.tenant_id: str = tenant_id self.id = id self.redis = redis diff --git a/backend/onyx/redis/redis_connector_index.py b/backend/onyx/redis/redis_connector_index.py index 868af24c7..5a67879e5 100644 --- a/backend/onyx/redis/redis_connector_index.py +++ b/backend/onyx/redis/redis_connector_index.py @@ -52,12 +52,12 @@ class RedisConnectorIndex: def __init__( self, - tenant_id: str | None, + tenant_id: str, id: int, search_settings_id: int, redis: redis.Redis, ) -> None: - self.tenant_id: str | None = tenant_id + self.tenant_id: str = tenant_id self.id = id self.search_settings_id = search_settings_id self.redis = redis diff --git a/backend/onyx/redis/redis_connector_prune.py b/backend/onyx/redis/redis_connector_prune.py index 3a61c059d..e36c1ce58 100644 --- a/backend/onyx/redis/redis_connector_prune.py +++ b/backend/onyx/redis/redis_connector_prune.py @@ -52,8 +52,8 @@ class RedisConnectorPrune: ACTIVE_PREFIX = PREFIX + "_active" ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2 - def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: - self.tenant_id: str | None = tenant_id + def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None: + self.tenant_id: str = tenant_id self.id = id self.redis = redis diff --git a/backend/onyx/redis/redis_connector_stop.py b/backend/onyx/redis/redis_connector_stop.py index 3567cf9b6..5dc1e7364 100644 --- a/backend/onyx/redis/redis_connector_stop.py +++ b/backend/onyx/redis/redis_connector_stop.py @@ -13,8 +13,8 @@ class RedisConnectorStop: TIMEOUT_PREFIX = f"{PREFIX}_timeout" TIMEOUT_TTL = 300 - def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: - self.tenant_id: str | None = tenant_id + def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None: + self.tenant_id: str = tenant_id self.id: int = id self.redis = redis diff --git a/backend/onyx/redis/redis_document_set.py b/backend/onyx/redis/redis_document_set.py index 6fd5a453b..9bedb0965 100644 --- a/backend/onyx/redis/redis_document_set.py +++ b/backend/onyx/redis/redis_document_set.py @@ -23,7 +23,7 @@ class RedisDocumentSet(RedisObjectHelper): FENCE_PREFIX = PREFIX + "_fence" TASKSET_PREFIX = PREFIX + "_taskset" - def __init__(self, tenant_id: str | None, id: int) -> None: + def __init__(self, tenant_id: str, id: int) -> None: super().__init__(tenant_id, str(id)) @property @@ -58,7 +58,7 @@ class RedisDocumentSet(RedisObjectHelper): db_session: Session, redis_client: Redis, lock: RedisLock, - tenant_id: str | None, + tenant_id: str, ) -> tuple[int, int] | None: """Max tasks is ignored for now until we can build the logic to mark the document set up to date over multiple batches. diff --git a/backend/onyx/redis/redis_object_helper.py b/backend/onyx/redis/redis_object_helper.py index 34b301882..e166b75e9 100644 --- a/backend/onyx/redis/redis_object_helper.py +++ b/backend/onyx/redis/redis_object_helper.py @@ -14,8 +14,8 @@ class RedisObjectHelper(ABC): FENCE_PREFIX = PREFIX + "_fence" TASKSET_PREFIX = PREFIX + "_taskset" - def __init__(self, tenant_id: str | None, id: str): - self._tenant_id: str | None = tenant_id + def __init__(self, tenant_id: str, id: str): + self._tenant_id: str = tenant_id self._id: str = id self.redis = get_redis_client(tenant_id=tenant_id) @@ -87,7 +87,7 @@ class RedisObjectHelper(ABC): db_session: Session, redis_client: Redis, lock: RedisLock, - tenant_id: str | None, + tenant_id: str, ) -> tuple[int, int] | None: """First element should be the number of actual tasks generated, second should be the number of docs that were candidates to be synced for the cc pair. diff --git a/backend/onyx/redis/redis_usergroup.py b/backend/onyx/redis/redis_usergroup.py index 92ff5548c..e5793dd4c 100644 --- a/backend/onyx/redis/redis_usergroup.py +++ b/backend/onyx/redis/redis_usergroup.py @@ -24,7 +24,7 @@ class RedisUserGroup(RedisObjectHelper): FENCE_PREFIX = PREFIX + "_fence" TASKSET_PREFIX = PREFIX + "_taskset" - def __init__(self, tenant_id: str | None, id: int) -> None: + def __init__(self, tenant_id: str, id: int) -> None: super().__init__(tenant_id, str(id)) @property @@ -59,7 +59,7 @@ class RedisUserGroup(RedisObjectHelper): db_session: Session, redis_client: Redis, lock: RedisLock, - tenant_id: str | None, + tenant_id: str, ) -> tuple[int, int] | None: """Max tasks is ignored for now until we can build the logic to mark the user group up to date over multiple batches. diff --git a/backend/onyx/seeding/load_docs.py b/backend/onyx/seeding/load_docs.py index 40e848b3d..38ad52345 100644 --- a/backend/onyx/seeding/load_docs.py +++ b/backend/onyx/seeding/load_docs.py @@ -37,13 +37,15 @@ from onyx.key_value_store.interface import KvKeyNotFoundError from onyx.server.documents.models import ConnectorBase from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation +from shared_configs.configs import MULTI_TENANT +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() def _create_indexable_chunks( preprocessed_docs: list[dict], - tenant_id: str | None, + tenant_id: str, ) -> tuple[list[Document], list[DocMetadataAwareIndexChunk]]: ids_to_documents = {} chunks = [] @@ -86,7 +88,7 @@ def _create_indexable_chunks( mini_chunk_embeddings=[], ), title_embedding=preprocessed_doc["title_embedding"], - tenant_id=tenant_id, + tenant_id=tenant_id if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA, access=default_public_access, document_sets=set(), boost=DEFAULT_BOOST, @@ -111,7 +113,7 @@ def load_processed_docs(cohere_enabled: bool) -> list[dict]: def seed_initial_documents( - db_session: Session, tenant_id: str | None, cohere_enabled: bool = False + db_session: Session, tenant_id: str, cohere_enabled: bool = False ) -> None: """ Seed initial documents so users don't have an empty index to start diff --git a/backend/onyx/server/documents/cc_pair.py b/backend/onyx/server/documents/cc_pair.py index 18a7be23d..281ab125d 100644 --- a/backend/onyx/server/documents/cc_pair.py +++ b/backend/onyx/server/documents/cc_pair.py @@ -620,7 +620,7 @@ def associate_credential_to_connector( ) try: - validate_ccpair_for_user(connector_id, credential_id, db_session, tenant_id) + validate_ccpair_for_user(connector_id, credential_id, db_session) response = add_credential_to_connector( db_session=db_session, diff --git a/backend/onyx/server/documents/connector.py b/backend/onyx/server/documents/connector.py index 1130b38cd..60511ae93 100644 --- a/backend/onyx/server/documents/connector.py +++ b/backend/onyx/server/documents/connector.py @@ -902,7 +902,6 @@ def create_connector_with_mock_credential( connector_id=connector_id, credential_id=credential_id, db_session=db_session, - tenant_id=tenant_id, ) response = add_credential_to_connector( db_session=db_session, diff --git a/backend/onyx/server/documents/credential.py b/backend/onyx/server/documents/credential.py index 6ef587682..5060d15a3 100644 --- a/backend/onyx/server/documents/credential.py +++ b/backend/onyx/server/documents/credential.py @@ -18,7 +18,6 @@ from onyx.db.credentials import fetch_credentials_by_source_for_user from onyx.db.credentials import fetch_credentials_for_user from onyx.db.credentials import swap_credentials_connector from onyx.db.credentials import update_credential -from onyx.db.engine import get_current_tenant_id from onyx.db.engine import get_session from onyx.db.models import DocumentSource from onyx.db.models import User @@ -100,13 +99,11 @@ def swap_credentials_for_connector( credential_swap_req: CredentialSwapRequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), - tenant_id: str | None = Depends(get_current_tenant_id), ) -> StatusResponse: validate_ccpair_for_user( credential_swap_req.connector_id, credential_swap_req.new_credential_id, db_session, - tenant_id, ) connector_credential_pair = swap_credentials_connector( diff --git a/backend/onyx/server/settings/models.py b/backend/onyx/server/settings/models.py index 1f57525c9..58dd0a51f 100644 --- a/backend/onyx/server/settings/models.py +++ b/backend/onyx/server/settings/models.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from onyx.configs.constants import NotificationType from onyx.db.models import Notification as NotificationDBModel +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA class PageType(str, Enum): @@ -54,4 +55,4 @@ class Settings(BaseModel): class UserSettings(Settings): notifications: list[Notification] needs_reindexing: bool - tenant_id: str | None = None + tenant_id: str = POSTGRES_DEFAULT_SCHEMA diff --git a/backend/onyx/setup.py b/backend/onyx/setup.py index 94279aad1..fc4b9f983 100644 --- a/backend/onyx/setup.py +++ b/backend/onyx/setup.py @@ -65,7 +65,7 @@ logger = setup_logger() def setup_onyx( - db_session: Session, tenant_id: str | None, cohere_enabled: bool = False + db_session: Session, tenant_id: str, cohere_enabled: bool = False ) -> None: """ Setup Onyx for a particular tenant. In the Single Tenant case, it will set it up for the default schema diff --git a/backend/scripts/debugging/onyx_vespa.py b/backend/scripts/debugging/onyx_vespa.py index 1e9465963..954072feb 100644 --- a/backend/scripts/debugging/onyx_vespa.py +++ b/backend/scripts/debugging/onyx_vespa.py @@ -260,7 +260,7 @@ def get_documents_for_tenant_connector( def search_for_document( index_name: str, document_id: str | None = None, - tenant_id: str | None = None, + tenant_id: str = POSTGRES_DEFAULT_SCHEMA, max_hits: int | None = 10, ) -> List[Dict[str, Any]]: yql_query = f"select * from sources {index_name}" @@ -507,9 +507,9 @@ def get_number_of_chunks_we_think_exist( class VespaDebugging: # Class for managing Vespa debugging actions. - def __init__(self, tenant_id: str | None = None): + def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA): CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - self.tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id else tenant_id + self.tenant_id = tenant_id self.index_name = get_index_name(self.tenant_id) def sample_document_counts(self) -> None: @@ -603,7 +603,7 @@ class VespaDebugging: delete_documents_for_tenant(self.index_name, self.tenant_id, count=count) def search_for_document( - self, document_id: str | None = None, tenant_id: str | None = None + self, document_id: str | None = None, tenant_id: str = POSTGRES_DEFAULT_SCHEMA ) -> List[Dict[str, Any]]: return search_for_document(self.index_name, document_id, tenant_id) diff --git a/backend/scripts/force_delete_connector_by_id.py b/backend/scripts/force_delete_connector_by_id.py index 90a3e3cee..8f3e120a8 100755 --- a/backend/scripts/force_delete_connector_by_id.py +++ b/backend/scripts/force_delete_connector_by_id.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import Session from onyx.db.document import delete_documents_complete__no_commit from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.search_settings import get_active_search_settings +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # Modify sys.path current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -74,7 +75,7 @@ def _unsafe_deletion( for document in documents: document_index.delete_single( doc_id=document.id, - tenant_id=None, + tenant_id=POSTGRES_DEFAULT_SCHEMA, chunk_count=document.chunk_count, ) diff --git a/backend/scripts/orphan_doc_cleanup_script.py b/backend/scripts/orphan_doc_cleanup_script.py index 499096387..413039936 100644 --- a/backend/scripts/orphan_doc_cleanup_script.py +++ b/backend/scripts/orphan_doc_cleanup_script.py @@ -6,6 +6,7 @@ from sqlalchemy import text from sqlalchemy.orm import Session from onyx.document_index.document_index_utils import get_multipass_config +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # makes it so `PYTHONPATH=.` is not required when running this script parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -96,7 +97,9 @@ def main() -> None: try: print(f"Deleting document {doc_id} in Vespa") chunks_deleted = vespa_index.delete_single( - doc_id, tenant_id=None, chunk_count=document.chunk_count + doc_id, + tenant_id=POSTGRES_DEFAULT_SCHEMA, + chunk_count=document.chunk_count, ) if chunks_deleted > 0: print( diff --git a/backend/shared_configs/contextvars.py b/backend/shared_configs/contextvars.py index 4237f3e97..01ecd4073 100644 --- a/backend/shared_configs/contextvars.py +++ b/backend/shared_configs/contextvars.py @@ -18,5 +18,7 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[ def get_current_tenant_id() -> str: tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if tenant_id is None: + if not MULTI_TENANT: + return POSTGRES_DEFAULT_SCHEMA raise RuntimeError("Tenant ID is not set. This should never happen.") return tenant_id