mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-03 19:20:53 +02:00
Feature/celery multi (#2470)
* first cut at redis * some new helper functions for the db * ignore kombu tables in alembic migrations (used by celery) * multiline commands for readability, add vespa_metadata_sync queue to worker * typo fix * fix returning tuple fields * add constants * fix _get_access_for_document * docstrings! * fix double function declaration and typing * fix type hinting * add a global redis pool * Add get_document function * use task_logger in various celery tasks * add celeryconfig.py to simplify configuration. Will be used in a subsequent commit * Add celery redis helper. used in a subsequent PR * kombu warning getting spammy since celery is not self managing its queue in Postgres any more * add last_modified and last_synced to documents * fix task naming convention * use celeryconfig.py * the big one. adds queues and tasks, updates functions to use the queues with priorities, etc * change vespa index log line to debug * mypy fixes * update alembic migration * fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call * mypy * switch to monotonic time * fix startup dependencies on redis * rebase alembic migration * kombu cleanup - fail silently * mypy * add redis_host environment override * update REDIS_HOST env var in docker-compose.dev.yml * update the rest of the docker files * in flight * harden indexing-status endpoint against db changes happening in the background. Needs further improvement but OK for now. * allow no task syncs to run because we create certain objects with no entries but initially marked as out of date * add back writing to vespa on indexing * actually working connector deletion * update contributing guide * backporting fixes from background_deletion * renaming cache to cache_volume * add redis password to various deployments * try setting up pr testing for helm * fix indent * hopefully this release version actually exists * fix command line option to --chart-dirs * fetch-depth 0 * edit values.yaml * try setting ct working directory * bypass testing only on change for now * move files and lint them * update helm testing * some issues suggest using --config works * add vespa repo * add postgresql repo * increase timeout * try amd64 runner * fix redis password reference * add comment to helm chart testing workflow * rename helm testing workflow to disable it * adding clarifying comments * address code review * missed a file * remove commented warning ... just not needed * fix imports * refactor to use update_single * mypy fixes * add vespa test * multiple celery workers * update logs as well and set prefetch multipliers appropriate to the worker intent * add db refresh to connector deletion * add some preliminary locking * organize tasks into separate files * celery auto associates tasks created inside another task, which bloats the result metadata considerably. trail=False prevents this. * code review fixes * move monitor_usergroup_taskset to ee, improve logging * add multi workers to dev_run_background_jobs.py * update supervisord with some recommended settings for celery * name celery workers and shorten dev script prefixing * add configurable sql alchemy engine settings on startup (needed for various intents like API server, different celery workers and tasks, etc) * fix comments * autoscale sqlalchemy pool size to celery concurrency (allow override later?) * supervisord needs the percent symbols escaped * use name as primary check, some minor refactoring and type hinting too. * addressing code review * fix import * fix prune_documents_task references --------- Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -160,3 +161,30 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se
|
|||||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||||
|
|
||||||
return all_connector_doc_ids
|
return all_connector_doc_ids
|
||||||
|
|
||||||
|
|
||||||
|
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||||
|
"""Checks to see if we're listening to the named queue"""
|
||||||
|
|
||||||
|
# how to get a list of queues this worker is listening to
|
||||||
|
# https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime
|
||||||
|
queue_names = list(worker.app.amqp.queues.consume_from.keys())
|
||||||
|
for queue_name in queue_names:
|
||||||
|
if queue_name == name:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def celery_is_worker_primary(worker: Any) -> bool:
|
||||||
|
"""There are multiple approaches that could be taken, but the way we do it is to
|
||||||
|
check the hostname set for the celery worker, either in celeryconfig.py or on the
|
||||||
|
command line."""
|
||||||
|
hostname = worker.hostname
|
||||||
|
if hostname.startswith("light"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if hostname.startswith("heavy"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
@ -0,0 +1,133 @@
|
|||||||
|
import redis
|
||||||
|
from celery import shared_task
|
||||||
|
from celery.exceptions import SoftTimeLimitExceeded
|
||||||
|
from celery.utils.log import get_task_logger
|
||||||
|
from redis import Redis
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||||
|
|
||||||
|
from danswer.background.celery.celery_app import celery_app
|
||||||
|
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||||
|
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||||
|
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||||
|
from danswer.configs.constants import DanswerRedisLocks
|
||||||
|
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||||
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||||
|
from danswer.db.enums import IndexingStatus
|
||||||
|
from danswer.db.index_attempt import get_last_attempt
|
||||||
|
from danswer.db.models import ConnectorCredentialPair
|
||||||
|
from danswer.db.search_settings import get_current_search_settings
|
||||||
|
from danswer.redis.redis_pool import RedisPool
|
||||||
|
|
||||||
|
redis_pool = RedisPool()
|
||||||
|
|
||||||
|
# use this within celery tasks to get celery task specific logging
|
||||||
|
task_logger = get_task_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(
|
||||||
|
name="check_for_connector_deletion_task",
|
||||||
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
|
trail=False,
|
||||||
|
)
|
||||||
|
def check_for_connector_deletion_task() -> None:
|
||||||
|
r = redis_pool.get_client()
|
||||||
|
|
||||||
|
lock_beat = r.lock(
|
||||||
|
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||||
|
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# these tasks should never overlap
|
||||||
|
if not lock_beat.acquire(blocking=False):
|
||||||
|
return
|
||||||
|
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
cc_pairs = get_connector_credential_pairs(db_session)
|
||||||
|
for cc_pair in cc_pairs:
|
||||||
|
try_generate_document_cc_pair_cleanup_tasks(
|
||||||
|
cc_pair, db_session, r, lock_beat
|
||||||
|
)
|
||||||
|
except SoftTimeLimitExceeded:
|
||||||
|
task_logger.info(
|
||||||
|
"Soft time limit exceeded, task is being terminated gracefully."
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
task_logger.exception("Unexpected exception")
|
||||||
|
finally:
|
||||||
|
if lock_beat.owned():
|
||||||
|
lock_beat.release()
|
||||||
|
|
||||||
|
|
||||||
|
def try_generate_document_cc_pair_cleanup_tasks(
|
||||||
|
cc_pair: ConnectorCredentialPair,
|
||||||
|
db_session: Session,
|
||||||
|
r: Redis,
|
||||||
|
lock_beat: redis.lock.Lock,
|
||||||
|
) -> int | None:
|
||||||
|
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||||
|
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||||
|
Returns None if no syncing is required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
lock_beat.reacquire()
|
||||||
|
|
||||||
|
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||||
|
|
||||||
|
# don't generate sync tasks if tasks are still pending
|
||||||
|
if r.exists(rcd.fence_key):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# we need to refresh the state of the object inside the fence
|
||||||
|
# to avoid a race condition with db.commit/fence deletion
|
||||||
|
# at the end of this taskset
|
||||||
|
try:
|
||||||
|
db_session.refresh(cc_pair)
|
||||||
|
except ObjectDeletedError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||||
|
return None
|
||||||
|
|
||||||
|
search_settings = get_current_search_settings(db_session)
|
||||||
|
|
||||||
|
last_indexing = get_last_attempt(
|
||||||
|
connector_id=cc_pair.connector_id,
|
||||||
|
credential_id=cc_pair.credential_id,
|
||||||
|
search_settings_id=search_settings.id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
if last_indexing:
|
||||||
|
if (
|
||||||
|
last_indexing.status == IndexingStatus.IN_PROGRESS
|
||||||
|
or last_indexing.status == IndexingStatus.NOT_STARTED
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# add tasks to celery and build up the task set to monitor in redis
|
||||||
|
r.delete(rcd.taskset_key)
|
||||||
|
|
||||||
|
# Add all documents that need to be updated into the queue
|
||||||
|
task_logger.info(
|
||||||
|
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||||
|
)
|
||||||
|
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||||
|
if tasks_generated is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||||
|
# It's possible for sets/groups to be generated initially with no entries
|
||||||
|
# and they still need to be marked as up to date.
|
||||||
|
# if tasks_generated == 0:
|
||||||
|
# return 0
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||||
|
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# set this only after all tasks have been added
|
||||||
|
r.set(rcd.fence_key, tasks_generated)
|
||||||
|
return tasks_generated
|
140
backend/danswer/background/celery/tasks/periodic/tasks.py
Normal file
140
backend/danswer/background/celery/tasks/periodic/tasks.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
#####
|
||||||
|
# Periodic Tasks
|
||||||
|
#####
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from celery import shared_task
|
||||||
|
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||||
|
from celery.exceptions import TaskRevokedError
|
||||||
|
from celery.utils.log import get_task_logger
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||||
|
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||||
|
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
|
||||||
|
|
||||||
|
# use this within celery tasks to get celery task specific logging
|
||||||
|
task_logger = get_task_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(
|
||||||
|
name="kombu_message_cleanup_task",
|
||||||
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
|
bind=True,
|
||||||
|
base=AbortableTask,
|
||||||
|
)
|
||||||
|
def kombu_message_cleanup_task(self: Any) -> int:
|
||||||
|
"""Runs periodically to clean up the kombu_message table"""
|
||||||
|
|
||||||
|
# we will select messages older than this amount to clean up
|
||||||
|
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
|
||||||
|
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
|
||||||
|
|
||||||
|
ctx = {}
|
||||||
|
ctx["last_processed_id"] = 0
|
||||||
|
ctx["deleted"] = 0
|
||||||
|
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||||
|
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
# Exit the task if we can't take the advisory lock
|
||||||
|
result = db_session.execute(
|
||||||
|
text("SELECT pg_try_advisory_lock(:id)"),
|
||||||
|
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
|
||||||
|
).scalar()
|
||||||
|
if not result:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if self.is_aborted():
|
||||||
|
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
|
||||||
|
|
||||||
|
b = kombu_message_cleanup_task_helper(ctx, db_session)
|
||||||
|
if not b:
|
||||||
|
break
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
if ctx["deleted"] > 0:
|
||||||
|
task_logger.info(
|
||||||
|
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
|
||||||
|
)
|
||||||
|
|
||||||
|
return ctx["deleted"]
|
||||||
|
|
||||||
|
|
||||||
|
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||||
|
"""
|
||||||
|
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
|
||||||
|
|
||||||
|
This function retrieves messages from the `kombu_message` table that are no longer visible and
|
||||||
|
older than a specified interval. It checks if the corresponding task_id exists in the
|
||||||
|
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (dict): A context dictionary containing configuration parameters such as:
|
||||||
|
- 'cleanup_age' (int): The age in days after which messages are considered old.
|
||||||
|
- 'page_limit' (int): The maximum number of messages to process in one batch.
|
||||||
|
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
|
||||||
|
- 'deleted' (int): A counter to track the number of deleted messages.
|
||||||
|
db_session (Session): The SQLAlchemy database session for executing queries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Returns True if there are more rows to process, False if not.
|
||||||
|
"""
|
||||||
|
|
||||||
|
inspector = inspect(db_session.bind)
|
||||||
|
if not inspector:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
|
||||||
|
# We can fail silently.
|
||||||
|
if not inspector.has_table("kombu_message"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
query = text(
|
||||||
|
"""
|
||||||
|
SELECT id, timestamp, payload
|
||||||
|
FROM kombu_message WHERE visible = 'false'
|
||||||
|
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
|
||||||
|
AND id > :last_processed_id
|
||||||
|
ORDER BY id
|
||||||
|
LIMIT :page_limit
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
kombu_messages = db_session.execute(
|
||||||
|
query,
|
||||||
|
{
|
||||||
|
"interval_days": f"{ctx['cleanup_age']} days",
|
||||||
|
"page_limit": ctx["page_limit"],
|
||||||
|
"last_processed_id": ctx["last_processed_id"],
|
||||||
|
},
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
if len(kombu_messages) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for msg in kombu_messages:
|
||||||
|
payload = json.loads(msg[2])
|
||||||
|
task_id = payload["headers"]["id"]
|
||||||
|
|
||||||
|
# Check if task_id exists in celery_taskmeta
|
||||||
|
task_exists = db_session.execute(
|
||||||
|
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
|
||||||
|
{"task_id": task_id},
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
# If task_id does not exist, delete the message
|
||||||
|
if not task_exists:
|
||||||
|
result = db_session.execute(
|
||||||
|
text("DELETE FROM kombu_message WHERE id = :message_id"),
|
||||||
|
{"message_id": msg[0]},
|
||||||
|
)
|
||||||
|
if result.rowcount > 0: # type: ignore
|
||||||
|
ctx["deleted"] += 1
|
||||||
|
|
||||||
|
ctx["last_processed_id"] = msg[0]
|
||||||
|
|
||||||
|
return True
|
120
backend/danswer/background/celery/tasks/pruning/tasks.py
Normal file
120
backend/danswer/background/celery/tasks/pruning/tasks.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from celery import shared_task
|
||||||
|
from celery.utils.log import get_task_logger
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.background.celery.celery_app import celery_app
|
||||||
|
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||||
|
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||||
|
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
|
||||||
|
from danswer.background.task_utils import build_celery_task_wrapper
|
||||||
|
from danswer.background.task_utils import name_cc_prune_task
|
||||||
|
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||||
|
from danswer.connectors.factory import instantiate_connector
|
||||||
|
from danswer.connectors.models import InputType
|
||||||
|
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||||
|
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||||
|
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||||
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.document_index.document_index_utils import get_both_index_names
|
||||||
|
from danswer.document_index.factory import get_default_document_index
|
||||||
|
|
||||||
|
|
||||||
|
# use this within celery tasks to get celery task specific logging
|
||||||
|
task_logger = get_task_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(
|
||||||
|
name="check_for_prune_task",
|
||||||
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
|
)
|
||||||
|
def check_for_prune_task() -> None:
|
||||||
|
"""Runs periodically to check if any prune tasks should be run and adds them
|
||||||
|
to the queue"""
|
||||||
|
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||||
|
|
||||||
|
for cc_pair in all_cc_pairs:
|
||||||
|
if should_prune_cc_pair(
|
||||||
|
connector=cc_pair.connector,
|
||||||
|
credential=cc_pair.credential,
|
||||||
|
db_session=db_session,
|
||||||
|
):
|
||||||
|
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
|
||||||
|
|
||||||
|
prune_documents_task.apply_async(
|
||||||
|
kwargs=dict(
|
||||||
|
connector_id=cc_pair.connector.id,
|
||||||
|
credential_id=cc_pair.credential.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@build_celery_task_wrapper(name_cc_prune_task)
|
||||||
|
@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT)
|
||||||
|
def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||||
|
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||||
|
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||||
|
from the most recently pulled document ID list"""
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
try:
|
||||||
|
cc_pair = get_connector_credential_pair(
|
||||||
|
db_session=db_session,
|
||||||
|
connector_id=connector_id,
|
||||||
|
credential_id=credential_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not cc_pair:
|
||||||
|
task_logger.warning(
|
||||||
|
f"ccpair not found for {connector_id} {credential_id}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
runnable_connector = instantiate_connector(
|
||||||
|
db_session,
|
||||||
|
cc_pair.connector.source,
|
||||||
|
InputType.PRUNE,
|
||||||
|
cc_pair.connector.connector_specific_config,
|
||||||
|
cc_pair.credential,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||||
|
runnable_connector
|
||||||
|
)
|
||||||
|
|
||||||
|
all_indexed_document_ids = {
|
||||||
|
doc.id
|
||||||
|
for doc in get_documents_for_connector_credential_pair(
|
||||||
|
db_session=db_session,
|
||||||
|
connector_id=connector_id,
|
||||||
|
credential_id=credential_id,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||||
|
|
||||||
|
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||||
|
document_index = get_default_document_index(
|
||||||
|
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(doc_ids_to_remove) == 0:
|
||||||
|
task_logger.info(
|
||||||
|
f"No docs to prune from {cc_pair.connector.source} connector"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
|
||||||
|
)
|
||||||
|
delete_connector_credential_pair_batch(
|
||||||
|
document_ids=doc_ids_to_remove,
|
||||||
|
connector_id=connector_id,
|
||||||
|
credential_id=credential_id,
|
||||||
|
document_index=document_index,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
task_logger.exception(
|
||||||
|
f"Failed to run pruning for connector id {connector_id}."
|
||||||
|
)
|
||||||
|
raise e
|
526
backend/danswer/background/celery/tasks/vespa/tasks.py
Normal file
526
backend/danswer/background/celery/tasks/vespa/tasks.py
Normal file
@ -0,0 +1,526 @@
|
|||||||
|
import traceback
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import redis
|
||||||
|
from celery import shared_task
|
||||||
|
from celery import Task
|
||||||
|
from celery.exceptions import SoftTimeLimitExceeded
|
||||||
|
from celery.utils.log import get_task_logger
|
||||||
|
from redis import Redis
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.access.access import get_access_for_document
|
||||||
|
from danswer.background.celery.celery_app import celery_app
|
||||||
|
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||||
|
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||||
|
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||||
|
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||||
|
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||||
|
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||||
|
from danswer.configs.constants import DanswerRedisLocks
|
||||||
|
from danswer.db.connector import fetch_connector_by_id
|
||||||
|
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||||
|
from danswer.db.connector_credential_pair import (
|
||||||
|
delete_connector_credential_pair__no_commit,
|
||||||
|
)
|
||||||
|
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||||
|
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||||
|
from danswer.db.document import count_documents_by_needs_sync
|
||||||
|
from danswer.db.document import get_document
|
||||||
|
from danswer.db.document import mark_document_as_synced
|
||||||
|
from danswer.db.document_set import delete_document_set
|
||||||
|
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||||
|
from danswer.db.document_set import fetch_document_sets
|
||||||
|
from danswer.db.document_set import fetch_document_sets_for_document
|
||||||
|
from danswer.db.document_set import get_document_set_by_id
|
||||||
|
from danswer.db.document_set import mark_document_set_as_synced
|
||||||
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.db.index_attempt import delete_index_attempts
|
||||||
|
from danswer.db.models import DocumentSet
|
||||||
|
from danswer.db.models import UserGroup
|
||||||
|
from danswer.document_index.document_index_utils import get_both_index_names
|
||||||
|
from danswer.document_index.factory import get_default_document_index
|
||||||
|
from danswer.document_index.interfaces import UpdateRequest
|
||||||
|
from danswer.redis.redis_pool import RedisPool
|
||||||
|
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||||
|
from danswer.utils.variable_functionality import (
|
||||||
|
fetch_versioned_implementation_with_fallback,
|
||||||
|
)
|
||||||
|
from danswer.utils.variable_functionality import noop_fallback
|
||||||
|
|
||||||
|
redis_pool = RedisPool()
|
||||||
|
|
||||||
|
# use this within celery tasks to get celery task specific logging
|
||||||
|
task_logger = get_task_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# celery auto associates tasks created inside another task,
|
||||||
|
# which bloats the result metadata considerably. trail=False prevents this.
|
||||||
|
@shared_task(
|
||||||
|
name="check_for_vespa_sync_task",
|
||||||
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
|
trail=False,
|
||||||
|
)
|
||||||
|
def check_for_vespa_sync_task() -> None:
|
||||||
|
"""Runs periodically to check if any document needs syncing.
|
||||||
|
Generates sets of tasks for Celery if syncing is needed."""
|
||||||
|
|
||||||
|
r = redis_pool.get_client()
|
||||||
|
|
||||||
|
lock_beat = r.lock(
|
||||||
|
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||||
|
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# these tasks should never overlap
|
||||||
|
if not lock_beat.acquire(blocking=False):
|
||||||
|
return
|
||||||
|
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
|
||||||
|
|
||||||
|
# check if any document sets are not synced
|
||||||
|
document_set_info = fetch_document_sets(
|
||||||
|
user_id=None, db_session=db_session, include_outdated=True
|
||||||
|
)
|
||||||
|
for document_set, _ in document_set_info:
|
||||||
|
try_generate_document_set_sync_tasks(
|
||||||
|
document_set, db_session, r, lock_beat
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if any user groups are not synced
|
||||||
|
try:
|
||||||
|
fetch_user_groups = fetch_versioned_implementation(
|
||||||
|
"danswer.db.user_group", "fetch_user_groups"
|
||||||
|
)
|
||||||
|
|
||||||
|
user_groups = fetch_user_groups(
|
||||||
|
db_session=db_session, only_up_to_date=False
|
||||||
|
)
|
||||||
|
for usergroup in user_groups:
|
||||||
|
try_generate_user_group_sync_tasks(
|
||||||
|
usergroup, db_session, r, lock_beat
|
||||||
|
)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# Always exceptions on the MIT version, which is expected
|
||||||
|
pass
|
||||||
|
except SoftTimeLimitExceeded:
|
||||||
|
task_logger.info(
|
||||||
|
"Soft time limit exceeded, task is being terminated gracefully."
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
task_logger.exception("Unexpected exception")
|
||||||
|
finally:
|
||||||
|
if lock_beat.owned():
|
||||||
|
lock_beat.release()
|
||||||
|
|
||||||
|
|
||||||
|
def try_generate_stale_document_sync_tasks(
|
||||||
|
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||||
|
) -> int | None:
|
||||||
|
# the fence is up, do nothing
|
||||||
|
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||||
|
return None
|
||||||
|
|
||||||
|
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
|
||||||
|
|
||||||
|
# add tasks to celery and build up the task set to monitor in redis
|
||||||
|
stale_doc_count = count_documents_by_needs_sync(db_session)
|
||||||
|
if stale_doc_count == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||||
|
)
|
||||||
|
|
||||||
|
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
|
||||||
|
|
||||||
|
# rkuo: we could technically sync all stale docs in one big pass.
|
||||||
|
# but I feel it's more understandable to group the docs by cc_pair
|
||||||
|
total_tasks_generated = 0
|
||||||
|
cc_pairs = get_connector_credential_pairs(db_session)
|
||||||
|
for cc_pair in cc_pairs:
|
||||||
|
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||||
|
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||||
|
|
||||||
|
if tasks_generated is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if tasks_generated == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
f"RedisConnector.generate_tasks finished for single cc_pair. "
|
||||||
|
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||||
|
)
|
||||||
|
|
||||||
|
total_tasks_generated += tasks_generated
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||||
|
)
|
||||||
|
|
||||||
|
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
|
||||||
|
return total_tasks_generated
|
||||||
|
|
||||||
|
|
||||||
|
def try_generate_document_set_sync_tasks(
|
||||||
|
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||||
|
) -> int | None:
|
||||||
|
lock_beat.reacquire()
|
||||||
|
|
||||||
|
rds = RedisDocumentSet(document_set.id)
|
||||||
|
|
||||||
|
# don't generate document set sync tasks if tasks are still pending
|
||||||
|
if r.exists(rds.fence_key):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# don't generate sync tasks if we're up to date
|
||||||
|
# race condition with the monitor/cleanup function if we use a cached result!
|
||||||
|
db_session.refresh(document_set)
|
||||||
|
if document_set.is_up_to_date:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# add tasks to celery and build up the task set to monitor in redis
|
||||||
|
r.delete(rds.taskset_key)
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add all documents that need to be updated into the queue
|
||||||
|
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||||
|
if tasks_generated is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||||
|
# It's possible for sets/groups to be generated initially with no entries
|
||||||
|
# and they still need to be marked as up to date.
|
||||||
|
# if tasks_generated == 0:
|
||||||
|
# return 0
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
f"RedisDocumentSet.generate_tasks finished. "
|
||||||
|
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# set this only after all tasks have been added
|
||||||
|
r.set(rds.fence_key, tasks_generated)
|
||||||
|
return tasks_generated
|
||||||
|
|
||||||
|
|
||||||
|
def try_generate_user_group_sync_tasks(
|
||||||
|
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||||
|
) -> int | None:
|
||||||
|
lock_beat.reacquire()
|
||||||
|
|
||||||
|
rug = RedisUserGroup(usergroup.id)
|
||||||
|
|
||||||
|
# don't generate sync tasks if tasks are still pending
|
||||||
|
if r.exists(rug.fence_key):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# race condition with the monitor/cleanup function if we use a cached result!
|
||||||
|
db_session.refresh(usergroup)
|
||||||
|
if usergroup.is_up_to_date:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# add tasks to celery and build up the task set to monitor in redis
|
||||||
|
r.delete(rug.taskset_key)
|
||||||
|
|
||||||
|
# Add all documents that need to be updated into the queue
|
||||||
|
task_logger.info(
|
||||||
|
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||||
|
)
|
||||||
|
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||||
|
if tasks_generated is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||||
|
# It's possible for sets/groups to be generated initially with no entries
|
||||||
|
# and they still need to be marked as up to date.
|
||||||
|
# if tasks_generated == 0:
|
||||||
|
# return 0
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
f"RedisUserGroup.generate_tasks finished. "
|
||||||
|
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# set this only after all tasks have been added
|
||||||
|
r.set(rug.fence_key, tasks_generated)
|
||||||
|
return tasks_generated
|
||||||
|
|
||||||
|
|
||||||
|
def monitor_connector_taskset(r: Redis) -> None:
|
||||||
|
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
|
||||||
|
if fence_value is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
initial_count = int(cast(int, fence_value))
|
||||||
|
except ValueError:
|
||||||
|
task_logger.error("The value is not an integer.")
|
||||||
|
return
|
||||||
|
|
||||||
|
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
|
||||||
|
task_logger.info(
|
||||||
|
f"Stale document sync progress: remaining={count} initial={initial_count}"
|
||||||
|
)
|
||||||
|
if count == 0:
|
||||||
|
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||||
|
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||||
|
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
|
||||||
|
|
||||||
|
|
||||||
|
def monitor_document_set_taskset(
|
||||||
|
key_bytes: bytes, r: Redis, db_session: Session
|
||||||
|
) -> None:
|
||||||
|
fence_key = key_bytes.decode("utf-8")
|
||||||
|
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||||
|
if document_set_id is None:
|
||||||
|
task_logger.warning("could not parse document set id from {key}")
|
||||||
|
return
|
||||||
|
|
||||||
|
rds = RedisDocumentSet(document_set_id)
|
||||||
|
|
||||||
|
fence_value = r.get(rds.fence_key)
|
||||||
|
if fence_value is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
initial_count = int(cast(int, fence_value))
|
||||||
|
except ValueError:
|
||||||
|
task_logger.error("The value is not an integer.")
|
||||||
|
return
|
||||||
|
|
||||||
|
count = cast(int, r.scard(rds.taskset_key))
|
||||||
|
task_logger.info(
|
||||||
|
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
document_set = cast(
|
||||||
|
DocumentSet,
|
||||||
|
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
|
||||||
|
) # casting since we "know" a document set with this ID exists
|
||||||
|
if document_set:
|
||||||
|
if not document_set.connector_credential_pairs:
|
||||||
|
# if there are no connectors, then delete the document set.
|
||||||
|
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||||
|
task_logger.info(
|
||||||
|
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mark_document_set_as_synced(document_set_id, db_session)
|
||||||
|
task_logger.info(
|
||||||
|
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||||
|
)
|
||||||
|
|
||||||
|
r.delete(rds.taskset_key)
|
||||||
|
r.delete(rds.fence_key)
|
||||||
|
|
||||||
|
|
||||||
|
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||||
|
fence_key = key_bytes.decode("utf-8")
|
||||||
|
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||||
|
if cc_pair_id is None:
|
||||||
|
task_logger.warning("could not parse document set id from {key}")
|
||||||
|
return
|
||||||
|
|
||||||
|
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||||
|
|
||||||
|
fence_value = r.get(rcd.fence_key)
|
||||||
|
if fence_value is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
initial_count = int(cast(int, fence_value))
|
||||||
|
except ValueError:
|
||||||
|
task_logger.error("The value is not an integer.")
|
||||||
|
return
|
||||||
|
|
||||||
|
count = cast(int, r.scard(rcd.taskset_key))
|
||||||
|
task_logger.info(
|
||||||
|
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||||
|
if not cc_pair:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# clean up the rest of the related Postgres entities
|
||||||
|
# index attempts
|
||||||
|
delete_index_attempts(
|
||||||
|
db_session=db_session,
|
||||||
|
cc_pair_id=cc_pair.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# document sets
|
||||||
|
delete_document_set_cc_pair_relationship__no_commit(
|
||||||
|
db_session=db_session,
|
||||||
|
connector_id=cc_pair.connector_id,
|
||||||
|
credential_id=cc_pair.credential_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# user groups
|
||||||
|
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||||
|
"danswer.db.user_group",
|
||||||
|
"delete_user_group_cc_pair_relationship__no_commit",
|
||||||
|
noop_fallback,
|
||||||
|
)
|
||||||
|
cleanup_user_groups(
|
||||||
|
cc_pair_id=cc_pair.id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# finally, delete the cc-pair
|
||||||
|
delete_connector_credential_pair__no_commit(
|
||||||
|
db_session=db_session,
|
||||||
|
connector_id=cc_pair.connector_id,
|
||||||
|
credential_id=cc_pair.credential_id,
|
||||||
|
)
|
||||||
|
# if there are no credentials left, delete the connector
|
||||||
|
connector = fetch_connector_by_id(
|
||||||
|
db_session=db_session,
|
||||||
|
connector_id=cc_pair.connector_id,
|
||||||
|
)
|
||||||
|
if not connector or not len(connector.credentials):
|
||||||
|
task_logger.info(
|
||||||
|
"Found no credentials left for connector, deleting connector"
|
||||||
|
)
|
||||||
|
db_session.delete(connector)
|
||||||
|
db_session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
stack_trace = traceback.format_exc()
|
||||||
|
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||||
|
add_deletion_failure_message(db_session, cc_pair.id, error_message)
|
||||||
|
task_logger.exception(
|
||||||
|
f"Failed to run connector_deletion. "
|
||||||
|
f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' "
|
||||||
|
f"and credential_id: '{cc_pair.credential_id}'. "
|
||||||
|
f"Deleted {initial_count} docs."
|
||||||
|
)
|
||||||
|
|
||||||
|
r.delete(rcd.taskset_key)
|
||||||
|
r.delete(rcd.fence_key)
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
|
||||||
|
def monitor_vespa_sync() -> None:
|
||||||
|
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||||
|
It scans for fence values and then gets the counts of any associated tasksets.
|
||||||
|
If the count is 0, that means all tasks finished and we should clean up.
|
||||||
|
|
||||||
|
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||||
|
do anything too expensive in this function!
|
||||||
|
"""
|
||||||
|
r = redis_pool.get_client()
|
||||||
|
|
||||||
|
lock_beat = r.lock(
|
||||||
|
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||||
|
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# prevent overlapping tasks
|
||||||
|
if not lock_beat.acquire(blocking=False):
|
||||||
|
return
|
||||||
|
|
||||||
|
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||||
|
monitor_connector_taskset(r)
|
||||||
|
|
||||||
|
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||||
|
monitor_connector_deletion_taskset(key_bytes, r)
|
||||||
|
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||||
|
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||||
|
|
||||||
|
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||||
|
monitor_usergroup_taskset = (
|
||||||
|
fetch_versioned_implementation_with_fallback(
|
||||||
|
"danswer.background.celery.tasks.vespa.tasks",
|
||||||
|
"monitor_usergroup_taskset",
|
||||||
|
noop_fallback,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||||
|
|
||||||
|
# uncomment for debugging if needed
|
||||||
|
# r_celery = celery_app.broker_connection().channel().client
|
||||||
|
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||||
|
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||||
|
except SoftTimeLimitExceeded:
|
||||||
|
task_logger.info(
|
||||||
|
"Soft time limit exceeded, task is being terminated gracefully."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if lock_beat.owned():
|
||||||
|
lock_beat.release()
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(
|
||||||
|
name="vespa_metadata_sync_task",
|
||||||
|
bind=True,
|
||||||
|
soft_time_limit=45,
|
||||||
|
time_limit=60,
|
||||||
|
max_retries=3,
|
||||||
|
)
|
||||||
|
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
|
||||||
|
task_logger.info(f"document_id={document_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||||
|
document_index = get_default_document_index(
|
||||||
|
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||||
|
)
|
||||||
|
|
||||||
|
doc = get_document(document_id, db_session)
|
||||||
|
if not doc:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# document set sync
|
||||||
|
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||||
|
update_doc_sets: set[str] = set(doc_sets)
|
||||||
|
|
||||||
|
# User group sync
|
||||||
|
doc_access = get_access_for_document(
|
||||||
|
document_id=document_id, db_session=db_session
|
||||||
|
)
|
||||||
|
update_request = UpdateRequest(
|
||||||
|
document_ids=[document_id],
|
||||||
|
document_sets=update_doc_sets,
|
||||||
|
access=doc_access,
|
||||||
|
boost=doc.boost,
|
||||||
|
hidden=doc.hidden,
|
||||||
|
)
|
||||||
|
|
||||||
|
# update Vespa
|
||||||
|
document_index.update(update_requests=[update_request])
|
||||||
|
|
||||||
|
# update db last. Worst case = we crash right before this and
|
||||||
|
# the sync might repeat again later
|
||||||
|
mark_document_as_synced(document_id, db_session)
|
||||||
|
except SoftTimeLimitExceeded:
|
||||||
|
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||||
|
except Exception as e:
|
||||||
|
task_logger.exception("Unexpected exception")
|
||||||
|
|
||||||
|
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||||
|
countdown = 2 ** (self.request.retries + 4)
|
||||||
|
self.retry(exc=e, countdown=countdown)
|
||||||
|
|
||||||
|
return True
|
@ -10,15 +10,27 @@ are multiple connector / credential pairs that have indexed it
|
|||||||
connector / credential pair from the access list
|
connector / credential pair from the access list
|
||||||
(6) delete all relevant entries from postgres
|
(6) delete all relevant entries from postgres
|
||||||
"""
|
"""
|
||||||
|
from celery import shared_task
|
||||||
|
from celery import Task
|
||||||
|
from celery.exceptions import SoftTimeLimitExceeded
|
||||||
|
from celery.utils.log import get_task_logger
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.access.access import get_access_for_document
|
||||||
from danswer.access.access import get_access_for_documents
|
from danswer.access.access import get_access_for_documents
|
||||||
|
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||||
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
|
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
|
||||||
from danswer.db.document import delete_documents_complete__no_commit
|
from danswer.db.document import delete_documents_complete__no_commit
|
||||||
|
from danswer.db.document import get_document
|
||||||
|
from danswer.db.document import get_document_connector_count
|
||||||
from danswer.db.document import get_document_connector_counts
|
from danswer.db.document import get_document_connector_counts
|
||||||
|
from danswer.db.document import mark_document_as_synced
|
||||||
from danswer.db.document import prepare_to_modify_documents
|
from danswer.db.document import prepare_to_modify_documents
|
||||||
|
from danswer.db.document_set import fetch_document_sets_for_document
|
||||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.document_index.document_index_utils import get_both_index_names
|
||||||
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.document_index.interfaces import DocumentIndex
|
from danswer.document_index.interfaces import DocumentIndex
|
||||||
from danswer.document_index.interfaces import UpdateRequest
|
from danswer.document_index.interfaces import UpdateRequest
|
||||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||||
@ -26,6 +38,9 @@ from danswer.utils.logger import setup_logger
|
|||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
# use this within celery tasks to get celery task specific logging
|
||||||
|
task_logger = get_task_logger(__name__)
|
||||||
|
|
||||||
_DELETION_BATCH_SIZE = 1000
|
_DELETION_BATCH_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
@ -108,3 +123,89 @@ def delete_connector_credential_pair_batch(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(
|
||||||
|
name="document_by_cc_pair_cleanup_task",
|
||||||
|
bind=True,
|
||||||
|
soft_time_limit=45,
|
||||||
|
time_limit=60,
|
||||||
|
max_retries=3,
|
||||||
|
)
|
||||||
|
def document_by_cc_pair_cleanup_task(
|
||||||
|
self: Task, document_id: str, connector_id: int, credential_id: int
|
||||||
|
) -> bool:
|
||||||
|
task_logger.info(f"document_id={document_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
|
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||||
|
document_index = get_default_document_index(
|
||||||
|
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||||
|
)
|
||||||
|
|
||||||
|
count = get_document_connector_count(db_session, document_id)
|
||||||
|
if count == 1:
|
||||||
|
# count == 1 means this is the only remaining cc_pair reference to the doc
|
||||||
|
# delete it from vespa and the db
|
||||||
|
document_index.delete(doc_ids=[document_id])
|
||||||
|
delete_documents_complete__no_commit(
|
||||||
|
db_session=db_session,
|
||||||
|
document_ids=[document_id],
|
||||||
|
)
|
||||||
|
elif count > 1:
|
||||||
|
# count > 1 means the document still has cc_pair references
|
||||||
|
doc = get_document(document_id, db_session)
|
||||||
|
if not doc:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# the below functions do not include cc_pairs being deleted.
|
||||||
|
# i.e. they will correctly omit access for the current cc_pair
|
||||||
|
doc_access = get_access_for_document(
|
||||||
|
document_id=document_id, db_session=db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||||
|
update_doc_sets: set[str] = set(doc_sets)
|
||||||
|
|
||||||
|
update_request = UpdateRequest(
|
||||||
|
document_ids=[document_id],
|
||||||
|
document_sets=update_doc_sets,
|
||||||
|
access=doc_access,
|
||||||
|
boost=doc.boost,
|
||||||
|
hidden=doc.hidden,
|
||||||
|
)
|
||||||
|
|
||||||
|
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||||
|
document_index.update_single(update_request=update_request)
|
||||||
|
|
||||||
|
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||||
|
delete_document_by_connector_credential_pair__no_commit(
|
||||||
|
db_session=db_session,
|
||||||
|
document_id=document_id,
|
||||||
|
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||||
|
connector_id=connector_id,
|
||||||
|
credential_id=credential_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
mark_document_as_synced(document_id, db_session)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# update_docs_last_modified__no_commit(
|
||||||
|
# db_session=db_session,
|
||||||
|
# document_ids=[document_id],
|
||||||
|
# )
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
except SoftTimeLimitExceeded:
|
||||||
|
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||||
|
except Exception as e:
|
||||||
|
task_logger.exception("Unexpected exception")
|
||||||
|
|
||||||
|
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||||
|
countdown = 2 ** (self.request.retries + 4)
|
||||||
|
self.retry(exc=e, countdown=countdown)
|
||||||
|
|
||||||
|
return True
|
||||||
|
@ -416,6 +416,7 @@ def update_loop(
|
|||||||
warm_up_bi_encoder(
|
warm_up_bi_encoder(
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
)
|
)
|
||||||
|
logger.notice("First inference complete.")
|
||||||
|
|
||||||
client_primary: Client | SimpleJobClient
|
client_primary: Client | SimpleJobClient
|
||||||
client_secondary: Client | SimpleJobClient
|
client_secondary: Client | SimpleJobClient
|
||||||
@ -444,6 +445,7 @@ def update_loop(
|
|||||||
|
|
||||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||||
|
|
||||||
|
logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||||
while True:
|
while True:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
@ -34,7 +34,9 @@ POSTGRES_WEB_APP_NAME = "web"
|
|||||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||||
POSTGRES_CELERY_APP_NAME = "celery"
|
POSTGRES_CELERY_APP_NAME = "celery"
|
||||||
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||||
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
|
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||||
|
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||||
|
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||||
|
|
||||||
@ -62,6 +64,7 @@ KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
|||||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||||
|
|
||||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||||
|
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||||
|
|
||||||
|
|
||||||
class DocumentSource(str, Enum):
|
class DocumentSource(str, Enum):
|
||||||
@ -186,6 +189,7 @@ class DanswerCeleryQueues:
|
|||||||
|
|
||||||
|
|
||||||
class DanswerRedisLocks:
|
class DanswerRedisLocks:
|
||||||
|
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
from typing import ContextManager
|
from typing import ContextManager
|
||||||
|
|
||||||
from sqlalchemy import event
|
from sqlalchemy import event
|
||||||
@ -32,14 +34,9 @@ logger = setup_logger()
|
|||||||
SYNC_DB_API = "psycopg2"
|
SYNC_DB_API = "psycopg2"
|
||||||
ASYNC_DB_API = "asyncpg"
|
ASYNC_DB_API = "asyncpg"
|
||||||
|
|
||||||
POSTGRES_APP_NAME = (
|
|
||||||
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
|
|
||||||
)
|
|
||||||
|
|
||||||
# global so we don't create more than one engine per process
|
# global so we don't create more than one engine per process
|
||||||
# outside of being best practice, this is needed so we can properly pool
|
# outside of being best practice, this is needed so we can properly pool
|
||||||
# connections and not create a new pool on every request
|
# connections and not create a new pool on every request
|
||||||
_SYNC_ENGINE: Engine | None = None
|
|
||||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||||
|
|
||||||
SessionFactory: sessionmaker[Session] | None = None
|
SessionFactory: sessionmaker[Session] | None = None
|
||||||
@ -108,6 +105,67 @@ def get_db_current_time(db_session: Session) -> datetime:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class SqlEngine:
|
||||||
|
"""Class to manage a global sql alchemy engine (needed for proper resource control)
|
||||||
|
Will eventually subsume most of the standalone functions in this file.
|
||||||
|
Sync only for now"""
|
||||||
|
|
||||||
|
_engine: Engine | None = None
|
||||||
|
_lock: threading.Lock = threading.Lock()
|
||||||
|
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||||
|
|
||||||
|
# Default parameters for engine creation
|
||||||
|
DEFAULT_ENGINE_KWARGS = {
|
||||||
|
"pool_size": 40,
|
||||||
|
"max_overflow": 10,
|
||||||
|
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
|
||||||
|
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||||
|
"""Private helper method to create and return an Engine."""
|
||||||
|
connection_string = build_connection_string(
|
||||||
|
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
|
||||||
|
)
|
||||||
|
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
||||||
|
return create_engine(connection_string, **merged_kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||||
|
"""Allow the caller to init the engine with extra params. Different clients
|
||||||
|
such as the API server and different celery workers and tasks
|
||||||
|
need different settings."""
|
||||||
|
with cls._lock:
|
||||||
|
if not cls._engine:
|
||||||
|
cls._engine = cls._init_engine(**engine_kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_engine(cls) -> Engine:
|
||||||
|
"""Gets the sql alchemy engine. Will init a default engine if init hasn't
|
||||||
|
already been called. You probably want to init first!"""
|
||||||
|
if not cls._engine:
|
||||||
|
with cls._lock:
|
||||||
|
if not cls._engine:
|
||||||
|
cls._engine = cls._init_engine()
|
||||||
|
return cls._engine
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_app_name(cls, app_name: str) -> None:
|
||||||
|
"""Class method to set the app name."""
|
||||||
|
cls._app_name = app_name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_app_name(cls) -> str:
|
||||||
|
"""Class method to get current app name."""
|
||||||
|
if not cls._app_name:
|
||||||
|
return ""
|
||||||
|
return cls._app_name
|
||||||
|
|
||||||
|
|
||||||
def build_connection_string(
|
def build_connection_string(
|
||||||
*,
|
*,
|
||||||
db_api: str = ASYNC_DB_API,
|
db_api: str = ASYNC_DB_API,
|
||||||
@ -125,24 +183,11 @@ def build_connection_string(
|
|||||||
|
|
||||||
|
|
||||||
def init_sqlalchemy_engine(app_name: str) -> None:
|
def init_sqlalchemy_engine(app_name: str) -> None:
|
||||||
global POSTGRES_APP_NAME
|
SqlEngine.set_app_name(app_name)
|
||||||
POSTGRES_APP_NAME = app_name
|
|
||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_engine() -> Engine:
|
def get_sqlalchemy_engine() -> Engine:
|
||||||
global _SYNC_ENGINE
|
return SqlEngine.get_engine()
|
||||||
if _SYNC_ENGINE is None:
|
|
||||||
connection_string = build_connection_string(
|
|
||||||
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
|
|
||||||
)
|
|
||||||
_SYNC_ENGINE = create_engine(
|
|
||||||
connection_string,
|
|
||||||
pool_size=40,
|
|
||||||
max_overflow=10,
|
|
||||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
|
||||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
|
||||||
)
|
|
||||||
return _SYNC_ENGINE
|
|
||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||||
@ -154,7 +199,9 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
|||||||
_ASYNC_ENGINE = create_async_engine(
|
_ASYNC_ENGINE = create_async_engine(
|
||||||
connection_string,
|
connection_string,
|
||||||
connect_args={
|
connect_args={
|
||||||
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
|
"server_settings": {
|
||||||
|
"application_name": SqlEngine.get_app_name() + "_async"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
pool_size=40,
|
pool_size=40,
|
||||||
max_overflow=10,
|
max_overflow=10,
|
||||||
|
@ -239,7 +239,7 @@ def prune_cc_pair(
|
|||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> StatusResponse[list[int]]:
|
) -> StatusResponse[list[int]]:
|
||||||
# avoiding circular refs
|
# avoiding circular refs
|
||||||
from danswer.background.celery.celery_app import prune_documents_task
|
from danswer.background.celery.tasks.pruning.tasks import prune_documents_task
|
||||||
|
|
||||||
cc_pair = get_connector_credential_pair_from_id(
|
cc_pair = get_connector_credential_pair_from_id(
|
||||||
cc_pair_id=cc_pair_id,
|
cc_pair_id=cc_pair_id,
|
||||||
|
@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.auth.users import current_curator_or_admin_user
|
from danswer.auth.users import current_curator_or_admin_user
|
||||||
|
from danswer.background.celery.celery_app import celery_app
|
||||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||||
from danswer.configs.constants import DanswerCeleryPriority
|
from danswer.configs.constants import DanswerCeleryPriority
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
@ -146,10 +147,6 @@ def create_deletion_attempt_for_connector_id(
|
|||||||
user: User = Depends(current_curator_or_admin_user),
|
user: User = Depends(current_curator_or_admin_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> None:
|
) -> None:
|
||||||
from danswer.background.celery.celery_app import (
|
|
||||||
check_for_connector_deletion_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
connector_id = connector_credential_pair_identifier.connector_id
|
connector_id = connector_credential_pair_identifier.connector_id
|
||||||
credential_id = connector_credential_pair_identifier.credential_id
|
credential_id = connector_credential_pair_identifier.credential_id
|
||||||
|
|
||||||
@ -193,8 +190,11 @@ def create_deletion_attempt_for_connector_id(
|
|||||||
status=ConnectorCredentialPairStatus.DELETING,
|
status=ConnectorCredentialPairStatus.DELETING,
|
||||||
)
|
)
|
||||||
|
|
||||||
# run the beat task to pick up this deletion early
|
db_session.commit()
|
||||||
check_for_connector_deletion_task.apply_async(
|
|
||||||
|
# run the beat task to pick up this deletion from the db immediately
|
||||||
|
celery_app.send_task(
|
||||||
|
"check_for_connector_deletion_task",
|
||||||
priority=DanswerCeleryPriority.HIGH,
|
priority=DanswerCeleryPriority.HIGH,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
52
backend/ee/danswer/background/celery/tasks/vespa/tasks.py
Normal file
52
backend/ee/danswer/background/celery/tasks/vespa/tasks.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from redis import Redis
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.background.celery.celery_app import task_logger
|
||||||
|
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from ee.danswer.db.user_group import delete_user_group
|
||||||
|
from ee.danswer.db.user_group import fetch_user_group
|
||||||
|
from ee.danswer.db.user_group import mark_user_group_as_synced
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
|
||||||
|
"""This function is likely to move in the worker refactor happening next."""
|
||||||
|
key = key_bytes.decode("utf-8")
|
||||||
|
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
|
||||||
|
if not usergroup_id:
|
||||||
|
task_logger.warning("Could not parse usergroup id from {key}")
|
||||||
|
return
|
||||||
|
|
||||||
|
rug = RedisUserGroup(usergroup_id)
|
||||||
|
fence_value = r.get(rug.fence_key)
|
||||||
|
if fence_value is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
initial_count = int(cast(int, fence_value))
|
||||||
|
except ValueError:
|
||||||
|
task_logger.error("The value is not an integer.")
|
||||||
|
return
|
||||||
|
|
||||||
|
count = cast(int, r.scard(rug.taskset_key))
|
||||||
|
task_logger.info(
|
||||||
|
f"User group sync: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||||
|
)
|
||||||
|
if count > 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||||
|
if user_group:
|
||||||
|
if user_group.is_up_for_deletion:
|
||||||
|
delete_user_group(db_session=db_session, user_group=user_group)
|
||||||
|
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
|
||||||
|
else:
|
||||||
|
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||||
|
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||||
|
|
||||||
|
r.delete(rug.taskset_key)
|
||||||
|
r.delete(rug.fence_key)
|
@ -1,11 +1,5 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
from redis import Redis
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.background.celery.celery_app import task_logger
|
|
||||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
|
||||||
from danswer.db.enums import AccessType
|
from danswer.db.enums import AccessType
|
||||||
from danswer.db.models import ConnectorCredentialPair
|
from danswer.db.models import ConnectorCredentialPair
|
||||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||||
@ -18,9 +12,6 @@ from ee.danswer.background.task_name_builders import (
|
|||||||
from ee.danswer.background.task_name_builders import (
|
from ee.danswer.background.task_name_builders import (
|
||||||
name_sync_external_group_permissions_task,
|
name_sync_external_group_permissions_task,
|
||||||
)
|
)
|
||||||
from ee.danswer.db.user_group import delete_user_group
|
|
||||||
from ee.danswer.db.user_group import fetch_user_group
|
|
||||||
from ee.danswer.db.user_group import mark_user_group_as_synced
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@ -79,43 +70,3 @@ def should_perform_external_group_permissions_check(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis) -> None:
|
|
||||||
"""This function is likely to move in the worker refactor happening next."""
|
|
||||||
key = key_bytes.decode("utf-8")
|
|
||||||
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
|
|
||||||
if not usergroup_id:
|
|
||||||
task_logger.warning("Could not parse usergroup id from {key}")
|
|
||||||
return
|
|
||||||
|
|
||||||
rug = RedisUserGroup(usergroup_id)
|
|
||||||
fence_value = r.get(rug.fence_key)
|
|
||||||
if fence_value is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
initial_count = int(cast(int, fence_value))
|
|
||||||
except ValueError:
|
|
||||||
task_logger.error("The value is not an integer.")
|
|
||||||
return
|
|
||||||
|
|
||||||
count = cast(int, r.scard(rug.taskset_key))
|
|
||||||
task_logger.info(
|
|
||||||
f"User group sync: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
|
||||||
)
|
|
||||||
if count > 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
|
||||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
|
||||||
if user_group:
|
|
||||||
if user_group.is_up_for_deletion:
|
|
||||||
delete_user_group(db_session=db_session, user_group=user_group)
|
|
||||||
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
|
|
||||||
else:
|
|
||||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
|
||||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
|
||||||
|
|
||||||
r.delete(rug.taskset_key)
|
|
||||||
r.delete(rug.fence_key)
|
|
||||||
|
@ -18,7 +18,8 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def run_jobs(exclude_indexing: bool) -> None:
|
def run_jobs(exclude_indexing: bool) -> None:
|
||||||
cmd_worker = [
|
# command setup
|
||||||
|
cmd_worker_primary = [
|
||||||
"celery",
|
"celery",
|
||||||
"-A",
|
"-A",
|
||||||
"ee.danswer.background.celery.celery_app",
|
"ee.danswer.background.celery.celery_app",
|
||||||
@ -26,8 +27,38 @@ def run_jobs(exclude_indexing: bool) -> None:
|
|||||||
"--pool=threads",
|
"--pool=threads",
|
||||||
"--concurrency=6",
|
"--concurrency=6",
|
||||||
"--loglevel=INFO",
|
"--loglevel=INFO",
|
||||||
|
"-n",
|
||||||
|
"primary@%n",
|
||||||
"-Q",
|
"-Q",
|
||||||
"celery,vespa_metadata_sync,connector_deletion",
|
"celery",
|
||||||
|
]
|
||||||
|
|
||||||
|
cmd_worker_light = [
|
||||||
|
"celery",
|
||||||
|
"-A",
|
||||||
|
"ee.danswer.background.celery.celery_app",
|
||||||
|
"worker",
|
||||||
|
"--pool=threads",
|
||||||
|
"--concurrency=16",
|
||||||
|
"--loglevel=INFO",
|
||||||
|
"-n",
|
||||||
|
"light@%n",
|
||||||
|
"-Q",
|
||||||
|
"vespa_metadata_sync,connector_deletion",
|
||||||
|
]
|
||||||
|
|
||||||
|
cmd_worker_heavy = [
|
||||||
|
"celery",
|
||||||
|
"-A",
|
||||||
|
"ee.danswer.background.celery.celery_app",
|
||||||
|
"worker",
|
||||||
|
"--pool=threads",
|
||||||
|
"--concurrency=6",
|
||||||
|
"--loglevel=INFO",
|
||||||
|
"-n",
|
||||||
|
"heavy@%n",
|
||||||
|
"-Q",
|
||||||
|
"connector_pruning",
|
||||||
]
|
]
|
||||||
|
|
||||||
cmd_beat = [
|
cmd_beat = [
|
||||||
@ -38,19 +69,38 @@ def run_jobs(exclude_indexing: bool) -> None:
|
|||||||
"--loglevel=INFO",
|
"--loglevel=INFO",
|
||||||
]
|
]
|
||||||
|
|
||||||
worker_process = subprocess.Popen(
|
# spawn processes
|
||||||
cmd_worker, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
worker_primary_process = subprocess.Popen(
|
||||||
|
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
worker_light_process = subprocess.Popen(
|
||||||
|
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_heavy_process = subprocess.Popen(
|
||||||
|
cmd_worker_heavy, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||||
|
)
|
||||||
|
|
||||||
beat_process = subprocess.Popen(
|
beat_process = subprocess.Popen(
|
||||||
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_thread = threading.Thread(
|
# monitor threads
|
||||||
target=monitor_process, args=("WORKER", worker_process)
|
worker_primary_thread = threading.Thread(
|
||||||
|
target=monitor_process, args=("PRIMARY", worker_primary_process)
|
||||||
|
)
|
||||||
|
worker_light_thread = threading.Thread(
|
||||||
|
target=monitor_process, args=("LIGHT", worker_light_process)
|
||||||
|
)
|
||||||
|
worker_heavy_thread = threading.Thread(
|
||||||
|
target=monitor_process, args=("HEAVY", worker_heavy_process)
|
||||||
)
|
)
|
||||||
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
||||||
|
|
||||||
worker_thread.start()
|
worker_primary_thread.start()
|
||||||
|
worker_light_thread.start()
|
||||||
|
worker_heavy_thread.start()
|
||||||
beat_thread.start()
|
beat_thread.start()
|
||||||
|
|
||||||
if not exclude_indexing:
|
if not exclude_indexing:
|
||||||
@ -93,7 +143,9 @@ def run_jobs(exclude_indexing: bool) -> None:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
worker_thread.join()
|
worker_primary_thread.join()
|
||||||
|
worker_light_thread.join()
|
||||||
|
worker_heavy_thread.join()
|
||||||
beat_thread.join()
|
beat_thread.join()
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,16 +24,50 @@ autorestart=true
|
|||||||
# on a system, but this should be okay for now since all our celery tasks are
|
# on a system, but this should be okay for now since all our celery tasks are
|
||||||
# relatively compute-light (e.g. they tend to just make a bunch of requests to
|
# relatively compute-light (e.g. they tend to just make a bunch of requests to
|
||||||
# Vespa / Postgres)
|
# Vespa / Postgres)
|
||||||
[program:celery_worker]
|
[program:celery_worker_primary]
|
||||||
command=celery -A danswer.background.celery.celery_run:celery_app worker
|
command=celery -A danswer.background.celery.celery_run:celery_app worker
|
||||||
--pool=threads
|
--pool=threads
|
||||||
--concurrency=6
|
--concurrency=4
|
||||||
|
--prefetch-multiplier=1
|
||||||
--loglevel=INFO
|
--loglevel=INFO
|
||||||
--logfile=/var/log/celery_worker_supervisor.log
|
--logfile=/var/log/celery_worker_primary_supervisor.log
|
||||||
-Q celery,vespa_metadata_sync,connector_deletion
|
--hostname=primary@%%n
|
||||||
environment=LOG_FILE_NAME=celery_worker
|
-Q celery
|
||||||
|
environment=LOG_FILE_NAME=celery_worker_primary
|
||||||
redirect_stderr=true
|
redirect_stderr=true
|
||||||
autorestart=true
|
autorestart=true
|
||||||
|
startsecs=10
|
||||||
|
stopasgroup=true
|
||||||
|
|
||||||
|
[program:celery_worker_light]
|
||||||
|
command=celery -A danswer.background.celery.celery_run:celery_app worker
|
||||||
|
--pool=threads
|
||||||
|
--concurrency=16
|
||||||
|
--prefetch-multiplier=8
|
||||||
|
--loglevel=INFO
|
||||||
|
--logfile=/var/log/celery_worker_light_supervisor.log
|
||||||
|
--hostname=light@%%n
|
||||||
|
-Q vespa_metadata_sync,connector_deletion
|
||||||
|
environment=LOG_FILE_NAME=celery_worker_light
|
||||||
|
redirect_stderr=true
|
||||||
|
autorestart=true
|
||||||
|
startsecs=10
|
||||||
|
stopasgroup=true
|
||||||
|
|
||||||
|
[program:celery_worker_heavy]
|
||||||
|
command=celery -A danswer.background.celery.celery_run:celery_app worker
|
||||||
|
--pool=threads
|
||||||
|
--concurrency=4
|
||||||
|
--prefetch-multiplier=1
|
||||||
|
--loglevel=INFO
|
||||||
|
--logfile=/var/log/celery_worker_heavy_supervisor.log
|
||||||
|
--hostname=heavy@%%n
|
||||||
|
-Q connector_pruning
|
||||||
|
environment=LOG_FILE_NAME=celery_worker_heavy
|
||||||
|
redirect_stderr=true
|
||||||
|
autorestart=true
|
||||||
|
startsecs=10
|
||||||
|
stopasgroup=true
|
||||||
|
|
||||||
# Job scheduler for periodic tasks
|
# Job scheduler for periodic tasks
|
||||||
[program:celery_beat]
|
[program:celery_beat]
|
||||||
@ -41,6 +75,8 @@ command=celery -A danswer.background.celery.celery_run:celery_app beat
|
|||||||
--logfile=/var/log/celery_beat_supervisor.log
|
--logfile=/var/log/celery_beat_supervisor.log
|
||||||
environment=LOG_FILE_NAME=celery_beat
|
environment=LOG_FILE_NAME=celery_beat
|
||||||
redirect_stderr=true
|
redirect_stderr=true
|
||||||
|
startsecs=10
|
||||||
|
stopasgroup=true
|
||||||
|
|
||||||
# Listens for Slack messages and responds with answers
|
# Listens for Slack messages and responds with answers
|
||||||
# for all channels that the DanswerBot has been added to.
|
# for all channels that the DanswerBot has been added to.
|
||||||
@ -60,9 +96,13 @@ startsecs=60
|
|||||||
command=tail -qF
|
command=tail -qF
|
||||||
/var/log/document_indexing_info.log
|
/var/log/document_indexing_info.log
|
||||||
/var/log/celery_beat_supervisor.log
|
/var/log/celery_beat_supervisor.log
|
||||||
/var/log/celery_worker_supervisor.log
|
/var/log/celery_worker_primary_supervisor.log
|
||||||
|
/var/log/celery_worker_light_supervisor.log
|
||||||
|
/var/log/celery_worker_heavy_supervisor.log
|
||||||
/var/log/celery_beat_debug.log
|
/var/log/celery_beat_debug.log
|
||||||
/var/log/celery_worker_debug.log
|
/var/log/celery_worker_primary_debug.log
|
||||||
|
/var/log/celery_worker_light_debug.log
|
||||||
|
/var/log/celery_worker_heavy_debug.log
|
||||||
/var/log/slack_bot_debug.log
|
/var/log/slack_bot_debug.log
|
||||||
stdout_logfile=/dev/stdout
|
stdout_logfile=/dev/stdout
|
||||||
stdout_logfile_maxbytes=0
|
stdout_logfile_maxbytes=0
|
||||||
|
Reference in New Issue
Block a user