mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
Merge branch 'main' of https://github.com/danswer-ai/danswer into bugfix/index_attempt_query
This commit is contained in:
commit
46cfaa96b7
29
.vscode/launch.template.jsonc
vendored
29
.vscode/launch.template.jsonc
vendored
@ -28,6 +28,7 @@
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
@ -51,7 +52,8 @@
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery beat"
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1",
|
||||
@ -269,6 +271,31 @@
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
"type": "debugpy",
|
||||
|
72
backend/alembic/versions/97dbb53fa8c8_add_syncrecord.py
Normal file
72
backend/alembic/versions/97dbb53fa8c8_add_syncrecord.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""Add SyncRecord
|
||||
|
||||
Revision ID: 97dbb53fa8c8
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-11 19:39:50.426302
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "97dbb53fa8c8"
|
||||
down_revision = "be2ab2aa50ee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"sync_record",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("entity_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"sync_type",
|
||||
sa.Enum(
|
||||
"DOCUMENT_SET",
|
||||
"USER_GROUP",
|
||||
"CONNECTOR_DELETION",
|
||||
name="synctype",
|
||||
native_enum=False,
|
||||
length=40,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"sync_status",
|
||||
sa.Enum(
|
||||
"IN_PROGRESS",
|
||||
"SUCCESS",
|
||||
"FAILED",
|
||||
"CANCELED",
|
||||
name="syncstatus",
|
||||
native_enum=False,
|
||||
length=40,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("num_docs_synced", sa.Integer(), nullable=False),
|
||||
sa.Column("sync_start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("sync_end_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Add index for fetch_latest_sync_record query
|
||||
op.create_index(
|
||||
"ix_sync_record_entity_id_sync_type_sync_start_time",
|
||||
"sync_record",
|
||||
["entity_id", "sync_type", "sync_start_time"],
|
||||
)
|
||||
|
||||
# Add index for cleanup_sync_records query
|
||||
op.create_index(
|
||||
"ix_sync_record_entity_id_sync_type_sync_status",
|
||||
"sync_record",
|
||||
["entity_id", "sync_type", "sync_status"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_sync_record_entity_id_sync_type_sync_status")
|
||||
op.drop_index("ix_sync_record_entity_id_sync_type_sync_start_time")
|
||||
op.drop_table("sync_record")
|
38
backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py
Normal file
38
backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""fix_capitalization
|
||||
|
||||
Revision ID: be2ab2aa50ee
|
||||
Revises: 369644546676
|
||||
Create Date: 2025-01-10 13:13:26.228960
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "be2ab2aa50ee"
|
||||
down_revision = "369644546676"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE document
|
||||
SET
|
||||
external_user_group_ids = ARRAY(
|
||||
SELECT LOWER(unnest(external_user_group_ids))
|
||||
),
|
||||
last_modified = NOW()
|
||||
WHERE
|
||||
external_user_group_ids IS NOT NULL
|
||||
AND external_user_group_ids::text[] <> ARRAY(
|
||||
SELECT LOWER(unnest(external_user_group_ids))
|
||||
)::text[]
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# No way to cleanly persist the bad state through an upgrade/downgrade
|
||||
# cycle, so we just pass
|
||||
pass
|
@ -0,0 +1,41 @@
|
||||
"""Add time_updated to UserGroup and DocumentSet
|
||||
|
||||
Revision ID: fec3db967bf7
|
||||
Revises: 97dbb53fa8c8
|
||||
Create Date: 2025-01-12 15:49:02.289100
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fec3db967bf7"
|
||||
down_revision = "97dbb53fa8c8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"document_set",
|
||||
sa.Column(
|
||||
"time_last_modified_by_user",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user_group",
|
||||
sa.Column(
|
||||
"time_last_modified_by_user",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_group", "time_last_modified_by_user")
|
||||
op.drop_column("document_set", "time_last_modified_by_user")
|
@ -8,6 +8,9 @@ from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import mark_user_group_as_synced
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -43,24 +46,59 @@ def monitor_usergroup_taskset(
|
||||
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=count,
|
||||
)
|
||||
return
|
||||
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||
if user_group:
|
||||
usergroup_name = user_group.name
|
||||
if user_group.is_up_for_deletion:
|
||||
# this prepare should have been run when the deletion was scheduled,
|
||||
# but run it again to be sure we're ready to go
|
||||
mark_user_group_as_synced(db_session, user_group)
|
||||
prepare_user_group_for_deletion(db_session, usergroup_id)
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(
|
||||
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(
|
||||
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
|
||||
try:
|
||||
if user_group.is_up_for_deletion:
|
||||
# this prepare should have been run when the deletion was scheduled,
|
||||
# but run it again to be sure we're ready to go
|
||||
mark_user_group_as_synced(db_session, user_group)
|
||||
prepare_user_group_for_deletion(db_session, usergroup_id)
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
raise e
|
||||
|
||||
rug.reset()
|
||||
|
@ -5,7 +5,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Document as DbDocument
|
||||
|
||||
@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit(
|
||||
).first()
|
||||
|
||||
prefixed_external_groups = [
|
||||
prefix_group_w_source(
|
||||
build_ext_group_name_for_onyx(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
@ -66,7 +66,7 @@ def upsert_document_external_perms(
|
||||
).first()
|
||||
|
||||
prefixed_external_groups: set[str] = {
|
||||
prefix_group_w_source(
|
||||
build_ext_group_name_for_onyx(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
|
@ -6,8 +6,9 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
@ -61,8 +62,10 @@ def replace_user__ext_group_for_cc_pair(
|
||||
all_group_member_emails.add(user_email)
|
||||
|
||||
# batch add users if they don't exist and get their ids
|
||||
all_group_members = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session, emails=list(all_group_member_emails)
|
||||
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
# NOTE: this function handles case sensitivity for emails
|
||||
emails=list(all_group_member_emails),
|
||||
)
|
||||
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
@ -84,12 +87,14 @@ def replace_user__ext_group_for_cc_pair(
|
||||
f" with email {user_email} not found"
|
||||
)
|
||||
continue
|
||||
external_group_id = build_ext_group_name_for_onyx(
|
||||
ext_group_name=external_group.id,
|
||||
source=source,
|
||||
)
|
||||
new_external_permissions.append(
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=prefix_group_w_source(
|
||||
external_group.id, source
|
||||
),
|
||||
external_user_group_id=external_group_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
@ -374,7 +374,7 @@ def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(name=user_group.name)
|
||||
db_user_group = UserGroup(name=user_group.name, time_updated=func.now())
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
@ -630,6 +630,10 @@ def update_user_group(
|
||||
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
|
||||
).unique()
|
||||
_validate_curator_status__no_commit(db_session, list(removed_users))
|
||||
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
|
@ -19,6 +19,9 @@ def prefix_external_group(ext_group_name: str) -> str:
|
||||
return f"external_group:{ext_group_name}"
|
||||
|
||||
|
||||
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""External groups may collide across sources, every source needs its own prefix."""
|
||||
return f"{source.value.upper()}_{ext_group_name}"
|
||||
def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""
|
||||
External groups may collide across sources, every source needs its own prefix.
|
||||
NOTE: the name is lowercased to handle case sensitivity for group names
|
||||
"""
|
||||
return f"{source.value}_{ext_group_name}".lower()
|
||||
|
@ -161,9 +161,34 @@ def on_task_postrun(
|
||||
return
|
||||
|
||||
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
|
||||
# NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn"
|
||||
# But something is blocking set_start_method from working in the cloud unless
|
||||
# force=True. so we use force=True as a fallback.
|
||||
|
||||
all_start_methods: list[str] = multiprocessing.get_all_start_methods()
|
||||
logger.info(f"Multiprocessing all start methods: {all_start_methods}")
|
||||
|
||||
try:
|
||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Multiprocessing set_start_method exceptioned. Trying force=True..."
|
||||
)
|
||||
try:
|
||||
multiprocessing.set_start_method(
|
||||
"spawn", force=True
|
||||
) # fork is unsafe, set to spawn
|
||||
except Exception:
|
||||
logger.info(
|
||||
"Multiprocessing set_start_method force=True exceptioned even with force=True."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
|
||||
)
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
|
@ -1,9 +1,9 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
@ -49,17 +49,16 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
@ -1,9 +1,9 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
@ -50,22 +50,21 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
|
||||
# rkuo: been seeing transient connection exceptions here, so upping the connection count
|
||||
# from just concurrency/concurrency to concurrency/concurrency*2
|
||||
SqlEngine.init_engine(
|
||||
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
|
||||
)
|
||||
# rkuo: Transient errors keep happening in the indexing watchdog threads.
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
# actually setting the spawn method in the cloud fixes 95% of these.
|
||||
# setting pre ping might help even more, but not worrying about that yet
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
@ -1,9 +1,9 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
@ -15,7 +15,6 @@ from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
@ -49,17 +48,18 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
95
backend/onyx/background/celery/apps/monitoring.py
Normal file
95
backend/onyx/background/celery/apps/monitoring.py
Normal file
@ -0,0 +1,95 @@
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_MONITORING_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_MONITORING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=3)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
]
|
||||
)
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@ -7,6 +6,7 @@ from celery import bootsteps # type: ignore
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
@ -73,14 +73,13 @@ def on_task_postrun(
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
@ -135,7 +134,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_lock = lock
|
||||
sender.primary_worker_lock = lock # type: ignore
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
|
21
backend/onyx/background/celery/configs/monitoring.py
Normal file
21
backend/onyx/background/celery/configs/monitoring.py
Normal file
@ -0,0 +1,21 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Monitoring worker specific settings
|
||||
worker_concurrency = 1 # Single worker is sufficient for monitoring
|
||||
worker_pool = "solo"
|
||||
worker_prefetch_multiplier = 1
|
@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@ -68,6 +69,16 @@ tasks_to_schedule = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
|
@ -17,7 +17,10 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@ -118,6 +121,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
# there should be no in-progress sync records if this is up to date
|
||||
# clean it up just in case things got into a bad state
|
||||
cleanup_sync_records(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
@ -126,6 +136,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
submitted=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
|
||||
try:
|
||||
|
@ -391,5 +391,7 @@ def update_external_document_permissions_task(
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error Syncing Document Permissions")
|
||||
logger.exception(
|
||||
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
return False
|
||||
|
@ -1,3 +1,4 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@ -862,11 +863,14 @@ def connector_indexing_proxy_task(
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
"""celery tasks are forked, but forking is unstable.
|
||||
This is a thread that proxies work to a spawned task."""
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
f"search_settings={search_settings_id} "
|
||||
f"mp_start_method={multiprocessing.get_start_method()}"
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
|
427
backend/onyx/background/celery/tasks/monitoring/tasks.py
Normal file
427
backend/onyx/background/celery/tasks/monitoring/tasks.py
Normal file
@ -0,0 +1,427 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SyncRecord
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
|
||||
_CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT = (
|
||||
"monitoring_connector_index_attempt_start_latency:{cc_pair_id}:{index_attempt_id}"
|
||||
)
|
||||
|
||||
_CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
|
||||
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
|
||||
)
|
||||
|
||||
|
||||
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
|
||||
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
|
||||
redis_std.set(key, "1", ex=24 * 60 * 60) # Expire after 1 day
|
||||
|
||||
|
||||
def _has_metric_been_emitted(redis_std: Redis, key: str) -> bool:
|
||||
"""Check if a metric has been emitted by checking for existence of Redis key"""
|
||||
return bool(redis_std.exists(key))
|
||||
|
||||
|
||||
class Metric(BaseModel):
|
||||
key: str | None # only required if we need to store that we have emitted this metric
|
||||
name: str
|
||||
value: Any
|
||||
tags: dict[str, str]
|
||||
|
||||
def log(self) -> None:
|
||||
"""Log the metric in a standardized format"""
|
||||
data = {
|
||||
"metric": self.name,
|
||||
"value": self.value,
|
||||
"tags": self.tags,
|
||||
}
|
||||
task_logger.info(json.dumps(data))
|
||||
|
||||
def emit(self) -> None:
|
||||
# Convert value to appropriate type
|
||||
float_value = (
|
||||
float(self.value) if isinstance(self.value, (int, float)) else None
|
||||
)
|
||||
int_value = int(self.value) if isinstance(self.value, int) else None
|
||||
string_value = str(self.value) if isinstance(self.value, str) else None
|
||||
bool_value = bool(self.value) if isinstance(self.value, bool) else None
|
||||
|
||||
if (
|
||||
float_value is None
|
||||
and int_value is None
|
||||
and string_value is None
|
||||
and bool_value is None
|
||||
):
|
||||
task_logger.error(
|
||||
f"Invalid metric value type: {type(self.value)} "
|
||||
f"({self.value}) for metric {self.name}."
|
||||
)
|
||||
return
|
||||
|
||||
# don't send None values over the wire
|
||||
data = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"metric_name": self.name,
|
||||
"float_value": float_value,
|
||||
"int_value": int_value,
|
||||
"string_value": string_value,
|
||||
"bool_value": bool_value,
|
||||
"tags": self.tags,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
optional_telemetry(
|
||||
record_type=RecordType.METRIC,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
|
||||
"""Collect metrics about queue lengths for different Celery queues"""
|
||||
metrics = []
|
||||
queue_mappings = {
|
||||
"celery_queue_length": "celery",
|
||||
"indexing_queue_length": "indexing",
|
||||
"sync_queue_length": "sync",
|
||||
"deletion_queue_length": "deletion",
|
||||
"pruning_queue_length": "pruning",
|
||||
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
}
|
||||
|
||||
for name, queue in queue_mappings.items():
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None,
|
||||
name=name,
|
||||
value=celery_get_queue_length(queue, redis_celery),
|
||||
tags={"queue": name},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _build_connector_start_latency_metric(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
recent_attempt: IndexAttempt,
|
||||
second_most_recent_attempt: IndexAttempt | None,
|
||||
redis_std: Redis,
|
||||
) -> Metric | None:
|
||||
if not recent_attempt.time_started:
|
||||
return None
|
||||
|
||||
# check if we already emitted a metric for this index attempt
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=recent_attempt.id,
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping metric for connector {cc_pair.connector.id} "
|
||||
f"index attempt {recent_attempt.id} because it has already been "
|
||||
"emitted"
|
||||
)
|
||||
return None
|
||||
|
||||
# Connector start latency
|
||||
# first run case - we should start as soon as it's created
|
||||
if not second_most_recent_attempt:
|
||||
desired_start_time = cc_pair.connector.time_created
|
||||
else:
|
||||
if not cc_pair.connector.refresh_freq:
|
||||
task_logger.error(
|
||||
"Found non-initial index attempt for connector "
|
||||
"without refresh_freq. This should never happen."
|
||||
)
|
||||
return None
|
||||
|
||||
desired_start_time = second_most_recent_attempt.time_updated + timedelta(
|
||||
seconds=cc_pair.connector.refresh_freq
|
||||
)
|
||||
|
||||
start_latency = (recent_attempt.time_started - desired_start_time).total_seconds()
|
||||
|
||||
return Metric(
|
||||
key=metric_key,
|
||||
name="connector_start_latency",
|
||||
value=start_latency,
|
||||
tags={},
|
||||
)
|
||||
|
||||
|
||||
def _build_run_success_metric(
|
||||
cc_pair: ConnectorCredentialPair, recent_attempt: IndexAttempt, redis_std: Redis
|
||||
) -> Metric | None:
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=recent_attempt.id,
|
||||
)
|
||||
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping metric for connector {cc_pair.connector.id} "
|
||||
f"index attempt {recent_attempt.id} because it has already been "
|
||||
"emitted"
|
||||
)
|
||||
return None
|
||||
|
||||
if recent_attempt.status in [
|
||||
IndexingStatus.SUCCESS,
|
||||
IndexingStatus.FAILED,
|
||||
IndexingStatus.CANCELED,
|
||||
]:
|
||||
return Metric(
|
||||
key=metric_key,
|
||||
name="connector_run_succeeded",
|
||||
value=recent_attempt.status == IndexingStatus.SUCCESS,
|
||||
tags={"source": str(cc_pair.connector.source)},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""Collect metrics about connector runs from the past hour"""
|
||||
# NOTE: use get_db_current_time since the IndexAttempt times are set based on DB time
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all connector credential pairs
|
||||
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
|
||||
|
||||
metrics = []
|
||||
for cc_pair in cc_pairs:
|
||||
# Get most recent attempt in the last hour
|
||||
recent_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair.id,
|
||||
IndexAttempt.time_created >= one_hour_ago,
|
||||
)
|
||||
.order_by(IndexAttempt.time_created.desc())
|
||||
.limit(2)
|
||||
.all()
|
||||
)
|
||||
recent_attempt = recent_attempts[0] if recent_attempts else None
|
||||
second_most_recent_attempt = (
|
||||
recent_attempts[1] if len(recent_attempts) > 1 else None
|
||||
)
|
||||
|
||||
# if no metric to emit, skip
|
||||
if not recent_attempt:
|
||||
continue
|
||||
|
||||
# Connector start latency
|
||||
start_latency_metric = _build_connector_start_latency_metric(
|
||||
cc_pair, recent_attempt, second_most_recent_attempt, redis_std
|
||||
)
|
||||
if start_latency_metric:
|
||||
metrics.append(start_latency_metric)
|
||||
|
||||
# Connector run success/failure
|
||||
run_success_metric = _build_run_success_metric(
|
||||
cc_pair, recent_attempt, redis_std
|
||||
)
|
||||
if run_success_metric:
|
||||
metrics.append(run_success_metric)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""Collect metrics about document set and group syncing speed"""
|
||||
# NOTE: use get_db_current_time since the SyncRecord times are set based on DB time
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all sync records from the last hour
|
||||
recent_sync_records = db_session.scalars(
|
||||
select(SyncRecord)
|
||||
.where(SyncRecord.sync_start_time >= one_hour_ago)
|
||||
.order_by(SyncRecord.sync_start_time.desc())
|
||||
).all()
|
||||
|
||||
metrics = []
|
||||
for sync_record in recent_sync_records:
|
||||
# Skip if no end time (sync still in progress)
|
||||
if not sync_record.sync_end_time:
|
||||
continue
|
||||
|
||||
# Check if we already emitted a metric for this sync record
|
||||
metric_key = (
|
||||
f"sync_speed:{sync_record.sync_type}:"
|
||||
f"{sync_record.entity_id}:{sync_record.id}"
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.debug(
|
||||
f"Skipping metric for sync record {sync_record.id} "
|
||||
"because it has already been emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate sync duration in minutes
|
||||
sync_duration_mins = (
|
||||
sync_record.sync_end_time - sync_record.sync_start_time
|
||||
).total_seconds() / 60.0
|
||||
|
||||
# Calculate sync speed (docs/min) - avoid division by zero
|
||||
sync_speed = (
|
||||
sync_record.num_docs_synced / sync_duration_mins
|
||||
if sync_duration_mins > 0
|
||||
else None
|
||||
)
|
||||
|
||||
if sync_speed is None:
|
||||
task_logger.error(
|
||||
"Something went wrong with sync speed calculation. "
|
||||
f"Sync record: {sync_record.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=metric_key,
|
||||
name="sync_speed_docs_per_min",
|
||||
value=sync_speed,
|
||||
tags={
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
"status": str(sync_record.sync_status),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add sync start latency metric
|
||||
start_latency_key = (
|
||||
f"sync_start_latency:{sync_record.sync_type}"
|
||||
f":{sync_record.entity_id}:{sync_record.id}"
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, start_latency_key):
|
||||
task_logger.debug(
|
||||
f"Skipping start latency metric for sync record {sync_record.id} "
|
||||
"because it has already been emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# Get the entity's last update time based on sync type
|
||||
entity: DocumentSet | UserGroup | None = None
|
||||
if sync_record.sync_type == SyncType.DOCUMENT_SET:
|
||||
entity = db_session.scalar(
|
||||
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
|
||||
)
|
||||
elif sync_record.sync_type == SyncType.USER_GROUP:
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
else:
|
||||
# Skip other sync types
|
||||
task_logger.debug(
|
||||
f"Skipping sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} "
|
||||
f"and id {sync_record.entity_id} "
|
||||
"because it is not a document set or user group"
|
||||
)
|
||||
continue
|
||||
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
f"Could not find entity for sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} and id {sync_record.entity_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate start latency in seconds
|
||||
start_latency = (
|
||||
sync_record.sync_start_time - entity.time_last_modified_by_user
|
||||
).total_seconds()
|
||||
if start_latency < 0:
|
||||
task_logger.error(
|
||||
f"Start latency is negative for sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} and id {sync_record.entity_id}."
|
||||
"This is likely because the entity was updated between the time the "
|
||||
"time the sync finished and this job ran. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_latency_key,
|
||||
name="sync_start_latency_seconds",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
"""Collect and emit metrics about background processes.
|
||||
This task runs periodically to gather metrics about:
|
||||
- Queue lengths for different Celery queues
|
||||
- Connector run metrics (start latency, success rate)
|
||||
- Syncing speed metrics
|
||||
- Worker status and task counts
|
||||
"""
|
||||
task_logger.info("Starting background process monitoring")
|
||||
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
lambda: _collect_queue_metrics(redis_celery),
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
# Collect and log each metric
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
metric.log()
|
||||
metric.emit()
|
||||
if metric.key:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background process metrics")
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception("Error collecting background process metrics")
|
||||
raise e
|
@ -1,6 +1,7 @@
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
@ -53,10 +54,16 @@ from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import mark_document_set_as_synced
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
@ -283,6 +290,13 @@ def try_generate_document_set_sync_tasks(
|
||||
return None
|
||||
|
||||
if document_set.is_up_to_date:
|
||||
# there should be no in-progress sync records if this is up to date
|
||||
# clean it up just in case things got into a bad state
|
||||
cleanup_sync_records(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
@ -311,6 +325,13 @@ def try_generate_document_set_sync_tasks(
|
||||
f"document_set={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
# set this only after all tasks have been added
|
||||
rds.set_fence(tasks_generated)
|
||||
return tasks_generated
|
||||
@ -332,8 +353,9 @@ def try_generate_user_group_sync_tasks(
|
||||
return None
|
||||
|
||||
# race condition with the monitor/cleanup function if we use a cached result!
|
||||
fetch_user_group = fetch_versioned_implementation(
|
||||
"onyx.db.user_group", "fetch_user_group"
|
||||
fetch_user_group = cast(
|
||||
Callable[[Session, int], UserGroup | None],
|
||||
fetch_versioned_implementation("onyx.db.user_group", "fetch_user_group"),
|
||||
)
|
||||
|
||||
usergroup = fetch_user_group(db_session, usergroup_id)
|
||||
@ -341,6 +363,13 @@ def try_generate_user_group_sync_tasks(
|
||||
return None
|
||||
|
||||
if usergroup.is_up_to_date:
|
||||
# there should be no in-progress sync records if this is up to date
|
||||
# clean it up just in case things got into a bad state
|
||||
cleanup_sync_records(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
@ -368,8 +397,16 @@ def try_generate_user_group_sync_tasks(
|
||||
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
# set this only after all tasks have been added
|
||||
rug.set_fence(tasks_generated)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
|
||||
@ -419,6 +456,13 @@ def monitor_document_set_taskset(
|
||||
f"remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=count,
|
||||
)
|
||||
return
|
||||
|
||||
document_set = cast(
|
||||
@ -437,6 +481,13 @@ def monitor_document_set_taskset(
|
||||
task_logger.info(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
)
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
rds.reset()
|
||||
|
||||
@ -470,6 +521,14 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=remaining,
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@ -545,11 +604,29 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
|
15
backend/onyx/background/celery/versioned_apps/monitoring.py
Normal file
15
backend/onyx/background/celery/versioned_apps/monitoring.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.monitoring import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
@ -4,9 +4,10 @@ not follow the expected behavior, etc.
|
||||
|
||||
NOTE: cannot use Celery directly due to
|
||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||
import multiprocessing as mp
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.context import SpawnProcess
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
@ -63,7 +64,7 @@ class SimpleJob:
|
||||
"""Drop in replacement for `dask.distributed.Future`"""
|
||||
|
||||
id: int
|
||||
process: Optional["Process"] = None
|
||||
process: Optional["SpawnProcess"] = None
|
||||
|
||||
def cancel(self) -> bool:
|
||||
return self.release()
|
||||
@ -131,7 +132,10 @@ class SimpleJobClient:
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
process = Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
# this approach allows us to always "spawn" a new process regardless of
|
||||
# get_start_method's current setting
|
||||
ctx = mp.get_context("spawn")
|
||||
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
|
@ -47,6 +47,7 @@ POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
@ -260,6 +261,9 @@ class OnyxCeleryQueues:
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
|
||||
|
||||
class OnyxRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
@ -308,6 +312,7 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
|
@ -135,32 +135,6 @@ class OnyxConfluence(Confluence):
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
@ -363,6 +337,9 @@ class OnyxConfluence(Confluence):
|
||||
fetch the permissions of a space.
|
||||
This is better logging than calling the get_space_permissions method
|
||||
because it returns a jsonrpc response.
|
||||
TODO: Make this call these endpoints for newer confluence versions:
|
||||
- /rest/api/space/{spaceKey}/permissions
|
||||
- /rest/api/space/{spaceKey}/permissions/anonymous
|
||||
"""
|
||||
url = "rpc/json-rpc/confluenceservice-v2"
|
||||
data = {
|
||||
@ -381,6 +358,32 @@ class OnyxConfluence(Confluence):
|
||||
|
||||
return response.get("result", [])
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
|
||||
def _validate_connector_configuration(
|
||||
credentials: dict[str, Any],
|
||||
|
@ -218,6 +218,7 @@ def insert_document_set(
|
||||
description=document_set_creation_request.description,
|
||||
user_id=user_id,
|
||||
is_public=document_set_creation_request.is_public,
|
||||
time_updated=func.now(),
|
||||
)
|
||||
db_session.add(new_document_set_row)
|
||||
db_session.flush() # ensure the new document set gets assigned an ID
|
||||
@ -293,7 +294,7 @@ def update_document_set(
|
||||
document_set_row.description = document_set_update_request.description
|
||||
document_set_row.is_up_to_date = False
|
||||
document_set_row.is_public = document_set_update_request.is_public
|
||||
|
||||
document_set_row.time_last_modified_by_user = func.now()
|
||||
versioned_private_doc_set_fn = fetch_versioned_implementation(
|
||||
"onyx.db.document_set", "make_doc_set_private"
|
||||
)
|
||||
|
@ -24,12 +24,27 @@ class IndexingMode(str, PyEnum):
|
||||
REINDEX = "reindex"
|
||||
|
||||
|
||||
# these may differ in the future, which is why we're okay with this duplication
|
||||
class DeletionStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
class SyncType(str, PyEnum):
|
||||
DOCUMENT_SET = "document_set"
|
||||
USER_GROUP = "user_group"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class SyncStatus(str, PyEnum):
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
CANCELED = "canceled"
|
||||
|
||||
def is_terminal(self) -> bool:
|
||||
terminal_states = {
|
||||
SyncStatus.SUCCESS,
|
||||
SyncStatus.FAILED,
|
||||
}
|
||||
return self in terminal_states
|
||||
|
||||
|
||||
# Consistent with Celery task statuses
|
||||
|
@ -44,7 +44,7 @@ from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import AccessType, IndexingMode
|
||||
from onyx.db.enums import AccessType, IndexingMode, SyncType, SyncStatus
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.constants import TokenRateLimitScope
|
||||
@ -881,6 +881,46 @@ class IndexAttemptError(Base):
|
||||
)
|
||||
|
||||
|
||||
class SyncRecord(Base):
|
||||
"""
|
||||
Represents the status of a "sync" operation (e.g. document set, user group, deletion).
|
||||
|
||||
A "sync" operation is an operation which needs to update a set of documents within
|
||||
Vespa, usually to match the state of Postgres.
|
||||
"""
|
||||
|
||||
__tablename__ = "sync_record"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
# document set id, user group id, or deletion id
|
||||
entity_id: Mapped[int] = mapped_column(Integer)
|
||||
|
||||
sync_type: Mapped[SyncType] = mapped_column(Enum(SyncType, native_enum=False))
|
||||
sync_status: Mapped[SyncStatus] = mapped_column(Enum(SyncStatus, native_enum=False))
|
||||
|
||||
num_docs_synced: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
sync_start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
sync_end_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_sync_record_entity_id_sync_type_sync_start_time",
|
||||
"entity_id",
|
||||
"sync_type",
|
||||
"sync_start_time",
|
||||
),
|
||||
Index(
|
||||
"ix_sync_record_entity_id_sync_type_sync_status",
|
||||
"entity_id",
|
||||
"sync_type",
|
||||
"sync_status",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DocumentByConnectorCredentialPair(Base):
|
||||
"""Represents an indexing of a document by a specific connector / credential pair"""
|
||||
|
||||
@ -1284,6 +1324,11 @@ class DocumentSet(Base):
|
||||
# given access to it either via the `users` or `groups` relationships
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
# Last time a user updated this document set
|
||||
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
connector_credential_pairs: Mapped[list[ConnectorCredentialPair]] = relationship(
|
||||
"ConnectorCredentialPair",
|
||||
secondary=DocumentSet__ConnectorCredentialPair.__table__,
|
||||
@ -1763,6 +1808,11 @@ class UserGroup(Base):
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
||||
# Last time a user updated this user group
|
||||
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
users: Mapped[list[User]] = relationship(
|
||||
"User",
|
||||
secondary=User__UserGroup.__table__,
|
||||
|
110
backend/onyx/db/sync_record.py
Normal file
110
backend/onyx/db/sync_record.py
Normal file
@ -0,0 +1,110 @@
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import SyncRecord
|
||||
|
||||
|
||||
def insert_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int | None,
|
||||
sync_type: SyncType,
|
||||
) -> SyncRecord:
|
||||
"""Insert a new sync record into the database.
|
||||
|
||||
Args:
|
||||
db_session: The database session to use
|
||||
entity_id: The ID of the entity being synced (document set ID, user group ID, etc.)
|
||||
sync_type: The type of sync operation
|
||||
"""
|
||||
sync_record = SyncRecord(
|
||||
entity_id=entity_id,
|
||||
sync_type=sync_type,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=0,
|
||||
sync_start_time=func.now(),
|
||||
)
|
||||
db_session.add(sync_record)
|
||||
db_session.commit()
|
||||
|
||||
return sync_record
|
||||
|
||||
|
||||
def fetch_latest_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int,
|
||||
sync_type: SyncType,
|
||||
) -> SyncRecord | None:
|
||||
"""Fetch the most recent sync record for a given entity ID and status.
|
||||
|
||||
Args:
|
||||
db_session: The database session to use
|
||||
entity_id: The ID of the entity to fetch sync record for
|
||||
sync_type: The type of sync operation
|
||||
"""
|
||||
stmt = (
|
||||
select(SyncRecord)
|
||||
.where(
|
||||
and_(
|
||||
SyncRecord.entity_id == entity_id,
|
||||
SyncRecord.sync_type == sync_type,
|
||||
)
|
||||
)
|
||||
.order_by(desc(SyncRecord.sync_start_time))
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def update_sync_record_status(
|
||||
db_session: Session,
|
||||
entity_id: int,
|
||||
sync_type: SyncType,
|
||||
sync_status: SyncStatus,
|
||||
num_docs_synced: int | None = None,
|
||||
) -> None:
|
||||
"""Update the status of a sync record.
|
||||
|
||||
Args:
|
||||
db_session: The database session to use
|
||||
entity_id: The ID of the entity being synced
|
||||
sync_type: The type of sync operation
|
||||
sync_status: The new status to set
|
||||
num_docs_synced: Optional number of documents synced to update
|
||||
"""
|
||||
sync_record = fetch_latest_sync_record(db_session, entity_id, sync_type)
|
||||
if sync_record is None:
|
||||
raise ValueError(
|
||||
f"No sync record found for entity_id={entity_id} sync_type={sync_type}"
|
||||
)
|
||||
|
||||
sync_record.sync_status = sync_status
|
||||
if num_docs_synced is not None:
|
||||
sync_record.num_docs_synced = num_docs_synced
|
||||
|
||||
if sync_status.is_terminal():
|
||||
sync_record.sync_end_time = func.now() # type: ignore
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cleanup_sync_records(
|
||||
db_session: Session, entity_id: int, sync_type: SyncType
|
||||
) -> None:
|
||||
"""Cleanup sync records for a given entity ID and sync type by marking them as failed."""
|
||||
stmt = (
|
||||
update(SyncRecord)
|
||||
.where(SyncRecord.entity_id == entity_id)
|
||||
.where(SyncRecord.sync_type == sync_type)
|
||||
.where(SyncRecord.sync_status == SyncStatus.IN_PROGRESS)
|
||||
.values(sync_status=SyncStatus.CANCELED, sync_end_time=func.now())
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
@ -33,6 +33,7 @@ class RecordType(str, Enum):
|
||||
USAGE = "usage"
|
||||
LATENCY = "latency"
|
||||
FAILURE = "failure"
|
||||
METRIC = "metric"
|
||||
|
||||
|
||||
def get_or_generate_uuid() -> str:
|
||||
|
@ -65,6 +65,18 @@ autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
[program:celery_worker_monitoring]
|
||||
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
|
||||
--loglevel=INFO
|
||||
--hostname=monitoring@%%n
|
||||
-Q monitoring
|
||||
stdout_logfile=/var/log/celery_worker_monitoring.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
# Job scheduler for periodic tasks
|
||||
[program:celery_beat]
|
||||
command=celery -A onyx.background.celery.versioned_apps.beat beat
|
||||
|
@ -63,57 +63,57 @@ def _run_migrations(
|
||||
logging.getLogger("alembic").setLevel(logging.INFO)
|
||||
|
||||
|
||||
def reset_postgres(
|
||||
database: str = "postgres", config_name: str = "alembic", setup_onyx: bool = True
|
||||
def downgrade_postgres(
|
||||
database: str = "postgres",
|
||||
config_name: str = "alembic",
|
||||
revision: str = "base",
|
||||
clear_data: bool = False,
|
||||
) -> None:
|
||||
"""Reset the Postgres database."""
|
||||
"""Downgrade Postgres database to base state."""
|
||||
if clear_data:
|
||||
if revision != "base":
|
||||
logger.warning("Clearing data without rolling back to base state")
|
||||
# Delete all rows to allow migrations to be rolled back
|
||||
conn = psycopg2.connect(
|
||||
dbname=database,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
)
|
||||
cur = conn.cursor()
|
||||
|
||||
# NOTE: need to delete all rows to allow migrations to be rolled back
|
||||
# as there are a few downgrades that don't properly handle data in tables
|
||||
conn = psycopg2.connect(
|
||||
dbname=database,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
)
|
||||
cur = conn.cursor()
|
||||
# Disable triggers to prevent foreign key constraints from being checked
|
||||
cur.execute("SET session_replication_role = 'replica';")
|
||||
|
||||
# Disable triggers to prevent foreign key constraints from being checked
|
||||
cur.execute("SET session_replication_role = 'replica';")
|
||||
|
||||
# Fetch all table names in the current database
|
||||
cur.execute(
|
||||
# Fetch all table names in the current database
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT tablename
|
||||
FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
"""
|
||||
SELECT tablename
|
||||
FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
tables = cur.fetchall()
|
||||
tables = cur.fetchall()
|
||||
|
||||
for table in tables:
|
||||
table_name = table[0]
|
||||
for table in tables:
|
||||
table_name = table[0]
|
||||
|
||||
# Don't touch migration history
|
||||
if table_name == "alembic_version":
|
||||
continue
|
||||
# Don't touch migration history or Kombu
|
||||
if table_name in ("alembic_version", "kombu_message", "kombu_queue"):
|
||||
continue
|
||||
|
||||
# Don't touch Kombu
|
||||
if table_name == "kombu_message" or table_name == "kombu_queue":
|
||||
continue
|
||||
cur.execute(f'DELETE FROM "{table_name}"')
|
||||
|
||||
cur.execute(f'DELETE FROM "{table_name}"')
|
||||
# Re-enable triggers
|
||||
cur.execute("SET session_replication_role = 'origin';")
|
||||
|
||||
# Re-enable triggers
|
||||
cur.execute("SET session_replication_role = 'origin';")
|
||||
conn.commit()
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
conn.commit()
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
# downgrade to base + upgrade back to head
|
||||
# Downgrade to base
|
||||
conn_str = build_connection_string(
|
||||
db=database,
|
||||
user=POSTGRES_USER,
|
||||
@ -126,20 +126,43 @@ def reset_postgres(
|
||||
conn_str,
|
||||
config_name,
|
||||
direction="downgrade",
|
||||
revision="base",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
|
||||
def upgrade_postgres(
|
||||
database: str = "postgres", config_name: str = "alembic", revision: str = "head"
|
||||
) -> None:
|
||||
"""Upgrade Postgres database to latest version."""
|
||||
conn_str = build_connection_string(
|
||||
db=database,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
db_api=SYNC_DB_API,
|
||||
)
|
||||
_run_migrations(
|
||||
conn_str,
|
||||
config_name,
|
||||
direction="upgrade",
|
||||
revision="head",
|
||||
revision=revision,
|
||||
)
|
||||
if not setup_onyx:
|
||||
return
|
||||
|
||||
# do the same thing as we do on API server startup
|
||||
with get_session_context_manager() as db_session:
|
||||
setup_postgres(db_session)
|
||||
|
||||
def reset_postgres(
|
||||
database: str = "postgres",
|
||||
config_name: str = "alembic",
|
||||
setup_onyx: bool = True,
|
||||
) -> None:
|
||||
"""Reset the Postgres database."""
|
||||
downgrade_postgres(
|
||||
database=database, config_name=config_name, revision="base", clear_data=True
|
||||
)
|
||||
upgrade_postgres(database=database, config_name=config_name, revision="head")
|
||||
if setup_onyx:
|
||||
with get_session_context_manager() as db_session:
|
||||
setup_postgres(db_session)
|
||||
|
||||
|
||||
def reset_vespa() -> None:
|
||||
|
125
backend/tests/integration/tests/migrations/test_migrations.py
Normal file
125
backend/tests/integration/tests/migrations/test_migrations.py
Normal file
@ -0,0 +1,125 @@
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from tests.integration.common_utils.reset import downgrade_postgres
|
||||
from tests.integration.common_utils.reset import upgrade_postgres
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Migration test no longer needed - migration has been applied to production"
|
||||
)
|
||||
def test_fix_capitalization_migration() -> None:
|
||||
"""Test that the be2ab2aa50ee migration correctly lowercases external_user_group_ids"""
|
||||
# Reset the database and run migrations up to the second to last migration
|
||||
downgrade_postgres(
|
||||
database="postgres", config_name="alembic", revision="base", clear_data=True
|
||||
)
|
||||
upgrade_postgres(
|
||||
database="postgres",
|
||||
config_name="alembic",
|
||||
# Upgrade it to the migration before the fix
|
||||
revision="369644546676",
|
||||
)
|
||||
|
||||
# Insert test data with mixed case group IDs
|
||||
test_data = [
|
||||
{
|
||||
"id": "test_doc_1",
|
||||
"external_user_group_ids": ["Group1", "GROUP2", "group3"],
|
||||
"semantic_id": "test_doc_1",
|
||||
"boost": DEFAULT_BOOST,
|
||||
"hidden": False,
|
||||
"from_ingestion_api": False,
|
||||
"last_modified": "NOW()",
|
||||
},
|
||||
{
|
||||
"id": "test_doc_2",
|
||||
"external_user_group_ids": ["UPPER1", "upper2", "UPPER3"],
|
||||
"semantic_id": "test_doc_2",
|
||||
"boost": DEFAULT_BOOST,
|
||||
"hidden": False,
|
||||
"from_ingestion_api": False,
|
||||
"last_modified": "NOW()",
|
||||
},
|
||||
]
|
||||
|
||||
# Insert the test data
|
||||
with get_session_context_manager() as db_session:
|
||||
for doc in test_data:
|
||||
db_session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO document (
|
||||
id,
|
||||
external_user_group_ids,
|
||||
semantic_id,
|
||||
boost,
|
||||
hidden,
|
||||
from_ingestion_api,
|
||||
last_modified
|
||||
)
|
||||
VALUES (
|
||||
:id,
|
||||
:group_ids,
|
||||
:semantic_id,
|
||||
:boost,
|
||||
:hidden,
|
||||
:from_ingestion_api,
|
||||
:last_modified
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": doc["id"],
|
||||
"group_ids": doc["external_user_group_ids"],
|
||||
"semantic_id": doc["semantic_id"],
|
||||
"boost": doc["boost"],
|
||||
"hidden": doc["hidden"],
|
||||
"from_ingestion_api": doc["from_ingestion_api"],
|
||||
"last_modified": doc["last_modified"],
|
||||
},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Verify the data was inserted correctly
|
||||
with get_session_context_manager() as db_session:
|
||||
results = db_session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, external_user_group_ids
|
||||
FROM document
|
||||
WHERE id IN ('test_doc_1', 'test_doc_2')
|
||||
ORDER BY id
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
# Verify initial state
|
||||
assert len(results) == 2
|
||||
assert results[0].external_user_group_ids == ["Group1", "GROUP2", "group3"]
|
||||
assert results[1].external_user_group_ids == ["UPPER1", "upper2", "UPPER3"]
|
||||
|
||||
# Run migrations again to apply the fix
|
||||
upgrade_postgres(
|
||||
database="postgres", config_name="alembic", revision="be2ab2aa50ee"
|
||||
)
|
||||
|
||||
# Verify the fix was applied
|
||||
with get_session_context_manager() as db_session:
|
||||
results = db_session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, external_user_group_ids
|
||||
FROM document
|
||||
WHERE id IN ('test_doc_1', 'test_doc_2')
|
||||
ORDER BY id
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
# Verify all group IDs are lowercase
|
||||
assert len(results) == 2
|
||||
assert results[0].external_user_group_ids == ["group1", "group2", "group3"]
|
||||
assert results[1].external_user_group_ids == ["upper1", "upper2", "upper3"]
|
@ -102,12 +102,13 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
|
||||
// meta models
|
||||
"llama-3.2-90b-vision-instruct",
|
||||
"llama-3.2-11b-vision-instruct",
|
||||
"Llama-3-2-11B-Vision-Instruct-yb",
|
||||
];
|
||||
|
||||
export function checkLLMSupportsImageInput(model: string) {
|
||||
// Original exact match check
|
||||
const exactMatch = MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some(
|
||||
(modelName) => modelName === model
|
||||
(modelName) => modelName.toLowerCase() === model.toLowerCase()
|
||||
);
|
||||
|
||||
if (exactMatch) {
|
||||
@ -116,12 +117,13 @@ export function checkLLMSupportsImageInput(model: string) {
|
||||
|
||||
// Additional check for the last part of the model name
|
||||
const modelParts = model.split(/[/.]/);
|
||||
const lastPart = modelParts[modelParts.length - 1];
|
||||
const lastPart = modelParts[modelParts.length - 1]?.toLowerCase();
|
||||
|
||||
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) => {
|
||||
const modelNameParts = modelName.split(/[/.]/);
|
||||
const modelNameLastPart = modelNameParts[modelNameParts.length - 1];
|
||||
return modelNameLastPart === lastPart;
|
||||
// lastPart is already lowercased above for tiny performance gain
|
||||
return modelNameLastPart?.toLowerCase() === lastPart;
|
||||
});
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user