Merge branch 'main' of https://github.com/danswer-ai/danswer into bugfix/index_attempt_query

This commit is contained in:
Richard Kuo (Danswer) 2025-01-13 13:23:30 -08:00
commit 46cfaa96b7
35 changed files with 1413 additions and 143 deletions

View File

@ -28,6 +28,7 @@
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
@ -51,7 +52,8 @@
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat"
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
@ -269,6 +271,31 @@
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery beat",
"type": "debugpy",

View File

@ -0,0 +1,72 @@
"""Add SyncRecord
Revision ID: 97dbb53fa8c8
Revises: 369644546676
Create Date: 2025-01-11 19:39:50.426302
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "97dbb53fa8c8"
down_revision = "be2ab2aa50ee"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"sync_record",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("entity_id", sa.Integer(), nullable=False),
sa.Column(
"sync_type",
sa.Enum(
"DOCUMENT_SET",
"USER_GROUP",
"CONNECTOR_DELETION",
name="synctype",
native_enum=False,
length=40,
),
nullable=False,
),
sa.Column(
"sync_status",
sa.Enum(
"IN_PROGRESS",
"SUCCESS",
"FAILED",
"CANCELED",
name="syncstatus",
native_enum=False,
length=40,
),
nullable=False,
),
sa.Column("num_docs_synced", sa.Integer(), nullable=False),
sa.Column("sync_start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("sync_end_time", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
# Add index for fetch_latest_sync_record query
op.create_index(
"ix_sync_record_entity_id_sync_type_sync_start_time",
"sync_record",
["entity_id", "sync_type", "sync_start_time"],
)
# Add index for cleanup_sync_records query
op.create_index(
"ix_sync_record_entity_id_sync_type_sync_status",
"sync_record",
["entity_id", "sync_type", "sync_status"],
)
def downgrade() -> None:
op.drop_index("ix_sync_record_entity_id_sync_type_sync_status")
op.drop_index("ix_sync_record_entity_id_sync_type_sync_start_time")
op.drop_table("sync_record")

View File

@ -0,0 +1,38 @@
"""fix_capitalization
Revision ID: be2ab2aa50ee
Revises: 369644546676
Create Date: 2025-01-10 13:13:26.228960
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "be2ab2aa50ee"
down_revision = "369644546676"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE document
SET
external_user_group_ids = ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
),
last_modified = NOW()
WHERE
external_user_group_ids IS NOT NULL
AND external_user_group_ids::text[] <> ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
)::text[]
"""
)
def downgrade() -> None:
# No way to cleanly persist the bad state through an upgrade/downgrade
# cycle, so we just pass
pass

View File

@ -0,0 +1,41 @@
"""Add time_updated to UserGroup and DocumentSet
Revision ID: fec3db967bf7
Revises: 97dbb53fa8c8
Create Date: 2025-01-12 15:49:02.289100
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fec3db967bf7"
down_revision = "97dbb53fa8c8"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"document_set",
sa.Column(
"time_last_modified_by_user",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.add_column(
"user_group",
sa.Column(
"time_last_modified_by_user",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
def downgrade() -> None:
op.drop_column("user_group", "time_last_modified_by_user")
op.drop_column("document_set", "time_last_modified_by_user")

View File

@ -8,6 +8,9 @@ from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import mark_user_group_as_synced
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from onyx.background.celery.apps.app_base import task_logger
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
@ -43,24 +46,59 @@ def monitor_usergroup_taskset(
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=count,
)
return
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
if user_group:
usergroup_name = user_group.name
if user_group.is_up_for_deletion:
# this prepare should have been run when the deletion was scheduled,
# but run it again to be sure we're ready to go
mark_user_group_as_synced(db_session, user_group)
prepare_user_group_for_deletion(db_session, usergroup_id)
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
try:
if user_group.is_up_for_deletion:
# this prepare should have been run when the deletion was scheduled,
# but run it again to be sure we're ready to go
mark_user_group_as_synced(db_session, user_group)
prepare_user_group_for_deletion(db_session, usergroup_id)
delete_user_group(db_session=db_session, user_group=user_group)
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
task_logger.info(
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
task_logger.info(
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
)
except Exception as e:
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.FAILED,
num_docs_synced=initial_count,
)
raise e
rug.reset()

View File

@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.models import ExternalAccess
from onyx.access.utils import prefix_group_w_source
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.db.models import Document as DbDocument
@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit(
).first()
prefixed_external_groups = [
prefix_group_w_source(
build_ext_group_name_for_onyx(
ext_group_name=group_id,
source=source_type,
)
@ -66,7 +66,7 @@ def upsert_document_external_perms(
).first()
prefixed_external_groups: set[str] = {
prefix_group_w_source(
build_ext_group_name_for_onyx(
ext_group_name=group_id,
source=source_type,
)

View File

@ -6,8 +6,9 @@ from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.utils import prefix_group_w_source
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.users import get_user_by_email
@ -61,8 +62,10 @@ def replace_user__ext_group_for_cc_pair(
all_group_member_emails.add(user_email)
# batch add users if they don't exist and get their ids
all_group_members = batch_add_ext_perm_user_if_not_exists(
db_session=db_session, emails=list(all_group_member_emails)
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
# NOTE: this function handles case sensitivity for emails
emails=list(all_group_member_emails),
)
delete_user__ext_group_for_cc_pair__no_commit(
@ -84,12 +87,14 @@ def replace_user__ext_group_for_cc_pair(
f" with email {user_email} not found"
)
continue
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=prefix_group_w_source(
external_group.id, source
),
external_user_group_id=external_group_id,
cc_pair_id=cc_pair_id,
)
)

View File

@ -374,7 +374,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(name=user_group.name)
db_user_group = UserGroup(name=user_group.name, time_updated=func.now())
db_session.add(db_user_group)
db_session.flush() # give the group an ID
@ -630,6 +630,10 @@ def update_user_group(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
_validate_curator_status__no_commit(db_session, list(removed_users))
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
db_session.commit()
return db_user_group

View File

@ -19,6 +19,9 @@ def prefix_external_group(ext_group_name: str) -> str:
return f"external_group:{ext_group_name}"
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
"""External groups may collide across sources, every source needs its own prefix."""
return f"{source.value.upper()}_{ext_group_name}"
def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str:
"""
External groups may collide across sources, every source needs its own prefix.
NOTE: the name is lowercased to handle case sensitivity for group names
"""
return f"{source.value}_{ext_group_name}".lower()

View File

@ -161,9 +161,34 @@ def on_task_postrun(
return
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
# NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn"
# But something is blocking set_start_method from working in the cloud unless
# force=True. so we use force=True as a fallback.
all_start_methods: list[str] = multiprocessing.get_all_start_methods()
logger.info(f"Multiprocessing all start methods: {all_start_methods}")
try:
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
except Exception:
logger.info(
"Multiprocessing set_start_method exceptioned. Trying force=True..."
)
try:
multiprocessing.set_start_method(
"spawn", force=True
) # fork is unsafe, set to spawn
except Exception:
logger.info(
"Multiprocessing set_start_method force=True exceptioned even with force=True."
)
logger.info(
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
)
def wait_for_redis(sender: Any, **kwargs: Any) -> None:

View File

@ -1,9 +1,9 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@ -49,17 +49,16 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@ -1,9 +1,9 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
@ -50,22 +50,21 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
# rkuo: been seeing transient connection exceptions here, so upping the connection count
# from just concurrency/concurrency to concurrency/concurrency*2
SqlEngine.init_engine(
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
)
# rkuo: Transient errors keep happening in the indexing watchdog threads.
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@ -1,9 +1,9 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@ -15,7 +15,6 @@ from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
@ -49,17 +48,18 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@ -0,0 +1,95 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_MONITORING_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_MONITORING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=3)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
]
)

View File

@ -1,5 +1,4 @@
import logging
import multiprocessing
from typing import Any
from typing import cast
@ -7,6 +6,7 @@ from celery import bootsteps # type: ignore
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.exceptions import WorkerShutdown
from celery.signals import celeryd_init
from celery.signals import worker_init
@ -73,14 +73,13 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
@ -135,7 +134,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
raise WorkerShutdown("Primary worker lock could not be acquired!")
# tacking on our own user data to the sender
sender.primary_worker_lock = lock
sender.primary_worker_lock = lock # type: ignore
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)

View File

@ -0,0 +1,21 @@
import onyx.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Monitoring worker specific settings
worker_concurrency = 1 # Single worker is sufficient for monitoring
worker_pool = "solo"
worker_prefetch_multiplier = 1

View File

@ -3,6 +3,7 @@ from typing import Any
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
# choosing 15 minutes because it roughly gives us enough time to process many tasks
@ -68,6 +69,16 @@ tasks_to_schedule = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,

View File

@ -17,7 +17,10 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncType
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
from onyx.redis.redis_pool import get_redis_client
@ -118,6 +121,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
# there should be no in-progress sync records if this is up to date
# clean it up just in case things got into a bad state
cleanup_sync_records(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
return None
# set a basic fence to start
@ -126,6 +136,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
submitted=datetime.now(timezone.utc),
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
redis_connector.delete.set_fence(fence_payload)
try:

View File

@ -391,5 +391,7 @@ def update_external_document_permissions_task(
)
return True
except Exception:
logger.exception("Error Syncing Document Permissions")
logger.exception(
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
)
return False

View File

@ -1,3 +1,4 @@
import multiprocessing
import os
import sys
import time
@ -862,11 +863,14 @@ def connector_indexing_proxy_task(
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
"""celery tasks are forked, but forking is unstable.
This is a thread that proxies work to a spawned task."""
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
f"search_settings={search_settings_id} "
f"mp_start_method={multiprocessing.get_start_method()}"
)
if not self.request.id:

View File

@ -0,0 +1,427 @@
import json
from collections.abc import Callable
from datetime import timedelta
from typing import Any
from celery import shared_task
from celery import Task
from pydantic import BaseModel
from redis import Redis
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import IndexAttempt
from onyx.db.models import SyncRecord
from onyx.db.models import UserGroup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
_CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT = (
"monitoring_connector_index_attempt_start_latency:{cc_pair_id}:{index_attempt_id}"
)
_CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
)
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
redis_std.set(key, "1", ex=24 * 60 * 60) # Expire after 1 day
def _has_metric_been_emitted(redis_std: Redis, key: str) -> bool:
"""Check if a metric has been emitted by checking for existence of Redis key"""
return bool(redis_std.exists(key))
class Metric(BaseModel):
key: str | None # only required if we need to store that we have emitted this metric
name: str
value: Any
tags: dict[str, str]
def log(self) -> None:
"""Log the metric in a standardized format"""
data = {
"metric": self.name,
"value": self.value,
"tags": self.tags,
}
task_logger.info(json.dumps(data))
def emit(self) -> None:
# Convert value to appropriate type
float_value = (
float(self.value) if isinstance(self.value, (int, float)) else None
)
int_value = int(self.value) if isinstance(self.value, int) else None
string_value = str(self.value) if isinstance(self.value, str) else None
bool_value = bool(self.value) if isinstance(self.value, bool) else None
if (
float_value is None
and int_value is None
and string_value is None
and bool_value is None
):
task_logger.error(
f"Invalid metric value type: {type(self.value)} "
f"({self.value}) for metric {self.name}."
)
return
# don't send None values over the wire
data = {
k: v
for k, v in {
"metric_name": self.name,
"float_value": float_value,
"int_value": int_value,
"string_value": string_value,
"bool_value": bool_value,
"tags": self.tags,
}.items()
if v is not None
}
optional_telemetry(
record_type=RecordType.METRIC,
data=data,
)
def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
"""Collect metrics about queue lengths for different Celery queues"""
metrics = []
queue_mappings = {
"celery_queue_length": "celery",
"indexing_queue_length": "indexing",
"sync_queue_length": "sync",
"deletion_queue_length": "deletion",
"pruning_queue_length": "pruning",
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
}
for name, queue in queue_mappings.items():
metrics.append(
Metric(
key=None,
name=name,
value=celery_get_queue_length(queue, redis_celery),
tags={"queue": name},
)
)
return metrics
def _build_connector_start_latency_metric(
cc_pair: ConnectorCredentialPair,
recent_attempt: IndexAttempt,
second_most_recent_attempt: IndexAttempt | None,
redis_std: Redis,
) -> Metric | None:
if not recent_attempt.time_started:
return None
# check if we already emitted a metric for this index attempt
metric_key = _CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=recent_attempt.id,
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping metric for connector {cc_pair.connector.id} "
f"index attempt {recent_attempt.id} because it has already been "
"emitted"
)
return None
# Connector start latency
# first run case - we should start as soon as it's created
if not second_most_recent_attempt:
desired_start_time = cc_pair.connector.time_created
else:
if not cc_pair.connector.refresh_freq:
task_logger.error(
"Found non-initial index attempt for connector "
"without refresh_freq. This should never happen."
)
return None
desired_start_time = second_most_recent_attempt.time_updated + timedelta(
seconds=cc_pair.connector.refresh_freq
)
start_latency = (recent_attempt.time_started - desired_start_time).total_seconds()
return Metric(
key=metric_key,
name="connector_start_latency",
value=start_latency,
tags={},
)
def _build_run_success_metric(
cc_pair: ConnectorCredentialPair, recent_attempt: IndexAttempt, redis_std: Redis
) -> Metric | None:
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=recent_attempt.id,
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping metric for connector {cc_pair.connector.id} "
f"index attempt {recent_attempt.id} because it has already been "
"emitted"
)
return None
if recent_attempt.status in [
IndexingStatus.SUCCESS,
IndexingStatus.FAILED,
IndexingStatus.CANCELED,
]:
return Metric(
key=metric_key,
name="connector_run_succeeded",
value=recent_attempt.status == IndexingStatus.SUCCESS,
tags={"source": str(cc_pair.connector.source)},
)
return None
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
"""Collect metrics about connector runs from the past hour"""
# NOTE: use get_db_current_time since the IndexAttempt times are set based on DB time
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all connector credential pairs
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
metrics = []
for cc_pair in cc_pairs:
# Get most recent attempt in the last hour
recent_attempts = (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair.id,
IndexAttempt.time_created >= one_hour_ago,
)
.order_by(IndexAttempt.time_created.desc())
.limit(2)
.all()
)
recent_attempt = recent_attempts[0] if recent_attempts else None
second_most_recent_attempt = (
recent_attempts[1] if len(recent_attempts) > 1 else None
)
# if no metric to emit, skip
if not recent_attempt:
continue
# Connector start latency
start_latency_metric = _build_connector_start_latency_metric(
cc_pair, recent_attempt, second_most_recent_attempt, redis_std
)
if start_latency_metric:
metrics.append(start_latency_metric)
# Connector run success/failure
run_success_metric = _build_run_success_metric(
cc_pair, recent_attempt, redis_std
)
if run_success_metric:
metrics.append(run_success_metric)
return metrics
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
"""Collect metrics about document set and group syncing speed"""
# NOTE: use get_db_current_time since the SyncRecord times are set based on DB time
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all sync records from the last hour
recent_sync_records = db_session.scalars(
select(SyncRecord)
.where(SyncRecord.sync_start_time >= one_hour_ago)
.order_by(SyncRecord.sync_start_time.desc())
).all()
metrics = []
for sync_record in recent_sync_records:
# Skip if no end time (sync still in progress)
if not sync_record.sync_end_time:
continue
# Check if we already emitted a metric for this sync record
metric_key = (
f"sync_speed:{sync_record.sync_type}:"
f"{sync_record.entity_id}:{sync_record.id}"
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.debug(
f"Skipping metric for sync record {sync_record.id} "
"because it has already been emitted"
)
continue
# Calculate sync duration in minutes
sync_duration_mins = (
sync_record.sync_end_time - sync_record.sync_start_time
).total_seconds() / 60.0
# Calculate sync speed (docs/min) - avoid division by zero
sync_speed = (
sync_record.num_docs_synced / sync_duration_mins
if sync_duration_mins > 0
else None
)
if sync_speed is None:
task_logger.error(
"Something went wrong with sync speed calculation. "
f"Sync record: {sync_record.id}"
)
continue
metrics.append(
Metric(
key=metric_key,
name="sync_speed_docs_per_min",
value=sync_speed,
tags={
"sync_type": str(sync_record.sync_type),
"status": str(sync_record.sync_status),
},
)
)
# Add sync start latency metric
start_latency_key = (
f"sync_start_latency:{sync_record.sync_type}"
f":{sync_record.entity_id}:{sync_record.id}"
)
if _has_metric_been_emitted(redis_std, start_latency_key):
task_logger.debug(
f"Skipping start latency metric for sync record {sync_record.id} "
"because it has already been emitted"
)
continue
# Get the entity's last update time based on sync type
entity: DocumentSet | UserGroup | None = None
if sync_record.sync_type == SyncType.DOCUMENT_SET:
entity = db_session.scalar(
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
)
elif sync_record.sync_type == SyncType.USER_GROUP:
entity = db_session.scalar(
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
)
else:
# Skip other sync types
task_logger.debug(
f"Skipping sync record {sync_record.id} "
f"with type {sync_record.sync_type} "
f"and id {sync_record.entity_id} "
"because it is not a document set or user group"
)
continue
if entity is None:
task_logger.error(
f"Could not find entity for sync record {sync_record.id} "
f"with type {sync_record.sync_type} and id {sync_record.entity_id}"
)
continue
# Calculate start latency in seconds
start_latency = (
sync_record.sync_start_time - entity.time_last_modified_by_user
).total_seconds()
if start_latency < 0:
task_logger.error(
f"Start latency is negative for sync record {sync_record.id} "
f"with type {sync_record.sync_type} and id {sync_record.entity_id}."
"This is likely because the entity was updated between the time the "
"time the sync finished and this job ran. Skipping."
)
continue
metrics.append(
Metric(
key=start_latency_key,
name="sync_start_latency_seconds",
value=start_latency,
tags={
"sync_type": str(sync_record.sync_type),
},
)
)
return metrics
@shared_task(
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
soft_time_limit=JOB_TIMEOUT,
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
"""Collect and emit metrics about background processes.
This task runs periodically to gather metrics about:
- Queue lengths for different Celery queues
- Connector run metrics (start latency, success rate)
- Syncing speed metrics
- Worker status and task counts
"""
task_logger.info("Starting background process monitoring")
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client(tenant_id=tenant_id)
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
lambda: _collect_queue_metrics(redis_celery),
lambda: _collect_connector_metrics(db_session, redis_std),
lambda: _collect_sync_metrics(db_session, redis_std),
]
# Collect and log each metric
with get_session_with_tenant(tenant_id) as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
metric.log()
metric.emit()
if metric.key:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background process metrics")
except Exception as e:
task_logger.exception("Error collecting background process metrics")
raise e

View File

@ -1,6 +1,7 @@
import random
import time
import traceback
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
@ -53,10 +54,16 @@ from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import mark_document_set_as_synced
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import DocumentSet
from onyx.db.models import UserGroup
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
@ -283,6 +290,13 @@ def try_generate_document_set_sync_tasks(
return None
if document_set.is_up_to_date:
# there should be no in-progress sync records if this is up to date
# clean it up just in case things got into a bad state
cleanup_sync_records(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
)
return None
# add tasks to celery and build up the task set to monitor in redis
@ -311,6 +325,13 @@ def try_generate_document_set_sync_tasks(
f"document_set={document_set.id} tasks_generated={tasks_generated}"
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
)
# set this only after all tasks have been added
rds.set_fence(tasks_generated)
return tasks_generated
@ -332,8 +353,9 @@ def try_generate_user_group_sync_tasks(
return None
# race condition with the monitor/cleanup function if we use a cached result!
fetch_user_group = fetch_versioned_implementation(
"onyx.db.user_group", "fetch_user_group"
fetch_user_group = cast(
Callable[[Session, int], UserGroup | None],
fetch_versioned_implementation("onyx.db.user_group", "fetch_user_group"),
)
usergroup = fetch_user_group(db_session, usergroup_id)
@ -341,6 +363,13 @@ def try_generate_user_group_sync_tasks(
return None
if usergroup.is_up_to_date:
# there should be no in-progress sync records if this is up to date
# clean it up just in case things got into a bad state
cleanup_sync_records(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
)
return None
# add tasks to celery and build up the task set to monitor in redis
@ -368,8 +397,16 @@ def try_generate_user_group_sync_tasks(
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
)
# set this only after all tasks have been added
rug.set_fence(tasks_generated)
return tasks_generated
@ -419,6 +456,13 @@ def monitor_document_set_taskset(
f"remaining={count} initial={initial_count}"
)
if count > 0:
update_sync_record_status(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=count,
)
return
document_set = cast(
@ -437,6 +481,13 @@ def monitor_document_set_taskset(
task_logger.info(
f"Successfully synced document set: document_set={document_set_id}"
)
update_sync_record_status(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
rds.reset()
@ -470,6 +521,14 @@ def monitor_connector_deletion_taskset(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=remaining,
)
return
with get_session_with_tenant(tenant_id) as db_session:
@ -545,11 +604,29 @@ def monitor_connector_deletion_taskset(
)
db_session.delete(connector)
db_session.commit()
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=fence_data.num_tasks,
)
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.FAILED,
num_docs_synced=fence_data.num_tasks,
)
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"

View File

@ -0,0 +1,15 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.monitoring import celery_app
return celery_app
app = get_app()

View File

@ -4,9 +4,10 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing as mp
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing import Process
from multiprocessing.context import SpawnProcess
from typing import Any
from typing import Literal
from typing import Optional
@ -63,7 +64,7 @@ class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""
id: int
process: Optional["Process"] = None
process: Optional["SpawnProcess"] = None
def cancel(self) -> bool:
return self.release()
@ -131,7 +132,10 @@ class SimpleJobClient:
job_id = self.job_id_counter
self.job_id_counter += 1
process = Process(target=_run_in_process, args=(func, args), daemon=True)
# this approach allows us to always "spawn" a new process regardless of
# get_start_method's current setting
ctx = mp.get_context("spawn")
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@ -47,6 +47,7 @@ POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
@ -260,6 +261,9 @@ class OnyxCeleryQueues:
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
# Monitoring queue
MONITORING = "monitoring"
class OnyxRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
@ -308,6 +312,7 @@ class OnyxCeleryTask:
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"

View File

@ -135,32 +135,6 @@ class OnyxConfluence(Confluence):
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _wrap_methods(self) -> None:
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
@ -363,6 +337,9 @@ class OnyxConfluence(Confluence):
fetch the permissions of a space.
This is better logging than calling the get_space_permissions method
because it returns a jsonrpc response.
TODO: Make this call these endpoints for newer confluence versions:
- /rest/api/space/{spaceKey}/permissions
- /rest/api/space/{spaceKey}/permissions/anonymous
"""
url = "rpc/json-rpc/confluenceservice-v2"
data = {
@ -381,6 +358,32 @@ class OnyxConfluence(Confluence):
return response.get("result", [])
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _validate_connector_configuration(
credentials: dict[str, Any],

View File

@ -218,6 +218,7 @@ def insert_document_set(
description=document_set_creation_request.description,
user_id=user_id,
is_public=document_set_creation_request.is_public,
time_updated=func.now(),
)
db_session.add(new_document_set_row)
db_session.flush() # ensure the new document set gets assigned an ID
@ -293,7 +294,7 @@ def update_document_set(
document_set_row.description = document_set_update_request.description
document_set_row.is_up_to_date = False
document_set_row.is_public = document_set_update_request.is_public
document_set_row.time_last_modified_by_user = func.now()
versioned_private_doc_set_fn = fetch_versioned_implementation(
"onyx.db.document_set", "make_doc_set_private"
)

View File

@ -24,12 +24,27 @@ class IndexingMode(str, PyEnum):
REINDEX = "reindex"
# these may differ in the future, which is why we're okay with this duplication
class DeletionStatus(str, PyEnum):
NOT_STARTED = "not_started"
class SyncType(str, PyEnum):
DOCUMENT_SET = "document_set"
USER_GROUP = "user_group"
CONNECTOR_DELETION = "connector_deletion"
def __str__(self) -> str:
return self.value
class SyncStatus(str, PyEnum):
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
CANCELED = "canceled"
def is_terminal(self) -> bool:
terminal_states = {
SyncStatus.SUCCESS,
SyncStatus.FAILED,
}
return self in terminal_states
# Consistent with Celery task statuses

View File

@ -44,7 +44,7 @@ from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.db.enums import AccessType, IndexingMode
from onyx.db.enums import AccessType, IndexingMode, SyncType, SyncStatus
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.constants import TokenRateLimitScope
@ -881,6 +881,46 @@ class IndexAttemptError(Base):
)
class SyncRecord(Base):
"""
Represents the status of a "sync" operation (e.g. document set, user group, deletion).
A "sync" operation is an operation which needs to update a set of documents within
Vespa, usually to match the state of Postgres.
"""
__tablename__ = "sync_record"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
# document set id, user group id, or deletion id
entity_id: Mapped[int] = mapped_column(Integer)
sync_type: Mapped[SyncType] = mapped_column(Enum(SyncType, native_enum=False))
sync_status: Mapped[SyncStatus] = mapped_column(Enum(SyncStatus, native_enum=False))
num_docs_synced: Mapped[int] = mapped_column(Integer, default=0)
sync_start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
sync_end_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
__table_args__ = (
Index(
"ix_sync_record_entity_id_sync_type_sync_start_time",
"entity_id",
"sync_type",
"sync_start_time",
),
Index(
"ix_sync_record_entity_id_sync_type_sync_status",
"entity_id",
"sync_type",
"sync_status",
),
)
class DocumentByConnectorCredentialPair(Base):
"""Represents an indexing of a document by a specific connector / credential pair"""
@ -1284,6 +1324,11 @@ class DocumentSet(Base):
# given access to it either via the `users` or `groups` relationships
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# Last time a user updated this document set
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
connector_credential_pairs: Mapped[list[ConnectorCredentialPair]] = relationship(
"ConnectorCredentialPair",
secondary=DocumentSet__ConnectorCredentialPair.__table__,
@ -1763,6 +1808,11 @@ class UserGroup(Base):
Boolean, nullable=False, default=False
)
# Last time a user updated this user group
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
users: Mapped[list[User]] = relationship(
"User",
secondary=User__UserGroup.__table__,

View File

@ -0,0 +1,110 @@
from sqlalchemy import and_
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import SyncRecord
def insert_sync_record(
db_session: Session,
entity_id: int | None,
sync_type: SyncType,
) -> SyncRecord:
"""Insert a new sync record into the database.
Args:
db_session: The database session to use
entity_id: The ID of the entity being synced (document set ID, user group ID, etc.)
sync_type: The type of sync operation
"""
sync_record = SyncRecord(
entity_id=entity_id,
sync_type=sync_type,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=0,
sync_start_time=func.now(),
)
db_session.add(sync_record)
db_session.commit()
return sync_record
def fetch_latest_sync_record(
db_session: Session,
entity_id: int,
sync_type: SyncType,
) -> SyncRecord | None:
"""Fetch the most recent sync record for a given entity ID and status.
Args:
db_session: The database session to use
entity_id: The ID of the entity to fetch sync record for
sync_type: The type of sync operation
"""
stmt = (
select(SyncRecord)
.where(
and_(
SyncRecord.entity_id == entity_id,
SyncRecord.sync_type == sync_type,
)
)
.order_by(desc(SyncRecord.sync_start_time))
.limit(1)
)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def update_sync_record_status(
db_session: Session,
entity_id: int,
sync_type: SyncType,
sync_status: SyncStatus,
num_docs_synced: int | None = None,
) -> None:
"""Update the status of a sync record.
Args:
db_session: The database session to use
entity_id: The ID of the entity being synced
sync_type: The type of sync operation
sync_status: The new status to set
num_docs_synced: Optional number of documents synced to update
"""
sync_record = fetch_latest_sync_record(db_session, entity_id, sync_type)
if sync_record is None:
raise ValueError(
f"No sync record found for entity_id={entity_id} sync_type={sync_type}"
)
sync_record.sync_status = sync_status
if num_docs_synced is not None:
sync_record.num_docs_synced = num_docs_synced
if sync_status.is_terminal():
sync_record.sync_end_time = func.now() # type: ignore
db_session.commit()
def cleanup_sync_records(
db_session: Session, entity_id: int, sync_type: SyncType
) -> None:
"""Cleanup sync records for a given entity ID and sync type by marking them as failed."""
stmt = (
update(SyncRecord)
.where(SyncRecord.entity_id == entity_id)
.where(SyncRecord.sync_type == sync_type)
.where(SyncRecord.sync_status == SyncStatus.IN_PROGRESS)
.values(sync_status=SyncStatus.CANCELED, sync_end_time=func.now())
)
db_session.execute(stmt)
db_session.commit()

View File

@ -33,6 +33,7 @@ class RecordType(str, Enum):
USAGE = "usage"
LATENCY = "latency"
FAILURE = "failure"
METRIC = "metric"
def get_or_generate_uuid() -> str:

View File

@ -65,6 +65,18 @@ autorestart=true
startsecs=10
stopasgroup=true
[program:celery_worker_monitoring]
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
--loglevel=INFO
--hostname=monitoring@%%n
-Q monitoring
stdout_logfile=/var/log/celery_worker_monitoring.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
# Job scheduler for periodic tasks
[program:celery_beat]
command=celery -A onyx.background.celery.versioned_apps.beat beat

View File

@ -63,57 +63,57 @@ def _run_migrations(
logging.getLogger("alembic").setLevel(logging.INFO)
def reset_postgres(
database: str = "postgres", config_name: str = "alembic", setup_onyx: bool = True
def downgrade_postgres(
database: str = "postgres",
config_name: str = "alembic",
revision: str = "base",
clear_data: bool = False,
) -> None:
"""Reset the Postgres database."""
"""Downgrade Postgres database to base state."""
if clear_data:
if revision != "base":
logger.warning("Clearing data without rolling back to base state")
# Delete all rows to allow migrations to be rolled back
conn = psycopg2.connect(
dbname=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
)
cur = conn.cursor()
# NOTE: need to delete all rows to allow migrations to be rolled back
# as there are a few downgrades that don't properly handle data in tables
conn = psycopg2.connect(
dbname=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
)
cur = conn.cursor()
# Disable triggers to prevent foreign key constraints from being checked
cur.execute("SET session_replication_role = 'replica';")
# Disable triggers to prevent foreign key constraints from being checked
cur.execute("SET session_replication_role = 'replica';")
# Fetch all table names in the current database
cur.execute(
# Fetch all table names in the current database
cur.execute(
"""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
"""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
"""
)
)
tables = cur.fetchall()
tables = cur.fetchall()
for table in tables:
table_name = table[0]
for table in tables:
table_name = table[0]
# Don't touch migration history
if table_name == "alembic_version":
continue
# Don't touch migration history or Kombu
if table_name in ("alembic_version", "kombu_message", "kombu_queue"):
continue
# Don't touch Kombu
if table_name == "kombu_message" or table_name == "kombu_queue":
continue
cur.execute(f'DELETE FROM "{table_name}"')
cur.execute(f'DELETE FROM "{table_name}"')
# Re-enable triggers
cur.execute("SET session_replication_role = 'origin';")
# Re-enable triggers
cur.execute("SET session_replication_role = 'origin';")
conn.commit()
cur.close()
conn.close()
conn.commit()
cur.close()
conn.close()
# downgrade to base + upgrade back to head
# Downgrade to base
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
@ -126,20 +126,43 @@ def reset_postgres(
conn_str,
config_name,
direction="downgrade",
revision="base",
revision=revision,
)
def upgrade_postgres(
database: str = "postgres", config_name: str = "alembic", revision: str = "head"
) -> None:
"""Upgrade Postgres database to latest version."""
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
)
_run_migrations(
conn_str,
config_name,
direction="upgrade",
revision="head",
revision=revision,
)
if not setup_onyx:
return
# do the same thing as we do on API server startup
with get_session_context_manager() as db_session:
setup_postgres(db_session)
def reset_postgres(
database: str = "postgres",
config_name: str = "alembic",
setup_onyx: bool = True,
) -> None:
"""Reset the Postgres database."""
downgrade_postgres(
database=database, config_name=config_name, revision="base", clear_data=True
)
upgrade_postgres(database=database, config_name=config_name, revision="head")
if setup_onyx:
with get_session_context_manager() as db_session:
setup_postgres(db_session)
def reset_vespa() -> None:

View File

@ -0,0 +1,125 @@
import pytest
from sqlalchemy import text
from onyx.configs.constants import DEFAULT_BOOST
from onyx.db.engine import get_session_context_manager
from tests.integration.common_utils.reset import downgrade_postgres
from tests.integration.common_utils.reset import upgrade_postgres
@pytest.mark.skip(
reason="Migration test no longer needed - migration has been applied to production"
)
def test_fix_capitalization_migration() -> None:
"""Test that the be2ab2aa50ee migration correctly lowercases external_user_group_ids"""
# Reset the database and run migrations up to the second to last migration
downgrade_postgres(
database="postgres", config_name="alembic", revision="base", clear_data=True
)
upgrade_postgres(
database="postgres",
config_name="alembic",
# Upgrade it to the migration before the fix
revision="369644546676",
)
# Insert test data with mixed case group IDs
test_data = [
{
"id": "test_doc_1",
"external_user_group_ids": ["Group1", "GROUP2", "group3"],
"semantic_id": "test_doc_1",
"boost": DEFAULT_BOOST,
"hidden": False,
"from_ingestion_api": False,
"last_modified": "NOW()",
},
{
"id": "test_doc_2",
"external_user_group_ids": ["UPPER1", "upper2", "UPPER3"],
"semantic_id": "test_doc_2",
"boost": DEFAULT_BOOST,
"hidden": False,
"from_ingestion_api": False,
"last_modified": "NOW()",
},
]
# Insert the test data
with get_session_context_manager() as db_session:
for doc in test_data:
db_session.execute(
text(
"""
INSERT INTO document (
id,
external_user_group_ids,
semantic_id,
boost,
hidden,
from_ingestion_api,
last_modified
)
VALUES (
:id,
:group_ids,
:semantic_id,
:boost,
:hidden,
:from_ingestion_api,
:last_modified
)
"""
),
{
"id": doc["id"],
"group_ids": doc["external_user_group_ids"],
"semantic_id": doc["semantic_id"],
"boost": doc["boost"],
"hidden": doc["hidden"],
"from_ingestion_api": doc["from_ingestion_api"],
"last_modified": doc["last_modified"],
},
)
db_session.commit()
# Verify the data was inserted correctly
with get_session_context_manager() as db_session:
results = db_session.execute(
text(
"""
SELECT id, external_user_group_ids
FROM document
WHERE id IN ('test_doc_1', 'test_doc_2')
ORDER BY id
"""
)
).fetchall()
# Verify initial state
assert len(results) == 2
assert results[0].external_user_group_ids == ["Group1", "GROUP2", "group3"]
assert results[1].external_user_group_ids == ["UPPER1", "upper2", "UPPER3"]
# Run migrations again to apply the fix
upgrade_postgres(
database="postgres", config_name="alembic", revision="be2ab2aa50ee"
)
# Verify the fix was applied
with get_session_context_manager() as db_session:
results = db_session.execute(
text(
"""
SELECT id, external_user_group_ids
FROM document
WHERE id IN ('test_doc_1', 'test_doc_2')
ORDER BY id
"""
)
).fetchall()
# Verify all group IDs are lowercase
assert len(results) == 2
assert results[0].external_user_group_ids == ["group1", "group2", "group3"]
assert results[1].external_user_group_ids == ["upper1", "upper2", "upper3"]

View File

@ -102,12 +102,13 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
// meta models
"llama-3.2-90b-vision-instruct",
"llama-3.2-11b-vision-instruct",
"Llama-3-2-11B-Vision-Instruct-yb",
];
export function checkLLMSupportsImageInput(model: string) {
// Original exact match check
const exactMatch = MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
(modelName) => modelName === model
(modelName) => modelName.toLowerCase() === model.toLowerCase()
);
if (exactMatch) {
@ -116,12 +117,13 @@ export function checkLLMSupportsImageInput(model: string) {
// Additional check for the last part of the model name
const modelParts = model.split(/[/.]/);
const lastPart = modelParts[modelParts.length - 1];
const lastPart = modelParts[modelParts.length - 1]?.toLowerCase();
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) => {
const modelNameParts = modelName.split(/[/.]/);
const modelNameLastPart = modelNameParts[modelNameParts.length - 1];
return modelNameLastPart === lastPart;
// lastPart is already lowercased above for tiny performance gain
return modelNameLastPart?.toLowerCase() === lastPart;
});
}