fresh indexing feature branch (#2790)

* fresh indexing feature branch

* cherry pick test

* Revert "cherry pick test"

This reverts commit 2a624220687affdda3de347e30f2011136f64bda.

* set multitenant so that vespa fields match when indexing

* cleanup pass

* mypy

* pass through env var to control celery indexing concurrency

* comments on task kickoff and some logging improvements

* use get_session_with_tenant

* comment out all of update.py

* rename to RedisConnectorIndexingFenceData

* first check num_indexing_workers

* refactor RedisConnectorIndexingFenceData

* comment out on_worker_process_init

* fix where num_indexing_workers falls back

* remove extra brace
This commit is contained in:
rkuo-danswer 2024-10-18 15:40:05 -07:00 committed by GitHub
parent 12cbbe6cee
commit 6913efef90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1679 additions and 765 deletions

View File

@ -1,4 +1,5 @@
import logging
import multiprocessing
import time
from datetime import timedelta
from typing import Any
@ -12,6 +13,7 @@ from celery import signals
from celery import Task
from celery.exceptions import WorkerShutdown
from celery.signals import beat_init
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
@ -21,23 +23,32 @@ from sentry_sdk.integrations.celery import CeleryIntegration
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.update import get_all_tenant_ids
from danswer.background.celery.celery_utils import get_all_tenant_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import SqlEngine
from danswer.db.search_settings import get_current_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
@ -62,8 +73,20 @@ celery_app.config_from_object(
) # Load configuration from 'celeryconfig.py'
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
pass
@signals.task_postrun.connect
def celery_task_postrun(
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
@ -80,6 +103,9 @@ def celery_task_postrun(
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
This also does not fire if a worker with acks_late=False crashes (which all of our
long running workers are)
"""
if not task:
return
@ -101,32 +127,38 @@ def celery_task_postrun(
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(document_set_id)
rds = RedisDocumentSet(int(document_set_id))
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(usergroup_id)
rug = RedisUserGroup(int(usergroup_id))
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcd = RedisConnectorDeletion(cc_pair_id)
rcd = RedisConnectorDeletion(int(cc_pair_id))
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcp = RedisConnectorPruning(cc_pair_id)
rcp = RedisConnectorPruning(int(cc_pair_id))
r.srem(rcp.taskset_key, task_id)
return
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
@ -135,6 +167,9 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
# decide some initial startup settings based on the celery worker's hostname
# (set at the command line)
hostname = sender.hostname
@ -144,6 +179,30 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
elif hostname.startswith("heavy"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
elif hostname.startswith("indexing"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# TODO: why is this necessary for the indexer to do?
with get_session_with_tenant(tenant_id) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
else:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
@ -234,6 +293,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
sender.primary_worker_lock = lock
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
@ -270,6 +331,31 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
# """This only runs inside child processes when the worker is in pool=prefork mode.
# This may be technically unnecessary since we're finding prefork pools to be
# unstable and currently aren't planning on using them."""
# logger.info("worker_process_init signal received.")
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
# SqlEngine.get_engine().dispose(close=False)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
@ -318,7 +404,7 @@ def on_setup_logging(
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats celery's worker logger
# reformats the root logger
root_logger = logging.getLogger()
root_handler = logging.StreamHandler() # Set up a handler for the root logger
@ -441,6 +527,7 @@ celery_app.steps["worker"].add(HubPeriodicTask)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
@ -467,9 +554,15 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_prune_task_2",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},

View File

@ -29,8 +29,8 @@ class RedisObjectHelper(ABC):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int):
self._id: int = id
def __init__(self, id: str):
self._id: str = id
@property
def task_id_prefix(self) -> str:
@ -47,7 +47,7 @@ class RedisObjectHelper(ABC):
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> int | None:
def get_id_from_fence_key(key: str) -> str | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
@ -61,15 +61,11 @@ class RedisObjectHelper(ABC):
if len(parts) != 3:
return None
try:
object_id = int(parts[2])
except ValueError:
return None
object_id = parts[2]
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> int | None:
def get_id_from_task_id(task_id: str) -> str | None:
"""
Extracts the object ID from a task ID string.
@ -93,11 +89,7 @@ class RedisObjectHelper(ABC):
if len(parts) != 3:
return None
try:
object_id = int(parts[1])
except ValueError:
return None
object_id = parts[1]
return object_id
@abstractmethod
@ -117,6 +109,9 @@ class RedisDocumentSet(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
@ -128,7 +123,7 @@ class RedisDocumentSet(RedisObjectHelper):
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(self._id, current_only=False)
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
@ -164,6 +159,9 @@ class RedisUserGroup(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
@ -187,7 +185,7 @@ class RedisUserGroup(RedisObjectHelper):
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(self._id)
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
@ -219,13 +217,19 @@ class RedisUserGroup(RedisObjectHelper):
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class differs from the default in that the taskset used spans
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@ -252,7 +256,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
@ -298,6 +302,9 @@ class RedisConnectorDeletion(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
@ -309,7 +316,7 @@ class RedisConnectorDeletion(RedisObjectHelper):
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
@ -386,9 +393,7 @@ class RedisConnectorPruning(RedisObjectHelper):
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
"""id: the cc_pair_id of the connector credential pair"""
super().__init__(id)
super().__init__(str(id))
self.documents_to_prune: set[str] = set()
@property
@ -420,7 +425,7 @@ class RedisConnectorPruning(RedisObjectHelper):
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
if not cc_pair:
return None
@ -463,7 +468,7 @@ class RedisConnectorPruning(RedisObjectHelper):
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=self._id, db_session=db_session
cc_pair_id=int(self._id), db_session=db_session
)
if not cc_pair:
raise ValueError(f"cc_pair_id {self._id} does not exist.")
@ -474,6 +479,66 @@ class RedisConnectorPruning(RedisObjectHelper):
return False
class RedisConnectorIndexing(RedisObjectHelper):
"""Celery will kick off a long running indexing task to crawl the connector and
find any new or updated docs docs, which will each then get a new sync task or be
indexed inline.
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
e.g. "2/5"
"""
PREFIX = "connectorindexing"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
super().__init__(f"{cc_pair_id}/{search_settings_id}")
@property
def generator_lock_key(self) -> str:
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists

View File

@ -3,10 +3,13 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.constants import TENANT_ID_PREFIX
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@ -16,6 +19,7 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.redis.redis_pool import get_redis_client
@ -124,10 +128,30 @@ def celery_is_worker_primary(worker: Any) -> bool:
for the celery worker, which can be done either in celeryconfig.py or on the
command line with '--hostname'."""
hostname = worker.hostname
if hostname.startswith("light"):
return False
if hostname.startswith("primary"):
return True
if hostname.startswith("heavy"):
return False
return False
return True
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id="public") as session:
result = session.execute(
text(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
)
)
tenant_ids = [row[0] for row in result]
valid_tenants = [
tenant
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants

View File

@ -41,6 +41,11 @@ result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PO
# can stall other tasks.
worker_prefetch_multiplier = 4
# Leaving this to the default of True may cause double logging since both our own app
# and celery think they are controlling the logger.
# TODO: Configure celery's logger entirely manually and set this to False
# worker_hijack_root_logger = False
broker_connection_retry_on_startup = True
broker_pool_limit = CELERY_BROKER_POOL_LIMIT

View File

@ -0,0 +1,452 @@
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import cast
from uuid import uuid4
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
)
def check_for_indexing(tenant_id: str | None) -> int | None:
tasks_created = 0
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
for search_settings_instance in search_settings:
rci = RedisConnectorIndexing(
cc_pair.id, search_settings_instance.id
)
if r.exists(rci.fence_key):
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
cc_pair,
search_settings_instance,
False,
db_session,
r,
tenant_id,
)
if attempt_id:
task_logger.info(
f"Indexing queued: cc_pair_id={cc_pair.id} index_attempt_id={attempt_id}"
)
tasks_created += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
return tasks_created
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
secondary_index_building: bool,
db_session: Session,
) -> bool:
"""Checks various global settings and past indexing attempts to determine if
we should try to start indexing the cc pair / search setting combination.
Note that tactical checks such as preventing overlap with a currently running task
are not handled here.
Return True if we should try to index, False if not.
"""
connector = cc_pair.connector
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
return False
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
if connector.refresh_freq is None:
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
return False
return True
def try_creating_indexing_task(
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
"""
LOCK_TIMEOUT = 30
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
# skip if already indexing
if r.exists(rci.fence_key):
return None
# skip indexing if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rci.generator_complete_key)
r.delete(rci.taskset_key)
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
# create the index attempt ... just for tracking purposes
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
db_session=db_session,
)
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_INDEXING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
if not result:
return None
# set this only after all tasks have been added
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=index_attempt_id,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=result.id,
)
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
task_logger.exception("Unexpected exception")
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
def connector_indexing_proxy_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
client = SimpleJobClient()
job = client.submit(
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
global_version.is_ee_version(),
pure=False,
)
if not job:
return
while True:
sleep(10)
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
logger.error(job.exception())
job.release()
break
return
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list
acks_late must be set to False. Otherwise, celery's visibility timeout will
cause any task that runs longer than the timeout to be redispatched by the broker.
There appears to be no good workaround for this, so we need to handle redispatching
manually.
Returns None if the task did not run (possibly due to a conflict).
Otherwise, returns an int >= 0 representing the number of indexed docs.
"""
attempt = None
n_final_progress = 0
r = get_redis_client()
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
lock = r.lock(
rci.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise ValueError(
f"Index attempt not found: index_attempt_id={index_attempt_id}"
)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
)
if not cc_pair:
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
if not cc_pair.connector:
raise ValueError(
f"Connector not found: connector_id={cc_pair.connector_id}"
)
if not cc_pair.credential:
raise ValueError(
f"Credential not found: credential_id={cc_pair.credential_id}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rci.generator_progress_key, amount)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
progress_callback=redis_increment_callback,
)
# get back the total number of indexed docs and return it
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
n_final_progress = int(cast(int, generator_progress_value))
except ValueError:
pass
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
if attempt:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)
r.delete(rci.taskset_key)
r.delete(rci.fence_key)
raise e
finally:
if lock.owned():
lock.release()
return n_final_progress

