add multi tenancy to redis

This commit is contained in:
pablodanswer 2024-10-23 13:29:26 -07:00
parent 14e75bbd24
commit 336f931c85
14 changed files with 324 additions and 141 deletions

View File

@ -19,6 +19,7 @@ from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_all_tenant_ids
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
@ -56,7 +57,7 @@ def on_task_postrun(
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
kwargs: dict[str, Any] | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
@ -83,7 +84,19 @@ def on_task_postrun(
if not task_id:
return
r = get_redis_client()
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
else:
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
@ -124,7 +137,7 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
r = get_redis_client()
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@ -157,26 +170,44 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
r = get_redis_client()
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
logger.info("Running as a secondary celery worker.")
logger.info("Waiting for primary worker to be ready...")
logger.info("Waiting for all tenant primary workers to be ready...")
time_start = time.monotonic()
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
tenant_ids = get_all_tenant_ids()
# Check if we have a primary worker lock for each tenant
all_tenants_ready = all(
get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
for tenant_id in tenant_ids
)
if all_tenants_ready:
break
time.monotonic()
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
ready_tenants = sum(
1
for tenant_id in tenant_ids
if get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
)
logger.info(
f"Not all tenant primary workers are ready yet. "
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"Not all tenant primary workers were ready within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
@ -184,7 +215,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
logger.info("All tenant primary workers are ready. Continuing...")
return
@ -196,14 +227,14 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not sender.primary_worker_lock:
if not hasattr(sender, "primary_worker_locks"):
return
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
if lock.owned():
lock.release()
sender.primary_worker_lock = None
for tenant_id, lock in sender.primary_worker_locks.items():
if lock and lock.owned():
logger.debug(f"Releasing lock for tenant {tenant_id}")
lock.release()
sender.primary_worker_locks[tenant_id] = None
def on_setup_logging(

View File

@ -1,7 +1,6 @@
import multiprocessing
from typing import Any
import redis
from celery import bootsteps # type: ignore
from celery import Celery
from celery import signals
@ -24,6 +23,7 @@ from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
@ -80,81 +80,83 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client()
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
sender.primary_worker_lock = lock
sender.primary_worker_locks[tenant_id] = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect
@ -217,42 +219,58 @@ class HubPeriodicTask(bootsteps.StartStopStep):
def run_periodic_task(self, worker: Any) -> None:
try:
if not worker.primary_worker_lock:
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_lock"):
if not hasattr(worker, "primary_worker_locks"):
return
r = get_redis_client()
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
lock: redis.lock.Lock = worker.primary_worker_lock
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be computer sleep or a clock change."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
r = get_redis_client(tenant_id=tenant_id)
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
worker.primary_worker_lock = lock
except Exception:
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception as e:
task_logger.error(f"Error in periodic task: {e}")
def stop(self, worker: Any) -> None:
# Cancel the scheduled task when the worker stops

View File

@ -27,7 +27,10 @@ logger = setup_logger()
def _get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
@ -40,7 +43,7 @@ def _get_deletion_status(
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key):
return None
@ -50,9 +53,14 @@ def _get_deletion_status(
def get_deletion_attempt_snapshot(
connector_id: int, credential_id: int, db_session: Session
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
if not deletion_task:
return None

View File

@ -24,8 +24,8 @@ from danswer.redis.redis_pool import get_redis_client
trail=False,
bind=True,
)
def check_for_connector_deletion_task(self: Task, tenant_id: str | None) -> None:
r = get_redis_client()
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,

View File

@ -55,10 +55,10 @@ logger = setup_logger()
soft_time_limit=300,
bind=True,
)
def check_for_indexing(self: Task, tenant_id: str | None) -> int | None:
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
@ -68,6 +68,7 @@ def check_for_indexing(self: Task, tenant_id: str | None) -> int | None:
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
return None
cc_pair_ids: list[int] = []
@ -398,7 +399,7 @@ def connector_indexing_task(
attempt = None
n_final_progress = 0
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)

View File

@ -41,8 +41,8 @@ logger = setup_logger()
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_pruning(self: Task, tenant_id: str | None) -> None:
r = get_redis_client()
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@ -222,7 +222,7 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
rcp = RedisConnectorPruning(cc_pair_id)

View File

@ -60,6 +60,7 @@ from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
@ -67,6 +68,8 @@ from danswer.utils.variable_functionality import (
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@ -76,11 +79,11 @@ from danswer.utils.variable_functionality import noop_fallback
trail=False,
bind=True,
)
def check_for_vespa_sync_task(self: Task, tenant_id: str | None) -> None:
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@ -680,7 +683,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
Returns True if the task actually did work, False
"""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
lock_beat: redis.lock.Lock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,

View File

@ -242,7 +242,6 @@ def create_credential(
)
db_session.add(credential)
db_session.flush() # This ensures the credential gets an ID
_relate_credential_to_user_groups__no_commit(
db_session=db_session,
credential_id=credential.id,

View File

@ -319,30 +319,32 @@ async def get_async_session_with_tenant(
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
if tenant_id is None:
tenant_id = current_tenant_id.get()
else:
current_tenant_id.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Establish a raw connection without starting a transaction
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Execute SET search_path outside of any transaction
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path TO "{tenant_id}"')
# Optionally verify the search_path was set correctly
cursor.execute("SHOW search_path")
cursor.fetchone()
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
cursor.close()
# Proceed to create a session using the connection
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
@ -356,6 +358,18 @@ def get_session_with_tenant(
cursor.close()
def set_search_path_on_checkout(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
tenant_id = current_tenant_id.get()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
logger.debug(
f"Set search_path to {tenant_id} for connection {connection_record}"
)
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
tenant_id = current_tenant_id.get()
with get_session_with_tenant(tenant_id) as session:

View File

@ -28,7 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self) -> None:
self.redis_client = get_redis_client()
tenant_id = current_tenant_id.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
@contextmanager
def get_session(self) -> Iterator[Session]:

View File

@ -1,4 +1,7 @@
import functools
import threading
from collections.abc import Callable
from typing import Any
from typing import Optional
import redis
@ -14,6 +17,98 @@ from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
from danswer.utils.logger import setup_logger
logger = setup_logger()
class TenantRedis(redis.Redis):
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.tenant_id: str = tenant_id
def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview:
prefix: str = f"{self.tenant_id}:"
if isinstance(key, str):
if key.startswith(prefix):
return key
else:
return prefix + key
elif isinstance(key, bytes):
prefix_bytes = prefix.encode()
if key.startswith(prefix_bytes):
return key
else:
return prefix_bytes + key
elif isinstance(key, memoryview):
key_bytes = key.tobytes()
prefix_bytes = prefix.encode()
if key_bytes.startswith(prefix_bytes):
return key
else:
return memoryview(prefix_bytes + key_bytes)
else:
raise TypeError(f"Unsupported key type: {type(key)}")
def _prefix_method(self, method: Callable) -> Callable:
@functools.wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if "name" in kwargs:
kwargs["name"] = self._prefixed(kwargs["name"])
elif len(args) > 0:
args = (self._prefixed(args[0]),) + args[1:]
return method(*args, **kwargs)
return wrapper
def _prefix_scan_iter(self, method: Callable) -> Callable:
@functools.wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# Prefix the match pattern if provided
if "match" in kwargs:
kwargs["match"] = self._prefixed(kwargs["match"])
elif len(args) > 0:
args = (self._prefixed(args[0]),) + args[1:]
# Get the iterator
iterator = method(*args, **kwargs)
# Remove prefix from returned keys
prefix = f"{self.tenant_id}:".encode()
prefix_len = len(prefix)
for key in iterator:
if isinstance(key, bytes) and key.startswith(prefix):
yield key[prefix_len:]
else:
yield key
return wrapper
def __getattribute__(self, item: str) -> Any:
original_attr = super().__getattribute__(item)
methods_to_wrap = [
"lock",
"unlock",
"get",
"set",
"delete",
"exists",
"incrby",
"hset",
"hget",
"getset",
"owned",
"reacquire",
"create_lock",
"startswith",
] # Regular methods that need simple prefixing
if item == "scan_iter":
return self._prefix_scan_iter(original_attr)
elif item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)
return original_attr
class RedisPool:
@ -32,8 +127,10 @@ class RedisPool:
def _init_pool(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
def get_client(self) -> Redis:
return redis.Redis(connection_pool=self._pool)
def get_client(self, tenant_id: str | None) -> Redis:
if tenant_id is None:
tenant_id = "public"
return TenantRedis(tenant_id, connection_pool=self._pool)
@staticmethod
def create_pool(
@ -84,8 +181,8 @@ class RedisPool:
redis_pool = RedisPool()
def get_redis_client() -> Redis:
return redis_pool.get_client()
def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)
# # Usage example

View File

@ -26,6 +26,7 @@ from danswer.db.connector_credential_pair import (
)
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import current_tenant_id
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
@ -94,6 +95,7 @@ def get_cc_pair_full_info(
cc_pair_id: int,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> CCPairFullInfo:
r = get_redis_client()
@ -147,6 +149,7 @@ def get_cc_pair_full_info(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
tenant_id=tenant_id,
),
num_docs_indexed=documents_indexed,
is_editable_for_current_user=is_editable_for_current_user,
@ -243,6 +246,7 @@ def prune_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Triggers pruning on a particular cc_pair immediately"""
@ -258,7 +262,7 @@ def prune_cc_pair(
detail="Connection not found for current user's permissions",
)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
rcp = RedisConnectorPruning(cc_pair_id)
if rcp.is_pruning(r):
raise HTTPException(

View File

@ -483,10 +483,11 @@ def get_connector_indexing_status(
get_editable: bool = Query(
False, description="If true, return editable document sets"
),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> list[ConnectorIndexingStatus]:
indexing_statuses: list[ConnectorIndexingStatus] = []
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# NOTE: If the connector is deleting behind the scenes,
# accessing cc_pairs can be inconsistent and members like
@ -607,6 +608,7 @@ def get_connector_indexing_status(
connector_id=connector.id,
credential_id=credential.id,
db_session=db_session,
tenant_id=tenant_id,
),
is_deletable=check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair,
@ -684,15 +686,18 @@ def create_connector_with_mock_credential(
connector_response = create_connector(
db_session=db_session, connector_data=connector_data
)
mock_credential = CredentialBase(
credential_json={}, admin_public=True, source=connector_data.source
)
credential = create_credential(
mock_credential, user=user, db_session=db_session
)
access_type = (
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
)
response = add_credential_to_connector(
db_session=db_session,
user=user,
@ -776,7 +781,7 @@ def connector_run_once(
"""Used to trigger indexing on a set of cc_pairs associated with a
single connector."""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
connector_id = run_info.connector_id
specified_credential_ids = run_info.credential_ids

View File

@ -39,7 +39,9 @@ global_version.set_ee()
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None:
def sync_external_doc_permissions_task(
cc_pair_id: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@ -47,7 +49,7 @@ def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None
cc_pair_id: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@ -56,7 +58,7 @@ def sync_external_group_permissions_task(
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(
retention_limit_days: int, tenant_id: str | None
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
@ -69,7 +71,7 @@ def perform_ttl_management_task(
name="check_sync_external_doc_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None:
"""Runs periodically to sync external permissions"""
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
@ -86,7 +88,7 @@ def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
name="check_sync_external_group_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None:
"""Runs periodically to sync external group permissions"""
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
@ -103,7 +105,7 @@ def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
name="check_ttl_management_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(tenant_id: str | None) -> None:
def check_ttl_management_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
@ -127,7 +129,7 @@ def check_ttl_management_task(tenant_id: str | None) -> None:
name="autogenerate_usage_report_task",
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(tenant_id: str | None) -> None:
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_tenant(tenant_id) as db_session:
create_new_usage_report(