mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-07 21:50:17 +02:00
k
This commit is contained in:
parent
f745ca1e03
commit
802dc00f78
@ -4,7 +4,6 @@ import time
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import redis
|
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from celery import bootsteps # type: ignore
|
from celery import bootsteps # type: ignore
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
@ -79,6 +78,7 @@ def on_task_prerun(
|
|||||||
task_id: str | None = None,
|
task_id: str | None = None,
|
||||||
task: Task | None = None,
|
task: Task | None = None,
|
||||||
args: tuple | None = None,
|
args: tuple | None = None,
|
||||||
|
tenant_id: str | None = None,
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
**kwds: Any,
|
**kwds: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -118,7 +118,7 @@ def on_task_postrun(
|
|||||||
if not task_id:
|
if not task_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
|
|
||||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||||
@ -171,7 +171,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
|||||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||||
|
|
||||||
# decide some initial startup settings based on the celery worker's hostname
|
# decide some initial startup settings based on the celery worker's hostname
|
||||||
# (set at the command line)
|
# (set at the command line)'
|
||||||
|
tenant_id = kwargs.get("tenant_id")
|
||||||
|
|
||||||
hostname = sender.hostname
|
hostname = sender.hostname
|
||||||
if hostname.startswith("light"):
|
if hostname.startswith("light"):
|
||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||||
@ -207,7 +209,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
|||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||||
|
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
|
|
||||||
WAIT_INTERVAL = 5
|
WAIT_INTERVAL = 5
|
||||||
WAIT_LIMIT = 60
|
WAIT_LIMIT = 60
|
||||||
@ -267,7 +269,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
|||||||
|
|
||||||
# This is singleton work that should be done on startup exactly once
|
# This is singleton work that should be done on startup exactly once
|
||||||
# by the primary worker
|
# by the primary worker
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
|
|
||||||
# For the moment, we're assuming that we are the only primary worker
|
# For the moment, we're assuming that we are the only primary worker
|
||||||
# that should be running.
|
# that should be running.
|
||||||
@ -449,17 +451,18 @@ def on_setup_logging(
|
|||||||
|
|
||||||
|
|
||||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
"""Regularly reacquires the primary worker locks for all tenants outside of the task queue.
|
||||||
Use the task_logger in this class to avoid double logging.
|
Use the task_logger in this class to avoid double logging.
|
||||||
|
|
||||||
This cannot be done inside a regular beat task because it must run on schedule and
|
This cannot be done inside a regular beat task because it must run on schedule and
|
||||||
a queue of existing work would starve the task from running.
|
a queue of existing work would starve the task from running.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# it's unclear to me whether using the hub's timer or the bootstep timer is better
|
# Requires the Hub component
|
||||||
requires = {"celery.worker.components:Hub"}
|
requires = {"celery.worker.components:Hub"}
|
||||||
|
|
||||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||||
|
super().__init__(worker, **kwargs)
|
||||||
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
||||||
self.task_tref = None
|
self.task_tref = None
|
||||||
|
|
||||||
@ -478,42 +481,58 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
|||||||
|
|
||||||
def run_periodic_task(self, worker: Any) -> None:
|
def run_periodic_task(self, worker: Any) -> None:
|
||||||
try:
|
try:
|
||||||
if not worker.primary_worker_lock:
|
if not celery_is_worker_primary(worker):
|
||||||
return
|
return
|
||||||
|
|
||||||
if not hasattr(worker, "primary_worker_lock"):
|
if not hasattr(worker, "primary_worker_locks"):
|
||||||
return
|
return
|
||||||
|
|
||||||
r = get_redis_client()
|
# Retrieve all tenant IDs
|
||||||
|
tenant_ids = get_all_tenant_ids()
|
||||||
|
|
||||||
lock: redis.lock.Lock = worker.primary_worker_lock
|
for tenant_id in tenant_ids:
|
||||||
|
lock = worker.primary_worker_locks.get(tenant_id)
|
||||||
|
if not lock:
|
||||||
|
continue # Skip if no lock for this tenant
|
||||||
|
|
||||||
if lock.owned():
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
task_logger.debug("Reacquiring primary worker lock.")
|
|
||||||
lock.reacquire()
|
|
||||||
else:
|
|
||||||
task_logger.warning(
|
|
||||||
"Full acquisition of primary worker lock. "
|
|
||||||
"Reasons could be computer sleep or a clock change."
|
|
||||||
)
|
|
||||||
lock = r.lock(
|
|
||||||
DanswerRedisLocks.PRIMARY_WORKER,
|
|
||||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
task_logger.info("Primary worker lock: Acquire starting.")
|
if lock.owned():
|
||||||
acquired = lock.acquire(
|
task_logger.debug(
|
||||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
f"Reacquiring primary worker lock for tenant {tenant_id}."
|
||||||
)
|
)
|
||||||
if acquired:
|
lock.reacquire()
|
||||||
task_logger.info("Primary worker lock: Acquire succeeded.")
|
|
||||||
else:
|
else:
|
||||||
task_logger.error("Primary worker lock: Acquire failed!")
|
task_logger.warning(
|
||||||
raise TimeoutError("Primary worker lock could not be acquired!")
|
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
|
||||||
|
"Reasons could be worker restart or lock expiration."
|
||||||
|
)
|
||||||
|
lock = r.lock(
|
||||||
|
DanswerRedisLocks.PRIMARY_WORKER,
|
||||||
|
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
worker.primary_worker_lock = lock
|
task_logger.info(
|
||||||
except Exception:
|
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
|
||||||
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
|
)
|
||||||
|
acquired = lock.acquire(
|
||||||
|
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||||
|
)
|
||||||
|
if acquired:
|
||||||
|
task_logger.info(
|
||||||
|
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
|
||||||
|
)
|
||||||
|
worker.primary_worker_locks[tenant_id] = lock
|
||||||
|
else:
|
||||||
|
task_logger.error(
|
||||||
|
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
|
||||||
|
)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
task_logger.error(f"Error in periodic task: {e}")
|
||||||
|
|
||||||
def stop(self, worker: Any) -> None:
|
def stop(self, worker: Any) -> None:
|
||||||
# Cancel the scheduled task when the worker stops
|
# Cancel the scheduled task when the worker stops
|
||||||
|
@ -31,7 +31,10 @@ logger = setup_logger()
|
|||||||
|
|
||||||
|
|
||||||
def _get_deletion_status(
|
def _get_deletion_status(
|
||||||
connector_id: int, credential_id: int, db_session: Session
|
connector_id: int,
|
||||||
|
credential_id: int,
|
||||||
|
db_session: Session,
|
||||||
|
tenant_id: str | None = None,
|
||||||
) -> TaskQueueState | None:
|
) -> TaskQueueState | None:
|
||||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||||
This function populates TaskQueueState by just checking redis.
|
This function populates TaskQueueState by just checking redis.
|
||||||
@ -44,7 +47,7 @@ def _get_deletion_status(
|
|||||||
|
|
||||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||||
|
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
if not r.exists(rcd.fence_key):
|
if not r.exists(rcd.fence_key):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -54,9 +57,14 @@ def _get_deletion_status(
|
|||||||
|
|
||||||
|
|
||||||
def get_deletion_attempt_snapshot(
|
def get_deletion_attempt_snapshot(
|
||||||
connector_id: int, credential_id: int, db_session: Session
|
connector_id: int,
|
||||||
|
credential_id: int,
|
||||||
|
db_session: Session,
|
||||||
|
tenant_id: str | None = None,
|
||||||
) -> DeletionAttemptSnapshot | None:
|
) -> DeletionAttemptSnapshot | None:
|
||||||
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
|
deletion_task = _get_deletion_status(
|
||||||
|
connector_id, credential_id, db_session, tenant_id
|
||||||
|
)
|
||||||
if not deletion_task:
|
if not deletion_task:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from danswer.redis.redis_pool import get_redis_client
|
|||||||
trail=False,
|
trail=False,
|
||||||
)
|
)
|
||||||
def check_for_connector_deletion_task(tenant_id: str | None) -> None:
|
def check_for_connector_deletion_task(tenant_id: str | None) -> None:
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
|
|
||||||
lock_beat = r.lock(
|
lock_beat = r.lock(
|
||||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||||
|
@ -64,7 +64,10 @@ def check_for_indexing(tenant_id: str | None) -> int | None:
|
|||||||
try:
|
try:
|
||||||
# these tasks should never overlap
|
# these tasks should never overlap
|
||||||
if not lock_beat.acquire(blocking=False):
|
if not lock_beat.acquire(blocking=False):
|
||||||
|
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
|
||||||
return None
|
return None
|
||||||
|
else:
|
||||||
|
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
|
||||||
|
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
# Get the primary search settings
|
# Get the primary search settings
|
||||||
|
@ -39,7 +39,7 @@ logger = setup_logger()
|
|||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
)
|
)
|
||||||
def check_for_pruning(tenant_id: str | None) -> None:
|
def check_for_pruning(tenant_id: str | None) -> None:
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id)
|
||||||
|
|
||||||
lock_beat = r.lock(
|
lock_beat = r.lock(
|
||||||
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||||
@ -204,7 +204,7 @@ def connector_pruning_generator_task(
|
|||||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||||
from the most recently pulled document ID list"""
|
from the most recently pulled document ID list"""
|
||||||
|
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id)
|
||||||
|
|
||||||
rcp = RedisConnectorPruning(cc_pair_id)
|
rcp = RedisConnectorPruning(cc_pair_id)
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def check_for_vespa_sync_task(tenant_id: str | None) -> None:
|
|||||||
"""Runs periodically to check if any document needs syncing.
|
"""Runs periodically to check if any document needs syncing.
|
||||||
Generates sets of tasks for Celery if syncing is needed."""
|
Generates sets of tasks for Celery if syncing is needed."""
|
||||||
|
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id)
|
||||||
|
|
||||||
lock_beat = r.lock(
|
lock_beat = r.lock(
|
||||||
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||||
@ -640,7 +640,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
|||||||
|
|
||||||
Returns True if the task actually did work, False
|
Returns True if the task actually did work, False
|
||||||
"""
|
"""
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id)
|
||||||
|
|
||||||
lock_beat: redis.lock.Lock = r.lock(
|
lock_beat: redis.lock.Lock = r.lock(
|
||||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||||
|
@ -41,7 +41,7 @@ class ConfluenceRateLimitError(Exception):
|
|||||||
|
|
||||||
# # for testing purposes, rate limiting is written to fall back to a simpler
|
# # for testing purposes, rate limiting is written to fall back to a simpler
|
||||||
# # rate limiting approach when redis is not available
|
# # rate limiting approach when redis is not available
|
||||||
# r = get_redis_client()
|
# r = get_redis_client(tenant_id=tenant_id)
|
||||||
|
|
||||||
# for attempt in range(max_retries):
|
# for attempt in range(max_retries):
|
||||||
# try:
|
# try:
|
||||||
|
@ -28,7 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
|||||||
|
|
||||||
class PgRedisKVStore(KeyValueStore):
|
class PgRedisKVStore(KeyValueStore):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.redis_client = get_redis_client()
|
tenant_id = current_tenant_id.get()
|
||||||
|
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_session(self) -> Iterator[Session]:
|
def get_session(self) -> Iterator[Session]:
|
||||||
|
@ -1,8 +1,16 @@
|
|||||||
import threading
|
import threading
|
||||||
|
from typing import Any
|
||||||
|
from typing import cast
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from redis.client import Redis
|
from redis.client import Redis
|
||||||
|
from redis.typing import AbsExpiryT
|
||||||
|
from redis.typing import EncodableT
|
||||||
|
from redis.typing import ExpiryT
|
||||||
|
from redis.typing import KeyT
|
||||||
|
from redis.typing import ResponseT
|
||||||
|
|
||||||
from danswer.configs.app_configs import REDIS_DB_NUMBER
|
from danswer.configs.app_configs import REDIS_DB_NUMBER
|
||||||
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||||
@ -16,6 +24,108 @@ from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
|||||||
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: enforce typing strictly
|
||||||
|
class TenantRedis(redis.Redis):
|
||||||
|
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
|
||||||
|
def _prefixed(
|
||||||
|
self, key: Union[str, bytes, memoryview]
|
||||||
|
) -> Union[str, bytes, memoryview]:
|
||||||
|
prefix = f"{self.tenant_id}:"
|
||||||
|
if isinstance(key, str):
|
||||||
|
return prefix + key
|
||||||
|
elif isinstance(key, bytes):
|
||||||
|
return prefix.encode() + key
|
||||||
|
elif isinstance(key, memoryview):
|
||||||
|
return memoryview(prefix.encode() + key.tobytes())
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported key type: {type(key)}")
|
||||||
|
|
||||||
|
def lock(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
sleep: float = 0.1,
|
||||||
|
blocking: bool = True,
|
||||||
|
blocking_timeout: Optional[float] = None,
|
||||||
|
lock_class: Union[None, Any] = None,
|
||||||
|
thread_local: bool = True,
|
||||||
|
) -> Any:
|
||||||
|
prefixed_name = cast(str, self._prefixed(name))
|
||||||
|
return super().lock(
|
||||||
|
prefixed_name,
|
||||||
|
timeout=timeout,
|
||||||
|
sleep=sleep,
|
||||||
|
blocking=blocking,
|
||||||
|
blocking_timeout=blocking_timeout,
|
||||||
|
lock_class=lock_class,
|
||||||
|
thread_local=thread_local,
|
||||||
|
)
|
||||||
|
|
||||||
|
def incrby(self, name: KeyT, amount: int = 1) -> ResponseT:
|
||||||
|
"""
|
||||||
|
Increments the value of ``key`` by ``amount``. If no key exists,
|
||||||
|
the value will be initialized as ``amount``
|
||||||
|
|
||||||
|
For more information see https://redis.io/commands/incrby
|
||||||
|
"""
|
||||||
|
prefixed_name = self._prefixed(name)
|
||||||
|
return super().incrby(prefixed_name, amount)
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
name: KeyT,
|
||||||
|
value: EncodableT,
|
||||||
|
ex: Union[ExpiryT, None] = None,
|
||||||
|
px: Union[ExpiryT, None] = None,
|
||||||
|
nx: bool = False,
|
||||||
|
xx: bool = False,
|
||||||
|
keepttl: bool = False,
|
||||||
|
get: bool = False,
|
||||||
|
exat: Union[AbsExpiryT, None] = None,
|
||||||
|
pxat: Union[AbsExpiryT, None] = None,
|
||||||
|
) -> ResponseT:
|
||||||
|
prefixed_name = self._prefixed(name)
|
||||||
|
return super().set(
|
||||||
|
prefixed_name,
|
||||||
|
value,
|
||||||
|
ex=ex,
|
||||||
|
px=px,
|
||||||
|
nx=nx,
|
||||||
|
xx=xx,
|
||||||
|
keepttl=keepttl,
|
||||||
|
get=get,
|
||||||
|
exat=exat,
|
||||||
|
pxat=pxat,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(self, name: KeyT) -> ResponseT:
|
||||||
|
prefixed_name = self._prefixed(name)
|
||||||
|
return super().get(prefixed_name)
|
||||||
|
|
||||||
|
def delete(self, *names: KeyT) -> ResponseT:
|
||||||
|
prefixed_names = [self._prefixed(name) for name in names]
|
||||||
|
return super().delete(*prefixed_names)
|
||||||
|
|
||||||
|
def exists(self, *names: KeyT) -> ResponseT:
|
||||||
|
prefixed_names = [self._prefixed(name) for name in names]
|
||||||
|
return super().exists(*prefixed_names)
|
||||||
|
|
||||||
|
# def expire(self, name: str, time: int, **kwargs: Any) -> Any:
|
||||||
|
# prefixed_name = self._prefixed(name)
|
||||||
|
# return super().expire(prefixed_name, time, **kwargs)
|
||||||
|
|
||||||
|
# def ttl(self, name: str, **kwargs: Any) -> Any:
|
||||||
|
# prefixed_name = self._prefixed(name)
|
||||||
|
# return super().ttl(prefixed_name, **kwargs)
|
||||||
|
|
||||||
|
# def type(self, name: str, **kwargs: Any) -> Any:
|
||||||
|
# prefixed_name = self._prefixed(name)
|
||||||
|
# return super().type(prefixed_name, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class RedisPool:
|
class RedisPool:
|
||||||
_instance: Optional["RedisPool"] = None
|
_instance: Optional["RedisPool"] = None
|
||||||
_lock: threading.Lock = threading.Lock()
|
_lock: threading.Lock = threading.Lock()
|
||||||
@ -32,8 +142,11 @@ class RedisPool:
|
|||||||
def _init_pool(self) -> None:
|
def _init_pool(self) -> None:
|
||||||
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
|
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
|
||||||
|
|
||||||
def get_client(self) -> Redis:
|
def get_client(self, tenant_id: str | None) -> Redis:
|
||||||
return redis.Redis(connection_pool=self._pool)
|
if tenant_id is not None:
|
||||||
|
return TenantRedis(tenant_id, connection_pool=self._pool)
|
||||||
|
else:
|
||||||
|
return redis.Redis(connection_pool=self._pool)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_pool(
|
def create_pool(
|
||||||
@ -84,8 +197,8 @@ class RedisPool:
|
|||||||
redis_pool = RedisPool()
|
redis_pool = RedisPool()
|
||||||
|
|
||||||
|
|
||||||
def get_redis_client() -> Redis:
|
def get_redis_client(tenant_id: str | None = None) -> Redis:
|
||||||
return redis_pool.get_client()
|
return redis_pool.get_client(tenant_id)
|
||||||
|
|
||||||
|
|
||||||
# # Usage example
|
# # Usage example
|
||||||
|
@ -24,6 +24,7 @@ from danswer.db.connector_credential_pair import (
|
|||||||
)
|
)
|
||||||
from danswer.db.document import get_document_counts_for_cc_pairs
|
from danswer.db.document import get_document_counts_for_cc_pairs
|
||||||
from danswer.db.engine import current_tenant_id
|
from danswer.db.engine import current_tenant_id
|
||||||
|
from danswer.db.engine import get_current_tenant_id
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.db.enums import AccessType
|
from danswer.db.enums import AccessType
|
||||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||||
@ -90,6 +91,7 @@ def get_cc_pair_full_info(
|
|||||||
cc_pair_id: int,
|
cc_pair_id: int,
|
||||||
user: User | None = Depends(current_curator_or_admin_user),
|
user: User | None = Depends(current_curator_or_admin_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
|
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||||
) -> CCPairFullInfo:
|
) -> CCPairFullInfo:
|
||||||
cc_pair = get_connector_credential_pair_from_id(
|
cc_pair = get_connector_credential_pair_from_id(
|
||||||
cc_pair_id, db_session, user, get_editable=False
|
cc_pair_id, db_session, user, get_editable=False
|
||||||
@ -136,6 +138,7 @@ def get_cc_pair_full_info(
|
|||||||
connector_id=cc_pair.connector_id,
|
connector_id=cc_pair.connector_id,
|
||||||
credential_id=cc_pair.credential_id,
|
credential_id=cc_pair.credential_id,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
|
tenant_id=tenant_id,
|
||||||
),
|
),
|
||||||
num_docs_indexed=documents_indexed,
|
num_docs_indexed=documents_indexed,
|
||||||
is_editable_for_current_user=is_editable_for_current_user,
|
is_editable_for_current_user=is_editable_for_current_user,
|
||||||
@ -231,6 +234,7 @@ def prune_cc_pair(
|
|||||||
cc_pair_id: int,
|
cc_pair_id: int,
|
||||||
user: User = Depends(current_curator_or_admin_user),
|
user: User = Depends(current_curator_or_admin_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
|
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||||
) -> StatusResponse[list[int]]:
|
) -> StatusResponse[list[int]]:
|
||||||
"""Triggers pruning on a particular cc_pair immediately"""
|
"""Triggers pruning on a particular cc_pair immediately"""
|
||||||
|
|
||||||
@ -246,7 +250,7 @@ def prune_cc_pair(
|
|||||||
detail="Connection not found for current user's permissions",
|
detail="Connection not found for current user's permissions",
|
||||||
)
|
)
|
||||||
|
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
rcp = RedisConnectorPruning(cc_pair_id)
|
rcp = RedisConnectorPruning(cc_pair_id)
|
||||||
if rcp.is_pruning(db_session, r):
|
if rcp.is_pruning(db_session, r):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -482,10 +482,11 @@ def get_connector_indexing_status(
|
|||||||
get_editable: bool = Query(
|
get_editable: bool = Query(
|
||||||
False, description="If true, return editable document sets"
|
False, description="If true, return editable document sets"
|
||||||
),
|
),
|
||||||
|
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||||
) -> list[ConnectorIndexingStatus]:
|
) -> list[ConnectorIndexingStatus]:
|
||||||
indexing_statuses: list[ConnectorIndexingStatus] = []
|
indexing_statuses: list[ConnectorIndexingStatus] = []
|
||||||
|
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
|
|
||||||
# NOTE: If the connector is deleting behind the scenes,
|
# NOTE: If the connector is deleting behind the scenes,
|
||||||
# accessing cc_pairs can be inconsistent and members like
|
# accessing cc_pairs can be inconsistent and members like
|
||||||
@ -606,6 +607,7 @@ def get_connector_indexing_status(
|
|||||||
connector_id=connector.id,
|
connector_id=connector.id,
|
||||||
credential_id=credential.id,
|
credential_id=credential.id,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
|
tenant_id=tenant_id,
|
||||||
),
|
),
|
||||||
is_deletable=check_deletion_attempt_is_allowed(
|
is_deletable=check_deletion_attempt_is_allowed(
|
||||||
connector_credential_pair=cc_pair,
|
connector_credential_pair=cc_pair,
|
||||||
@ -775,7 +777,7 @@ def connector_run_once(
|
|||||||
"""Used to trigger indexing on a set of cc_pairs associated with a
|
"""Used to trigger indexing on a set of cc_pairs associated with a
|
||||||
single connector."""
|
single connector."""
|
||||||
|
|
||||||
r = get_redis_client()
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
|
|
||||||
connector_id = run_info.connector_id
|
connector_id = run_info.connector_id
|
||||||
specified_credential_ids = run_info.credential_ids
|
specified_credential_ids = run_info.credential_ids
|
||||||
|
Loading…
x
Reference in New Issue
Block a user