Feature/background processing (#2275)

* first cut at redis

* some new helper functions for the db

* ignore kombu tables in alembic migrations (used by celery)

* multiline commands for readability, add vespa_metadata_sync queue to worker

* typo fix

* fix returning tuple fields

* add constants

* fix _get_access_for_document

* docstrings!

* fix double function declaration and typing

* fix type hinting

* add a global redis pool

* Add get_document function

* use task_logger in various celery tasks

* add celeryconfig.py to simplify configuration. Will be used in a subsequent commit

* Add celery redis helper. used in a subsequent PR

* kombu warning getting spammy since celery is not self managing its queue in Postgres any more

* add last_modified and last_synced to documents

* fix task naming convention

* use celeryconfig.py

* the big one. adds queues and tasks, updates functions to use the queues with priorities, etc

* change vespa index log line to debug

* mypy fixes

* update alembic migration

* fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call

* mypy

* switch to monotonic time

* fix startup dependencies on redis

* rebase alembic migration

* kombu cleanup - fail silently

* mypy

* add redis_host environment override

* update REDIS_HOST env var in docker-compose.dev.yml

* update the rest of the docker files

* harden indexing-status endpoint against db changes happening in the background.  Needs further improvement but OK for now.

* allow no task syncs to run because we create certain objects with no entries but initially marked as out of date

* add back writing to vespa on indexing

* update contributing guide

* backporting fixes from background_deletion

* renaming cache to cache_volume

* add redis password to various deployments

* try setting up pr testing for helm

* fix indent

* hopefully this release version actually exists

* fix command line option to --chart-dirs

* fetch-depth 0

* edit values.yaml

* try setting ct working directory

* bypass testing only on change for now

* move files and lint them

* update helm testing

* some issues suggest using --config works

* add vespa repo

* add postgresql repo

* increase timeout

* try amd64 runner

* fix redis password reference

* add comment to helm chart testing workflow

* rename helm testing workflow to disable it

* adding clarifying comments

* address code review

* missed a file

* remove commented warning ... just not needed

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
rkuo-danswer
2024-09-10 09:28:19 -07:00
committed by GitHub
parent b7ad810d83
commit f1c5e80f17
26 changed files with 1428 additions and 350 deletions

View File

@@ -0,0 +1,66 @@
"""Add last synced and last modified to document table
Revision ID: 52a219fb5233
Revises: f17bf3b0d9f1
Create Date: 2024-08-28 17:40:46.077470
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import func
# revision identifiers, used by Alembic.
revision = "52a219fb5233"
down_revision = "f7e58d357687"
branch_labels = None
depends_on = None
def upgrade() -> None:
# last modified represents the last time anything needing syncing to vespa changed
# including row metadata and the document itself. This obviously does not include
# the last_synced column.
op.add_column(
"document",
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
server_default=func.now(),
),
)
# last synced represents the last time this document was synced to Vespa
op.add_column(
"document",
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True),
)
# Set last_synced to the same value as last_modified for existing rows
op.execute(
"""
UPDATE document
SET last_synced = last_modified
"""
)
op.create_index(
op.f("ix_document_last_modified"),
"document",
["last_modified"],
unique=False,
)
op.create_index(
op.f("ix_document_last_synced"),
"document",
["last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_document_last_synced"), table_name="document")
op.drop_index(op.f("ix_document_last_modified"), table_name="document")
op.drop_column("document", "last_synced")
op.drop_column("document", "last_modified")

View File

@@ -3,21 +3,49 @@ from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.document import get_access_info_for_document
from danswer.db.document import get_access_info_for_documents
from danswer.db.models import User
from danswer.utils.variable_functionality import fetch_versioned_implementation
def _get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
info = get_access_info_for_document(
db_session=db_session,
document_id=document_id,
)
if not info:
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
def get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
versioned_get_access_for_document_fn = fetch_versioned_implementation(
"danswer.access.access", "_get_access_for_document"
)
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,
) -> dict[str, DocumentAccess]:
document_access_info = get_acccess_info_for_documents(
document_access_info = get_access_info_for_documents(
db_session=db_session,
document_ids=document_ids,
)
return {
document_id: DocumentAccess.build(user_ids, [], is_public)
document_id: DocumentAccess.build(
user_ids=user_ids, user_groups=[], is_public=is_public
)
for document_id, user_ids, is_public in document_access_info
}

View File

@@ -3,66 +3,83 @@ from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery # type: ignore
import redis
from celery import Celery
from celery import signals
from celery import Task
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import SoftTimeLimitExceeded
from celery.exceptions import TaskRevokedError
from celery.signals import beat_init
from celery.signals import worker_init
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from redis import Redis
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.celery.celery_utils import should_sync_doc_set
from danswer.background.connector_deletion import delete_connector_credential_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import (
get_connector_credential_pair,
)
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import count_documents_by_needs_sync
from danswer.db.document import get_document
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_set_for_document
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set_paginated
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.models import DocumentSet
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.redis.redis_pool import RedisPool
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
CELERY_PASSWORD_PART = ""
if REDIS_PASSWORD:
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
# example celery_broker_url: "redis://:password@localhost:6379/15"
celery_broker_url = (
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
)
celery_backend_url = (
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
)
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
redis_pool = RedisPool()
_SYNC_BATCH_SIZE = 100
celery_app = Celery(__name__)
celery_app.config_from_object(
"danswer.background.celery.celeryconfig"
) # Load configuration from 'celeryconfig.py'
#####
@@ -111,7 +128,10 @@ def cleanup_connector_credential_pair_task(
cc_pair=cc_pair,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
task_logger.exception(
f"Failed to run connector_deletion. "
f"connector_id={connector_id} credential_id={credential_id}"
)
raise e
@@ -130,7 +150,9 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
)
if not cc_pair:
logger.warning(f"ccpair not found for {connector_id} {credential_id}")
task_logger.warning(
f"ccpair not found for {connector_id} {credential_id}"
)
return
runnable_connector = instantiate_connector(
@@ -162,12 +184,12 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
)
if len(doc_ids_to_remove) == 0:
logger.info(
task_logger.info(
f"No docs to prune from {cc_pair.connector.source} connector"
)
return
logger.info(
task_logger.info(
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
)
delete_connector_credential_pair_batch(
@@ -177,113 +199,202 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
document_index=document_index,
)
except Exception as e:
logger.exception(
f"Failed to run pruning for connector id {connector_id} due to {e}"
task_logger.exception(
f"Failed to run pruning for connector id {connector_id}."
)
raise e
@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
"""For document sets marked as not up to date, sync the state from postgres
into the datastore. Also handles deletions."""
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
return None
def _sync_document_batch(document_ids: list[str], db_session: Session) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset
# Acquires a lock on the documents so that no other process can modify them
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
)
}
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
return None
# update Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
task_logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
update_requests = [
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
document_index.update(update_requests=update_requests)
with Session(get_sqlalchemy_engine()) as db_session:
try:
cursor = None
while True:
document_batch, cursor = fetch_documents_for_document_set_paginated(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
last_document_id=cursor,
limit=_SYNC_BATCH_SIZE,
)
_sync_document_batch(
document_ids=[document.id for document in document_batch],
db_session=db_session,
)
if cursor is None:
break
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(cc_pair.id)
tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat)
# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.
document_set = cast(
DocumentSet,
get_document_set_by_id(
db_session=db_session, document_set_id=document_set_id
),
) # casting since we "know" a document set with this ID exists
if not document_set.connector_credential_pairs:
delete_document_set(
document_set_row=document_set, db_session=db_session
)
logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(
document_set_id=document_set_id, db_session=db_session
)
logger.info(f"Document set sync for '{document_set_id}' complete!")
if tasks_generated is None:
continue
except Exception:
logger.exception("Failed to sync document set %s", document_set_id)
raise
if tasks_generated == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
total_tasks_generated += tasks_generated
task_logger.info(
f"All per connector generate_tasks finished. total_tasks_generated={total_tasks_generated}"
)
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
return total_tasks_generated
def try_generate_document_set_sync_tasks(
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rds = RedisDocumentSet(document_set.id)
# don't generate document set sync tasks if tasks are still pending
if r.exists(rds.fence_key):
return None
# don't generate sync tasks if we're up to date
if document_set.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rds.taskset_key)
task_logger.info(
f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}"
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rds.fence_key, tasks_generated)
return tasks_generated
def try_generate_user_group_sync_tasks(
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
lock_beat.reacquire()
rug = RedisUserGroup(usergroup.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rug.fence_key):
return None
if usergroup.is_up_to_date:
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rug.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(f"generate_tasks starting. usergroup_id={usergroup.id}")
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rug.fence_key, tasks_generated)
return tasks_generated
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_for_document_sets_sync_task",
name="check_for_vespa_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any sync tasks should be run and adds them
to the queue"""
def check_for_vespa_sync_task() -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = redis_pool.get_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
try_generate_stale_document_sync_tasks(db_session, r, lock_beat)
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if should_sync_doc_set(document_set, db_session):
logger.info(f"Syncing the {document_set.name} document set")
sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
try_generate_document_set_sync_tasks(
document_set, db_session, r, lock_beat
)
# check if any user groups are not synced
try:
fetch_user_groups = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_groups"
)
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
for usergroup in user_groups:
try_generate_user_group_sync_tasks(
usergroup, db_session, r, lock_beat
)
except ModuleNotFoundError:
# Always exceptions on the MIT version, which is expected
pass
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
@celery_app.task(
name="check_for_cc_pair_deletion_task",
@@ -292,11 +403,13 @@ def check_for_document_sets_sync_task() -> None:
def check_for_cc_pair_deletion_task() -> None:
"""Runs periodically to check if any deletion tasks should be run"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
# check if any cc pairs are up for deletion
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
logger.notice(f"Deleting the {cc_pair.name} connector credential pair")
task_logger.info(
f"Deleting the {cc_pair.name} connector credential pair"
)
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
@@ -343,7 +456,9 @@ def kombu_message_cleanup_task(self: Any) -> int:
db_session.commit()
if ctx["deleted"] > 0:
logger.info(f"Deleted {ctx['deleted']} orphaned messages from kombu_message.")
task_logger.info(
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
)
return ctx["deleted"]
@@ -417,12 +532,6 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
)
if result.rowcount > 0: # type: ignore
ctx["deleted"] += 1
else:
task_name = payload["headers"]["task"]
logger.warning(
f"Message found for task older than {ctx['cleanup_age']} days. "
f"id={task_id} name={task_name}"
)
ctx["last_processed_id"] = msg[0]
@@ -446,7 +555,7 @@ def check_for_prune_task() -> None:
credential=cc_pair.credential,
db_session=db_session,
):
logger.info(f"Pruning the {cc_pair.connector.name} connector")
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
prune_documents_task.apply_async(
kwargs=dict(
@@ -456,19 +565,331 @@ def check_for_prune_task() -> None:
)
@celery_app.task(
name="vespa_metadata_sync_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
doc = get_document(document_id, db_session)
if not doc:
return False
# document set sync
doc_sets = fetch_document_set_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
update_request = UpdateRequest(
document_ids=[document_id],
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa
document_index.update(update_requests=[update_request])
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True
@signals.task_postrun.connect
def celery_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
"""We handle this signal in order to remove completed tasks
from their respective tasksets. This allows us to track the progress of document set
and user group syncs.
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
"""
if not task:
return
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
# logger.debug(f"Result: {retval}")
if state not in READY_STATES:
return
if not task_id:
return
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r = redis_pool.get_client()
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
return
if task_id.startswith(RedisDocumentSet.PREFIX):
r = redis_pool.get_client()
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(document_set_id)
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
r = redis_pool.get_client()
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(usergroup_id)
r.srem(rug.taskset_key, task_id)
return
def monitor_connector_taskset(r: Redis) -> None:
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = r.scard(RedisConnectorCredentialPair.get_taskset_key())
task_logger.info(f"Stale documents: remaining={count} initial={initial_count}")
if count == 0:
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
task_logger.warning("could not parse document set id from {key}")
return
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"document_set_id={document_set_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
document_set = cast(
DocumentSet,
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
) # casting since we "know" a document set with this ID exists
if document_set:
if not document_set.connector_credential_pairs:
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
)
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
key = key_bytes.decode("utf-8")
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
if not usergroup_id:
task_logger.warning("Could not parse usergroup id from {key}")
return
rug = RedisUserGroup(usergroup_id)
fence_value = r.get(rug.fence_key)
if fence_value is None:
return
try:
initial_count = int(cast(int, fence_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rug.taskset_key))
task_logger.info(
f"usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
try:
fetch_user_group = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_group"
)
except ModuleNotFoundError:
task_logger.exception(
"fetch_versioned_implementation failed to look up fetch_user_group."
)
return
user_group: UserGroup | None = fetch_user_group(
db_session=db_session, user_group_id=usergroup_id
)
if user_group:
if user_group.is_up_for_deletion:
delete_user_group = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group", "delete_user_group", noop_fallback
)
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(f" Deleted usergroup. id='{usergroup_id}'")
else:
mark_user_group_as_synced = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group", "mark_user_group_as_synced", noop_fallback
)
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
r.delete(rug.taskset_key)
r.delete(rug.fence_key)
@celery_app.task(name="monitor_vespa_sync", soft_time_limit=300)
def monitor_vespa_sync() -> None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
"""
r = redis_pool.get_client()
lock_beat = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset(key_bytes, r, db_session)
#
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
finally:
if lock_beat.owned():
lock_beat.release()
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME)
# TODO(rkuo): this is singleton work that should be done on startup exactly once
# if we run multiple workers, we'll need to centralize where this cleanup happens
r = redis_pool.get_client()
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-document-set-sync": {
"task": "check_for_document_sets_sync_task",
"check-for-vespa-sync": {
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
"check-for-cc-pair-deletion": {
"task": "check_for_cc_pair_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(minutes=1),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
celery_app.conf.beat_schedule.update(
@@ -476,6 +897,7 @@ celery_app.conf.beat_schedule.update(
"check-for-prune": {
"task": "check_for_prune_task",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
@@ -484,6 +906,16 @@ celery_app.conf.beat_schedule.update(
"kombu-message-cleanup": {
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
}
)
celery_app.conf.beat_schedule.update(
{
"monitor-vespa-sync": {
"task": "monitor_vespa_sync",
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)

View File

@@ -0,0 +1,299 @@
# These are helper objects for tracking the keys we need to write in redis
import time
from abc import ABC
from abc import abstractmethod
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.document_set import construct_document_select_by_docset
from danswer.utils.variable_functionality import fetch_versioned_implementation
class RedisObjectHelper(ABC):
PREFIX = "base"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int):
self._id: int = id
@property
def task_id_prefix(self) -> str:
return f"{self.PREFIX}_{self._id}"
@property
def fence_key(self) -> str:
# example: documentset_fence_1
return f"{self.FENCE_PREFIX}_{self._id}"
@property
def taskset_key(self) -> str:
# example: documentset_taskset_1
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> int | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
Args:
key (str): The fence key string.
Returns:
Optional[int]: The extracted ID if the key is in the correct format, otherwise None.
"""
parts = key.split("_")
if len(parts) != 3:
return None
try:
object_id = int(parts[2])
except ValueError:
return None
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> int | None:
"""
Extracts the object ID from a task ID string.
This method assumes the task ID is formatted as `prefix_objectid_suffix`, where:
- `prefix` is an arbitrary string (e.g., the name of the task or entity),
- `objectid` is the ID you want to extract,
- `suffix` is another arbitrary string (e.g., a UUID).
Example:
If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`,
this method will return the string `"1"`.
Args:
task_id (str): The task ID string from which to extract the object ID.
Returns:
str | None: The extracted object ID if the task ID is in the correct format, otherwise None.
"""
# example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc
parts = task_id.split("_")
if len(parts) != 3:
return None
try:
object_id = int(parts[1])
except ValueError:
return None
return object_id
@abstractmethod
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
pass
class RedisDocumentSet(RedisObjectHelper):
PREFIX = "documentset"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(self._id)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisUserGroup(RedisObjectHelper):
PREFIX = "usergroup"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
"danswer.db.user_group",
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(self._id)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
async_results.append(result)
return len(async_results)
class RedisConnectorCredentialPair(RedisObjectHelper):
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@classmethod
def get_taskset_key(cls) -> str:
return RedisConnectorCredentialPair.TASKSET_PREFIX
@property
def taskset_key(self) -> str:
"""Notice that this is intentionally reusing the same taskset for all
connector syncs"""
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"vespa_metadata_sync_task",
kwargs=dict(document_id=doc.id),
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
used to implement task prioritization.
This operation is not atomic."""
total_length = 0
for i in range(len(DanswerCeleryPriority)):
queue_name = queue
if i > 0:
queue_name += CELERY_SEPARATOR
queue_name += str(i)
length = r.llen(queue_name)
total_length += cast(int, length)
return total_length

View File

@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
@@ -22,7 +21,6 @@ from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import DocumentSet
from danswer.db.models import TaskQueueState
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
@@ -81,21 +79,6 @@ def should_kick_off_deletion_of_cc_pair(
return True
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
if document_set.is_up_to_date:
return False
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
return False
logger.info(f"Document set {document_set.id} syncing now.")
return True
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:

View File

@@ -0,0 +1,35 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
from danswer.configs.constants import DanswerCeleryPriority
CELERY_SEPARATOR = ":"
CELERY_PASSWORD_PART = ""
if REDIS_PASSWORD:
CELERY_PASSWORD_PART = f":{REDIS_PASSWORD}@"
# example celery_broker_url: "redis://:password@localhost:6379/15"
broker_url = (
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
)
result_backend = (
f"redis://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}"
)
# NOTE: prefetch 4 is significantly faster than prefetch 1
# however, prefetching is bad when tasks are lengthy as those tasks
# can stall other tasks.
worker_prefetch_multiplier = 4
broker_transport_options = {
"priority_steps": list(range(len(DanswerCeleryPriority))),
"sep": CELERY_SEPARATOR,
"queue_order_strategy": "priority",
}
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True

View File

@@ -61,6 +61,8 @@ KV_INSTANCE_DOMAIN_KEY = "instance_domain"
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
@@ -167,3 +169,23 @@ class FileOrigin(str, Enum):
class PostgresAdvisoryLocks(Enum):
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
class DanswerCeleryQueues:
VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator"
VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator"
VESPA_METADATA_SYNC = "vespa_metadata_sync"
CONNECTOR_DELETION = "connector_deletion"
class DanswerRedisLocks:
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
class DanswerCeleryPriority(int, Enum):
HIGHEST = 0
HIGH = auto()
MEDIUM = auto()
LOW = auto()
LOWEST = auto()

View File

@@ -3,6 +3,7 @@ import time
from collections.abc import Generator
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from uuid import UUID
from sqlalchemy import and_
@@ -10,6 +11,7 @@ from sqlalchemy import delete
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.util import TransactionalContext
@@ -38,6 +40,68 @@ def check_docs_exist(db_session: Session) -> bool:
return result.scalar() or False
def count_documents_by_needs_sync(session: Session) -> int:
"""Get the count of all documents where:
1. last_modified is newer than last_synced
2. last_synced is null (meaning we've never synced)
This function executes the query and returns the count of
documents matching the criteria."""
count = (
session.query(func.count())
.select_from(DbDocument)
.filter(
or_(
DbDocument.last_modified > DbDocument.last_synced,
DbDocument.last_synced.is_(None),
)
)
.scalar()
)
return count
def construct_document_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
select(DbDocument)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
)
.distinct()
)
return stmt
def construct_document_select_for_connector_credential_pair(
connector_id: int, credential_id: int | None = None
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct()
return stmt
def get_documents_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> Sequence[DbDocument]:
@@ -108,7 +172,29 @@ def get_document_cnts_for_cc_pairs(
return db_session.execute(stmt).all() # type: ignore
def get_acccess_info_for_documents(
def get_access_info_for_document(
db_session: Session,
document_id: str,
) -> tuple[str, list[UUID | None], bool] | None:
"""Gets access info for a single document by calling the get_access_info_for_documents function
and passing a list with a single document ID.
Args:
db_session (Session): The database session to use.
document_id (str): The document ID to fetch access info for.
Returns:
Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs,
and a boolean indicating if the document is globally public, or None if no results are found.
"""
results = get_access_info_for_documents(db_session, [document_id])
if not results:
return None
return results[0]
def get_access_info_for_documents(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, list[UUID | None], bool]]:
@@ -173,6 +259,7 @@ def upsert_documents(
semantic_id=doc.semantic_identifier,
link=doc.first_link,
doc_updated_at=None, # this is intentional
last_modified=datetime.now(timezone.utc),
primary_owners=doc.primary_owners,
secondary_owners=doc.secondary_owners,
)
@@ -214,7 +301,7 @@ def upsert_document_by_connector_credential_pair(
db_session.commit()
def update_docs_updated_at(
def update_docs_updated_at__no_commit(
ids_to_new_updated_at: dict[str, datetime],
db_session: Session,
) -> None:
@@ -226,6 +313,28 @@ def update_docs_updated_at(
for document in documents_to_update:
document.doc_updated_at = ids_to_new_updated_at[document.id]
def update_docs_last_modified__no_commit(
document_ids: list[str],
db_session: Session,
) -> None:
documents_to_update = (
db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all()
)
now = datetime.now(timezone.utc)
for doc in documents_to_update:
doc.last_modified = now
def mark_document_as_synced(document_id: str, db_session: Session) -> None:
stmt = select(DbDocument).where(DbDocument.id == document_id)
doc = db_session.scalar(stmt)
if doc is None:
raise ValueError(f"No document with ID: {document_id}")
# update last_synced
doc.last_synced = datetime.now(timezone.utc)
db_session.commit()
@@ -379,3 +488,12 @@ def get_documents_by_cc_pair(
.filter(ConnectorCredentialPair.id == cc_pair_id)
.all()
)
def get_document(
document_id: str,
db_session: Session,
) -> DbDocument | None:
stmt = select(DbDocument).where(DbDocument.id == document_id)
doc: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
return doc

View File

@@ -248,6 +248,10 @@ def update_document_set(
document_set_update_request: DocumentSetUpdateRequest,
user: User | None = None,
) -> tuple[DocumentSetDBModel, list[DocumentSet__ConnectorCredentialPair]]:
"""If successful, this sets document_set_row.is_up_to_date = False.
That will be processed via Celery in check_for_vespa_sync_task
and trigger a long running background sync to Vespa.
"""
if not document_set_update_request.cc_pair_ids:
# It's cc-pairs in actuality but the UI displays this error
raise ValueError("Cannot create a document set with no Connectors")
@@ -519,6 +523,70 @@ def fetch_documents_for_document_set_paginated(
return documents, documents[-1].id if documents else None
def construct_document_select_by_docset(
document_set_id: int,
current_only: bool = True,
) -> Select:
"""This returns a statement that should be executed using
.yield_per() to minimize overhead. The primary consumers of this function
are background processing task generators."""
stmt = (
select(Document)
.join(
DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == Document.id,
)
.join(
ConnectorCredentialPair,
and_(
ConnectorCredentialPair.connector_id
== DocumentByConnectorCredentialPair.connector_id,
ConnectorCredentialPair.credential_id
== DocumentByConnectorCredentialPair.credential_id,
),
)
.join(
DocumentSet__ConnectorCredentialPair,
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
== ConnectorCredentialPair.id,
)
.join(
DocumentSetDBModel,
DocumentSetDBModel.id
== DocumentSet__ConnectorCredentialPair.document_set_id,
)
.where(DocumentSetDBModel.id == document_set_id)
.order_by(Document.id)
)
if current_only:
stmt = stmt.where(
DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712
)
stmt = stmt.distinct()
return stmt
def fetch_document_set_for_document(
document_id: str,
db_session: Session,
) -> list[str]:
"""
Fetches the document set names for a single document ID.
:param document_id: The ID of the document to fetch sets for.
:param db_session: The SQLAlchemy session to use for the query.
:return: A list of document set names, or None if no result is found.
"""
result = fetch_document_sets_for_documents([document_id], db_session)
if not result:
return []
return result[0][1]
def fetch_document_sets_for_documents(
document_ids: list[str],
db_session: Session,

View File

@@ -1,3 +1,5 @@
from datetime import datetime
from datetime import timezone
from uuid import UUID
from fastapi import HTTPException
@@ -24,7 +26,6 @@ from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -123,12 +124,11 @@ def update_document_boost(
db_session: Session,
document_id: str,
boost: int,
document_index: DocumentIndex,
user: User | None = None,
) -> None:
stmt = select(DbDocument).where(DbDocument.id == document_id)
stmt = _add_user_filters(stmt, user, get_editable=True)
result = db_session.execute(stmt).scalar_one_or_none()
result: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
if result is None:
raise HTTPException(
status_code=400, detail="Document is not editable by this user"
@@ -136,13 +136,9 @@ def update_document_boost(
result.boost = boost
update = UpdateRequest(
document_ids=[document_id],
boost=boost,
)
document_index.update(update_requests=[update])
# updating last_modified triggers sync
# TODO: Should this submit to the queue directly so that the UI can update?
result.last_modified = datetime.now(timezone.utc)
db_session.commit()
@@ -163,13 +159,9 @@ def update_document_hidden(
result.hidden = hidden
update = UpdateRequest(
document_ids=[document_id],
hidden=hidden,
)
document_index.update(update_requests=[update])
# updating last_modified triggers sync
# TODO: Should this submit to the queue directly so that the UI can update?
result.last_modified = datetime.now(timezone.utc)
db_session.commit()
@@ -210,11 +202,9 @@ def create_doc_retrieval_feedback(
SearchFeedbackType.REJECT,
SearchFeedbackType.HIDE,
]:
update = UpdateRequest(
document_ids=[document_id], boost=db_doc.boost, hidden=db_doc.hidden
)
# Updates are generally batched for efficiency, this case only 1 doc/value is updated
document_index.update(update_requests=[update])
# updating last_modified triggers sync
# TODO: Should this submit to the queue directly so that the UI can update?
db_doc.last_modified = datetime.now(timezone.utc)
db_session.add(retrieval_feedback)
db_session.commit()

View File

@@ -428,12 +428,27 @@ class Document(Base):
semantic_id: Mapped[str] = mapped_column(String)
# First Section's link
link: Mapped[str | None] = mapped_column(String, nullable=True)
# The updated time is also used as a measure of the last successful state of the doc
# pulled from the source (to help skip reindexing already updated docs in case of
# connector retries)
# TODO: rename this column because it conflates the time of the source doc
# with the local last modified time of the doc and any associated metadata
# it should just be the server timestamp of the source doc
doc_updated_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# last time any vespa relevant row metadata or the doc changed.
# does not include last_synced
last_modified: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=False, index=True, default=func.now()
)
# last successful sync to vespa
last_synced: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, index=True
)
# The following are not attached to User because the account/email may not be known
# within Danswer
# Something like the document creator

View File

@@ -282,7 +282,7 @@ class VespaIndex(DocumentIndex):
raise requests.HTTPError(failure_msg) from e
def update(self, update_requests: list[UpdateRequest]) -> None:
logger.info(f"Updating {len(update_requests)} documents in Vespa")
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
# Handle Vespa character limitations
# Mutating update_requests but it's not used later anyway

View File

@@ -162,14 +162,16 @@ def _index_vespa_chunk(
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
BOOST: chunk.boost,
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),
PRIMARY_OWNERS: get_experts_stores_representations(document.primary_owners),
SECONDARY_OWNERS: get_experts_stores_representations(document.secondary_owners),
# the only `set` vespa has is `weightedset`, so we have to give each
# element an arbitrary weight
# rkuo: acl, docset and boost metadata are also updated through the metadata sync queue
# which only calls VespaIndex.update
ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()},
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
BOOST: chunk.boost,
}
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"

View File

@@ -18,7 +18,8 @@ from danswer.connectors.models import Document
from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.document import get_documents_by_ids
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document import update_docs_updated_at
from danswer.db.document import update_docs_last_modified__no_commit
from danswer.db.document import update_docs_updated_at__no_commit
from danswer.db.document import upsert_documents_complete
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.index_attempt import create_index_attempt_error
@@ -264,7 +265,7 @@ def index_doc_batch(
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
no_access = DocumentAccess.build([], [], False)
no_access = DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
ctx = index_doc_batch_prepare(
document_batch=document_batch,
@@ -295,9 +296,6 @@ def index_doc_batch(
# NOTE: don't need to acquire till here, since this is when the actual race condition
# with Vespa can occur.
with prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids):
# Attach the latest status from Postgres (source of truth for access) to each
# chunk. This access status will be attached to each chunk in the document index
# TODO: attach document sets to the chunk based on the status of Postgres as well
document_id_to_access_info = get_access_for_documents(
document_ids=updatable_ids, db_session=db_session
)
@@ -307,6 +305,12 @@ def index_doc_batch(
document_ids=updatable_ids, db_session=db_session
)
}
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
@@ -338,17 +342,25 @@ def index_doc_batch(
doc for doc in ctx.updatable_docs if doc.id in successful_doc_ids
]
# Update the time of latest version of the doc successfully indexed
last_modified_ids = []
ids_to_new_updated_at = {}
for doc in successful_docs:
last_modified_ids.append(doc.id)
# doc_updated_at is the connector source's idea of when the doc was last modified
if doc.doc_updated_at is None:
continue
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
update_docs_updated_at(
update_docs_updated_at__no_commit(
ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session
)
update_docs_last_modified__no_commit(
document_ids=last_modified_ids, db_session=db_session
)
db_session.commit()
return len([r for r in insertion_records if r.already_existed is False]), len(
access_aware_chunks
)

View File

@@ -61,6 +61,8 @@ class IndexChunk(DocAwareChunk):
title_embedding: Embedding | None
# TODO(rkuo): currently, this extra metadata sent during indexing is just for speed,
# but full consistency happens on background sync
class DocMetadataAwareIndexChunk(IndexChunk):
"""An `IndexChunk` that contains all necessary metadata to be indexed. This includes
the following:

View File

@@ -0,0 +1,49 @@
import threading
from typing import Optional
import redis
from redis.client import Redis
from redis.connection import ConnectionPool
from danswer.configs.app_configs import REDIS_DB_NUMBER
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
REDIS_POOL_MAX_CONNECTIONS = 10
class RedisPool:
_instance: Optional["RedisPool"] = None
_lock: threading.Lock = threading.Lock()
_pool: ConnectionPool
def __new__(cls) -> "RedisPool":
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super(RedisPool, cls).__new__(cls)
cls._instance._init_pool()
return cls._instance
def _init_pool(self) -> None:
self._pool = redis.ConnectionPool(
host=REDIS_HOST,
port=REDIS_PORT,
db=REDIS_DB_NUMBER,
password=REDIS_PASSWORD,
max_connections=REDIS_POOL_MAX_CONNECTIONS,
)
def get_client(self) -> Redis:
return redis.Redis(connection_pool=self._pool)
# # Usage example
# redis_pool = RedisPool()
# redis_client = redis_pool.get_client()
# # Example of setting and getting a value
# redis_client.set('key', 'value')
# value = redis_client.get('key')
# print(value.decode()) # Output: 'value'

View File

@@ -77,16 +77,10 @@ def document_boost_update(
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
update_document_boost(
db_session=db_session,
document_id=boost_update.document_id,
boost=boost_update.boost,
document_index=document_index,
user=user,
)
return StatusResponse(success=True, message="Updated document boost")

View File

@@ -31,6 +31,28 @@ def set_is_ee_based_on_env_variable() -> None:
@functools.lru_cache(maxsize=128)
def fetch_versioned_implementation(module: str, attribute: str) -> Any:
"""
Fetches a versioned implementation of a specified attribute from a given module.
This function first checks if the application is running in an Enterprise Edition (EE)
context. If so, it attempts to import the attribute from the EE-specific module.
If the module or attribute is not found, it falls back to the default module or
raises the appropriate exception depending on the context.
Args:
module (str): The name of the module from which to fetch the attribute.
attribute (str): The name of the attribute to fetch from the module.
Returns:
Any: The fetched implementation of the attribute.
Raises:
ModuleNotFoundError: If the module cannot be found and the error is not related to
the Enterprise Edition fallback logic.
Logs:
Logs debug information about the fetching process and warnings if the versioned
implementation cannot be found or loaded.
"""
logger.debug("Fetching versioned implementation for %s.%s", module, attribute)
is_ee = global_version.get_is_ee_version()
@@ -66,6 +88,19 @@ T = TypeVar("T")
def fetch_versioned_implementation_with_fallback(
module: str, attribute: str, fallback: T
) -> T:
"""
Attempts to fetch a versioned implementation of a specified attribute from a given module.
If the attempt fails (e.g., due to an import error or missing attribute), the function logs
a warning and returns the provided fallback implementation.
Args:
module (str): The name of the module from which to fetch the attribute.
attribute (str): The name of the attribute to fetch from the module.
fallback (T): The fallback implementation to return if fetching the attribute fails.
Returns:
T: The fetched implementation if successful, otherwise the provided fallback.
"""
try:
return fetch_versioned_implementation(module, attribute)
except Exception:
@@ -73,4 +108,14 @@ def fetch_versioned_implementation_with_fallback(
def noop_fallback(*args: Any, **kwargs: Any) -> None:
pass
"""
A no-op (no operation) fallback function that accepts any arguments but does nothing.
This is often used as a default or placeholder callback function.
Args:
*args (Any): Positional arguments, which are ignored.
**kwargs (Any): Keyword arguments, which are ignored.
Returns:
None
"""

View File

@@ -11,6 +11,17 @@ from ee.danswer.db.user_group import fetch_user_groups_for_documents
from ee.danswer.db.user_group import fetch_user_groups_for_user
def _get_access_for_document(
document_id: str,
db_session: Session,
) -> DocumentAccess:
id_to_access = _get_access_for_documents([document_id], db_session)
if len(id_to_access) == 0:
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
return next(iter(id_to_access.values()))
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,

View File

@@ -1,28 +1,18 @@
from datetime import timedelta
from typing import Any
from celery.signals import beat_init
from celery.signals import worker_init
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME
from danswer.db.chat import delete_chat_sessions_older_than
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
from ee.danswer.background.celery_utils import should_sync_user_groups
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_user_group_sync_task
from ee.danswer.db.user_group import fetch_user_groups
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
from ee.danswer.user_groups.sync import sync_user_groups
logger = setup_logger()
@@ -30,17 +20,6 @@ logger = setup_logger()
global_version.set_ee()
@build_celery_task_wrapper(name_user_group_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_user_group_task(user_group_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
# actual sync logic
try:
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
except Exception as e:
logger.exception(f"Failed to sync user group - {e}")
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int) -> None:
@@ -51,8 +30,6 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_ttl_management_task",
soft_time_limit=JOB_TIMEOUT,
@@ -69,24 +46,6 @@ def check_ttl_management_task() -> None:
)
@celery_app.task(
name="check_for_user_groups_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_user_groups_sync_task() -> None:
"""Runs periodically to check if any user groups are out of sync
Creates a task to sync the user group if needed"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
user_groups = fetch_user_groups(db_session=db_session, only_current=False)
for user_group in user_groups:
if should_sync_user_groups(user_group, db_session):
logger.info(f"User Group {user_group.id} is not synced. Syncing now!")
sync_user_group_task.apply_async(
kwargs=dict(user_group_id=user_group.id),
)
@celery_app.task(
name="autogenerate_usage_report_task",
soft_time_limit=JOB_TIMEOUT,
@@ -101,25 +60,11 @@ def autogenerate_usage_report_task() -> None:
)
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME)
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-user-group-sync": {
"task": "check_for_user_groups_sync_task",
"schedule": timedelta(seconds=5),
},
"autogenerate_usage_report": {
"autogenerate-usage-report": {
"task": "autogenerate_usage_report_task",
"schedule": timedelta(days=30), # TODO: change this to config flag
},

View File

@@ -1,27 +1,13 @@
from sqlalchemy.orm import Session
from danswer.db.models import UserGroup
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_user_group_sync_task
logger = setup_logger()
def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool:
if user_group.is_up_to_date:
return False
task_name = name_user_group_sync_task(user_group.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info("TTL check is already being performed. Skipping.")
return False
return True
def should_perform_chat_ttl_check(
retention_limit_days: int | None, db_session: Session
) -> bool:

View File

@@ -5,6 +5,7 @@ from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
@@ -81,10 +82,25 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
def fetch_user_groups(
db_session: Session, only_current: bool = True
db_session: Session, only_up_to_date: bool = True
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
This function retrieves a sequence of `UserGroup` objects from the database.
If `only_up_to_date` is set to `True`, it filters the user groups to return only those
that are marked as up-to-date (`is_up_to_date` is `True`).
Args:
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
"""
stmt = select(UserGroup)
if only_current:
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
return db_session.scalars(stmt).all()
@@ -103,6 +119,42 @@ def fetch_user_groups_for_user(
return db_session.scalars(stmt).all()
def construct_document_select_by_usergroup(
user_group_id: int,
) -> Select:
"""This returns a statement that should be executed using
.yield_per() to minimize overhead. The primary consumers of this function
are background processing task generators."""
stmt = (
select(Document)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.join(
UserGroup__ConnectorCredentialPair,
UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id,
)
.join(
UserGroup,
UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id,
)
.where(UserGroup.id == user_group_id)
.order_by(Document.id)
)
stmt = stmt.distinct()
return stmt
def fetch_documents_for_user_group_paginated(
db_session: Session,
user_group_id: int,
@@ -361,6 +413,10 @@ def update_user_group(
user_group_id: int,
user_group_update: UserGroupUpdate,
) -> UserGroup:
"""If successful, this can set db_user_group.is_up_to_date = False.
That will be processed by check_for_vespa_user_groups_sync_task and trigger
a long running background sync to Vespa.
"""
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
if db_user_group is None:

View File

@@ -32,7 +32,7 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user is None or user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_current=False)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,

View File

@@ -1,87 +0,0 @@
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.document import prepare_to_modify_documents
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
from ee.danswer.db.user_group import delete_user_group
from ee.danswer.db.user_group import fetch_documents_for_user_group_paginated
from ee.danswer.db.user_group import fetch_user_group
from ee.danswer.db.user_group import mark_user_group_as_synced
logger = setup_logger()
_SYNC_BATCH_SIZE = 100
def _sync_user_group_batch(
document_ids: list[str], document_index: DocumentIndex, db_session: Session
) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# Acquires a lock on the documents so that no other process can modify them
with prepare_to_modify_documents(db_session=db_session, document_ids=document_ids):
# get current state of document sets for these documents
document_id_to_access = get_access_for_documents(
document_ids=document_ids, db_session=db_session
)
# update Vespa
document_index.update(
update_requests=[
UpdateRequest(
document_ids=[document_id],
access=document_id_to_access[document_id],
)
for document_id in document_ids
]
)
# Finish the transaction and release the locks
db_session.commit()
def sync_user_groups(user_group_id: int, db_session: Session) -> None:
"""Sync the status of Postgres for the specified user group"""
search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=secondary_search_settings.index_name
if secondary_search_settings
else None,
)
user_group = fetch_user_group(db_session=db_session, user_group_id=user_group_id)
if user_group is None:
raise ValueError(f"User group '{user_group_id}' does not exist")
cursor = None
while True:
# NOTE: this may miss some documents, but that is okay. Any new documents added
# will be added with the correct group membership
document_batch, cursor = fetch_documents_for_user_group_paginated(
db_session=db_session,
user_group_id=user_group_id,
last_document_id=cursor,
limit=_SYNC_BATCH_SIZE,
)
_sync_user_group_batch(
document_ids=[document.id for document in document_batch],
document_index=document_index,
db_session=db_session,
)
if cursor is None:
break
if user_group.is_up_for_deletion:
delete_user_group(db_session=db_session, user_group=user_group)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)

View File

@@ -24,14 +24,21 @@ autorestart=true
# relatively compute-light (e.g. they tend to just make a bunch of requests to
# Vespa / Postgres)
[program:celery_worker]
command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --concurrency=6 --loglevel=INFO --logfile=/var/log/celery_worker_supervisor.log
command=celery -A danswer.background.celery.celery_run:celery_app worker
--pool=threads
--concurrency=6
--loglevel=INFO
--logfile=/var/log/celery_worker_supervisor.log
-Q celery,vespa_metadata_sync
environment=LOG_FILE_NAME=celery_worker
redirect_stderr=true
autorestart=true
# Job scheduler for periodic tasks
[program:celery_beat]
command=celery -A danswer.background.celery.celery_run:celery_app beat --loglevel=INFO --logfile=/var/log/celery_beat_supervisor.log
command=celery -A danswer.background.celery.celery_run:celery_app beat
--loglevel=INFO
--logfile=/var/log/celery_beat_supervisor.log
environment=LOG_FILE_NAME=celery_beat
redirect_stderr=true
autorestart=true

View File

@@ -211,7 +211,7 @@ export function Explorer({
)}
{!query && (
<div className="flex text-emphasis mt-3">
Search for a document above to modify it&apos;s boost or hide it from
Search for a document above to modify its boost or hide it from
searches.
</div>
)}