diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index 702ab5205..0b03f3bcc 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -4,7 +4,6 @@ import time from datetime import timedelta from typing import Any -import redis import sentry_sdk from celery import bootsteps # type: ignore from celery import Celery @@ -79,6 +78,7 @@ def on_task_prerun( task_id: str | None = None, task: Task | None = None, args: tuple | None = None, + tenant_id: str | None = None, kwargs: dict | None = None, **kwds: Any, ) -> None: @@ -91,7 +91,7 @@ def on_task_postrun( task_id: str | None = None, task: Task | None = None, args: tuple | None = None, - kwargs: dict | None = None, + kwargs: dict[str, Any] | None = None, retval: Any | None = None, state: str | None = None, **kwds: Any, @@ -110,7 +110,17 @@ def on_task_postrun( if not task: return - task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}") + # Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg + if not kwargs: + logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs") + tenant_id = None + else: + tenant_id = kwargs.get("tenant_id") + + task_logger.debug( + f"Task {task.name} (ID: {task_id}) completed with state: {state} " + f"{f'for tenant_id={tenant_id}' if tenant_id else ''}" + ) if state not in READY_STATES: return @@ -118,7 +128,7 @@ def on_task_postrun( if not task_id: return - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) if task_id.startswith(RedisConnectorCredentialPair.PREFIX): r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id) @@ -171,7 +181,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") # decide some initial startup settings based on the celery worker's hostname - # (set at the command line) + # (set at the command line)' + hostname = sender.hostname if hostname.startswith("light"): SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME) @@ -182,166 +193,155 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: elif hostname.startswith("indexing"): SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME) SqlEngine.init_engine(pool_size=8, max_overflow=0) + tenant_ids = get_all_tenant_ids() - # TODO: why is this necessary for the indexer to do? - with get_session_with_tenant(tenant_id) as db_session: - check_index_swap(db_session=db_session) - search_settings = get_current_search_settings(db_session) + for tenant_id in tenant_ids: + # TODO: why is this necessary for the indexer to do? + with get_session_with_tenant(tenant_id) as db_session: + check_index_swap(db_session=db_session) + search_settings = get_current_search_settings(db_session) - # So that the first time users aren't surprised by really slow speed of first - # batch of documents indexed + # So that the first time users aren't surprised by really slow speed of first + # batch of documents indexed - if search_settings.provider_type is None: - logger.notice("Running a first inference to warm up embedding model") - embedding_model = EmbeddingModel.from_db_model( - search_settings=search_settings, - server_host=INDEXING_MODEL_SERVER_HOST, - server_port=MODEL_SERVER_PORT, - ) + if search_settings.provider_type is None: + logger.notice( + "Running a first inference to warm up embedding model" + ) + embedding_model = EmbeddingModel.from_db_model( + search_settings=search_settings, + server_host=INDEXING_MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) - warm_up_bi_encoder( - embedding_model=embedding_model, - ) - logger.notice("First inference complete.") + warm_up_bi_encoder( + embedding_model=embedding_model, + ) + logger.notice("First inference complete.") else: SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME) SqlEngine.init_engine(pool_size=8, max_overflow=0) - r = get_redis_client() + if not hasattr(sender, "primary_worker_locks"): + sender.primary_worker_locks = {} - WAIT_INTERVAL = 5 - WAIT_LIMIT = 60 - - time_start = time.monotonic() - logger.info("Redis: Readiness check starting.") - while True: - try: - if r.ping(): - break - except Exception: - pass - - time_elapsed = time.monotonic() - time_start - logger.info( - f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" - ) - if time_elapsed > WAIT_LIMIT: - msg = ( - f"Redis: Readiness check did not succeed within the timeout " - f"({WAIT_LIMIT} seconds). Exiting..." - ) - logger.error(msg) - raise WorkerShutdown(msg) - - time.sleep(WAIT_INTERVAL) - - logger.info("Redis: Readiness check succeeded. Continuing...") + tenant_ids = get_all_tenant_ids() if not celery_is_worker_primary(sender): logger.info("Running as a secondary celery worker.") - logger.info("Waiting for primary worker to be ready...") - time_start = time.monotonic() - while True: - if r.exists(DanswerRedisLocks.PRIMARY_WORKER): - break - - time.monotonic() - time_elapsed = time.monotonic() - time_start - logger.info( - f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" - ) - if time_elapsed > WAIT_LIMIT: - msg = ( - f"Primary worker was not ready within the timeout. " - f"({WAIT_LIMIT} seconds). Exiting..." + for tenant_id in tenant_ids: + r = get_redis_client(tenant_id=tenant_id) + WAIT_INTERVAL = 5 + WAIT_LIMIT = 60 + time_start = time.monotonic() + logger.notice("Redis: Readiness check starting.") + while True: + # Log all the locks in Redis + all_locks = r.keys("*") + logger.notice(f"Current Redis locks: {all_locks}") + if r.exists(DanswerRedisLocks.PRIMARY_WORKER): + break + time_elapsed = time.monotonic() - time_start + logger.info( + f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" ) - logger.error(msg) - raise WorkerShutdown(msg) + if time_elapsed > WAIT_LIMIT: + msg = ( + "Redis: Readiness check did not succeed within the timeout " + f"({WAIT_LIMIT} seconds). Exiting..." + ) + logger.error(msg) + raise WorkerShutdown(msg) + time.sleep(WAIT_INTERVAL) + logger.info("Wait for primary worker completed successfully. Continuing...") + return # Exit the function for secondary workers - time.sleep(WAIT_INTERVAL) + for tenant_id in tenant_ids: + r = get_redis_client(tenant_id=tenant_id) - logger.info("Wait for primary worker completed successfully. Continuing...") - return + WAIT_INTERVAL = 5 + WAIT_LIMIT = 60 - logger.info("Running as the primary celery worker.") + time_start = time.monotonic() + logger.info("Running as the primary celery worker.") - # This is singleton work that should be done on startup exactly once - # by the primary worker - r = get_redis_client() + # This is singleton work that should be done on startup exactly once + # by the primary worker + r = get_redis_client(tenant_id=tenant_id) - # For the moment, we're assuming that we are the only primary worker - # that should be running. - # TODO: maybe check for or clean up another zombie primary worker if we detect it - r.delete(DanswerRedisLocks.PRIMARY_WORKER) + # For the moment, we're assuming that we are the only primary worker + # that should be running. + # TODO: maybe check for or clean up another zombie primary worker if we detect it + r.delete(DanswerRedisLocks.PRIMARY_WORKER) - # this process wide lock is taken to help other workers start up in order. - # it is planned to use this lock to enforce singleton behavior on the primary - # worker, since the primary worker does redis cleanup on startup, but this isn't - # implemented yet. - lock = r.lock( - DanswerRedisLocks.PRIMARY_WORKER, - timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, - ) + # this process wide lock is taken to help other workers start up in order. + # it is planned to use this lock to enforce singleton behavior on the primary + # worker, since the primary worker does redis cleanup on startup, but this isn't + # implemented yet. + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) - logger.info("Primary worker lock: Acquire starting.") - acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) - if acquired: - logger.info("Primary worker lock: Acquire succeeded.") - else: - logger.error("Primary worker lock: Acquire failed!") - raise WorkerShutdown("Primary worker lock could not be acquired!") + logger.info("Primary worker lock: Acquire starting.") + acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) + if acquired: + logger.info("Primary worker lock: Acquire succeeded.") + else: + logger.error("Primary worker lock: Acquire failed!") + raise WorkerShutdown("Primary worker lock could not be acquired!") - sender.primary_worker_lock = lock + sender.primary_worker_locks[tenant_id] = lock - # As currently designed, when this worker starts as "primary", we reinitialize redis - # to a clean state (for our purposes, anyway) - r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) - r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) + # As currently designed, when this worker starts as "primary", we reinitialize redis + # to a clean state (for our purposes, anyway) + r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) + r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) - r.delete(RedisConnectorCredentialPair.get_taskset_key()) - r.delete(RedisConnectorCredentialPair.get_fence_key()) + r.delete(RedisConnectorCredentialPair.get_taskset_key()) + r.delete(RedisConnectorCredentialPair.get_fence_key()) - for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): + r.delete(key) # @worker_process_init.connect @@ -367,14 +367,15 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: if not celery_is_worker_primary(sender): return - if not sender.primary_worker_lock: + if not hasattr(sender, "primary_worker_locks"): return logger.info("Releasing primary worker lock.") - lock = sender.primary_worker_lock - if lock.owned(): - lock.release() - sender.primary_worker_lock = None + for tenant_id, lock in sender.primary_worker_locks.items(): + logger.info(f"Releasing primary worker lock for tenant {tenant_id}.") + if lock.owned(): + lock.release() + sender.primary_worker_locks = {} class CeleryTaskPlainFormatter(PlainFormatter): @@ -449,17 +450,18 @@ def on_setup_logging( class HubPeriodicTask(bootsteps.StartStopStep): - """Regularly reacquires the primary worker lock outside of the task queue. + """Regularly reacquires the primary worker locks for all tenants outside of the task queue. Use the task_logger in this class to avoid double logging. This cannot be done inside a regular beat task because it must run on schedule and a queue of existing work would starve the task from running. """ - # it's unclear to me whether using the hub's timer or the bootstep timer is better + # Requires the Hub component requires = {"celery.worker.components:Hub"} def __init__(self, worker: Any, **kwargs: Any) -> None: + super().__init__(worker, **kwargs) self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds self.task_tref = None @@ -478,42 +480,58 @@ class HubPeriodicTask(bootsteps.StartStopStep): def run_periodic_task(self, worker: Any) -> None: try: - if not worker.primary_worker_lock: + if not celery_is_worker_primary(worker): return - if not hasattr(worker, "primary_worker_lock"): + if not hasattr(worker, "primary_worker_locks"): return - r = get_redis_client() + # Retrieve all tenant IDs + tenant_ids = get_all_tenant_ids() - lock: redis.lock.Lock = worker.primary_worker_lock + for tenant_id in tenant_ids: + lock = worker.primary_worker_locks.get(tenant_id) + if not lock: + continue # Skip if no lock for this tenant - if lock.owned(): - task_logger.debug("Reacquiring primary worker lock.") - lock.reacquire() - else: - task_logger.warning( - "Full acquisition of primary worker lock. " - "Reasons could be computer sleep or a clock change." - ) - lock = r.lock( - DanswerRedisLocks.PRIMARY_WORKER, - timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, - ) + r = get_redis_client(tenant_id=tenant_id) - task_logger.info("Primary worker lock: Acquire starting.") - acquired = lock.acquire( - blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 - ) - if acquired: - task_logger.info("Primary worker lock: Acquire succeeded.") + if lock.owned(): + task_logger.debug( + f"Reacquiring primary worker lock for tenant {tenant_id}." + ) + lock.reacquire() else: - task_logger.error("Primary worker lock: Acquire failed!") - raise TimeoutError("Primary worker lock could not be acquired!") + task_logger.warning( + f"Full acquisition of primary worker lock for tenant {tenant_id}. " + "Reasons could be worker restart or lock expiration." + ) + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) - worker.primary_worker_lock = lock - except Exception: - task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.") + task_logger.info( + f"Primary worker lock for tenant {tenant_id}: Acquire starting." + ) + acquired = lock.acquire( + blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 + ) + if acquired: + task_logger.info( + f"Primary worker lock for tenant {tenant_id}: Acquire succeeded." + ) + worker.primary_worker_locks[tenant_id] = lock + else: + task_logger.error( + f"Primary worker lock for tenant {tenant_id}: Acquire failed!" + ) + raise TimeoutError( + f"Primary worker lock for tenant {tenant_id} could not be acquired!" + ) + + except Exception as e: + task_logger.error(f"Error in periodic task: {e}") def stop(self, worker: Any) -> None: # Cancel the scheduled task when the worker stops @@ -583,14 +601,14 @@ tasks_to_schedule = [ # Build the celery beat schedule dynamically beat_schedule = {} -for tenant_id in tenant_ids: +for id in tenant_ids: for task in tasks_to_schedule: - task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task + task_name = f"{task['name']}-{id}" # Unique name for each scheduled task beat_schedule[task_name] = { "task": task["task"], "schedule": task["schedule"], "options": task["options"], - "args": (tenant_id,), # Must pass tenant_id as an argument + "kwargs": {"tenant_id": id}, # Must pass tenant_id as an argument } # Include any existing beat schedules diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index b76e148e2..404feb88c 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -31,7 +31,10 @@ logger = setup_logger() def _get_deletion_status( - connector_id: int, credential_id: int, db_session: Session + connector_id: int, + credential_id: int, + db_session: Session, + tenant_id: str | None = None, ) -> TaskQueueState | None: """We no longer store TaskQueueState in the DB for a deletion attempt. This function populates TaskQueueState by just checking redis. @@ -44,7 +47,7 @@ def _get_deletion_status( rcd = RedisConnectorDeletion(cc_pair.id) - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) if not r.exists(rcd.fence_key): return None @@ -54,9 +57,14 @@ def _get_deletion_status( def get_deletion_attempt_snapshot( - connector_id: int, credential_id: int, db_session: Session + connector_id: int, + credential_id: int, + db_session: Session, + tenant_id: str | None = None, ) -> DeletionAttemptSnapshot | None: - deletion_task = _get_deletion_status(connector_id, credential_id, db_session) + deletion_task = _get_deletion_status( + connector_id, credential_id, db_session, tenant_id + ) if not deletion_task: return None diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index b13daff61..caab63299 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -23,8 +23,8 @@ from danswer.redis.redis_pool import get_redis_client soft_time_limit=JOB_TIMEOUT, trail=False, ) -def check_for_connector_deletion_task(tenant_id: str | None) -> None: - r = get_redis_client() +def check_for_connector_deletion_task(*, tenant_id: str | None) -> None: + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index fefbae032..a2631d369 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -51,10 +51,10 @@ logger = setup_logger() name="check_for_indexing", soft_time_limit=300, ) -def check_for_indexing(tenant_id: str | None) -> int | None: +def check_for_indexing(*, tenant_id: str | None) -> int | None: tasks_created = 0 - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK, @@ -64,7 +64,10 @@ def check_for_indexing(tenant_id: str | None) -> int | None: try: # these tasks should never overlap if not lock_beat.acquire(blocking=False): + task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}") return None + else: + task_logger.info(f"Lock acquired for tenant (N): {tenant_id}") with get_session_with_tenant(tenant_id) as db_session: # Get the primary search settings @@ -367,7 +370,7 @@ def connector_indexing_task( attempt = None n_final_progress = 0 - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) rci = RedisConnectorIndexing(cc_pair_id, search_settings_id) diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index ee5adfd10..19cedf6d2 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -38,8 +38,8 @@ logger = setup_logger() name="check_for_pruning", soft_time_limit=JOB_TIMEOUT, ) -def check_for_pruning(tenant_id: str | None) -> None: - r = get_redis_client() +def check_for_pruning(*, tenant_id: str | None) -> None: + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK, @@ -204,7 +204,7 @@ def connector_pruning_generator_task( and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list""" - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) rcp = RedisConnectorPruning(cc_pair_id) diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 9830d71f7..bd299fa63 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -59,6 +59,7 @@ from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import VespaDocumentFields from danswer.redis.redis_pool import get_redis_client +from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, @@ -66,6 +67,8 @@ from danswer.utils.variable_functionality import ( from danswer.utils.variable_functionality import global_version from danswer.utils.variable_functionality import noop_fallback +logger = setup_logger() + # celery auto associates tasks created inside another task, # which bloats the result metadata considerably. trail=False prevents this. @@ -74,11 +77,11 @@ from danswer.utils.variable_functionality import noop_fallback soft_time_limit=JOB_TIMEOUT, trail=False, ) -def check_for_vespa_sync_task(tenant_id: str | None) -> None: +def check_for_vespa_sync_task(*, tenant_id: str | None) -> None: """Runs periodically to check if any document needs syncing. Generates sets of tasks for Celery if syncing is needed.""" - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK, @@ -640,7 +643,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: Returns True if the task actually did work, False """ - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) lock_beat: redis.lock.Lock = r.lock( DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK, diff --git a/backend/danswer/connectors/confluence/rate_limit_handler.py b/backend/danswer/connectors/confluence/rate_limit_handler.py index 8dbdeba1a..4f740ecbf 100644 --- a/backend/danswer/connectors/confluence/rate_limit_handler.py +++ b/backend/danswer/connectors/confluence/rate_limit_handler.py @@ -41,7 +41,7 @@ class ConfluenceRateLimitError(Exception): # # for testing purposes, rate limiting is written to fall back to a simpler # # rate limiting approach when redis is not available -# r = get_redis_client() +# r = get_redis_client(tenant_id=tenant_id) # for attempt in range(max_retries): # try: diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index abab904cc..7729b675e 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -242,7 +242,6 @@ def create_credential( ) db_session.add(credential) db_session.flush() # This ensures the credential gets an ID - _relate_credential_to_user_groups__no_commit( db_session=db_session, credential_id=credential.id, diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 1c6a6a3a3..af1dd3a0f 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -295,30 +295,32 @@ async def get_async_session_with_tenant( def get_session_with_tenant( tenant_id: str | None = None, ) -> Generator[Session, None, None]: - """Generate a database session with the appropriate tenant schema set.""" + """Generate a database session bound to a connection with the appropriate tenant schema set.""" engine = get_sqlalchemy_engine() + if tenant_id is None: tenant_id = current_tenant_id.get() + else: + current_tenant_id.set(tenant_id) + + event.listen(engine, "checkout", set_search_path_on_checkout) if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") - # Establish a raw connection without starting a transaction + # Establish a raw connection with engine.connect() as connection: - # Access the raw DBAPI connection + # Access the raw DBAPI connection and set the search_path dbapi_connection = connection.connection - # Execute SET search_path outside of any transaction + # Set the search_path outside of any transaction cursor = dbapi_connection.cursor() try: - cursor.execute(f'SET search_path TO "{tenant_id}"') - # Optionally verify the search_path was set correctly - cursor.execute("SHOW search_path") - cursor.fetchone() + cursor.execute(f'SET search_path = "{tenant_id}"') finally: cursor.close() - # Proceed to create a session using the connection + # Bind the session to the connection with Session(bind=connection, expire_on_commit=False) as session: try: yield session @@ -332,6 +334,18 @@ def get_session_with_tenant( cursor.close() +def set_search_path_on_checkout( + dbapi_conn: Any, connection_record: Any, connection_proxy: Any +) -> None: + tenant_id = current_tenant_id.get() + if tenant_id and is_valid_schema_name(tenant_id): + with dbapi_conn.cursor() as cursor: + cursor.execute(f'SET search_path TO "{tenant_id}"') + logger.debug( + f"Set search_path to {tenant_id} for connection {connection_record}" + ) + + def get_session_generator_with_tenant( tenant_id: str | None = None, ) -> Generator[Session, None, None]: diff --git a/backend/danswer/key_value_store/store.py b/backend/danswer/key_value_store/store.py index 240ff355b..9e59f2db8 100644 --- a/backend/danswer/key_value_store/store.py +++ b/backend/danswer/key_value_store/store.py @@ -28,7 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day class PgRedisKVStore(KeyValueStore): def __init__(self) -> None: - self.redis_client = get_redis_client() + tenant_id = current_tenant_id.get() + self.redis_client = get_redis_client(tenant_id=tenant_id) @contextmanager def get_session(self) -> Iterator[Session]: diff --git a/backend/danswer/redis/redis_pool.py b/backend/danswer/redis/redis_pool.py index fd08b9157..fda71a143 100644 --- a/backend/danswer/redis/redis_pool.py +++ b/backend/danswer/redis/redis_pool.py @@ -1,4 +1,7 @@ +import functools import threading +from collections.abc import Callable +from typing import Any from typing import Optional import redis @@ -14,6 +17,72 @@ from danswer.configs.app_configs import REDIS_SSL from danswer.configs.app_configs import REDIS_SSL_CA_CERTS from danswer.configs.app_configs import REDIS_SSL_CERT_REQS from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class TenantRedis(redis.Redis): + def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.tenant_id: str = tenant_id + + def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview: + prefix: str = f"{self.tenant_id}:" + if isinstance(key, str): + if key.startswith(prefix): + return key + else: + return prefix + key + elif isinstance(key, bytes): + prefix_bytes = prefix.encode() + if key.startswith(prefix_bytes): + return key + else: + return prefix_bytes + key + elif isinstance(key, memoryview): + key_bytes = key.tobytes() + prefix_bytes = prefix.encode() + if key_bytes.startswith(prefix_bytes): + return key + else: + return memoryview(prefix_bytes + key_bytes) + else: + raise TypeError(f"Unsupported key type: {type(key)}") + + def _prefix_method(self, method: Callable) -> Callable: + @functools.wraps(method) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if "name" in kwargs: + kwargs["name"] = self._prefixed(kwargs["name"]) + elif len(args) > 0: + args = (self._prefixed(args[0]),) + args[1:] + return method(*args, **kwargs) + + return wrapper + + def __getattribute__(self, item: str) -> Any: + original_attr = super().__getattribute__(item) + methods_to_wrap = [ + "lock", + "unlock", + "get", + "set", + "delete", + "exists", + "incrby", + "hset", + "hget", + "getset", + "scan_iter", + "owned", + "reacquire", + "create_lock", + "startswith", + ] # Add all methods that need prefixing + if item in methods_to_wrap and callable(original_attr): + return self._prefix_method(original_attr) + return original_attr class RedisPool: @@ -32,8 +101,10 @@ class RedisPool: def _init_pool(self) -> None: self._pool = RedisPool.create_pool(ssl=REDIS_SSL) - def get_client(self) -> Redis: - return redis.Redis(connection_pool=self._pool) + def get_client(self, tenant_id: str | None) -> Redis: + if tenant_id is None: + tenant_id = "public" + return TenantRedis(tenant_id, connection_pool=self._pool) @staticmethod def create_pool( @@ -84,8 +155,8 @@ class RedisPool: redis_pool = RedisPool() -def get_redis_client() -> Redis: - return redis_pool.get_client() +def get_redis_client(*, tenant_id: str | None) -> Redis: + return redis_pool.get_client(tenant_id) # # Usage example diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 9cfe72275..c9802038e 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -24,6 +24,7 @@ from danswer.db.connector_credential_pair import ( ) from danswer.db.document import get_document_counts_for_cc_pairs from danswer.db.engine import current_tenant_id +from danswer.db.engine import get_current_tenant_id from danswer.db.engine import get_session from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus @@ -90,6 +91,7 @@ def get_cc_pair_full_info( cc_pair_id: int, user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> CCPairFullInfo: cc_pair = get_connector_credential_pair_from_id( cc_pair_id, db_session, user, get_editable=False @@ -136,6 +138,7 @@ def get_cc_pair_full_info( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, db_session=db_session, + tenant_id=tenant_id, ), num_docs_indexed=documents_indexed, is_editable_for_current_user=is_editable_for_current_user, @@ -231,6 +234,7 @@ def prune_cc_pair( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> StatusResponse[list[int]]: """Triggers pruning on a particular cc_pair immediately""" @@ -246,7 +250,7 @@ def prune_cc_pair( detail="Connection not found for current user's permissions", ) - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) rcp = RedisConnectorPruning(cc_pair_id) if rcp.is_pruning(db_session, r): raise HTTPException( diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 8de42db38..683fe7195 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -482,10 +482,11 @@ def get_connector_indexing_status( get_editable: bool = Query( False, description="If true, return editable document sets" ), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> list[ConnectorIndexingStatus]: indexing_statuses: list[ConnectorIndexingStatus] = [] - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) # NOTE: If the connector is deleting behind the scenes, # accessing cc_pairs can be inconsistent and members like @@ -606,6 +607,7 @@ def get_connector_indexing_status( connector_id=connector.id, credential_id=credential.id, db_session=db_session, + tenant_id=tenant_id, ), is_deletable=check_deletion_attempt_is_allowed( connector_credential_pair=cc_pair, @@ -683,15 +685,18 @@ def create_connector_with_mock_credential( connector_response = create_connector( db_session=db_session, connector_data=connector_data ) + mock_credential = CredentialBase( credential_json={}, admin_public=True, source=connector_data.source ) credential = create_credential( mock_credential, user=user, db_session=db_session ) + access_type = ( AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE ) + response = add_credential_to_connector( db_session=db_session, user=user, @@ -775,7 +780,7 @@ def connector_run_once( """Used to trigger indexing on a set of cc_pairs associated with a single connector.""" - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) connector_id = run_info.connector_id specified_credential_ids = run_info.credential_ids diff --git a/backend/ee/danswer/background/celery/celery_app.py b/backend/ee/danswer/background/celery/celery_app.py index afc77c146..f1282307b 100644 --- a/backend/ee/danswer/background/celery/celery_app.py +++ b/backend/ee/danswer/background/celery/celery_app.py @@ -42,7 +42,9 @@ global_version.set_ee() @build_celery_task_wrapper(name_sync_external_doc_permissions_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None: +def sync_external_doc_permissions_task( + cc_pair_id: int, *, tenant_id: str | None +) -> None: with get_session_with_tenant(tenant_id) as db_session: run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) @@ -50,7 +52,7 @@ def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) - @build_celery_task_wrapper(name_sync_external_group_permissions_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def sync_external_group_permissions_task( - cc_pair_id: int, tenant_id: str | None + cc_pair_id: int, *, tenant_id: str | None ) -> None: with get_session_with_tenant(tenant_id) as db_session: run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) @@ -59,7 +61,7 @@ def sync_external_group_permissions_task( @build_celery_task_wrapper(name_chat_ttl_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def perform_ttl_management_task( - retention_limit_days: int, tenant_id: str | None + retention_limit_days: int, *, tenant_id: str | None ) -> None: with get_session_with_tenant(tenant_id) as db_session: delete_chat_sessions_older_than(retention_limit_days, db_session) @@ -72,7 +74,7 @@ def perform_ttl_management_task( name="check_sync_external_doc_permissions_task", soft_time_limit=JOB_TIMEOUT, ) -def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None: +def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None: """Runs periodically to sync external permissions""" with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) @@ -89,7 +91,7 @@ def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None: name="check_sync_external_group_permissions_task", soft_time_limit=JOB_TIMEOUT, ) -def check_sync_external_group_permissions_task(tenant_id: str | None) -> None: +def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None: """Runs periodically to sync external group permissions""" with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) @@ -106,7 +108,7 @@ def check_sync_external_group_permissions_task(tenant_id: str | None) -> None: name="check_ttl_management_task", soft_time_limit=JOB_TIMEOUT, ) -def check_ttl_management_task(tenant_id: str | None) -> None: +def check_ttl_management_task(*, tenant_id: str | None) -> None: """Runs periodically to check if any ttl tasks should be run and adds them to the queue""" token = None @@ -130,7 +132,7 @@ def check_ttl_management_task(tenant_id: str | None) -> None: name="autogenerate_usage_report_task", soft_time_limit=JOB_TIMEOUT, ) -def autogenerate_usage_report_task(tenant_id: str | None) -> None: +def autogenerate_usage_report_task(*, tenant_id: str | None) -> None: """This generates usage report under the /admin/generate-usage/report endpoint""" with get_session_with_tenant(tenant_id) as db_session: create_new_usage_report( @@ -179,7 +181,7 @@ for tenant_id in tenant_ids: beat_schedule[task_name] = { "task": task["task"], "schedule": task["schedule"], - "args": (tenant_id,), # Must pass tenant_id as an argument + "kwargs": {"tenant_id": tenant_id}, # Pass tenant_id as a keyword argument } # Include any existing beat schedules