mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-25 19:37:29 +02:00
Feature/background processing (#2275)
* first cut at redis * some new helper functions for the db * ignore kombu tables in alembic migrations (used by celery) * multiline commands for readability, add vespa_metadata_sync queue to worker * typo fix * fix returning tuple fields * add constants * fix _get_access_for_document * docstrings! * fix double function declaration and typing * fix type hinting * add a global redis pool * Add get_document function * use task_logger in various celery tasks * add celeryconfig.py to simplify configuration. Will be used in a subsequent commit * Add celery redis helper. used in a subsequent PR * kombu warning getting spammy since celery is not self managing its queue in Postgres any more * add last_modified and last_synced to documents * fix task naming convention * use celeryconfig.py * the big one. adds queues and tasks, updates functions to use the queues with priorities, etc * change vespa index log line to debug * mypy fixes * update alembic migration * fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call * mypy * switch to monotonic time * fix startup dependencies on redis * rebase alembic migration * kombu cleanup - fail silently * mypy * add redis_host environment override * update REDIS_HOST env var in docker-compose.dev.yml * update the rest of the docker files * harden indexing-status endpoint against db changes happening in the background. Needs further improvement but OK for now. * allow no task syncs to run because we create certain objects with no entries but initially marked as out of date * add back writing to vespa on indexing * update contributing guide * backporting fixes from background_deletion * renaming cache to cache_volume * add redis password to various deployments * try setting up pr testing for helm * fix indent * hopefully this release version actually exists * fix command line option to --chart-dirs * fetch-depth 0 * edit values.yaml * try setting ct working directory * bypass testing only on change for now * move files and lint them * update helm testing * some issues suggest using --config works * add vespa repo * add postgresql repo * increase timeout * try amd64 runner * fix redis password reference * add comment to helm chart testing workflow * rename helm testing workflow to disable it * adding clarifying comments * address code review * missed a file * remove commented warning ... just not needed --------- Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
"""Add last synced and last modified to document table
|
||||
|
||||
Revision ID: 52a219fb5233
|
||||
Revises: f17bf3b0d9f1
|
||||
Create Date: 2024-08-28 17:40:46.077470
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "52a219fb5233"
|
||||
down_revision = "f7e58d357687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# last modified represents the last time anything needing syncing to vespa changed
|
||||
# including row metadata and the document itself. This obviously does not include
|
||||
# the last_synced column.
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"last_modified",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# last synced represents the last time this document was synced to Vespa
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
# Set last_synced to the same value as last_modified for existing rows
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE document
|
||||
SET last_synced = last_modified
|
||||
"""
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_document_last_modified"),
|
||||
"document",
|
||||
["last_modified"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
op.f("ix_document_last_synced"),
|
||||
"document",
|
||||
["last_synced"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_document_last_synced"), table_name="document")
|
||||
op.drop_index(op.f("ix_document_last_modified"), table_name="document")
|
||||
op.drop_column("document", "last_synced")
|
||||
op.drop_column("document", "last_modified")
|
@@ -3,21 +3,49 @@ from sqlalchemy.orm import Session
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.document import get_access_info_for_document
|
||||
from danswer.db.document import get_access_info_for_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
def _get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
info = get_access_info_for_document(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
if not info:
|
||||
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
|
||||
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
versioned_get_access_for_document_fn = fetch_versioned_implementation(
|
||||
"danswer.access.access", "_get_access_for_document"
|
||||
)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
document_access_info = get_acccess_info_for_documents(
|
||||
document_access_info = get_access_info_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(user_ids, [], is_public)
|
||||
document_id: DocumentAccess.build(
|
||||
user_ids=user_ids, user_groups=[], is_public=is_public
|
||||
)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
}
|
||||
|
||||
|
@@ -3,66 +3,83 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery # type: ignore
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.exceptions import TaskRevokedError
|
||||
from celery.signals import beat_init
|
||||
from celery.signals import worker_init
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from redis import Redis
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
|
||||
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||
from danswer.background.celery.celery_utils import should_sync_doc_set
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import (
|
||||
get_connector_credential_pair,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import count_documents_by_needs_sync
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_set_for_document
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.document_set import fetch_documents_for_document_set_paginated
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.redis.redis_pool import RedisPool
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
CELERY_PASSWORD_PART = ""
|
||||
if REDIS_PASSWORD:
|
||||
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
celery_broker_url = (
|
||||
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
|
||||
)
|
||||
celery_backend_url = (
|
||||
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
|
||||
)
|
||||
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
|
||||
redis_pool = RedisPool()
|
||||
|
||||
|
||||
_SYNC_BATCH_SIZE = 100
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object(
|
||||
"danswer.background.celery.celeryconfig"
|
||||
) # Load configuration from 'celeryconfig.py'
|
||||
|
||||
|
||||
#####
|
||||
@@ -111,7 +128,10 @@ def cleanup_connector_credential_pair_task(
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run connector_deletion due to {e}")
|
||||
task_logger.exception(
|
||||
f"Failed to run connector_deletion. "
|
||||
f"connector_id={connector_id} credential_id={credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@@ -130,7 +150,9 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
logger.warning(f"ccpair not found for {connector_id} {credential_id}")
|
||||
task_logger.warning(
|
||||
f"ccpair not found for {connector_id} {credential_id}"
|
||||
)
|
||||
return
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
@@ -162,12 +184,12 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
)
|
||||
|
||||
if len(doc_ids_to_remove) == 0:
|
||||
logger.info(
|
||||
task_logger.info(
|
||||
f"No docs to prune from {cc_pair.connector.source} connector"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
task_logger.info(
|
||||
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
|
||||
)
|
||||
delete_connector_credential_pair_batch(
|
||||
@@ -177,113 +199,202 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
document_index=document_index,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id} due to {e}"
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id}."
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_document_set_sync_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_document_set_task(document_set_id: int) -> None:
|
||||
"""For document sets marked as not up to date, sync the state from postgres
|
||||
into the datastore. Also handles deletions."""
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
return None
|
||||
|
||||
def _sync_document_batch(document_ids: list[str], db_session: Session) -> None:
|
||||
logger.debug(f"Syncing document sets for: {document_ids}")
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
with prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
):
|
||||
# get current state of document sets for these documents
|
||||
document_set_map = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
stale_doc_count = count_documents_by_needs_sync(db_session)
|
||||
if stale_doc_count == 0:
|
||||
return None
|
||||
|
||||
# update Vespa
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
task_logger.info(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=set(document_set_map.get(document_id, [])),
|
||||
)
|
||||
for document_id in document_ids
|
||||
]
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
cursor = None
|
||||
while True:
|
||||
document_batch, cursor = fetch_documents_for_document_set_paginated(
|
||||
document_set_id=document_set_id,
|
||||
db_session=db_session,
|
||||
current_only=False,
|
||||
last_document_id=cursor,
|
||||
limit=_SYNC_BATCH_SIZE,
|
||||
)
|
||||
_sync_document_batch(
|
||||
document_ids=[document.id for document in document_batch],
|
||||
db_session=db_session,
|
||||
)
|
||||
if cursor is None:
|
||||
break
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
|
||||
# if there are no connectors, then delete the document set. Otherwise, just
|
||||
# mark it as successfully synced.
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(
|
||||
db_session=db_session, document_set_id=document_set_id
|
||||
),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if not document_set.connector_credential_pairs:
|
||||
delete_document_set(
|
||||
document_set_row=document_set, db_session=db_session
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(
|
||||
document_set_id=document_set_id, db_session=db_session
|
||||
)
|
||||
logger.info(f"Document set sync for '{document_set_id}' complete!")
|
||||
if tasks_generated is None:
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to sync document set %s", document_set_id)
|
||||
raise
|
||||
if tasks_generated == 0:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
total_tasks_generated += tasks_generated
|
||||
|
||||
task_logger.info(
|
||||
f"All per connector generate_tasks finished. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
|
||||
return total_tasks_generated
|
||||
|
||||
|
||||
def try_generate_document_set_sync_tasks(
|
||||
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(document_set.id)
|
||||
|
||||
# don't generate document set sync tasks if tasks are still pending
|
||||
if r.exists(rds.fence_key):
|
||||
return None
|
||||
|
||||
# don't generate sync tasks if we're up to date
|
||||
if document_set.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rds.taskset_key)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks finished. "
|
||||
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rds.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def try_generate_user_group_sync_tasks(
|
||||
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(usergroup.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rug.fence_key):
|
||||
return None
|
||||
|
||||
if usergroup.is_up_to_date:
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rug.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(f"generate_tasks starting. usergroup_id={usergroup.id}")
|
||||
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"generate_tasks finished. "
|
||||
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rug.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_for_document_sets_sync_task",
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_document_sets_sync_task() -> None:
|
||||
"""Runs periodically to check if any sync tasks should be run and adds them
|
||||
to the queue"""
|
||||
def check_for_vespa_sync_task() -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
r = redis_pool.get_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
|
||||
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
if should_sync_doc_set(document_set, db_session):
|
||||
logger.info(f"Syncing the {document_set.name} document set")
|
||||
sync_document_set_task.apply_async(
|
||||
kwargs=dict(document_set_id=document_set.id),
|
||||
try_generate_document_set_sync_tasks(
|
||||
document_set, db_session, r, lock_beat
|
||||
)
|
||||
|
||||
# check if any user groups are not synced
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_groups"
|
||||
)
|
||||
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
for usergroup in user_groups:
|
||||
try_generate_user_group_sync_tasks(
|
||||
usergroup, db_session, r, lock_beat
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
# Always exceptions on the MIT version, which is expected
|
||||
pass
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_for_cc_pair_deletion_task",
|
||||
@@ -292,11 +403,13 @@ def check_for_document_sets_sync_task() -> None:
|
||||
def check_for_cc_pair_deletion_task() -> None:
|
||||
"""Runs periodically to check if any deletion tasks should be run"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
# check if any cc pairs are up for deletion
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
|
||||
logger.notice(f"Deleting the {cc_pair.name} connector credential pair")
|
||||
task_logger.info(
|
||||
f"Deleting the {cc_pair.name} connector credential pair"
|
||||
)
|
||||
cleanup_connector_credential_pair_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
@@ -343,7 +456,9 @@ def kombu_message_cleanup_task(self: Any) -> int:
|
||||
db_session.commit()
|
||||
|
||||
if ctx["deleted"] > 0:
|
||||
logger.info(f"Deleted {ctx['deleted']} orphaned messages from kombu_message.")
|
||||
task_logger.info(
|
||||
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
|
||||
)
|
||||
|
||||
return ctx["deleted"]
|
||||
|
||||
@@ -417,12 +532,6 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||
)
|
||||
if result.rowcount > 0: # type: ignore
|
||||
ctx["deleted"] += 1
|
||||
else:
|
||||
task_name = payload["headers"]["task"]
|
||||
logger.warning(
|
||||
f"Message found for task older than {ctx['cleanup_age']} days. "
|
||||
f"id={task_id} name={task_name}"
|
||||
)
|
||||
|
||||
ctx["last_processed_id"] = msg[0]
|
||||
|
||||
@@ -446,7 +555,7 @@ def check_for_prune_task() -> None:
|
||||
credential=cc_pair.credential,
|
||||
db_session=db_session,
|
||||
):
|
||||
logger.info(f"Pruning the {cc_pair.connector.name} connector")
|
||||
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
|
||||
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
@@ -456,19 +565,331 @@ def check_for_prune_task() -> None:
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# document set sync
|
||||
doc_sets = fetch_document_set_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
# User group sync
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
update_request = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
document_index.update(update_requests=[update_request])
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def celery_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
"""We handle this signal in order to remove completed tasks
|
||||
from their respective tasksets. This allows us to track the progress of document set
|
||||
and user group syncs.
|
||||
|
||||
This function runs after any task completes (both success and failure)
|
||||
Note that this signal does not fire on a task that failed to complete and is going
|
||||
to be retried.
|
||||
"""
|
||||
if not task:
|
||||
return
|
||||
|
||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||
# logger.debug(f"Result: {retval}")
|
||||
|
||||
if state not in READY_STATES:
|
||||
return
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r = redis_pool.get_client()
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||
r = redis_pool.get_client()
|
||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||
if document_set_id is not None:
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
r.srem(rds.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||
r = redis_pool.get_client()
|
||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||
if usergroup_id is not None:
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
def monitor_connector_taskset(r: Redis) -> None:
|
||||
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
|
||||
task_logger.info(f"Stale documents: remaining={count} initial={initial_count}")
|
||||
if count == 0:
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
if document_set_id is None:
|
||||
task_logger.warning("could not parse document set id from {key}")
|
||||
return
|
||||
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
|
||||
fence_value = r.get(rds.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"document_set_id={document_set_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if document_set:
|
||||
if not document_set.connector_credential_pairs:
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
|
||||
key = key_bytes.decode("utf-8")
|
||||
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
|
||||
if not usergroup_id:
|
||||
task_logger.warning("Could not parse usergroup id from {key}")
|
||||
return
|
||||
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
fence_value = r.get(rug.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, fence_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rug.taskset_key))
|
||||
task_logger.info(
|
||||
f"usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
try:
|
||||
fetch_user_group = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_group"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
task_logger.exception(
|
||||
"fetch_versioned_implementation failed to look up fetch_user_group."
|
||||
)
|
||||
return
|
||||
|
||||
user_group: UserGroup | None = fetch_user_group(
|
||||
db_session=db_session, user_group_id=usergroup_id
|
||||
)
|
||||
if user_group:
|
||||
if user_group.is_up_for_deletion:
|
||||
delete_user_group = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group", "delete_user_group", noop_fallback
|
||||
)
|
||||
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f" Deleted usergroup. id='{usergroup_id}'")
|
||||
else:
|
||||
mark_user_group_as_synced = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group", "mark_user_group_as_synced", noop_fallback
|
||||
)
|
||||
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||
|
||||
r.delete(rug.taskset_key)
|
||||
r.delete(rug.fence_key)
|
||||
|
||||
|
||||
@celery_app.task(name="monitor_vespa_sync", soft_time_limit=300)
|
||||
def monitor_vespa_sync() -> None:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
"""
|
||||
r = redis_pool.get_client()
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
#
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME)
|
||||
|
||||
# TODO(rkuo): this is singleton work that should be done on startup exactly once
|
||||
# if we run multiple workers, we'll need to centralize where this cleanup happens
|
||||
r = redis_pool.get_client()
|
||||
|
||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"check-for-document-set-sync": {
|
||||
"task": "check_for_document_sets_sync_task",
|
||||
"check-for-vespa-sync": {
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
"check-for-cc-pair-deletion": {
|
||||
"task": "check_for_cc_pair_deletion_task",
|
||||
# don't need to check too often, since we kick off a deletion initially
|
||||
# during the API call that actually marks the CC pair for deletion
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
celery_app.conf.beat_schedule.update(
|
||||
@@ -476,6 +897,7 @@ celery_app.conf.beat_schedule.update(
|
||||
"check-for-prune": {
|
||||
"task": "check_for_prune_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -484,6 +906,16 @@ celery_app.conf.beat_schedule.update(
|
||||
"kombu-message-cleanup": {
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
}
|
||||
)
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"monitor-vespa-sync": {
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
299
backend/danswer/background/celery/celery_redis.py
Normal file
299
backend/danswer/background/celery/celery_redis.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
class RedisObjectHelper(ABC):
|
||||
PREFIX = "base"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, id: int):
|
||||
self._id: int = id
|
||||
|
||||
@property
|
||||
def task_id_prefix(self) -> str:
|
||||
return f"{self.PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def fence_key(self) -> str:
|
||||
# example: documentset_fence_1
|
||||
return f"{self.FENCE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
# example: documentset_taskset_1
|
||||
return f"{self.TASKSET_PREFIX}_{self._id}"
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
|
||||
|
||||
Args:
|
||||
key (str): The fence key string.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
|
||||
"""
|
||||
parts = key.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[2])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_task_id(task_id: str) -> int | None:
|
||||
"""
|
||||
Extracts the object ID from a task ID string.
|
||||
|
||||
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
|
||||
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
|
||||
- `objectid` is the ID you want to extract,
|
||||
- `suffix` is another arbitrary string (e.g., a UUID).
|
||||
|
||||
Example:
|
||||
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
|
||||
this method will return the string `"1"`.
|
||||
|
||||
Args:
|
||||
task_id (str): The task ID string from which to extract the object ID.
|
||||
|
||||
Returns:
|
||||
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
|
||||
"""
|
||||
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
|
||||
parts = task_id.split("_")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
object_id = int(parts[1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
return object_id
|
||||
|
||||
@abstractmethod
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(self._id)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
"danswer.db.user_group",
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(self._id)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the set BEFORE creating the task.
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
|
||||
@classmethod
|
||||
def get_taskset_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.TASKSET_PREFIX
|
||||
|
||||
@property
|
||||
def taskset_key(self) -> str:
|
||||
"""Notice that this is intentionally reusing the same taskset for all
|
||||
connector syncs"""
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
used to implement task prioritization.
|
||||
This operation is not atomic."""
|
||||
total_length = 0
|
||||
for i in range(len(DanswerCeleryPriority)):
|
||||
queue_name = queue
|
||||
if i > 0:
|
||||
queue_name += CELERY_SEPARATOR
|
||||
queue_name += str(i)
|
||||
|
||||
length = r.llen(queue_name)
|
||||
total_length += cast(int, length)
|
||||
|
||||
return total_length
|
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
@@ -22,7 +21,6 @@ from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
@@ -81,21 +79,6 @@ def should_kick_off_deletion_of_cc_pair(
|
||||
return True
|
||||
|
||||
|
||||
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
|
||||
if document_set.is_up_to_date:
|
||||
return False
|
||||
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
|
||||
return False
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now.")
|
||||
return True
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
|
35
backend/danswer/background/celery/celeryconfig.py
Normal file
35
backend/danswer/background/celery/celeryconfig.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
|
||||
CELERY_SEPARATOR = ":"
|
||||
|
||||
CELERY_PASSWORD_PART = ""
|
||||
if REDIS_PASSWORD:
|
||||
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
|
||||
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = (
|
||||
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
|
||||
)
|
||||
|
||||
result_backend = (
|
||||
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
|
||||
)
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
# can stall other tasks.
|
||||
worker_prefetch_multiplier = 4
|
||||
|
||||
broker_transport_options = {
|
||||
"priority_steps": list(range(len(DanswerCeleryPriority))),
|
||||
"sep": CELERY_SEPARATOR,
|
||||
"queue_order_strategy": "priority",
|
||||
}
|
||||
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
@@ -61,6 +61,8 @@ KV_INSTANCE_DOMAIN_KEY = "instance_domain"
|
||||
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
@@ -167,3 +169,23 @@ class FileOrigin(str, Enum):
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
|
||||
|
||||
|
||||
class DanswerCeleryQueues:
|
||||
VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator"
|
||||
VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator"
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
|
||||
|
||||
class DanswerRedisLocks:
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
|
||||
|
||||
class DanswerCeleryPriority(int, Enum):
|
||||
HIGHEST = 0
|
||||
HIGH = auto()
|
||||
MEDIUM = auto()
|
||||
LOW = auto()
|
||||
LOWEST = auto()
|
||||
|
@@ -3,6 +3,7 @@ import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -10,6 +11,7 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.engine.util import TransactionalContext
|
||||
@@ -38,6 +40,68 @@ def check_docs_exist(db_session: Session) -> bool:
|
||||
return result.scalar() or False
|
||||
|
||||
|
||||
def count_documents_by_needs_sync(session: Session) -> int:
|
||||
"""Get the count of all documents where:
|
||||
1. last_modified is newer than last_synced
|
||||
2. last_synced is null (meaning we've never synced)
|
||||
|
||||
This function executes the query and returns the count of
|
||||
documents matching the criteria."""
|
||||
|
||||
count = (
|
||||
session.query(func.count())
|
||||
.select_from(DbDocument)
|
||||
.filter(
|
||||
or_(
|
||||
DbDocument.last_modified > DbDocument.last_synced,
|
||||
DbDocument.last_synced.is_(None),
|
||||
)
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
connector_id: int, credential_id: int
|
||||
) -> Select:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
.where(
|
||||
DbDocument.id.in_(initial_doc_ids_stmt),
|
||||
or_(
|
||||
DbDocument.last_modified
|
||||
> DbDocument.last_synced, # last_modified is newer than last_synced
|
||||
DbDocument.last_synced.is_(None), # never synced
|
||||
),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
return stmt
|
||||
|
||||
|
||||
def construct_document_select_for_connector_credential_pair(
|
||||
connector_id: int, credential_id: int | None = None
|
||||
) -> Select:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct()
|
||||
return stmt
|
||||
|
||||
|
||||
def get_documents_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> Sequence[DbDocument]:
|
||||
@@ -108,7 +172,29 @@ def get_document_cnts_for_cc_pairs(
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def get_acccess_info_for_documents(
|
||||
def get_access_info_for_document(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> tuple[str, list[UUID | None], bool] | None:
|
||||
"""Gets access info for a single document by calling the get_access_info_for_documents function
|
||||
and passing a list with a single document ID.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use.
|
||||
document_id (str): The document ID to fetch access info for.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs,
|
||||
and a boolean indicating if the document is globally public, or None if no results are found.
|
||||
"""
|
||||
results = get_access_info_for_documents(db_session, [document_id])
|
||||
if not results:
|
||||
return None
|
||||
|
||||
return results[0]
|
||||
|
||||
|
||||
def get_access_info_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, list[UUID | None], bool]]:
|
||||
@@ -173,6 +259,7 @@ def upsert_documents(
|
||||
semantic_id=doc.semantic_identifier,
|
||||
link=doc.first_link,
|
||||
doc_updated_at=None, # this is intentional
|
||||
last_modified=datetime.now(timezone.utc),
|
||||
primary_owners=doc.primary_owners,
|
||||
secondary_owners=doc.secondary_owners,
|
||||
)
|
||||
@@ -214,7 +301,7 @@ def upsert_document_by_connector_credential_pair(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_docs_updated_at(
|
||||
def update_docs_updated_at__no_commit(
|
||||
ids_to_new_updated_at: dict[str, datetime],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -226,6 +313,28 @@ def update_docs_updated_at(
|
||||
for document in documents_to_update:
|
||||
document.doc_updated_at = ids_to_new_updated_at[document.id]
|
||||
|
||||
|
||||
def update_docs_last_modified__no_commit(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
documents_to_update = (
|
||||
db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all()
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
for doc in documents_to_update:
|
||||
doc.last_modified = now
|
||||
|
||||
|
||||
def mark_document_as_synced(document_id: str, db_session: Session) -> None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc = db_session.scalar(stmt)
|
||||
if doc is None:
|
||||
raise ValueError(f"No document with ID: {document_id}")
|
||||
|
||||
# update last_synced
|
||||
doc.last_synced = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -379,3 +488,12 @@ def get_documents_by_cc_pair(
|
||||
.filter(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DbDocument | None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
|
||||
return doc
|
||||
|
@@ -248,6 +248,10 @@ def update_document_set(
|
||||
document_set_update_request: DocumentSetUpdateRequest,
|
||||
user: User | None = None,
|
||||
) -> tuple[DocumentSetDBModel, list[DocumentSet__ConnectorCredentialPair]]:
|
||||
"""If successful, this sets document_set_row.is_up_to_date = False.
|
||||
That will be processed via Celery in check_for_vespa_sync_task
|
||||
and trigger a long running background sync to Vespa.
|
||||
"""
|
||||
if not document_set_update_request.cc_pair_ids:
|
||||
# It's cc-pairs in actuality but the UI displays this error
|
||||
raise ValueError("Cannot create a document set with no Connectors")
|
||||
@@ -519,6 +523,70 @@ def fetch_documents_for_document_set_paginated(
|
||||
return documents, documents[-1].id if documents else None
|
||||
|
||||
|
||||
def construct_document_select_by_docset(
|
||||
document_set_id: int,
|
||||
current_only: bool = True,
|
||||
) -> Select:
|
||||
"""This returns a statement that should be executed using
|
||||
.yield_per() to minimize overhead. The primary consumers of this function
|
||||
are background processing task generators."""
|
||||
|
||||
stmt = (
|
||||
select(Document)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DocumentByConnectorCredentialPair.id == Document.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
ConnectorCredentialPair.connector_id
|
||||
== DocumentByConnectorCredentialPair.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== DocumentByConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
DocumentSet__ConnectorCredentialPair,
|
||||
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
|
||||
== ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
DocumentSetDBModel,
|
||||
DocumentSetDBModel.id
|
||||
== DocumentSet__ConnectorCredentialPair.document_set_id,
|
||||
)
|
||||
.where(DocumentSetDBModel.id == document_set_id)
|
||||
.order_by(Document.id)
|
||||
)
|
||||
|
||||
if current_only:
|
||||
stmt = stmt.where(
|
||||
DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712
|
||||
)
|
||||
|
||||
stmt = stmt.distinct()
|
||||
return stmt
|
||||
|
||||
|
||||
def fetch_document_set_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Fetches the document set names for a single document ID.
|
||||
|
||||
:param document_id: The ID of the document to fetch sets for.
|
||||
:param db_session: The SQLAlchemy session to use for the query.
|
||||
:return: A list of document set names, or None if no result is found.
|
||||
"""
|
||||
result = fetch_document_sets_for_documents([document_id], db_session)
|
||||
if not result:
|
||||
return []
|
||||
|
||||
return result[0][1]
|
||||
|
||||
|
||||
def fetch_document_sets_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -24,7 +26,6 @@ from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -123,12 +124,11 @@ def update_document_boost(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
boost: int,
|
||||
document_index: DocumentIndex,
|
||||
user: User | None = None,
|
||||
) -> None:
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
stmt = _add_user_filters(stmt, user, get_editable=True)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
result: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Document is not editable by this user"
|
||||
@@ -136,13 +136,9 @@ def update_document_boost(
|
||||
|
||||
result.boost = boost
|
||||
|
||||
update = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
boost=boost,
|
||||
)
|
||||
|
||||
document_index.update(update_requests=[update])
|
||||
|
||||
# updating last_modified triggers sync
|
||||
# TODO: Should this submit to the queue directly so that the UI can update?
|
||||
result.last_modified = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -163,13 +159,9 @@ def update_document_hidden(
|
||||
|
||||
result.hidden = hidden
|
||||
|
||||
update = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
hidden=hidden,
|
||||
)
|
||||
|
||||
document_index.update(update_requests=[update])
|
||||
|
||||
# updating last_modified triggers sync
|
||||
# TODO: Should this submit to the queue directly so that the UI can update?
|
||||
result.last_modified = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -210,11 +202,9 @@ def create_doc_retrieval_feedback(
|
||||
SearchFeedbackType.REJECT,
|
||||
SearchFeedbackType.HIDE,
|
||||
]:
|
||||
update = UpdateRequest(
|
||||
document_ids=[document_id], boost=db_doc.boost, hidden=db_doc.hidden
|
||||
)
|
||||
# Updates are generally batched for efficiency, this case only 1 doc/value is updated
|
||||
document_index.update(update_requests=[update])
|
||||
# updating last_modified triggers sync
|
||||
# TODO: Should this submit to the queue directly so that the UI can update?
|
||||
db_doc.last_modified = datetime.now(timezone.utc)
|
||||
|
||||
db_session.add(retrieval_feedback)
|
||||
db_session.commit()
|
||||
|
@@ -428,12 +428,27 @@ class Document(Base):
|
||||
semantic_id: Mapped[str] = mapped_column(String)
|
||||
# First Section's link
|
||||
link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# The updated time is also used as a measure of the last successful state of the doc
|
||||
# pulled from the source (to help skip reindexing already updated docs in case of
|
||||
# connector retries)
|
||||
# TODO: rename this column because it conflates the time of the source doc
|
||||
# with the local last modified time of the doc and any associated metadata
|
||||
# it should just be the server timestamp of the source doc
|
||||
doc_updated_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# last time any vespa relevant row metadata or the doc changed.
|
||||
# does not include last_synced
|
||||
last_modified: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, index=True, default=func.now()
|
||||
)
|
||||
|
||||
# last successful sync to vespa
|
||||
last_synced: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, index=True
|
||||
)
|
||||
# The following are not attached to User because the account/email may not be known
|
||||
# within Danswer
|
||||
# Something like the document creator
|
||||
|
@@ -282,7 +282,7 @@ class VespaIndex(DocumentIndex):
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
logger.info(f"Updating {len(update_requests)} documents in Vespa")
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
# Mutating update_requests but it's not used later anyway
|
||||
|
@@ -162,14 +162,16 @@ def _index_vespa_chunk(
|
||||
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
BOOST: chunk.boost,
|
||||
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),
|
||||
PRIMARY_OWNERS: get_experts_stores_representations(document.primary_owners),
|
||||
SECONDARY_OWNERS: get_experts_stores_representations(document.secondary_owners),
|
||||
# the only `set` vespa has is `weightedset`, so we have to give each
|
||||
# element an arbitrary weight
|
||||
# rkuo: acl, docset and boost metadata are also updated through the metadata sync queue
|
||||
# which only calls VespaIndex.update
|
||||
ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()},
|
||||
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
|
||||
BOOST: chunk.boost,
|
||||
}
|
||||
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"
|
||||
|
@@ -18,7 +18,8 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.db.document import get_documents_by_ids
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document import update_docs_updated_at
|
||||
from danswer.db.document import update_docs_last_modified__no_commit
|
||||
from danswer.db.document import update_docs_updated_at__no_commit
|
||||
from danswer.db.document import upsert_documents_complete
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.index_attempt import create_index_attempt_error
|
||||
@@ -264,7 +265,7 @@ def index_doc_batch(
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
|
||||
no_access = DocumentAccess.build([], [], False)
|
||||
no_access = DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
|
||||
ctx = index_doc_batch_prepare(
|
||||
document_batch=document_batch,
|
||||
@@ -295,9 +296,6 @@ def index_doc_batch(
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
# with Vespa can occur.
|
||||
with prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids):
|
||||
# Attach the latest status from Postgres (source of truth for access) to each
|
||||
# chunk. This access status will be attached to each chunk in the document index
|
||||
# TODO: attach document sets to the chunk based on the status of Postgres as well
|
||||
document_id_to_access_info = get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
@@ -307,6 +305,12 @@ def index_doc_batch(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
# we still write data here for immediate and most likely correct sync, but
|
||||
# to resolve this, an update of the last modified field at the end of this loop
|
||||
# always triggers a final metadata sync
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
@@ -338,17 +342,25 @@ def index_doc_batch(
|
||||
doc for doc in ctx.updatable_docs if doc.id in successful_doc_ids
|
||||
]
|
||||
|
||||
# Update the time of latest version of the doc successfully indexed
|
||||
last_modified_ids = []
|
||||
ids_to_new_updated_at = {}
|
||||
for doc in successful_docs:
|
||||
last_modified_ids.append(doc.id)
|
||||
# doc_updated_at is the connector source's idea of when the doc was last modified
|
||||
if doc.doc_updated_at is None:
|
||||
continue
|
||||
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
|
||||
|
||||
update_docs_updated_at(
|
||||
update_docs_updated_at__no_commit(
|
||||
ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session
|
||||
)
|
||||
|
||||
update_docs_last_modified__no_commit(
|
||||
document_ids=last_modified_ids, db_session=db_session
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return len([r for r in insertion_records if r.already_existed is False]), len(
|
||||
access_aware_chunks
|
||||
)
|
||||
|
@@ -61,6 +61,8 @@ class IndexChunk(DocAwareChunk):
|
||||
title_embedding: Embedding | None
|
||||
|
||||
|
||||
# TODO(rkuo): currently, this extra metadata sent during indexing is just for speed,
|
||||
# but full consistency happens on background sync
|
||||
class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
"""An `IndexChunk` that contains all necessary metadata to be indexed. This includes
|
||||
the following:
|
||||
|
49
backend/danswer/redis/redis_pool.py
Normal file
49
backend/danswer/redis/redis_pool.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import redis
|
||||
from redis.client import Redis
|
||||
from redis.connection import ConnectionPool
|
||||
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
|
||||
REDIS_POOL_MAX_CONNECTIONS = 10
|
||||
|
||||
|
||||
class RedisPool:
|
||||
_instance: Optional["RedisPool"] = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_pool: ConnectionPool
|
||||
|
||||
def __new__(cls) -> "RedisPool":
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super(RedisPool, cls).__new__(cls)
|
||||
cls._instance._init_pool()
|
||||
return cls._instance
|
||||
|
||||
def _init_pool(self) -> None:
|
||||
self._pool = redis.ConnectionPool(
|
||||
host=REDIS_HOST,
|
||||
port=REDIS_PORT,
|
||||
db=REDIS_DB_NUMBER,
|
||||
password=REDIS_PASSWORD,
|
||||
max_connections=REDIS_POOL_MAX_CONNECTIONS,
|
||||
)
|
||||
|
||||
def get_client(self) -> Redis:
|
||||
return redis.Redis(connection_pool=self._pool)
|
||||
|
||||
|
||||
# # Usage example
|
||||
# redis_pool = RedisPool()
|
||||
# redis_client = redis_pool.get_client()
|
||||
|
||||
# # Example of setting and getting a value
|
||||
# redis_client.set('key', 'value')
|
||||
# value = redis_client.get('key')
|
||||
# print(value.decode()) # Output: 'value'
|
@@ -77,16 +77,10 @@ def document_boost_update(
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
update_document_boost(
|
||||
db_session=db_session,
|
||||
document_id=boost_update.document_id,
|
||||
boost=boost_update.boost,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
)
|
||||
return StatusResponse(success=True, message="Updated document boost")
|
||||
|
@@ -31,6 +31,28 @@ def set_is_ee_based_on_env_variable() -> None:
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def fetch_versioned_implementation(module: str, attribute: str) -> Any:
|
||||
"""
|
||||
Fetches a versioned implementation of a specified attribute from a given module.
|
||||
This function first checks if the application is running in an Enterprise Edition (EE)
|
||||
context. If so, it attempts to import the attribute from the EE-specific module.
|
||||
If the module or attribute is not found, it falls back to the default module or
|
||||
raises the appropriate exception depending on the context.
|
||||
|
||||
Args:
|
||||
module (str): The name of the module from which to fetch the attribute.
|
||||
attribute (str): The name of the attribute to fetch from the module.
|
||||
|
||||
Returns:
|
||||
Any: The fetched implementation of the attribute.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If the module cannot be found and the error is not related to
|
||||
the Enterprise Edition fallback logic.
|
||||
|
||||
Logs:
|
||||
Logs debug information about the fetching process and warnings if the versioned
|
||||
implementation cannot be found or loaded.
|
||||
"""
|
||||
logger.debug("Fetching versioned implementation for %s.%s", module, attribute)
|
||||
is_ee = global_version.get_is_ee_version()
|
||||
|
||||
@@ -66,6 +88,19 @@ T = TypeVar("T")
|
||||
def fetch_versioned_implementation_with_fallback(
|
||||
module: str, attribute: str, fallback: T
|
||||
) -> T:
|
||||
"""
|
||||
Attempts to fetch a versioned implementation of a specified attribute from a given module.
|
||||
If the attempt fails (e.g., due to an import error or missing attribute), the function logs
|
||||
a warning and returns the provided fallback implementation.
|
||||
|
||||
Args:
|
||||
module (str): The name of the module from which to fetch the attribute.
|
||||
attribute (str): The name of the attribute to fetch from the module.
|
||||
fallback (T): The fallback implementation to return if fetching the attribute fails.
|
||||
|
||||
Returns:
|
||||
T: The fetched implementation if successful, otherwise the provided fallback.
|
||||
"""
|
||||
try:
|
||||
return fetch_versioned_implementation(module, attribute)
|
||||
except Exception:
|
||||
@@ -73,4 +108,14 @@ def fetch_versioned_implementation_with_fallback(
|
||||
|
||||
|
||||
def noop_fallback(*args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
"""
|
||||
A no-op (no operation) fallback function that accepts any arguments but does nothing.
|
||||
This is often used as a default or placeholder callback function.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments, which are ignored.
|
||||
**kwargs (Any): Keyword arguments, which are ignored.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
@@ -11,6 +11,17 @@ from ee.danswer.db.user_group import fetch_user_groups_for_documents
|
||||
from ee.danswer.db.user_group import fetch_user_groups_for_user
|
||||
|
||||
|
||||
def _get_access_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> DocumentAccess:
|
||||
id_to_access = _get_access_for_documents([document_id], db_session)
|
||||
if len(id_to_access) == 0:
|
||||
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
|
||||
return next(iter(id_to_access.values()))
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
|
@@ -1,28 +1,18 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery.signals import beat_init
|
||||
from celery.signals import worker_init
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME
|
||||
from danswer.db.chat import delete_chat_sessions_older_than
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.danswer.background.celery_utils import should_sync_user_groups
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_user_group_sync_task
|
||||
from ee.danswer.db.user_group import fetch_user_groups
|
||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from ee.danswer.user_groups.sync import sync_user_groups
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -30,17 +20,6 @@ logger = setup_logger()
|
||||
global_version.set_ee()
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_user_group_sync_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_user_group_task(user_group_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# actual sync logic
|
||||
try:
|
||||
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to sync user group - {e}")
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(retention_limit_days: int) -> None:
|
||||
@@ -51,8 +30,6 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
@@ -69,24 +46,6 @@ def check_ttl_management_task() -> None:
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_for_user_groups_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_user_groups_sync_task() -> None:
|
||||
"""Runs periodically to check if any user groups are out of sync
|
||||
Creates a task to sync the user group if needed"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
user_groups = fetch_user_groups(db_session=db_session, only_current=False)
|
||||
for user_group in user_groups:
|
||||
if should_sync_user_groups(user_group, db_session):
|
||||
logger.info(f"User Group {user_group.id} is not synced. Syncing now!")
|
||||
sync_user_group_task.apply_async(
|
||||
kwargs=dict(user_group_id=user_group.id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
@@ -101,25 +60,11 @@ def autogenerate_usage_report_task() -> None:
|
||||
)
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"check-for-user-group-sync": {
|
||||
"task": "check_for_user_groups_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
"autogenerate_usage_report": {
|
||||
"autogenerate-usage-report": {
|
||||
"task": "autogenerate_usage_report_task",
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
|
@@ -1,27 +1,13 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_user_group_sync_task
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool:
|
||||
if user_group.is_up_to_date:
|
||||
return False
|
||||
task_name = name_user_group_sync_task(user_group.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||
logger.info("TTL check is already being performed. Skipping.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_chat_ttl_check(
|
||||
retention_limit_days: int | None, db_session: Session
|
||||
) -> bool:
|
||||
|
@@ -5,6 +5,7 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -81,10 +82,25 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
|
||||
|
||||
|
||||
def fetch_user_groups(
|
||||
db_session: Session, only_current: bool = True
|
||||
db_session: Session, only_up_to_date: bool = True
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
|
||||
This function retrieves a sequence of `UserGroup` objects from the database.
|
||||
If `only_up_to_date` is set to `True`, it filters the user groups to return only those
|
||||
that are marked as up-to-date (`is_up_to_date` is `True`).
|
||||
|
||||
Args:
|
||||
db_session (Session): The SQLAlchemy session used to query the database.
|
||||
only_up_to_date (bool, optional): Flag to determine whether to filter the results
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
"""
|
||||
stmt = select(UserGroup)
|
||||
if only_current:
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
@@ -103,6 +119,42 @@ def fetch_user_groups_for_user(
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def construct_document_select_by_usergroup(
|
||||
user_group_id: int,
|
||||
) -> Select:
|
||||
"""This returns a statement that should be executed using
|
||||
.yield_per() to minimize overhead. The primary consumers of this function
|
||||
are background processing task generators."""
|
||||
stmt = (
|
||||
select(Document)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
UserGroup__ConnectorCredentialPair,
|
||||
UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
UserGroup,
|
||||
UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id,
|
||||
)
|
||||
.where(UserGroup.id == user_group_id)
|
||||
.order_by(Document.id)
|
||||
)
|
||||
stmt = stmt.distinct()
|
||||
return stmt
|
||||
|
||||
|
||||
def fetch_documents_for_user_group_paginated(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
@@ -361,6 +413,10 @@ def update_user_group(
|
||||
user_group_id: int,
|
||||
user_group_update: UserGroupUpdate,
|
||||
) -> UserGroup:
|
||||
"""If successful, this can set db_user_group.is_up_to_date = False.
|
||||
That will be processed by check_for_vespa_user_groups_sync_task and trigger
|
||||
a long running background sync to Vespa.
|
||||
"""
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
if db_user_group is None:
|
||||
|
@@ -32,7 +32,7 @@ def list_user_groups(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_current=False)
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
|
@@ -1,87 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_documents_for_user_group_paginated
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
from ee.danswer.db.user_group import mark_user_group_as_synced
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SYNC_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def _sync_user_group_batch(
|
||||
document_ids: list[str], document_index: DocumentIndex, db_session: Session
|
||||
) -> None:
|
||||
logger.debug(f"Syncing document sets for: {document_ids}")
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
with prepare_to_modify_documents(db_session=db_session, document_ids=document_ids):
|
||||
# get current state of document sets for these documents
|
||||
document_id_to_access = get_access_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
document_index.update(
|
||||
update_requests=[
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=document_id_to_access[document_id],
|
||||
)
|
||||
for document_id in document_ids
|
||||
]
|
||||
)
|
||||
|
||||
# Finish the transaction and release the locks
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def sync_user_groups(user_group_id: int, db_session: Session) -> None:
|
||||
"""Sync the status of Postgres for the specified user group"""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=user_group_id)
|
||||
if user_group is None:
|
||||
raise ValueError(f"User group '{user_group_id}' does not exist")
|
||||
|
||||
cursor = None
|
||||
while True:
|
||||
# NOTE: this may miss some documents, but that is okay. Any new documents added
|
||||
# will be added with the correct group membership
|
||||
document_batch, cursor = fetch_documents_for_user_group_paginated(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
last_document_id=cursor,
|
||||
limit=_SYNC_BATCH_SIZE,
|
||||
)
|
||||
|
||||
_sync_user_group_batch(
|
||||
document_ids=[document.id for document in document_batch],
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if cursor is None:
|
||||
break
|
||||
|
||||
if user_group.is_up_for_deletion:
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
@@ -24,14 +24,21 @@ autorestart=true
|
||||
# relatively compute-light (e.g. they tend to just make a bunch of requests to
|
||||
# Vespa / Postgres)
|
||||
[program:celery_worker]
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --concurrency=6 --loglevel=INFO --logfile=/var/log/celery_worker_supervisor.log
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app worker
|
||||
--pool=threads
|
||||
--concurrency=6
|
||||
--loglevel=INFO
|
||||
--logfile=/var/log/celery_worker_supervisor.log
|
||||
-Q celery,vespa_metadata_sync
|
||||
environment=LOG_FILE_NAME=celery_worker
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
# Job scheduler for periodic tasks
|
||||
[program:celery_beat]
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app beat --loglevel=INFO --logfile=/var/log/celery_beat_supervisor.log
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app beat
|
||||
--loglevel=INFO
|
||||
--logfile=/var/log/celery_beat_supervisor.log
|
||||
environment=LOG_FILE_NAME=celery_beat
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
@@ -211,7 +211,7 @@ export function Explorer({
|
||||
)}
|
||||
{!query && (
|
||||
<div className="flex text-emphasis mt-3">
|
||||
Search for a document above to modify it's boost or hide it from
|
||||
Search for a document above to modify its boost or hide it from
|
||||
searches.
|
||||
</div>
|
||||
)}
|
||||
|
Reference in New Issue
Block a user