add multi tenancy to redis

This commit is contained in:
pablodanswer 2024-10-23 13:07:00 -07:00
parent f745ca1e03
commit 668cd7bb49
14 changed files with 335 additions and 207 deletions

View File

@ -4,7 +4,6 @@ import time
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
import redis
import sentry_sdk import sentry_sdk
from celery import bootsteps # type: ignore from celery import bootsteps # type: ignore
from celery import Celery from celery import Celery
@ -79,6 +78,7 @@ def on_task_prerun(
task_id: str | None = None, task_id: str | None = None,
task: Task | None = None, task: Task | None = None,
args: tuple | None = None, args: tuple | None = None,
tenant_id: str | None = None,
kwargs: dict | None = None, kwargs: dict | None = None,
**kwds: Any, **kwds: Any,
) -> None: ) -> None:
@ -91,7 +91,7 @@ def on_task_postrun(
task_id: str | None = None, task_id: str | None = None,
task: Task | None = None, task: Task | None = None,
args: tuple | None = None, args: tuple | None = None,
kwargs: dict | None = None, kwargs: dict[str, Any] | None = None,
retval: Any | None = None, retval: Any | None = None,
state: str | None = None, state: str | None = None,
**kwds: Any, **kwds: Any,
@ -110,7 +110,17 @@ def on_task_postrun(
if not task: if not task:
return return
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}") # Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
else:
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
if state not in READY_STATES: if state not in READY_STATES:
return return
@ -118,7 +128,7 @@ def on_task_postrun(
if not task_id: if not task_id:
return return
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX): if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id) r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
@ -171,7 +181,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
# decide some initial startup settings based on the celery worker's hostname # decide some initial startup settings based on the celery worker's hostname
# (set at the command line) # (set at the command line)'
hostname = sender.hostname hostname = sender.hostname
if hostname.startswith("light"): if hostname.startswith("light"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME) SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
@ -182,166 +193,155 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
elif hostname.startswith("indexing"): elif hostname.startswith("indexing"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME) SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0) SqlEngine.init_engine(pool_size=8, max_overflow=0)
tenant_ids = get_all_tenant_ids()
# TODO: why is this necessary for the indexer to do? for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session: # TODO: why is this necessary for the indexer to do?
check_index_swap(db_session=db_session) with get_session_with_tenant(tenant_id) as db_session:
search_settings = get_current_search_settings(db_session) check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first # So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed # batch of documents indexed
if search_settings.provider_type is None: if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model") logger.notice(
embedding_model = EmbeddingModel.from_db_model( "Running a first inference to warm up embedding model"
search_settings=search_settings, )
server_host=INDEXING_MODEL_SERVER_HOST, embedding_model = EmbeddingModel.from_db_model(
server_port=MODEL_SERVER_PORT, search_settings=search_settings,
) server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder( warm_up_bi_encoder(
embedding_model=embedding_model, embedding_model=embedding_model,
) )
logger.notice("First inference complete.") logger.notice("First inference complete.")
else: else:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME) SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0) SqlEngine.init_engine(pool_size=8, max_overflow=0)
r = get_redis_client() if not hasattr(sender, "primary_worker_locks"):
sender.primary_worker_locks = {}
WAIT_INTERVAL = 5 tenant_ids = get_all_tenant_ids()
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Redis: Readiness check succeeded. Continuing...")
if not celery_is_worker_primary(sender): if not celery_is_worker_primary(sender):
logger.info("Running as a secondary celery worker.") logger.info("Running as a secondary celery worker.")
logger.info("Waiting for primary worker to be ready...") for tenant_id in tenant_ids:
time_start = time.monotonic() r = get_redis_client(tenant_id=tenant_id)
while True: WAIT_INTERVAL = 5
if r.exists(DanswerRedisLocks.PRIMARY_WORKER): WAIT_LIMIT = 60
break time_start = time.monotonic()
logger.notice("Redis: Readiness check starting.")
time.monotonic() while True:
time_elapsed = time.monotonic() - time_start # Log all the locks in Redis
logger.info( all_locks = r.keys("*")
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" logger.notice(f"Current Redis locks: {all_locks}")
) if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
if time_elapsed > WAIT_LIMIT: break
msg = ( time_elapsed = time.monotonic() - time_start
f"Primary worker was not ready within the timeout. " logger.info(
f"({WAIT_LIMIT} seconds). Exiting..." f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
) )
logger.error(msg) if time_elapsed > WAIT_LIMIT:
raise WorkerShutdown(msg) msg = (
"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
return # Exit the function for secondary workers
time.sleep(WAIT_INTERVAL) for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
logger.info("Wait for primary worker completed successfully. Continuing...") WAIT_INTERVAL = 5
return WAIT_LIMIT = 60
logger.info("Running as the primary celery worker.") time_start = time.monotonic()
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once # This is singleton work that should be done on startup exactly once
# by the primary worker # by the primary worker
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker # For the moment, we're assuming that we are the only primary worker
# that should be running. # that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it # TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER) r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order. # this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary # it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't # worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet. # implemented yet.
lock = r.lock( lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER, DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
) )
logger.info("Primary worker lock: Acquire starting.") logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired: if acquired:
logger.info("Primary worker lock: Acquire succeeded.") logger.info("Primary worker lock: Acquire succeeded.")
else: else:
logger.error("Primary worker lock: Acquire failed!") logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!") raise WorkerShutdown("Primary worker lock could not be acquired!")
sender.primary_worker_lock = lock sender.primary_worker_locks[tenant_id] = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis # As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway) # to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key()) r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key()) r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"): for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"): for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"): for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"): for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"): for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"): for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"): for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key) r.delete(key)
# @worker_process_init.connect # @worker_process_init.connect
@ -367,14 +367,15 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender): if not celery_is_worker_primary(sender):
return return
if not sender.primary_worker_lock: if not hasattr(sender, "primary_worker_locks"):
return return
logger.info("Releasing primary worker lock.") logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock for tenant_id, lock in sender.primary_worker_locks.items():
if lock.owned(): logger.info(f"Releasing primary worker lock for tenant {tenant_id}.")
lock.release() if lock.owned():
sender.primary_worker_lock = None lock.release()
sender.primary_worker_locks = {}
class CeleryTaskPlainFormatter(PlainFormatter): class CeleryTaskPlainFormatter(PlainFormatter):
@ -449,17 +450,18 @@ def on_setup_logging(
class HubPeriodicTask(bootsteps.StartStopStep): class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker lock outside of the task queue. """Regularly reacquires the primary worker locks for all tenants outside of the task queue.
Use the task_logger in this class to avoid double logging. Use the task_logger in this class to avoid double logging.
This cannot be done inside a regular beat task because it must run on schedule and This cannot be done inside a regular beat task because it must run on schedule and
a queue of existing work would starve the task from running. a queue of existing work would starve the task from running.
""" """
# it's unclear to me whether using the hub's timer or the bootstep timer is better # Requires the Hub component
requires = {"celery.worker.components:Hub"} requires = {"celery.worker.components:Hub"}
def __init__(self, worker: Any, **kwargs: Any) -> None: def __init__(self, worker: Any, **kwargs: Any) -> None:
super().__init__(worker, **kwargs)
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
self.task_tref = None self.task_tref = None
@ -478,42 +480,58 @@ class HubPeriodicTask(bootsteps.StartStopStep):
def run_periodic_task(self, worker: Any) -> None: def run_periodic_task(self, worker: Any) -> None:
try: try:
if not worker.primary_worker_lock: if not celery_is_worker_primary(worker):
return return
if not hasattr(worker, "primary_worker_lock"): if not hasattr(worker, "primary_worker_locks"):
return return
r = get_redis_client() # Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
lock: redis.lock.Lock = worker.primary_worker_lock for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
if lock.owned(): r = get_redis_client(tenant_id=tenant_id)
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be computer sleep or a clock change."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.info("Primary worker lock: Acquire starting.") if lock.owned():
acquired = lock.acquire( task_logger.debug(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 f"Reacquiring primary worker lock for tenant {tenant_id}."
) )
if acquired: lock.reacquire()
task_logger.info("Primary worker lock: Acquire succeeded.")
else: else:
task_logger.error("Primary worker lock: Acquire failed!") task_logger.warning(
raise TimeoutError("Primary worker lock could not be acquired!") f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
worker.primary_worker_lock = lock task_logger.info(
except Exception: f"Primary worker lock for tenant {tenant_id}: Acquire starting."
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.") )
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception as e:
task_logger.error(f"Error in periodic task: {e}")
def stop(self, worker: Any) -> None: def stop(self, worker: Any) -> None:
# Cancel the scheduled task when the worker stops # Cancel the scheduled task when the worker stops
@ -583,14 +601,14 @@ tasks_to_schedule = [
# Build the celery beat schedule dynamically # Build the celery beat schedule dynamically
beat_schedule = {} beat_schedule = {}
for tenant_id in tenant_ids: for id in tenant_ids:
for task in tasks_to_schedule: for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task task_name = f"{task['name']}-{id}" # Unique name for each scheduled task
beat_schedule[task_name] = { beat_schedule[task_name] = {
"task": task["task"], "task": task["task"],
"schedule": task["schedule"], "schedule": task["schedule"],
"options": task["options"], "options": task["options"],
"args": (tenant_id,), # Must pass tenant_id as an argument "kwargs": {"tenant_id": id}, # Must pass tenant_id as an argument
} }
# Include any existing beat schedules # Include any existing beat schedules

View File

@ -31,7 +31,10 @@ logger = setup_logger()
def _get_deletion_status( def _get_deletion_status(
connector_id: int, credential_id: int, db_session: Session connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
) -> TaskQueueState | None: ) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt. """We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis. This function populates TaskQueueState by just checking redis.
@ -44,7 +47,7 @@ def _get_deletion_status(
rcd = RedisConnectorDeletion(cc_pair.id) rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key): if not r.exists(rcd.fence_key):
return None return None
@ -54,9 +57,14 @@ def _get_deletion_status(
def get_deletion_attempt_snapshot( def get_deletion_attempt_snapshot(
connector_id: int, credential_id: int, db_session: Session connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
) -> DeletionAttemptSnapshot | None: ) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(connector_id, credential_id, db_session) deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
if not deletion_task: if not deletion_task:
return None return None

View File

@ -23,8 +23,8 @@ from danswer.redis.redis_pool import get_redis_client
soft_time_limit=JOB_TIMEOUT, soft_time_limit=JOB_TIMEOUT,
trail=False, trail=False,
) )
def check_for_connector_deletion_task(tenant_id: str | None) -> None: def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock( lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,

View File

@ -51,10 +51,10 @@ logger = setup_logger()
name="check_for_indexing", name="check_for_indexing",
soft_time_limit=300, soft_time_limit=300,
) )
def check_for_indexing(tenant_id: str | None) -> int | None: def check_for_indexing(*, tenant_id: str | None) -> int | None:
tasks_created = 0 tasks_created = 0
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock( lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK, DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
@ -64,7 +64,10 @@ def check_for_indexing(tenant_id: str | None) -> int | None:
try: try:
# these tasks should never overlap # these tasks should never overlap
if not lock_beat.acquire(blocking=False): if not lock_beat.acquire(blocking=False):
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
return None return None
else:
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings # Get the primary search settings
@ -367,7 +370,7 @@ def connector_indexing_task(
attempt = None attempt = None
n_final_progress = 0 n_final_progress = 0
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id) rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)

