This commit is contained in:
pablodanswer 2024-10-18 17:57:56 -07:00
parent f745ca1e03
commit 802dc00f78
11 changed files with 202 additions and 52 deletions

View File

@ -4,7 +4,6 @@ import time
from datetime import timedelta
from typing import Any
import redis
import sentry_sdk
from celery import bootsteps # type: ignore
from celery import Celery
@ -79,6 +78,7 @@ def on_task_prerun(
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
tenant_id: str | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
@ -118,7 +118,7 @@ def on_task_postrun(
if not task_id:
return
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
@ -171,7 +171,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
# decide some initial startup settings based on the celery worker's hostname
# (set at the command line)
# (set at the command line)'
tenant_id = kwargs.get("tenant_id")
hostname = sender.hostname
if hostname.startswith("light"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
@ -207,7 +209,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@ -267,7 +269,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
@ -449,17 +451,18 @@ def on_setup_logging(
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker lock outside of the task queue.
"""Regularly reacquires the primary worker locks for all tenants outside of the task queue.
Use the task_logger in this class to avoid double logging.
This cannot be done inside a regular beat task because it must run on schedule and
a queue of existing work would starve the task from running.
"""
# it's unclear to me whether using the hub's timer or the bootstep timer is better
# Requires the Hub component
requires = {"celery.worker.components:Hub"}
def __init__(self, worker: Any, **kwargs: Any) -> None:
super().__init__(worker, **kwargs)
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
self.task_tref = None
@ -478,42 +481,58 @@ class HubPeriodicTask(bootsteps.StartStopStep):
def run_periodic_task(self, worker: Any) -> None:
try:
if not worker.primary_worker_lock:
if not celery_is_worker_primary(worker):
return
if not hasattr(worker, "primary_worker_lock"):
if not hasattr(worker, "primary_worker_locks"):
return
r = get_redis_client()
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
lock: redis.lock.Lock = worker.primary_worker_lock
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be computer sleep or a clock change."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
r = get_redis_client(tenant_id=tenant_id)
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
else:
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
worker.primary_worker_lock = lock
except Exception:
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception as e:
task_logger.error(f"Error in periodic task: {e}")
def stop(self, worker: Any) -> None:
# Cancel the scheduled task when the worker stops

View File

@ -31,7 +31,10 @@ logger = setup_logger()
def _get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
@ -44,7 +47,7 @@ def _get_deletion_status(
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
if not r.exists(rcd.fence_key):
return None
@ -54,9 +57,14 @@ def _get_deletion_status(
def get_deletion_attempt_snapshot(
connector_id: int, credential_id: int, db_session: Session
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
if not deletion_task:
return None

View File

@ -24,7 +24,7 @@ from danswer.redis.redis_pool import get_redis_client
trail=False,
)
def check_for_connector_deletion_task(tenant_id: str | None) -> None:
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,

View File

@ -64,7 +64,10 @@ def check_for_indexing(tenant_id: str | None) -> int | None:
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
return None
else:
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings

View File

@ -39,7 +39,7 @@ logger = setup_logger()
soft_time_limit=JOB_TIMEOUT,
)
def check_for_pruning(tenant_id: str | None) -> None:
r = get_redis_client()
r = get_redis_client(tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@ -204,7 +204,7 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
r = get_redis_client()
r = get_redis_client(tenant_id)
rcp = RedisConnectorPruning(cc_pair_id)

View File

@ -78,7 +78,7 @@ def check_for_vespa_sync_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = get_redis_client()
r = get_redis_client(tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@ -640,7 +640,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
Returns True if the task actually did work, False
"""
r = get_redis_client()
r = get_redis_client(tenant_id)
lock_beat: redis.lock.Lock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,

View File

@ -41,7 +41,7 @@ class ConfluenceRateLimitError(Exception):
# # for testing purposes, rate limiting is written to fall back to a simpler
# # rate limiting approach when redis is not available
# r = get_redis_client()
# r = get_redis_client(tenant_id=tenant_id)
# for attempt in range(max_retries):
# try:

View File

@ -28,7 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self) -> None:
self.redis_client = get_redis_client()
tenant_id = current_tenant_id.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
@contextmanager
def get_session(self) -> Iterator[Session]:

View File

@ -1,8 +1,16 @@
import threading
from typing import Any
from typing import cast
from typing import Optional
from typing import Union
import redis
from redis.client import Redis
from redis.typing import AbsExpiryT
from redis.typing import EncodableT
from redis.typing import ExpiryT
from redis.typing import KeyT
from redis.typing import ResponseT
from danswer.configs.app_configs import REDIS_DB_NUMBER
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
@ -16,6 +24,108 @@ from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
# TODO: enforce typing strictly
class TenantRedis(redis.Redis):
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.tenant_id = tenant_id
def _prefixed(
self, key: Union[str, bytes, memoryview]
) -> Union[str, bytes, memoryview]:
prefix = f"{self.tenant_id}:"
if isinstance(key, str):
return prefix + key
elif isinstance(key, bytes):
return prefix.encode() + key
elif isinstance(key, memoryview):
return memoryview(prefix.encode() + key.tobytes())
else:
raise TypeError(f"Unsupported key type: {type(key)}")
def lock(
self,
name: str,
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[float] = None,
lock_class: Union[None, Any] = None,
thread_local: bool = True,
) -> Any:
prefixed_name = cast(str, self._prefixed(name))
return super().lock(
prefixed_name,
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
lock_class=lock_class,
thread_local=thread_local,
)
def incrby(self, name: KeyT, amount: int = 1) -> ResponseT:
"""
Increments the value of ``key`` by ``amount``. If no key exists,
the value will be initialized as ``amount``
For more information see https://redis.io/commands/incrby
"""
prefixed_name = self._prefixed(name)
return super().incrby(prefixed_name, amount)
def set(
self,
name: KeyT,
value: EncodableT,
ex: Union[ExpiryT, None] = None,
px: Union[ExpiryT, None] = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: Union[AbsExpiryT, None] = None,
pxat: Union[AbsExpiryT, None] = None,
) -> ResponseT:
prefixed_name = self._prefixed(name)
return super().set(
prefixed_name,
value,
ex=ex,
px=px,
nx=nx,
xx=xx,
keepttl=keepttl,
get=get,
exat=exat,
pxat=pxat,
)
def get(self, name: KeyT) -> ResponseT:
prefixed_name = self._prefixed(name)
return super().get(prefixed_name)
def delete(self, *names: KeyT) -> ResponseT:
prefixed_names = [self._prefixed(name) for name in names]
return super().delete(*prefixed_names)
def exists(self, *names: KeyT) -> ResponseT:
prefixed_names = [self._prefixed(name) for name in names]
return super().exists(*prefixed_names)
# def expire(self, name: str, time: int, **kwargs: Any) -> Any:
# prefixed_name = self._prefixed(name)
# return super().expire(prefixed_name, time, **kwargs)
# def ttl(self, name: str, **kwargs: Any) -> Any:
# prefixed_name = self._prefixed(name)
# return super().ttl(prefixed_name, **kwargs)
# def type(self, name: str, **kwargs: Any) -> Any:
# prefixed_name = self._prefixed(name)
# return super().type(prefixed_name, **kwargs)
class RedisPool:
_instance: Optional["RedisPool"] = None
_lock: threading.Lock = threading.Lock()
@ -32,8 +142,11 @@ class RedisPool:
def _init_pool(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
def get_client(self) -> Redis:
return redis.Redis(connection_pool=self._pool)
def get_client(self, tenant_id: str | None) -> Redis:
if tenant_id is not None:
return TenantRedis(tenant_id, connection_pool=self._pool)
else:
return redis.Redis(connection_pool=self._pool)
@staticmethod
def create_pool(
@ -84,8 +197,8 @@ class RedisPool:
redis_pool = RedisPool()
def get_redis_client() -> Redis:
return redis_pool.get_client()
def get_redis_client(tenant_id: str | None = None) -> Redis:
return redis_pool.get_client(tenant_id)
# # Usage example

View File

@ -24,6 +24,7 @@ from danswer.db.connector_credential_pair import (
)
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import current_tenant_id
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
@ -90,6 +91,7 @@ def get_cc_pair_full_info(
cc_pair_id: int,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> CCPairFullInfo:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id, db_session, user, get_editable=False
@ -136,6 +138,7 @@ def get_cc_pair_full_info(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
tenant_id=tenant_id,
),
num_docs_indexed=documents_indexed,
is_editable_for_current_user=is_editable_for_current_user,
@ -231,6 +234,7 @@ def prune_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Triggers pruning on a particular cc_pair immediately"""
@ -246,7 +250,7 @@ def prune_cc_pair(
detail="Connection not found for current user's permissions",
)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
rcp = RedisConnectorPruning(cc_pair_id)
if rcp.is_pruning(db_session, r):
raise HTTPException(

View File

@ -482,10 +482,11 @@ def get_connector_indexing_status(
get_editable: bool = Query(
False, description="If true, return editable document sets"
),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> list[ConnectorIndexingStatus]:
indexing_statuses: list[ConnectorIndexingStatus] = []
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# NOTE: If the connector is deleting behind the scenes,
# accessing cc_pairs can be inconsistent and members like
@ -606,6 +607,7 @@ def get_connector_indexing_status(
connector_id=connector.id,
credential_id=credential.id,
db_session=db_session,
tenant_id=tenant_id,
),
is_deletable=check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair,
@ -775,7 +777,7 @@ def connector_run_once(
"""Used to trigger indexing on a set of cc_pairs associated with a
single connector."""
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
connector_id = run_info.connector_id
specified_credential_ids = run_info.credential_ids