View File

@ -3,7 +3,6 @@ from datetime import timedelta
from datetime import timezone
from uuid import uuid4
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
@ -15,7 +14,9 @@ from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
@ -30,15 +31,14 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name="check_for_prune_task_2",
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task_2(tenant_id: str | None) -> None:
def check_for_pruning(tenant_id: str | None) -> None:
r = get_redis_client()
lock_beat = r.lock(
@ -54,13 +54,17 @@ def check_for_prune_task_2(tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
tasks_created = ccpair_pruning_generator_task_creation_helper(
cc_pair, db_session, tenant_id, r, lock_beat
lock_beat.reacquire()
if not is_pruning_due(cc_pair, db_session, r):
continue
tasks_created = try_creating_prune_generator_task(
cc_pair, db_session, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"Pruning started: cc_pair_id={cc_pair.id}")
task_logger.info(f"Pruning queued: cc_pair_id={cc_pair.id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@ -72,13 +76,11 @@ def check_for_prune_task_2(tenant_id: str | None) -> None:
lock_beat.release()
def ccpair_pruning_generator_task_creation_helper(
def is_pruning_due(
cc_pair: ConnectorCredentialPair,
db_session: Session,
tenant_id: str | None,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
) -> bool:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
@ -89,24 +91,30 @@ def ccpair_pruning_generator_task_creation_helper(
try_creating_prune_generator_task.
"""
lock_beat.reacquire()
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return None
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
# if never pruned, use the connector time created as the last_pruned time
last_pruned = cc_pair.connector.time_created
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return None
return False
return try_creating_prune_generator_task(cc_pair, db_session, r, tenant_id)
return True
def try_creating_prune_generator_task(
@ -119,50 +127,78 @@ def try_creating_prune_generator_task(
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger prunes immediately.
is used to trigger prunes immediately, e.g. via the web ui.
"""
if not ALLOW_SIMULTANEOUS_PRUNING:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
rcp = RedisConnectorPruning(cc_pair.id)
LOCK_TIMEOUT = 30
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
# we need to serialize starting pruning since it can be triggered either via
# celery beat or manually (API call)
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
timeout=LOCK_TIMEOUT,
)
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
cc_pair_id=cc_pair.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
except Exception:
task_logger.exception("Unexpected exception")
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
@shared_task(
name="connector_pruning_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
)
def connector_pruning_generator_task(
connector_id: int, credential_id: int, tenant_id: str | None
cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str | None
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@ -170,8 +206,22 @@ def connector_pruning_generator_task(
r = get_redis_client()
with get_session_with_tenant(tenant_id) as db_session:
try:
rcp = RedisConnectorPruning(cc_pair_id)
lock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
@ -180,14 +230,13 @@ def connector_pruning_generator_task(
if not cc_pair:
task_logger.warning(
f"ccpair not found for {connector_id} {credential_id}"
f"cc_pair not found for {connector_id} {credential_id}"
)
return
rcp = RedisConnectorPruning(cc_pair.id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rcp.generator_progress_key, amount)
runnable_connector = instantiate_connector(
@ -240,12 +289,13 @@ def connector_pruning_generator_task(
)
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(
f"Failed to run pruning for connector id {connector_id}."
)
except Exception as e:
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e
finally:
if lock.owned():
lock.release()

View File

@ -1,6 +1,9 @@
from datetime import datetime
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import BaseModel
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import task_logger
@ -17,6 +20,13 @@ from danswer.document_index.interfaces import VespaDocumentFields
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int
started: datetime | None
submitted: datetime
celery_task_id: str
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
@ -46,6 +56,8 @@ def document_by_cc_pair_cleanup_task(
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.info(f"document_id={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
@ -111,11 +123,17 @@ def document_by_cc_pair_cleanup_task(
pass
task_logger.info(
f"document_id={document_id} action={action} refcount={count} chunks={chunks_affected}"
f"tenant_id={tenant_id} "
f"document_id={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected}"
)
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
task_logger.info(
f"SoftTimeLimitExceeded exception. tenant_id={tenant_id} doc_id={document_id}"
)
except Exception as e:
task_logger.exception("Unexpected exception")

View File

@ -1,10 +1,15 @@
import traceback
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from typing import cast
import redis
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from sqlalchemy.orm import Session
@ -14,9 +19,11 @@ from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import celery_get_queue_length
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
@ -40,8 +47,13 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
@ -296,11 +308,13 @@ def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id_str is None:
task_logger.warning(f"could not parse document set id from {fence_key}")
return
document_set_id = int(document_set_id_str)
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
@ -315,7 +329,8 @@ def monitor_document_set_taskset(
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
f"Document set sync progress: document_set_id={document_set_id} "
f"remaining={count} initial={initial_count}"
)
if count > 0:
return
@ -345,11 +360,13 @@ def monitor_connector_deletion_taskset(
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
@ -458,13 +475,15 @@ def monitor_ccpair_pruning_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_connector_pruning_taskset: could not parse cc_pair_id from {fence_key}"
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
@ -488,7 +507,7 @@ def monitor_ccpair_pruning_taskset(
if count > 0:
return
mark_ccpair_as_pruned(cc_pair_id, db_session)
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
@ -499,14 +518,127 @@ def monitor_ccpair_pruning_taskset(
r.delete(rcp.fence_key)
def monitor_ccpair_indexing_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
return
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
progress_count = int(cast(int, generator_progress_value))
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress_count} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
except ValueError:
task_logger.error(
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
)
# Read result state BEFORE generator_complete_key to avoid a race condition
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state
generator_complete_value = r.get(rci.generator_complete_key)
if generator_complete_value is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
task_logger.info(
f"Connector indexing aborted: "
f"cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt=index_attempt,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
)
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
return
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
try:
status_value = int(cast(int, generator_complete_value))
status_enum = HTTPStatus(status_value)
except ValueError:
task_logger.error(
f"monitor_ccpair_indexing_taskset: "
f"generator_complete_value=f{generator_complete_value} could not be parsed."
)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
Returns True if the task actually did work, False
"""
r = get_redis_client()
@ -518,11 +650,14 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return
return False
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
n_sync = celery_get_queue_length(
DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery
)
@ -534,7 +669,11 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
)
task_logger.info(
f"Queue lengths: celery={n_celery} sync={n_sync} deletion={n_deletion} pruning={n_pruning}"
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning}"
)
lock_beat.reacquire()
@ -565,6 +704,29 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
rci = RedisConnectorIndexing(
a.connector_credential_pair_id, a.search_settings_id
)
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
if not r.exists(rci.fence_key):
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
@ -577,6 +739,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
if lock_beat.owned():
lock_beat.release()
return True
@shared_task(
name="vespa_metadata_sync_task",

View File

@ -1,5 +1,6 @@
import time
import traceback
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@ -88,12 +89,18 @@ def _get_connector_runner(
def _run_indexing(
db_session: Session, index_attempt: IndexAttempt, tenant_id: str | None
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
progress_callback: Callable[[int], None] | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
TODO: do not change index attempt statuses here ... instead, set signals in redis
and allow the monitor function to clean them up
"""
start_time = time.time()
@ -236,6 +243,8 @@ def _run_indexing(
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
index_attempt_metadata=index_attempt_md,
@ -254,6 +263,9 @@ def _run_indexing(
# be inaccurate
db_session.commit()
if progress_callback:
progress_callback(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
@ -382,6 +394,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
progress_callback: Callable[[int], None] | None = None,
) -> None:
try:
if is_ee:
@ -404,7 +417,7 @@ def run_indexing_entrypoint(
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
_run_indexing(db_session, attempt, tenant_id)
_run_indexing(db_session, attempt, tenant_id, progress_callback)
logger.info(
f"Indexing finished for tenant {tenant_id}: "

File diff suppressed because it is too large Load Diff

View File

@ -42,6 +42,8 @@ POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"
@ -73,6 +75,16 @@ KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
@ -196,14 +208,19 @@ class DanswerCeleryQueues:
VESPA_METADATA_SYNC = "vespa_metadata_sync"
CONNECTOR_DELETION = "connector_deletion"
CONNECTOR_PRUNING = "connector_pruning"
CONNECTOR_INDEXING = "connector_indexing"
class DanswerRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
class DanswerCeleryPriority(int, Enum):

View File

@ -1,4 +1,6 @@
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from sqlalchemy import and_
from sqlalchemy import delete
@ -19,8 +21,6 @@ from danswer.db.models import SearchSettings
from danswer.server.documents.models import ConnectorCredentialPair
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
logger = setup_logger()
@ -66,7 +66,7 @@ def create_index_attempt(
return new_attempt.id
def get_inprogress_index_attempts(
def get_in_progress_index_attempts(
connector_id: int | None,
db_session: Session,
) -> list[IndexAttempt]:
@ -81,13 +81,15 @@ def get_inprogress_index_attempts(
return list(incomplete_attempts.all())
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
def get_all_index_attempts_by_status(
status: IndexingStatus, db_session: Session
) -> list[IndexAttempt]:
"""This eagerly loads the connector and credential so that the db_session can be expired
before running long-living indexing jobs, which causes increasing memory usage.
Results are ordered by time_created (oldest to newest)."""
stmt = select(IndexAttempt)
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
stmt = stmt.where(IndexAttempt.status == status)
stmt = stmt.order_by(IndexAttempt.time_created)
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair).joinedload(
@ -202,6 +204,8 @@ def mark_attempt_failed(
.with_for_update()
).scalar_one()
if not attempt.time_started:
attempt.time_started = datetime.now(timezone.utc)
attempt.status = IndexingStatus.FAILED
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
@ -210,9 +214,6 @@ def mark_attempt_failed(
db_session.rollback()
raise
source = index_attempt.connector_credential_pair.connector.source
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})
def update_docs_indexed(
db_session: Session,

View File

@ -1,5 +1,6 @@
from sqlalchemy.orm import Session
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.vespa.index import VespaIndex
@ -14,7 +15,9 @@ def get_default_document_index(
index both need to be updated, updates are applied to both indices"""
# Currently only supporting Vespa
return VespaIndex(
index_name=primary_index_name, secondary_index_name=secondary_index_name
index_name=primary_index_name,
secondary_index_name=secondary_index_name,
multitenant=MULTI_TENANT,
)

View File

@ -124,9 +124,15 @@ def add_ngrams_to_schema(schema_content: str) -> str:
class VespaIndex(DocumentIndex):
def __init__(self, index_name: str, secondary_index_name: str | None) -> None:
def __init__(
self,
index_name: str,
secondary_index_name: str | None,
multitenant: bool = False,
) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
self.multitenant = multitenant
def ensure_indices_exist(
self,
@ -341,6 +347,7 @@ class VespaIndex(DocumentIndex):
chunks=chunk_batch,
index_name=self.index_name,
http_client=http_client,
multitenant=self.multitenant,
executor=executor,
)

View File

@ -123,6 +123,7 @@ def _index_vespa_chunk(
chunk: DocMetadataAwareIndexChunk,
index_name: str,
http_client: httpx.Client,
multitenant: bool,
) -> None:
json_header = {
"Content-Type": "application/json",
@ -179,8 +180,9 @@ def _index_vespa_chunk(
BOOST: chunk.boost,
}
if chunk.tenant_id:
vespa_document_fields[TENANT_ID] = chunk.tenant_id
if multitenant:
if chunk.tenant_id:
vespa_document_fields[TENANT_ID] = chunk.tenant_id
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"
logger.debug(f'Indexing to URL "{vespa_url}"')
@ -200,6 +202,7 @@ def batch_index_vespa_chunks(
chunks: list[DocMetadataAwareIndexChunk],
index_name: str,
http_client: httpx.Client,
multitenant: bool,
executor: concurrent.futures.ThreadPoolExecutor | None = None,
) -> None:
external_executor = True
@ -210,7 +213,9 @@ def batch_index_vespa_chunks(
try:
chunk_index_future = {
executor.submit(_index_vespa_chunk, chunk, index_name, http_client): chunk
executor.submit(
_index_vespa_chunk, chunk, index_name, http_client, multitenant
): chunk
for chunk in chunks
}
for future in concurrent.futures.as_completed(chunk_index_future):

View File

@ -15,7 +15,9 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
@ -63,19 +65,22 @@ from danswer.db.credentials import delete_google_drive_service_account_credentia
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_latest_index_attempts
from danswer.db.index_attempt import get_latest_index_attempts_by_status
from danswer.db.models import IndexingStatus
from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.models import UserRole
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@ -480,6 +485,8 @@ def get_connector_indexing_status(
) -> list[ConnectorIndexingStatus]:
indexing_statuses: list[ConnectorIndexingStatus] = []
r = get_redis_client()
# NOTE: If the connector is deleting behind the scenes,
# accessing cc_pairs can be inconsistent and members like
# connector or credential may be None.
@ -531,6 +538,12 @@ def get_connector_indexing_status(
relationship.user_group_id
)
search_settings: SearchSettings | None = None
if not secondary_index:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_secondary_search_settings(db_session)
for cc_pair in cc_pairs:
# TODO remove this to enable ingestion API
if cc_pair.name == "DefaultCCPair":
@ -542,6 +555,12 @@ def get_connector_indexing_status(
# This may happen if background deletion is happening
continue
in_progress = False
if search_settings:
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
if r.exists(rci.fence_key):
in_progress = True
latest_index_attempt = cc_pair_to_latest_index_attempt.get(
(connector.id, credential.id)
)
@ -595,6 +614,7 @@ def get_connector_indexing_status(
allow_scheduled=True,
)
is None,
in_progress=in_progress,
)
)
@ -750,7 +770,13 @@ def connector_run_once(
run_info: RunConnectorRequest,
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Used to trigger indexing on a set of cc_pairs associated with a
single connector."""
r = get_redis_client()
connector_id = run_info.connector_id
specified_credential_ids = run_info.credential_ids
@ -804,16 +830,24 @@ def connector_run_once(
if credential_id not in skipped_credentials
]
index_attempt_ids = [
create_index_attempt(
connector_credential_pair_id=connector_credential_pair.id,
search_settings_id=search_settings.id,
from_beginning=run_info.from_beginning,
db_session=db_session,
)
for connector_credential_pair in connector_credential_pairs
if connector_credential_pair is not None
]
index_attempt_ids = []
for cc_pair in connector_credential_pairs:
if cc_pair is not None:
attempt_id = try_creating_indexing_task(
cc_pair,
search_settings,
run_info.from_beginning,
db_session,
r,
tenant_id,
)
if attempt_id:
logger.info(
f"try_creating_indexing_task succeeded: cc_pair={cc_pair.id} attempt_id={attempt_id}"
)
index_attempt_ids.append(attempt_id)
else:
logger.info(f"try_creating_indexing_task failed: cc_pair={cc_pair.id}")
if not index_attempt_ids:
raise HTTPException(

View File

@ -307,6 +307,10 @@ class ConnectorIndexingStatus(BaseModel):
deletion_attempt: DeletionAttemptSnapshot | None
is_deletable: bool
# index attempt in db can be marked successful while celery/redis
# is stil running/cleaning up
in_progress: bool
class ConnectorCredentialPairIdentifier(BaseModel):
connector_id: int

View File

@ -182,3 +182,24 @@ def setup_logger(
logger.notice = lambda msg, *args, **kwargs: logger.log(logging.getLevelName("NOTICE"), msg, *args, **kwargs) # type: ignore
return DanswerLoggingAdapter(logger, extra=extra)
def print_loggers() -> None:
root_logger = logging.getLogger()
loggers: list[logging.Logger | logging.PlaceHolder] = [root_logger]
loggers.extend(logging.Logger.manager.loggerDict.values())
for logger in loggers:
if isinstance(logger, logging.PlaceHolder):
# Skip placeholders that aren't actual loggers
continue
print(f"Logger: '{logger.name}' (Level: {logging.getLevelName(logger.level)})")
if logger.handlers:
for handler in logger.handlers:
print(f" Handler: {handler}")
else:
print(" No handlers")
print(f" Propagate: {logger.propagate}")
print()

View File

@ -1,8 +1,8 @@
from datetime import timedelta
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_utils import get_all_tenant_ids
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.update import get_all_tenant_ids
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.chat import delete_chat_sessions_older_than

View File

@ -16,11 +16,17 @@ logger = setup_logger()
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
"""This function is likely to move in the worker refactor happening next."""
fence_key = key_bytes.decode("utf-8")
usergroup_id = RedisUserGroup.get_id_from_fence_key(fence_key)
if not usergroup_id:
usergroup_id_str = RedisUserGroup.get_id_from_fence_key(fence_key)
if not usergroup_id_str:
task_logger.warning(f"Could not parse usergroup id from {fence_key}")
return
try:
usergroup_id = int(usergroup_id_str)
except ValueError:
task_logger.exception(f"usergroup_id ({usergroup_id_str}) is not an integer!")
raise
rug = RedisUserGroup(usergroup_id)
fence_value = r.get(rug.fence_key)
if fence_value is None:

View File

@ -1,5 +1,4 @@
import argparse
import os
import subprocess
import threading
@ -17,7 +16,7 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
break
def run_jobs(exclude_indexing: bool) -> None:
def run_jobs() -> None:
# command setup
cmd_worker_primary = [
"celery",
@ -26,6 +25,7 @@ def run_jobs(exclude_indexing: bool) -> None:
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"-n",
"primary@%n",
@ -40,6 +40,7 @@ def run_jobs(exclude_indexing: bool) -> None:
"worker",
"--pool=threads",
"--concurrency=16",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"-n",
"light@%n",
@ -54,6 +55,7 @@ def run_jobs(exclude_indexing: bool) -> None:
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"-n",
"heavy@%n",
@ -61,6 +63,20 @@ def run_jobs(exclude_indexing: bool) -> None:
"connector_pruning",
]
cmd_worker_indexing = [
"celery",
"-A",
"ee.danswer.background.celery.celery_app",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"-n",
"indexing@%n",
"--queues=connector_indexing",
]
cmd_beat = [
"celery",
"-A",
@ -82,6 +98,10 @@ def run_jobs(exclude_indexing: bool) -> None:
cmd_worker_heavy, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_indexing_process = subprocess.Popen(
cmd_worker_indexing, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
beat_process = subprocess.Popen(
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
@ -96,44 +116,26 @@ def run_jobs(exclude_indexing: bool) -> None:
worker_heavy_thread = threading.Thread(
target=monitor_process, args=("HEAVY", worker_heavy_process)
)
worker_indexing_thread = threading.Thread(
target=monitor_process, args=("INDEX", worker_indexing_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
worker_primary_thread.start()
worker_light_thread.start()
worker_heavy_thread.start()
worker_indexing_thread.start()
beat_thread.start()
if not exclude_indexing:
update_env = os.environ.copy()
update_env["PYTHONPATH"] = "."
cmd_indexing = ["python", "danswer/background/update.py"]
indexing_process = subprocess.Popen(
cmd_indexing,
env=update_env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
indexing_thread = threading.Thread(
target=monitor_process, args=("INDEXING", indexing_process)
)
indexing_thread.start()
indexing_thread.join()
worker_primary_thread.join()
worker_light_thread.join()
worker_heavy_thread.join()
worker_indexing_thread.join()
beat_thread.join()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run background jobs.")
parser.add_argument(
"--no-indexing", action="store_true", help="Do not run indexing process"
)
args = parser.parse_args()
run_jobs(args.no_indexing)
run_jobs()

View File

@ -3,17 +3,6 @@ nodaemon=true
user=root
logfile=/var/log/supervisord.log
# Indexing is the heaviest job, also requires some CPU intensive steps
# Cannot place this in Celery for now because Celery must run as a single process (see note below)
# Indexing uses multi-processing to speed things up
[program:document_indexing]
environment=CURRENT_PROCESS_IS_AN_INDEXING_JOB=true
command=python danswer/background/update.py
stdout_logfile=/var/log/document_indexing.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
# Background jobs that must be run async due to long time to completion
# NOTE: due to an issue with Celery + SQLAlchemy
# (https://github.com/celery/celery/issues/7007#issuecomment-1740139367)
@ -73,6 +62,21 @@ autorestart=true
startsecs=10
stopasgroup=true
[program:celery_worker_indexing]
command=bash -c "celery -A danswer.background.celery.celery_run:celery_app worker \
--pool=threads \
--concurrency=${CELERY_WORKER_INDEXING_CONCURRENCY:-${NUM_INDEXING_WORKERS:-1}} \
--prefetch-multiplier=1 \
--loglevel=INFO \
--hostname=indexing@%%n \
-Q connector_indexing"
stdout_logfile=/var/log/celery_worker_indexing.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
# Job scheduler for periodic tasks
[program:celery_beat]
command=celery -A danswer.background.celery.celery_run:celery_app beat
@ -103,7 +107,7 @@ command=tail -qF
/var/log/celery_worker_primary.log
/var/log/celery_worker_light.log
/var/log/celery_worker_heavy.log
/var/log/document_indexing.log
/var/log/celery_worker_indexing.log
/var/log/slack_bot.log
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes = 0 # must be set to 0 when stdout_logfile=/dev/stdout

View File

@ -239,24 +239,24 @@ class CCPairManager:
if fetched_cc_pair.cc_pair_id != cc_pair.id:
continue
if fetched_cc_pair.in_progress:
continue
if (
fetched_cc_pair.last_success
and fetched_cc_pair.last_success > after
):
print(f"cc_pair {cc_pair.id} indexing complete.")
print(f"CC pair {cc_pair.id} indexing complete.")
return
else:
print("cc_pair found but not finished:")
# print(fetched_cc_pair.__dict__)
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"CC pair indexing was not completed within {timeout} seconds"
f"CC pair {cc_pair.id} indexing was not completed within {timeout} seconds"
)
print(
f"Waiting for CC indexing to complete. elapsed={elapsed:.2f} timeout={timeout}"
f"CC pair {cc_pair.id} indexing to complete. elapsed={elapsed:.2f} timeout={timeout}"
)
time.sleep(5)

View File

@ -135,6 +135,7 @@ class DocumentSetManager:
all_up_to_date = all(doc_set.is_up_to_date for doc_set in doc_sets)
if all_up_to_date:
print("Document sets synced successfully.")
break
if time.time() - start > MAX_DELAY:

View File

@ -146,6 +146,7 @@ class UserGroupManager:
if user_group.id in check_ids
]
if all(ug.is_up_to_date for ug in user_groups):
print("User groups synced successfully.")
return
if time.time() - start > MAX_DELAY:

View File

@ -174,8 +174,9 @@ services:
- GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-}
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
- GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-}
# Celery Configs (defaults are set in the supervisord.conf file, prefer doing that to have on source
# of defaults)
# Celery Configs (defaults are set in the supervisord.conf file.
# prefer doing that to have one source of defaults)
- CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-}
- CELERY_WORKER_LIGHT_CONCURRENCY=${CELERY_WORKER_LIGHT_CONCURRENCY:-}
- CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER=${CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER:-}

View File

@ -186,8 +186,9 @@ services:
# Log all of Danswer prompts and interactions with the LLM
- LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-}
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
# Celery Configs (defaults are set in the supervisord.conf file, prefer doing that to have on source
# of defaults)
# Celery Configs (defaults are set in the supervisord.conf file.
# prefer doing that to have one source of defaults)
- CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-}
- CELERY_WORKER_LIGHT_CONCURRENCY=${CELERY_WORKER_LIGHT_CONCURRENCY:-}
- CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER=${CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER:-}

View File

@ -439,6 +439,7 @@ export function CCPairIndexingStatusTable({
error_msg: "",
deletion_attempt: null,
is_deletable: true,
in_progress: false,
groups: [], // Add this line
}}
isEditable={false}

View File

@ -111,6 +111,7 @@ export interface ConnectorIndexingStatus<
latest_index_attempt: IndexAttemptSnapshot | null;
deletion_attempt: DeletionAttemptSnapshot | null;
is_deletable: boolean;
in_progress: boolean;
}
export interface CCPairBasicInfo {