From 336f931c85bbaad07051e102814b1a6cf7e80613 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 23 Oct 2024 13:29:26 -0700 Subject: [PATCH] add multi tenancy to redis --- .../background/celery/apps/app_base.py | 67 +++++-- .../danswer/background/celery/apps/primary.py | 184 ++++++++++-------- .../danswer/background/celery/celery_utils.py | 16 +- .../celery/tasks/connector_deletion/tasks.py | 4 +- .../background/celery/tasks/indexing/tasks.py | 7 +- .../background/celery/tasks/pruning/tasks.py | 6 +- .../background/celery/tasks/vespa/tasks.py | 9 +- backend/danswer/db/credentials.py | 1 - backend/danswer/db/engine.py | 32 ++- backend/danswer/key_value_store/store.py | 3 +- backend/danswer/redis/redis_pool.py | 105 +++++++++- backend/danswer/server/documents/cc_pair.py | 6 +- backend/danswer/server/documents/connector.py | 9 +- .../danswer/background/celery/apps/primary.py | 16 +- 14 files changed, 324 insertions(+), 141 deletions(-) diff --git a/backend/danswer/background/celery/apps/app_base.py b/backend/danswer/background/celery/apps/app_base.py index 2a52abde5..0fb997d4f 100644 --- a/backend/danswer/background/celery/apps/app_base.py +++ b/backend/danswer/background/celery/apps/app_base.py @@ -19,6 +19,7 @@ from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisUserGroup from danswer.background.celery.celery_utils import celery_is_worker_primary from danswer.configs.constants import DanswerRedisLocks +from danswer.db.engine import get_all_tenant_ids from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import ColoredFormatter from danswer.utils.logger import PlainFormatter @@ -56,7 +57,7 @@ def on_task_postrun( task_id: str | None = None, task: Task | None = None, args: tuple | None = None, - kwargs: dict | None = None, + kwargs: dict[str, Any] | None = None, retval: Any | None = None, state: str | None = None, **kwds: Any, @@ -83,7 +84,19 @@ def on_task_postrun( if not task_id: return - r = get_redis_client() + # Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg + if not kwargs: + logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs") + tenant_id = None + else: + tenant_id = kwargs.get("tenant_id") + + task_logger.debug( + f"Task {task.name} (ID: {task_id}) completed with state: {state} " + f"{f'for tenant_id={tenant_id}' if tenant_id else ''}" + ) + + r = get_redis_client(tenant_id=tenant_id) if task_id.startswith(RedisConnectorCredentialPair.PREFIX): r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id) @@ -124,7 +137,7 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None def wait_for_redis(sender: Any, **kwargs: Any) -> None: - r = get_redis_client() + r = get_redis_client(tenant_id=None) WAIT_INTERVAL = 5 WAIT_LIMIT = 60 @@ -157,26 +170,44 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None: def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None: - r = get_redis_client() - WAIT_INTERVAL = 5 WAIT_LIMIT = 60 logger.info("Running as a secondary celery worker.") - logger.info("Waiting for primary worker to be ready...") + logger.info("Waiting for all tenant primary workers to be ready...") time_start = time.monotonic() + while True: - if r.exists(DanswerRedisLocks.PRIMARY_WORKER): + tenant_ids = get_all_tenant_ids() + # Check if we have a primary worker lock for each tenant + all_tenants_ready = all( + get_redis_client(tenant_id=tenant_id).exists( + DanswerRedisLocks.PRIMARY_WORKER + ) + for tenant_id in tenant_ids + ) + + if all_tenants_ready: break - time.monotonic() time_elapsed = time.monotonic() - time_start - logger.info( - f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" + ready_tenants = sum( + 1 + for tenant_id in tenant_ids + if get_redis_client(tenant_id=tenant_id).exists( + DanswerRedisLocks.PRIMARY_WORKER + ) ) + + logger.info( + f"Not all tenant primary workers are ready yet. " + f"Ready tenants: {ready_tenants}/{len(tenant_ids)} " + f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" + ) + if time_elapsed > WAIT_LIMIT: msg = ( - f"Primary worker was not ready within the timeout. " + f"Not all tenant primary workers were ready within the timeout " f"({WAIT_LIMIT} seconds). Exiting..." ) logger.error(msg) @@ -184,7 +215,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None: time.sleep(WAIT_INTERVAL) - logger.info("Wait for primary worker completed successfully. Continuing...") + logger.info("All tenant primary workers are ready. Continuing...") return @@ -196,14 +227,14 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: if not celery_is_worker_primary(sender): return - if not sender.primary_worker_lock: + if not hasattr(sender, "primary_worker_locks"): return - logger.info("Releasing primary worker lock.") - lock = sender.primary_worker_lock - if lock.owned(): - lock.release() - sender.primary_worker_lock = None + for tenant_id, lock in sender.primary_worker_locks.items(): + if lock and lock.owned(): + logger.debug(f"Releasing lock for tenant {tenant_id}") + lock.release() + sender.primary_worker_locks[tenant_id] = None def on_setup_logging( diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py index 58e464f37..2b77a90d8 100644 --- a/backend/danswer/background/celery/apps/primary.py +++ b/backend/danswer/background/celery/apps/primary.py @@ -1,7 +1,6 @@ import multiprocessing from typing import Any -import redis from celery import bootsteps # type: ignore from celery import Celery from celery import signals @@ -24,6 +23,7 @@ from danswer.background.celery.celery_utils import celery_is_worker_primary from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT from danswer.configs.constants import DanswerRedisLocks from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME +from danswer.db.engine import get_all_tenant_ids from danswer.db.engine import SqlEngine from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger @@ -80,81 +80,83 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: # This is singleton work that should be done on startup exactly once # by the primary worker - r = get_redis_client() + tenant_ids = get_all_tenant_ids() + for tenant_id in tenant_ids: + r = get_redis_client(tenant_id=tenant_id) - # For the moment, we're assuming that we are the only primary worker - # that should be running. - # TODO: maybe check for or clean up another zombie primary worker if we detect it - r.delete(DanswerRedisLocks.PRIMARY_WORKER) + # For the moment, we're assuming that we are the only primary worker + # that should be running. + # TODO: maybe check for or clean up another zombie primary worker if we detect it + r.delete(DanswerRedisLocks.PRIMARY_WORKER) - # this process wide lock is taken to help other workers start up in order. - # it is planned to use this lock to enforce singleton behavior on the primary - # worker, since the primary worker does redis cleanup on startup, but this isn't - # implemented yet. - lock = r.lock( - DanswerRedisLocks.PRIMARY_WORKER, - timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, - ) + # this process wide lock is taken to help other workers start up in order. + # it is planned to use this lock to enforce singleton behavior on the primary + # worker, since the primary worker does redis cleanup on startup, but this isn't + # implemented yet. + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) - logger.info("Primary worker lock: Acquire starting.") - acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) - if acquired: - logger.info("Primary worker lock: Acquire succeeded.") - else: - logger.error("Primary worker lock: Acquire failed!") - raise WorkerShutdown("Primary worker lock could not be acquired!") + logger.info("Primary worker lock: Acquire starting.") + acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) + if acquired: + logger.info("Primary worker lock: Acquire succeeded.") + else: + logger.error("Primary worker lock: Acquire failed!") + raise WorkerShutdown("Primary worker lock could not be acquired!") - sender.primary_worker_lock = lock + sender.primary_worker_locks[tenant_id] = lock - # As currently designed, when this worker starts as "primary", we reinitialize redis - # to a clean state (for our purposes, anyway) - r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) - r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) + # As currently designed, when this worker starts as "primary", we reinitialize redis + # to a clean state (for our purposes, anyway) + r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) + r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) - r.delete(RedisConnectorCredentialPair.get_taskset_key()) - r.delete(RedisConnectorCredentialPair.get_fence_key()) + r.delete(RedisConnectorCredentialPair.get_taskset_key()) + r.delete(RedisConnectorCredentialPair.get_fence_key()) - for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): + r.delete(key) # @worker_process_init.connect @@ -217,42 +219,58 @@ class HubPeriodicTask(bootsteps.StartStopStep): def run_periodic_task(self, worker: Any) -> None: try: - if not worker.primary_worker_lock: + if not celery_is_worker_primary(worker): return - if not hasattr(worker, "primary_worker_lock"): + if not hasattr(worker, "primary_worker_locks"): return - r = get_redis_client() + # Retrieve all tenant IDs + tenant_ids = get_all_tenant_ids() - lock: redis.lock.Lock = worker.primary_worker_lock + for tenant_id in tenant_ids: + lock = worker.primary_worker_locks.get(tenant_id) + if not lock: + continue # Skip if no lock for this tenant - if lock.owned(): - task_logger.debug("Reacquiring primary worker lock.") - lock.reacquire() - else: - task_logger.warning( - "Full acquisition of primary worker lock. " - "Reasons could be computer sleep or a clock change." - ) - lock = r.lock( - DanswerRedisLocks.PRIMARY_WORKER, - timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, - ) + r = get_redis_client(tenant_id=tenant_id) - task_logger.info("Primary worker lock: Acquire starting.") - acquired = lock.acquire( - blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 - ) - if acquired: - task_logger.info("Primary worker lock: Acquire succeeded.") + if lock.owned(): + task_logger.debug( + f"Reacquiring primary worker lock for tenant {tenant_id}." + ) + lock.reacquire() else: - task_logger.error("Primary worker lock: Acquire failed!") - raise TimeoutError("Primary worker lock could not be acquired!") + task_logger.warning( + f"Full acquisition of primary worker lock for tenant {tenant_id}. " + "Reasons could be worker restart or lock expiration." + ) + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) - worker.primary_worker_lock = lock - except Exception: - task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.") + task_logger.info( + f"Primary worker lock for tenant {tenant_id}: Acquire starting." + ) + acquired = lock.acquire( + blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 + ) + if acquired: + task_logger.info( + f"Primary worker lock for tenant {tenant_id}: Acquire succeeded." + ) + worker.primary_worker_locks[tenant_id] = lock + else: + task_logger.error( + f"Primary worker lock for tenant {tenant_id}: Acquire failed!" + ) + raise TimeoutError( + f"Primary worker lock for tenant {tenant_id} could not be acquired!" + ) + + except Exception as e: + task_logger.error(f"Error in periodic task: {e}") def stop(self, worker: Any) -> None: # Cancel the scheduled task when the worker stops diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index 794f89232..b1e9c2113 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -27,7 +27,10 @@ logger = setup_logger() def _get_deletion_status( - connector_id: int, credential_id: int, db_session: Session + connector_id: int, + credential_id: int, + db_session: Session, + tenant_id: str | None = None, ) -> TaskQueueState | None: """We no longer store TaskQueueState in the DB for a deletion attempt. This function populates TaskQueueState by just checking redis. @@ -40,7 +43,7 @@ def _get_deletion_status( rcd = RedisConnectorDeletion(cc_pair.id) - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) if not r.exists(rcd.fence_key): return None @@ -50,9 +53,14 @@ def _get_deletion_status( def get_deletion_attempt_snapshot( - connector_id: int, credential_id: int, db_session: Session + connector_id: int, + credential_id: int, + db_session: Session, + tenant_id: str | None = None, ) -> DeletionAttemptSnapshot | None: - deletion_task = _get_deletion_status(connector_id, credential_id, db_session) + deletion_task = _get_deletion_status( + connector_id, credential_id, db_session, tenant_id + ) if not deletion_task: return None diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index b3c2eea30..f6a59d03e 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -24,8 +24,8 @@ from danswer.redis.redis_pool import get_redis_client trail=False, bind=True, ) -def check_for_connector_deletion_task(self: Task, tenant_id: str | None) -> None: - r = get_redis_client() +def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None: + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index ed08787d5..dd3a69931 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -55,10 +55,10 @@ logger = setup_logger() soft_time_limit=300, bind=True, ) -def check_for_indexing(self: Task, tenant_id: str | None) -> int | None: +def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: tasks_created = 0 - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK, @@ -68,6 +68,7 @@ def check_for_indexing(self: Task, tenant_id: str | None) -> int | None: try: # these tasks should never overlap if not lock_beat.acquire(blocking=False): + task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}") return None cc_pair_ids: list[int] = [] @@ -398,7 +399,7 @@ def connector_indexing_task( attempt = None n_final_progress = 0 - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) rci = RedisConnectorIndexing(cc_pair_id, search_settings_id) diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index 698c29372..9f290d6f2 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -41,8 +41,8 @@ logger = setup_logger() soft_time_limit=JOB_TIMEOUT, bind=True, ) -def check_for_pruning(self: Task, tenant_id: str | None) -> None: - r = get_redis_client() +def check_for_pruning(self: Task, *, tenant_id: str | None) -> None: + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK, @@ -222,7 +222,7 @@ def connector_pruning_generator_task( and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list""" - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) rcp = RedisConnectorPruning(cc_pair_id) diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 53e26be69..812074b91 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -60,6 +60,7 @@ from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import VespaDocumentFields from danswer.redis.redis_pool import get_redis_client +from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, @@ -67,6 +68,8 @@ from danswer.utils.variable_functionality import ( from danswer.utils.variable_functionality import global_version from danswer.utils.variable_functionality import noop_fallback +logger = setup_logger() + # celery auto associates tasks created inside another task, # which bloats the result metadata considerably. trail=False prevents this. @@ -76,11 +79,11 @@ from danswer.utils.variable_functionality import noop_fallback trail=False, bind=True, ) -def check_for_vespa_sync_task(self: Task, tenant_id: str | None) -> None: +def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None: """Runs periodically to check if any document needs syncing. Generates sets of tasks for Celery if syncing is needed.""" - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK, @@ -680,7 +683,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: Returns True if the task actually did work, False """ - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) lock_beat: redis.lock.Lock = r.lock( DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK, diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index abab904cc..7729b675e 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -242,7 +242,6 @@ def create_credential( ) db_session.add(credential) db_session.flush() # This ensures the credential gets an ID - _relate_credential_to_user_groups__no_commit( db_session=db_session, credential_id=credential.id, diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 7bf813b44..59311de27 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -319,30 +319,32 @@ async def get_async_session_with_tenant( def get_session_with_tenant( tenant_id: str | None = None, ) -> Generator[Session, None, None]: - """Generate a database session with the appropriate tenant schema set.""" + """Generate a database session bound to a connection with the appropriate tenant schema set.""" engine = get_sqlalchemy_engine() + if tenant_id is None: tenant_id = current_tenant_id.get() + else: + current_tenant_id.set(tenant_id) + + event.listen(engine, "checkout", set_search_path_on_checkout) if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") - # Establish a raw connection without starting a transaction + # Establish a raw connection with engine.connect() as connection: - # Access the raw DBAPI connection + # Access the raw DBAPI connection and set the search_path dbapi_connection = connection.connection - # Execute SET search_path outside of any transaction + # Set the search_path outside of any transaction cursor = dbapi_connection.cursor() try: - cursor.execute(f'SET search_path TO "{tenant_id}"') - # Optionally verify the search_path was set correctly - cursor.execute("SHOW search_path") - cursor.fetchone() + cursor.execute(f'SET search_path = "{tenant_id}"') finally: cursor.close() - # Proceed to create a session using the connection + # Bind the session to the connection with Session(bind=connection, expire_on_commit=False) as session: try: yield session @@ -356,6 +358,18 @@ def get_session_with_tenant( cursor.close() +def set_search_path_on_checkout( + dbapi_conn: Any, connection_record: Any, connection_proxy: Any +) -> None: + tenant_id = current_tenant_id.get() + if tenant_id and is_valid_schema_name(tenant_id): + with dbapi_conn.cursor() as cursor: + cursor.execute(f'SET search_path TO "{tenant_id}"') + logger.debug( + f"Set search_path to {tenant_id} for connection {connection_record}" + ) + + def get_session_generator_with_tenant() -> Generator[Session, None, None]: tenant_id = current_tenant_id.get() with get_session_with_tenant(tenant_id) as session: diff --git a/backend/danswer/key_value_store/store.py b/backend/danswer/key_value_store/store.py index b461ca22f..b442009cc 100644 --- a/backend/danswer/key_value_store/store.py +++ b/backend/danswer/key_value_store/store.py @@ -28,7 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day class PgRedisKVStore(KeyValueStore): def __init__(self) -> None: - self.redis_client = get_redis_client() + tenant_id = current_tenant_id.get() + self.redis_client = get_redis_client(tenant_id=tenant_id) @contextmanager def get_session(self) -> Iterator[Session]: diff --git a/backend/danswer/redis/redis_pool.py b/backend/danswer/redis/redis_pool.py index fd08b9157..3f2ec03d7 100644 --- a/backend/danswer/redis/redis_pool.py +++ b/backend/danswer/redis/redis_pool.py @@ -1,4 +1,7 @@ +import functools import threading +from collections.abc import Callable +from typing import Any from typing import Optional import redis @@ -14,6 +17,98 @@ from danswer.configs.app_configs import REDIS_SSL from danswer.configs.app_configs import REDIS_SSL_CA_CERTS from danswer.configs.app_configs import REDIS_SSL_CERT_REQS from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class TenantRedis(redis.Redis): + def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.tenant_id: str = tenant_id + + def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview: + prefix: str = f"{self.tenant_id}:" + if isinstance(key, str): + if key.startswith(prefix): + return key + else: + return prefix + key + elif isinstance(key, bytes): + prefix_bytes = prefix.encode() + if key.startswith(prefix_bytes): + return key + else: + return prefix_bytes + key + elif isinstance(key, memoryview): + key_bytes = key.tobytes() + prefix_bytes = prefix.encode() + if key_bytes.startswith(prefix_bytes): + return key + else: + return memoryview(prefix_bytes + key_bytes) + else: + raise TypeError(f"Unsupported key type: {type(key)}") + + def _prefix_method(self, method: Callable) -> Callable: + @functools.wraps(method) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if "name" in kwargs: + kwargs["name"] = self._prefixed(kwargs["name"]) + elif len(args) > 0: + args = (self._prefixed(args[0]),) + args[1:] + return method(*args, **kwargs) + + return wrapper + + def _prefix_scan_iter(self, method: Callable) -> Callable: + @functools.wraps(method) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # Prefix the match pattern if provided + if "match" in kwargs: + kwargs["match"] = self._prefixed(kwargs["match"]) + elif len(args) > 0: + args = (self._prefixed(args[0]),) + args[1:] + + # Get the iterator + iterator = method(*args, **kwargs) + + # Remove prefix from returned keys + prefix = f"{self.tenant_id}:".encode() + prefix_len = len(prefix) + + for key in iterator: + if isinstance(key, bytes) and key.startswith(prefix): + yield key[prefix_len:] + else: + yield key + + return wrapper + + def __getattribute__(self, item: str) -> Any: + original_attr = super().__getattribute__(item) + methods_to_wrap = [ + "lock", + "unlock", + "get", + "set", + "delete", + "exists", + "incrby", + "hset", + "hget", + "getset", + "owned", + "reacquire", + "create_lock", + "startswith", + ] # Regular methods that need simple prefixing + + if item == "scan_iter": + return self._prefix_scan_iter(original_attr) + elif item in methods_to_wrap and callable(original_attr): + return self._prefix_method(original_attr) + return original_attr class RedisPool: @@ -32,8 +127,10 @@ class RedisPool: def _init_pool(self) -> None: self._pool = RedisPool.create_pool(ssl=REDIS_SSL) - def get_client(self) -> Redis: - return redis.Redis(connection_pool=self._pool) + def get_client(self, tenant_id: str | None) -> Redis: + if tenant_id is None: + tenant_id = "public" + return TenantRedis(tenant_id, connection_pool=self._pool) @staticmethod def create_pool( @@ -84,8 +181,8 @@ class RedisPool: redis_pool = RedisPool() -def get_redis_client() -> Redis: - return redis_pool.get_client() +def get_redis_client(*, tenant_id: str | None) -> Redis: + return redis_pool.get_client(tenant_id) # # Usage example diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 92a94a638..d8451b099 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -26,6 +26,7 @@ from danswer.db.connector_credential_pair import ( ) from danswer.db.document import get_document_counts_for_cc_pairs from danswer.db.engine import current_tenant_id +from danswer.db.engine import get_current_tenant_id from danswer.db.engine import get_session from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus @@ -94,6 +95,7 @@ def get_cc_pair_full_info( cc_pair_id: int, user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> CCPairFullInfo: r = get_redis_client() @@ -147,6 +149,7 @@ def get_cc_pair_full_info( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, db_session=db_session, + tenant_id=tenant_id, ), num_docs_indexed=documents_indexed, is_editable_for_current_user=is_editable_for_current_user, @@ -243,6 +246,7 @@ def prune_cc_pair( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> StatusResponse[list[int]]: """Triggers pruning on a particular cc_pair immediately""" @@ -258,7 +262,7 @@ def prune_cc_pair( detail="Connection not found for current user's permissions", ) - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) rcp = RedisConnectorPruning(cc_pair_id) if rcp.is_pruning(r): raise HTTPException( diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 54d11e867..4c7bc5603 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -483,10 +483,11 @@ def get_connector_indexing_status( get_editable: bool = Query( False, description="If true, return editable document sets" ), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> list[ConnectorIndexingStatus]: indexing_statuses: list[ConnectorIndexingStatus] = [] - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) # NOTE: If the connector is deleting behind the scenes, # accessing cc_pairs can be inconsistent and members like @@ -607,6 +608,7 @@ def get_connector_indexing_status( connector_id=connector.id, credential_id=credential.id, db_session=db_session, + tenant_id=tenant_id, ), is_deletable=check_deletion_attempt_is_allowed( connector_credential_pair=cc_pair, @@ -684,15 +686,18 @@ def create_connector_with_mock_credential( connector_response = create_connector( db_session=db_session, connector_data=connector_data ) + mock_credential = CredentialBase( credential_json={}, admin_public=True, source=connector_data.source ) credential = create_credential( mock_credential, user=user, db_session=db_session ) + access_type = ( AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE ) + response = add_credential_to_connector( db_session=db_session, user=user, @@ -776,7 +781,7 @@ def connector_run_once( """Used to trigger indexing on a set of cc_pairs associated with a single connector.""" - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) connector_id = run_info.connector_id specified_credential_ids = run_info.credential_ids diff --git a/backend/ee/danswer/background/celery/apps/primary.py b/backend/ee/danswer/background/celery/apps/primary.py index 97c5b0221..169b38ba4 100644 --- a/backend/ee/danswer/background/celery/apps/primary.py +++ b/backend/ee/danswer/background/celery/apps/primary.py @@ -39,7 +39,9 @@ global_version.set_ee() @build_celery_task_wrapper(name_sync_external_doc_permissions_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None: +def sync_external_doc_permissions_task( + cc_pair_id: int, *, tenant_id: str | None +) -> None: with get_session_with_tenant(tenant_id) as db_session: run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) @@ -47,7 +49,7 @@ def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) - @build_celery_task_wrapper(name_sync_external_group_permissions_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def sync_external_group_permissions_task( - cc_pair_id: int, tenant_id: str | None + cc_pair_id: int, *, tenant_id: str | None ) -> None: with get_session_with_tenant(tenant_id) as db_session: run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) @@ -56,7 +58,7 @@ def sync_external_group_permissions_task( @build_celery_task_wrapper(name_chat_ttl_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def perform_ttl_management_task( - retention_limit_days: int, tenant_id: str | None + retention_limit_days: int, *, tenant_id: str | None ) -> None: with get_session_with_tenant(tenant_id) as db_session: delete_chat_sessions_older_than(retention_limit_days, db_session) @@ -69,7 +71,7 @@ def perform_ttl_management_task( name="check_sync_external_doc_permissions_task", soft_time_limit=JOB_TIMEOUT, ) -def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None: +def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None: """Runs periodically to sync external permissions""" with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) @@ -86,7 +88,7 @@ def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None: name="check_sync_external_group_permissions_task", soft_time_limit=JOB_TIMEOUT, ) -def check_sync_external_group_permissions_task(tenant_id: str | None) -> None: +def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None: """Runs periodically to sync external group permissions""" with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) @@ -103,7 +105,7 @@ def check_sync_external_group_permissions_task(tenant_id: str | None) -> None: name="check_ttl_management_task", soft_time_limit=JOB_TIMEOUT, ) -def check_ttl_management_task(tenant_id: str | None) -> None: +def check_ttl_management_task(*, tenant_id: str | None) -> None: """Runs periodically to check if any ttl tasks should be run and adds them to the queue""" token = None @@ -127,7 +129,7 @@ def check_ttl_management_task(tenant_id: str | None) -> None: name="autogenerate_usage_report_task", soft_time_limit=JOB_TIMEOUT, ) -def autogenerate_usage_report_task(tenant_id: str | None) -> None: +def autogenerate_usage_report_task(*, tenant_id: str | None) -> None: """This generates usage report under the /admin/generate-usage/report endpoint""" with get_session_with_tenant(tenant_id) as db_session: create_new_usage_report(