View File

@ -38,8 +38,8 @@ logger = setup_logger()
name="check_for_pruning", name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT, soft_time_limit=JOB_TIMEOUT,
) )
def check_for_pruning(tenant_id: str | None) -> None: def check_for_pruning(*, tenant_id: str | None) -> None:
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock( lock_beat = r.lock(
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK, DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@ -204,7 +204,7 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list""" from the most recently pulled document ID list"""
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
rcp = RedisConnectorPruning(cc_pair_id) rcp = RedisConnectorPruning(cc_pair_id)

View File

@ -59,6 +59,7 @@ from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_pool import get_redis_client from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import ( from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback, fetch_versioned_implementation_with_fallback,
@ -66,6 +67,8 @@ from danswer.utils.variable_functionality import (
from danswer.utils.variable_functionality import global_version from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import noop_fallback from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
# celery auto associates tasks created inside another task, # celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this. # which bloats the result metadata considerably. trail=False prevents this.
@ -74,11 +77,11 @@ from danswer.utils.variable_functionality import noop_fallback
soft_time_limit=JOB_TIMEOUT, soft_time_limit=JOB_TIMEOUT,
trail=False, trail=False,
) )
def check_for_vespa_sync_task(tenant_id: str | None) -> None: def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing. """Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed.""" Generates sets of tasks for Celery if syncing is needed."""
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock( lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK, DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@ -640,7 +643,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
Returns True if the task actually did work, False Returns True if the task actually did work, False
""" """
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
lock_beat: redis.lock.Lock = r.lock( lock_beat: redis.lock.Lock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK, DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,

View File

@ -41,7 +41,7 @@ class ConfluenceRateLimitError(Exception):
# # for testing purposes, rate limiting is written to fall back to a simpler # # for testing purposes, rate limiting is written to fall back to a simpler
# # rate limiting approach when redis is not available # # rate limiting approach when redis is not available
# r = get_redis_client() # r = get_redis_client(tenant_id=tenant_id)
# for attempt in range(max_retries): # for attempt in range(max_retries):
# try: # try:

View File

@ -242,7 +242,6 @@ def create_credential(
) )
db_session.add(credential) db_session.add(credential)
db_session.flush() # This ensures the credential gets an ID db_session.flush() # This ensures the credential gets an ID
_relate_credential_to_user_groups__no_commit( _relate_credential_to_user_groups__no_commit(
db_session=db_session, db_session=db_session,
credential_id=credential.id, credential_id=credential.id,

View File

@ -295,30 +295,32 @@ async def get_async_session_with_tenant(
def get_session_with_tenant( def get_session_with_tenant(
tenant_id: str | None = None, tenant_id: str | None = None,
) -> Generator[Session, None, None]: ) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set.""" """Generate a database session bound to a connection with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine() engine = get_sqlalchemy_engine()
if tenant_id is None: if tenant_id is None:
tenant_id = current_tenant_id.get() tenant_id = current_tenant_id.get()
else:
current_tenant_id.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id): if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID") raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Establish a raw connection without starting a transaction # Establish a raw connection
with engine.connect() as connection: with engine.connect() as connection:
# Access the raw DBAPI connection # Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection dbapi_connection = connection.connection
# Execute SET search_path outside of any transaction # Set the search_path outside of any transaction
cursor = dbapi_connection.cursor() cursor = dbapi_connection.cursor()
try: try:
cursor.execute(f'SET search_path TO "{tenant_id}"') cursor.execute(f'SET search_path = "{tenant_id}"')
# Optionally verify the search_path was set correctly
cursor.execute("SHOW search_path")
cursor.fetchone()
finally: finally:
cursor.close() cursor.close()
# Proceed to create a session using the connection # Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session: with Session(bind=connection, expire_on_commit=False) as session:
try: try:
yield session yield session
@ -332,6 +334,18 @@ def get_session_with_tenant(
cursor.close() cursor.close()
def set_search_path_on_checkout(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
tenant_id = current_tenant_id.get()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
logger.debug(
f"Set search_path to {tenant_id} for connection {connection_record}"
)
def get_session_generator_with_tenant( def get_session_generator_with_tenant(
tenant_id: str | None = None, tenant_id: str | None = None,
) -> Generator[Session, None, None]: ) -> Generator[Session, None, None]:

View File

@ -28,7 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore): class PgRedisKVStore(KeyValueStore):
def __init__(self) -> None: def __init__(self) -> None:
self.redis_client = get_redis_client() tenant_id = current_tenant_id.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
@contextmanager @contextmanager
def get_session(self) -> Iterator[Session]: def get_session(self) -> Iterator[Session]:

View File

@ -1,4 +1,7 @@
import functools
import threading import threading
from collections.abc import Callable
from typing import Any
from typing import Optional from typing import Optional
import redis import redis
@ -14,6 +17,72 @@ from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
from danswer.utils.logger import setup_logger
logger = setup_logger()
class TenantRedis(redis.Redis):
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.tenant_id: str = tenant_id
def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview:
prefix: str = f"{self.tenant_id}:"
if isinstance(key, str):
if key.startswith(prefix):
return key
else:
return prefix + key
elif isinstance(key, bytes):
prefix_bytes = prefix.encode()
if key.startswith(prefix_bytes):
return key
else:
return prefix_bytes + key
elif isinstance(key, memoryview):
key_bytes = key.tobytes()
prefix_bytes = prefix.encode()
if key_bytes.startswith(prefix_bytes):
return key
else:
return memoryview(prefix_bytes + key_bytes)
else:
raise TypeError(f"Unsupported key type: {type(key)}")
def _prefix_method(self, method: Callable) -> Callable:
@functools.wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if "name" in kwargs:
kwargs["name"] = self._prefixed(kwargs["name"])
elif len(args) > 0:
args = (self._prefixed(args[0]),) + args[1:]
return method(*args, **kwargs)
return wrapper
def __getattribute__(self, item: str) -> Any:
original_attr = super().__getattribute__(item)
methods_to_wrap = [
"lock",
"unlock",
"get",
"set",
"delete",
"exists",
"incrby",
"hset",
"hget",
"getset",
"scan_iter",
"owned",
"reacquire",
"create_lock",
"startswith",
] # Add all methods that need prefixing
if item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)
return original_attr
class RedisPool: class RedisPool:
@ -32,8 +101,10 @@ class RedisPool:
def _init_pool(self) -> None: def _init_pool(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL) self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
def get_client(self) -> Redis: def get_client(self, tenant_id: str | None) -> Redis:
return redis.Redis(connection_pool=self._pool) if tenant_id is None:
tenant_id = "public"
return TenantRedis(tenant_id, connection_pool=self._pool)
@staticmethod @staticmethod
def create_pool( def create_pool(
@ -84,8 +155,8 @@ class RedisPool:
redis_pool = RedisPool() redis_pool = RedisPool()
def get_redis_client() -> Redis: def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client() return redis_pool.get_client(tenant_id)
# # Usage example # # Usage example

View File

@ -24,6 +24,7 @@ from danswer.db.connector_credential_pair import (
) )
from danswer.db.document import get_document_counts_for_cc_pairs from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import current_tenant_id from danswer.db.engine import current_tenant_id
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.enums import AccessType from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.enums import ConnectorCredentialPairStatus
@ -90,6 +91,7 @@ def get_cc_pair_full_info(
cc_pair_id: int, cc_pair_id: int,
user: User | None = Depends(current_curator_or_admin_user), user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> CCPairFullInfo: ) -> CCPairFullInfo:
cc_pair = get_connector_credential_pair_from_id( cc_pair = get_connector_credential_pair_from_id(
cc_pair_id, db_session, user, get_editable=False cc_pair_id, db_session, user, get_editable=False
@ -136,6 +138,7 @@ def get_cc_pair_full_info(
connector_id=cc_pair.connector_id, connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id, credential_id=cc_pair.credential_id,
db_session=db_session, db_session=db_session,
tenant_id=tenant_id,
), ),
num_docs_indexed=documents_indexed, num_docs_indexed=documents_indexed,
is_editable_for_current_user=is_editable_for_current_user, is_editable_for_current_user=is_editable_for_current_user,
@ -231,6 +234,7 @@ def prune_cc_pair(
cc_pair_id: int, cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user), user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]: ) -> StatusResponse[list[int]]:
"""Triggers pruning on a particular cc_pair immediately""" """Triggers pruning on a particular cc_pair immediately"""
@ -246,7 +250,7 @@ def prune_cc_pair(
detail="Connection not found for current user's permissions", detail="Connection not found for current user's permissions",
) )
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
rcp = RedisConnectorPruning(cc_pair_id) rcp = RedisConnectorPruning(cc_pair_id)
if rcp.is_pruning(db_session, r): if rcp.is_pruning(db_session, r):
raise HTTPException( raise HTTPException(

View File

@ -482,10 +482,11 @@ def get_connector_indexing_status(
get_editable: bool = Query( get_editable: bool = Query(
False, description="If true, return editable document sets" False, description="If true, return editable document sets"
), ),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> list[ConnectorIndexingStatus]: ) -> list[ConnectorIndexingStatus]:
indexing_statuses: list[ConnectorIndexingStatus] = [] indexing_statuses: list[ConnectorIndexingStatus] = []
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
# NOTE: If the connector is deleting behind the scenes, # NOTE: If the connector is deleting behind the scenes,
# accessing cc_pairs can be inconsistent and members like # accessing cc_pairs can be inconsistent and members like
@ -606,6 +607,7 @@ def get_connector_indexing_status(
connector_id=connector.id, connector_id=connector.id,
credential_id=credential.id, credential_id=credential.id,
db_session=db_session, db_session=db_session,
tenant_id=tenant_id,
), ),
is_deletable=check_deletion_attempt_is_allowed( is_deletable=check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair, connector_credential_pair=cc_pair,
@ -683,15 +685,18 @@ def create_connector_with_mock_credential(
connector_response = create_connector( connector_response = create_connector(
db_session=db_session, connector_data=connector_data db_session=db_session, connector_data=connector_data
) )
mock_credential = CredentialBase( mock_credential = CredentialBase(
credential_json={}, admin_public=True, source=connector_data.source credential_json={}, admin_public=True, source=connector_data.source
) )
credential = create_credential( credential = create_credential(
mock_credential, user=user, db_session=db_session mock_credential, user=user, db_session=db_session
) )
access_type = ( access_type = (
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
) )
response = add_credential_to_connector( response = add_credential_to_connector(
db_session=db_session, db_session=db_session,
user=user, user=user,
@ -775,7 +780,7 @@ def connector_run_once(
"""Used to trigger indexing on a set of cc_pairs associated with a """Used to trigger indexing on a set of cc_pairs associated with a
single connector.""" single connector."""
r = get_redis_client() r = get_redis_client(tenant_id=tenant_id)
connector_id = run_info.connector_id connector_id = run_info.connector_id
specified_credential_ids = run_info.credential_ids specified_credential_ids = run_info.credential_ids

View File

@ -42,7 +42,9 @@ global_version.set_ee()
@build_celery_task_wrapper(name_sync_external_doc_permissions_task) @build_celery_task_wrapper(name_sync_external_doc_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT) @celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None: def sync_external_doc_permissions_task(
cc_pair_id: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@ -50,7 +52,7 @@ def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -
@build_celery_task_wrapper(name_sync_external_group_permissions_task) @build_celery_task_wrapper(name_sync_external_group_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT) @celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_group_permissions_task( def sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None cc_pair_id: int, *, tenant_id: str | None
) -> None: ) -> None:
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@ -59,7 +61,7 @@ def sync_external_group_permissions_task(
@build_celery_task_wrapper(name_chat_ttl_task) @build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT) @celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task( def perform_ttl_management_task(
retention_limit_days: int, tenant_id: str | None retention_limit_days: int, *, tenant_id: str | None
) -> None: ) -> None:
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session) delete_chat_sessions_older_than(retention_limit_days, db_session)
@ -72,7 +74,7 @@ def perform_ttl_management_task(
name="check_sync_external_doc_permissions_task", name="check_sync_external_doc_permissions_task",
soft_time_limit=JOB_TIMEOUT, soft_time_limit=JOB_TIMEOUT,
) )
def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None: def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None:
"""Runs periodically to sync external permissions""" """Runs periodically to sync external permissions"""
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session) cc_pairs = get_all_auto_sync_cc_pairs(db_session)
@ -89,7 +91,7 @@ def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
name="check_sync_external_group_permissions_task", name="check_sync_external_group_permissions_task",
soft_time_limit=JOB_TIMEOUT, soft_time_limit=JOB_TIMEOUT,
) )
def check_sync_external_group_permissions_task(tenant_id: str | None) -> None: def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None:
"""Runs periodically to sync external group permissions""" """Runs periodically to sync external group permissions"""
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session) cc_pairs = get_all_auto_sync_cc_pairs(db_session)
@ -106,7 +108,7 @@ def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
name="check_ttl_management_task", name="check_ttl_management_task",
soft_time_limit=JOB_TIMEOUT, soft_time_limit=JOB_TIMEOUT,
) )
def check_ttl_management_task(tenant_id: str | None) -> None: def check_ttl_management_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them """Runs periodically to check if any ttl tasks should be run and adds them
to the queue""" to the queue"""
token = None token = None
@ -130,7 +132,7 @@ def check_ttl_management_task(tenant_id: str | None) -> None:
name="autogenerate_usage_report_task", name="autogenerate_usage_report_task",
soft_time_limit=JOB_TIMEOUT, soft_time_limit=JOB_TIMEOUT,
) )
def autogenerate_usage_report_task(tenant_id: str | None) -> None: def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint""" """This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
create_new_usage_report( create_new_usage_report(
@ -179,7 +181,7 @@ for tenant_id in tenant_ids:
beat_schedule[task_name] = { beat_schedule[task_name] = {
"task": task["task"], "task": task["task"],
"schedule": task["schedule"], "schedule": task["schedule"],
"args": (tenant_id,), # Must pass tenant_id as an argument "kwargs": {"tenant_id": tenant_id}, # Pass tenant_id as a keyword argument
} }
# Include any existing beat schedules # Include any existing beat schedules