Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/no_scan_iter

# Conflicts:
#	backend/onyx/background/celery/tasks/vespa/tasks.py
#	backend/onyx/redis/redis_connector_doc_perm_sync.py
This commit is contained in:
Richard Kuo (Danswer) 2025-01-31 10:38:10 -08:00
commit 5232aeacad
68 changed files with 1281 additions and 1133 deletions

View File

@ -0,0 +1,80 @@
"""foreign key input prompts
Revision ID: 33ea50e88f24
Revises: a6df6b88ef81
Create Date: 2025-01-29 10:54:22.141765
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "33ea50e88f24"
down_revision = "a6df6b88ef81"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Safely drop constraints if exists
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
"""
)
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
"""
)
# Recreate with ON DELETE CASCADE
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Drop the new FKs with ondelete
op.drop_constraint(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
op.drop_constraint(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
# Recreate them without cascading
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
)

View File

@ -0,0 +1,29 @@
"""remove recent assistants
Revision ID: a6df6b88ef81
Revises: 4d58345da04a
Create Date: 2025-01-29 10:25:52.790407
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a6df6b88ef81"
down_revision = "4d58345da04a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("user", "recent_assistants")
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)

View File

@ -13,6 +13,7 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -257,6 +258,7 @@ def _fetch_all_page_restrictions(
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
@ -265,6 +267,12 @@ def _fetch_all_page_restrictions(
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
if slim_doc.perm_sync_data is None:
raise ValueError(
f"No permission sync data found for document {slim_doc.id}"
@ -334,7 +342,7 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@ -359,6 +367,12 @@ def confluence_doc_sync(
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync", 1)
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
@ -367,4 +381,5 @@ def confluence_doc_sync(
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
callback=callback,
)

View File

@ -14,6 +14,8 @@ def _build_group_member_email_map(
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user_result}")
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
@ -33,10 +35,17 @@ def _build_group_member_email_map(
logger.warning(f"user result missing email field: {user_result}")
continue
all_users_groups: set[str] = set()
for group in confluence_client.paginated_groups_by_user_retrieval(user):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)
all_users_groups.add(group_id)
if not group_member_emails:
logger.warning(f"No groups found for user with email: {email}")
else:
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
return group_member_emails

View File

@ -6,6 +6,7 @@ from onyx.access.models import ExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -28,7 +29,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@ -44,6 +45,12 @@ def gmail_doc_sync(
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gmail_doc_sync: Stop signal detected")
callback.progress("gmail_doc_sync", 1)
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue

View File

@ -10,6 +10,7 @@ from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -128,7 +129,7 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@ -146,6 +147,12 @@ def gdrive_doc_sync(
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
callback.progress("gdrive_doc_sync", 1)
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,

View File

@ -7,6 +7,7 @@ from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@ -14,7 +15,7 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
@ -24,6 +25,14 @@ def _get_slack_document_ids_and_channels(
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
@ -114,7 +123,7 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@ -127,7 +136,7 @@ def slack_doc_sync(
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair,
cc_pair=cc_pair, callback=callback
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,

View File

@ -15,11 +15,13 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
list[DocExternalAccess],
]

View File

@ -111,6 +111,7 @@ async def login_as_anonymous_user(
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie("fastapiusersauth")
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,

View File

@ -198,7 +198,8 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_redis_client(tenant_id=None)

View File

@ -91,6 +91,28 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
return False
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.
"""
task_set: set[str] = set()
for priority in range(len(OnyxCeleryPriority)):
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
for task in tasks:
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
task_id = task_dict.get("headers", {}).get("id")
if task_id:
task_set.add(task_id)
return task_set
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
"""Returns a list of current workers containing name_filter, or all workers if
name_filter is None.

View File

@ -3,13 +3,16 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from time import sleep
from typing import cast
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@ -22,6 +25,10 @@ from ee.onyx.external_permissions.sync_params import (
)
from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@ -32,6 +39,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import upsert_document_by_connector_credential_pair
@ -44,14 +52,19 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -105,7 +118,12 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# TODO(rkuo): merge into check function after lookup table for fences is added
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
@ -126,14 +144,32 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
if _is_external_doc_permissions_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
payload_id = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
if not payload_id:
continue
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
task_logger.info(
f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}"
)
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES):
# clear any permission fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_permission_sync_fences(tenant_id, r, r_celery, lock_beat)
except Exception:
task_logger.exception(
"Exception while validating permission sync fences"
)
r.set(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@ -152,13 +188,15 @@ def try_creating_permissions_sync_task(
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
) -> str | None:
"""Returns a randomized payload id on success.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
payload_id: str | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
@ -193,7 +231,13 @@ def try_creating_permissions_sync_task(
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncPayload(started=None, celery_task_id=None)
redis_connector.permissions.set_active()
payload = RedisConnectorPermissionSyncPayload(
id=make_short_id(),
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
redis_connector.permissions.set_fence(payload)
result = app.send_task(
@ -208,8 +252,11 @@ def try_creating_permissions_sync_task(
)
# fill in the celery task id
redis_connector.permissions.set_active()
payload.celery_task_id = result.id
redis_connector.permissions.set_fence(payload)
payload_id = payload.celery_task_id
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
@ -217,7 +264,7 @@ def try_creating_permissions_sync_task(
if lock.owned():
lock.release()
return 1
return payload_id
@shared_task(
@ -238,6 +285,8 @@ def connector_permission_sync_generator_task(
This task assumes that the task has already been properly fenced
"""
LoggerContextVars.reset()
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
doc_permission_sync_ctx_dict["request_id"] = self.request.id
@ -325,12 +374,17 @@ def connector_permission_sync_generator_task(
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
new_payload = RedisConnectorPermissionSyncPayload(
id=payload.id,
submitted=payload.submitted,
started=datetime.now(timezone.utc),
celery_task_id=payload.celery_task_id,
)
redis_connector.permissions.set_fence(new_payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
callback = PermissionSyncCallback(redis_connector, lock, r)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
cc_pair, callback
)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
@ -380,6 +434,8 @@ def update_external_document_permissions_task(
connector_id: int,
credential_id: int,
) -> bool:
start = time.monotonic()
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
@ -409,16 +465,268 @@ def update_external_document_permissions_task(
document_ids=[doc_id],
)
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
elapsed = time.monotonic() - start
task_logger.info(
f"connector_id={connector_id} "
f"doc={doc_id} "
f"action=update_permissions "
f"elapsed={elapsed:.2f}"
)
return True
except Exception:
logger.exception(
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
task_logger.exception(
f"Exception in update_external_document_permissions_task: "
f"connector_id={connector_id} "
f"doc_id={doc_id}"
)
return False
return True
def validate_permission_sync_fences(
tenant_id: str | None,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
reserved_generator_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
lock_beat.reacquire()
validate_permission_sync_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
reserved_generator_tasks,
r,
r_celery,
)
return
def validate_permission_sync_fence(
tenant_id: str | None,
key_bytes: bytes,
queued_tasks: set[str],
reserved_tasks: set[str],
r: Redis,
r_celery: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_permission_sync_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.permissions.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.permissions.payload
except ValidationError:
task_logger.exception(
"validate_permission_sync_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.permissions.reset()
return
if not payload:
return
if not payload.celery_task_id:
return
# OK, there's actually something for us to validate
# either the generator task must be in flight or its subtasks must be
found = celery_find_task(
payload.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
r_celery,
)
if found:
# the celery task exists in the redis queue
redis_connector.permissions.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within a worker
redis_connector.permissions.set_active()
return
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.permissions.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
if member_str in reserved_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_permission_sync_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
if tasks_not_in_celery == 0:
redis_connector.permissions.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.permissions.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_permission_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.permissions.reset()
return
class PermissionSyncCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.redis_connector: RedisConnector = redis_connector
self.redis_lock: RedisLock = redis_lock
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "PermissionSyncCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
return False
def progress(self, tag: str, amount: int) -> None:
try:
self.redis_connector.permissions.set_active()
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_tag = tag
except LockError:
logger.exception(
f"PermissionSyncCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
redis_lock_dump(self.redis_lock, self.redis_client)
raise
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
@ -444,20 +752,36 @@ def monitor_ccpair_permissions_taskset(
if initial is None:
return
try:
payload = redis_connector.permissions.payload
except ValidationError:
task_logger.exception(
"Permissions sync payload failed to validate. "
"Schema may have been updated."
)
return
if not payload:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
f"Permissions sync progress: "
f"cc_pair={cc_pair_id} "
f"id={payload.id} "
f"remaining={remaining} "
f"initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started)
task_logger.info(
f"Permissions sync finished: "
f"cc_pair={cc_pair_id} "
f"id={payload.id} "
f"num_synced={initial}"
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
update_sync_record_status(
db_session=db_session,

View File

@ -1,3 +1,4 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@ -9,6 +10,7 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
@ -20,9 +22,12 @@ from ee.onyx.external_permissions.sync_params import (
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@ -39,10 +44,12 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -102,6 +109,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
@ -136,6 +147,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
@ -144,6 +156,23 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
continue
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
# we want to run this less frequently than the overall task
# lock_beat.reacquire()
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
# # clear any indexing fences that don't have associated celery tasks in progress
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# # or be currently executing
# try:
# validate_external_group_sync_fences(
# tenant_id, self.app, r, r_celery, lock_beat
# )
# except Exception:
# task_logger.exception(
# "Exception while validating external group sync fences"
# )
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@ -186,6 +215,12 @@ def try_creating_external_group_sync_task(
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
payload = RedisConnectorExternalGroupSyncPayload(
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
result = app.send_task(
@ -199,11 +234,6 @@ def try_creating_external_group_sync_task(
priority=OnyxCeleryPriority.HIGH,
)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
@ -213,8 +243,8 @@ def try_creating_external_group_sync_task(
sync_type=SyncType.EXTERNAL_GROUP,
)
payload.celery_task_id = result.id
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
@ -241,7 +271,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles external group syncing for a given connector credential pair
External group sync task for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
@ -249,19 +279,59 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
if not redis_connector.external_group_sync.fenced: # The fence must exist
raise ValueError(
f"connector_external_group_sync_generator_task - fence not found: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
payload = redis_connector.external_group_sync.payload # The payload must exist
if not payload:
raise ValueError(
"connector_external_group_sync_generator_task: payload invalid or not found"
)
if payload.celery_task_id is None:
logger.info(
f"connector_external_group_sync_generator_task - Waiting for fence: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
time.sleep(1)
continue
logger.info(
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
break
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
try:
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
payload.started = datetime.now(timezone.utc)
redis_connector.external_group_sync.set_fence(payload)
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
@ -330,3 +400,135 @@ def connector_external_group_sync_generator_task(
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()
def validate_external_group_sync_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_sync_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_external_group_sync_fence(
tenant_id,
key_bytes,
reserved_sync_tasks,
r_celery,
db_session,
)
return
def validate_external_group_sync_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.external_group_sync.fenced:
return
payload = redis_connector.external_group_sync.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
# if redis_connector_index.active():
# return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
"validate_external_group_sync_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector.external_group_sync.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
if found:
# the celery task exists in the redis queue
# redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
# redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
# if redis_connector_index.active():
# return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
"validate_external_group_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.external_group_sync.reset()
return

View File

@ -39,6 +39,7 @@ from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
@ -251,6 +252,8 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
LoggerContextVars.reset()
pruning_ctx_dict = pruning_ctx.get()
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
pruning_ctx_dict["request_id"] = self.request.id
@ -399,7 +402,7 @@ def monitor_ccpair_pruning_taskset(
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}"
)
update_sync_record_status(

View File

@ -75,6 +75,8 @@ def document_by_cc_pair_cleanup_task(
"""
task_logger.debug(f"Task start: doc={document_id}")
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
@ -154,11 +156,13 @@ def document_by_cc_pair_cleanup_task(
db_session.commit()
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected}"
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")

View File

@ -932,6 +932,9 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"Soft time limit exceeded, task is being terminated gracefully."
)
return False
except Exception:
task_logger.exception("monitor_vespa_sync exceptioned.")
return False
finally:
if lock_beat.owned():
lock_beat.release()
@ -1021,6 +1024,7 @@ def vespa_metadata_sync_task(
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(

View File

@ -478,6 +478,12 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
# Enable multi-threaded embedding model calls for parallel processing
# Note: only applies for API-based embedding models
INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1
)
# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)

View File

@ -300,6 +300,8 @@ class OnyxRedisLocks:
class OnyxRedisSignals:
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences"
VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences"
class OnyxRedisConstants:

View File

@ -1,3 +1,5 @@
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any
@ -274,6 +276,11 @@ class AirtableConnector(LoadConnector):
field_val = fields.get(field_name)
field_type = field_schema.type
logger.debug(
f"Processing field '{field_name}' of type '{field_type}' "
f"for record '{record_id}'."
)
field_sections, field_metadata = self._process_field(
field_id=field_schema.id,
field_name=field_name,
@ -327,19 +334,45 @@ class AirtableConnector(LoadConnector):
primary_field_name = field.name
break
record_documents: list[Document] = []
for record in records:
document = self._process_record(
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
if document:
record_documents.append(document)
logger.info(f"Starting to process Airtable records for {table.name}.")
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 16
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents: list[Document] = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
future_to_record = {
executor.submit(
self._process_record,
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
): record
for record in batch_records
}
# Wait for all tasks in this batch to complete
for future in as_completed(future_to_record):
record = future_to_record[future]
try:
document = future.result()
if document:
record_documents.append(document)
except Exception as e:
logger.exception(f"Failed to process record {record['id']}")
raise e
# After batch is complete, yield if we've hit the batch size
if len(record_documents) >= self.batch_size:
yield record_documents
record_documents = []
# Yield any remaining records
if record_documents:
yield record_documents

View File

@ -1,4 +1,5 @@
import sys
import time
from datetime import datetime
from onyx.connectors.interfaces import BaseConnector
@ -45,7 +46,17 @@ class ConnectorRunner:
def run(self) -> GenerateDocumentsOutput:
"""Adds additional exception logging to the connector."""
try:
yield from self.doc_batch_generator
start = time.monotonic()
for batch in self.doc_batch_generator:
# to know how long connector is taking
logger.debug(
f"Connector took {time.monotonic() - start} seconds to build a batch."
)
yield batch
start = time.monotonic()
except Exception:
exc_type, _, exc_traceback = sys.exc_info()

View File

@ -150,6 +150,16 @@ class Document(DocumentBase):
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
source: DocumentSource
def get_total_char_length(self) -> int:
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
section_length = sum(len(section.text) for section in self.sections)
identifier_length = len(self.semantic_identifier) + len(self.title or "")
metadata_length = sum(
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
for k, v in self.metadata.items()
)
return section_length + identifier_length + metadata_length
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a document"""
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"

View File

@ -127,13 +127,6 @@ class SharepointConnector(LoadConnector, PollConnector):
start: datetime | None = None,
end: datetime | None = None,
) -> list[tuple[DriveItem, str]]:
filter_str = ""
if start is not None and end is not None:
filter_str = (
f"last_modified_datetime ge {start.isoformat()} and "
f"last_modified_datetime le {end.isoformat()}"
)
final_driveitems: list[tuple[DriveItem, str]] = []
try:
site = self.graph_client.sites.get_by_url(site_descriptor.url)
@ -167,9 +160,10 @@ class SharepointConnector(LoadConnector, PollConnector):
root_folder = root_folder.get_by_path(folder_part)
# Get all items recursively
query = root_folder.get_files(True, 1000)
if filter_str:
query = query.filter(filter_str)
query = root_folder.get_files(
recursive=True,
page_size=1000,
)
driveitems = query.execute_query()
logger.debug(
f"Found {len(driveitems)} items in drive '{drive.name}'"
@ -180,11 +174,12 @@ class SharepointConnector(LoadConnector, PollConnector):
"Shared Documents" if drive.name == "Documents" else drive.name
)
# Filter items based on folder path if specified
if site_descriptor.folder_path:
# Filter items to ensure they're in the specified folder or its subfolders
# The path will be in format: /drives/{drive_id}/root:/folder/path
filtered_driveitems = [
(item, drive_name)
driveitems = [
item
for item in driveitems
if any(
path_part == site_descriptor.folder_path
@ -196,7 +191,7 @@ class SharepointConnector(LoadConnector, PollConnector):
)[1].split("/")
)
]
if len(filtered_driveitems) == 0:
if len(driveitems) == 0:
all_paths = [
item.parent_reference.path for item in driveitems
]
@ -204,11 +199,23 @@ class SharepointConnector(LoadConnector, PollConnector):
f"Nothing found for folder '{site_descriptor.folder_path}' "
f"in; any of valid paths: {all_paths}"
)
final_driveitems.extend(filtered_driveitems)
else:
final_driveitems.extend(
[(item, drive_name) for item in driveitems]
# Filter items based on time window if specified
if start is not None and end is not None:
driveitems = [
item
for item in driveitems
if start
<= item.last_modified_datetime.replace(tzinfo=timezone.utc)
<= end
]
logger.debug(
f"Found {len(driveitems)} items within time window in drive '{drive.name}'"
)
for item in driveitems:
final_driveitems.append((item, drive_name))
except Exception as e:
# Some drives might not be accessible
logger.warning(f"Failed to process drive: {str(e)}")

View File

@ -161,9 +161,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
hidden_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
)
recent_assistants: Mapped[list[dict]] = mapped_column(
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
)
pinned_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)

View File

@ -11,7 +11,7 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@ -291,8 +291,9 @@ def get_personas_for_user(
include_deleted: bool = False,
joinedload_all: bool = False,
) -> Sequence[Persona]:
stmt = select(Persona).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
stmt = select(Persona)
stmt = _add_user_filters(stmt, user, get_editable)
if not include_default:
stmt = stmt.where(Persona.builtin_persona.is_(False))
if not include_slack_bot_personas:
@ -302,14 +303,16 @@ def get_personas_for_user(
if joinedload_all:
stmt = stmt.options(
joinedload(Persona.prompts),
joinedload(Persona.tools),
joinedload(Persona.document_sets),
joinedload(Persona.groups),
joinedload(Persona.users),
selectinload(Persona.prompts),
selectinload(Persona.tools),
selectinload(Persona.document_sets),
selectinload(Persona.groups),
selectinload(Persona.users),
selectinload(Persona.labels),
)
return db_session.execute(stmt).unique().scalars().all()
results = db_session.execute(stmt).scalars().all()
return results
def get_personas(db_session: Session) -> Sequence[Persona]:

View File

@ -380,6 +380,15 @@ def index_doc_batch(
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
)
doc_descriptors = [
{
"doc_id": doc.id,
"doc_length": doc.get_total_char_length(),
}
for doc in ctx.updatable_docs
]
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)

View File

@ -1,6 +1,8 @@
import threading
import time
from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Any
@ -11,6 +13,7 @@ from requests import RequestException
from requests import Response
from retry import retry
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import SKIP_WARM_UP
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
@ -155,6 +158,7 @@ class EmbeddingModel:
text_type: EmbedTextType,
batch_size: int,
max_seq_length: int,
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)
@ -163,12 +167,15 @@ class EmbeddingModel:
)
embeddings: list[Embedding] = []
for idx, text_batch in enumerate(text_batches, start=1):
def process_batch(
batch_idx: int, text_batch: list[str]
) -> tuple[int, list[Embedding]]:
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
@ -185,10 +192,43 @@ class EmbeddingModel:
)
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
return batch_idx, response.embeddings
# only multi thread if:
# 1. num_threads is greater than 1
# 2. we are using an API-based embedding model (provider_type is not None)
# 3. there are more than 1 batch (no point in threading if only 1)
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_batch = {
executor.submit(process_batch, idx, batch): idx
for idx, batch in enumerate(text_batches, start=1)
}
# Collect results in order
batch_results: list[tuple[int, list[Embedding]]] = []
for future in as_completed(future_to_batch):
try:
result = future.result()
batch_results.append(result)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
except Exception as e:
logger.exception("Embedding model failed to process batch")
raise e
# Sort by batch index and extend embeddings
batch_results.sort(key=lambda x: x[0])
for _, batch_embeddings in batch_results:
embeddings.extend(batch_embeddings)
else:
# Original sequential processing
for idx, text_batch in enumerate(text_batches, start=1):
_, batch_embeddings = process_batch(idx, text_batch)
embeddings.extend(batch_embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
return embeddings
def encode(

View File

@ -18,6 +18,8 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorPermissionSyncPayload(BaseModel):
id: str
submitted: datetime
started: datetime | None
celery_task_id: str | None
@ -42,6 +44,12 @@ class RedisConnectorPermissionSync:
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@ -55,6 +63,7 @@ class RedisConnectorPermissionSync:
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@ -110,6 +119,20 @@ class RedisConnectorPermissionSync:
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
@ -177,6 +200,7 @@ class RedisConnectorPermissionSync:
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.active_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
@ -191,6 +215,9 @@ class RedisConnectorPermissionSync:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorPermissionSync.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
r.delete(key)

View File

@ -11,6 +11,7 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorExternalGroupSyncPayload(BaseModel):
submitted: datetime
started: datetime | None
celery_task_id: str | None
@ -135,6 +136,12 @@ class RedisConnectorExternalGroupSync:
) -> int | None:
pass
def reset(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"

View File

@ -35,8 +35,8 @@ class RedisConnectorIndex:
TERMINATE_TTL = 600
# used to signal the overall workflow is still active
# there are gaps in time between states where we need some slack
# to correctly transition
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600

View File

@ -123,7 +123,7 @@ class TenantRedis(redis.Redis):
"ttl",
] # Regular methods that need simple prefixing
if item == "scan_iter":
if item == "scan_iter" or item == "sscan_iter":
return self._prefix_scan_iter(original_attr)
elif item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)

View File

@ -422,27 +422,29 @@ def sync_cc_pair(
if redis_connector.permissions.fenced:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Doc permissions sync task already in progress.",
detail="Permissions sync task already in progress.",
)
logger.info(
f"Doc permissions sync cc_pair={cc_pair_id} "
f"Permissions sync cc_pair={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_permissions_sync_task(
payload_id = try_creating_permissions_sync_task(
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if not tasks_created:
if not payload_id:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Doc permissions sync task creation failed.",
detail="Permissions sync task creation failed.",
)
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the doc permissions sync task.",
message="Successfully created the permissions sync task.",
)

View File

@ -44,7 +44,6 @@ class UserPreferences(BaseModel):
chosen_assistants: list[int] | None = None
hidden_assistants: list[int] = []
visible_assistants: list[int] = []
recent_assistants: list[int] | None = None
default_model: str | None = None
auto_scroll: bool | None = None
pinned_assistants: list[int] | None = None

View File

@ -572,59 +572,6 @@ class ChosenDefaultModelRequest(BaseModel):
default_model: str | None = None
class RecentAssistantsRequest(BaseModel):
current_assistant: int
def update_recent_assistants(
recent_assistants: list[int] | None, current_assistant: int
) -> list[int]:
if recent_assistants is None:
recent_assistants = []
else:
recent_assistants = [x for x in recent_assistants if x != current_assistant]
# Add current assistant to start of list
recent_assistants.insert(0, current_assistant)
# Keep only the 5 most recent assistants
recent_assistants = recent_assistants[:5]
return recent_assistants
@router.patch("/user/recent-assistants")
def update_user_recent_assistants(
request: RecentAssistantsRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
preferences = no_auth_user.preferences
recent_assistants = preferences.recent_assistants
updated_preferences = update_recent_assistants(
recent_assistants, request.current_assistant
)
preferences.recent_assistants = updated_preferences
set_no_auth_user_preferences(store, preferences)
return
else:
raise RuntimeError("This should never happen")
recent_assistants = UserInfo.from_model(user).preferences.recent_assistants
updated_recent_assistants = update_recent_assistants(
recent_assistants, request.current_assistant
)
db_session.execute(
update(User)
.where(User.id == user.id) # type: ignore
.values(recent_assistants=updated_recent_assistants)
)
db_session.commit()
@router.patch("/shortcut-enabled")
def update_user_shortcut_enabled(
shortcut_enabled: bool,
@ -731,30 +678,6 @@ class ChosenAssistantsRequest(BaseModel):
chosen_assistants: list[int]
@router.patch("/user/assistant-list")
def update_user_assistant_list(
request: ChosenAssistantsRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.chosen_assistants = request.chosen_assistants
set_no_auth_user_preferences(store, no_auth_user.preferences)
return
else:
raise RuntimeError("This should never happen")
db_session.execute(
update(User)
.where(User.id == user.id) # type: ignore
.values(chosen_assistants=request.chosen_assistants)
)
db_session.commit()
def update_assistant_visibility(
preferences: UserPreferences, assistant_id: int, show: bool
) -> UserPreferences:

View File

@ -1,4 +1,6 @@
import base64
import json
import os
from datetime import datetime
from typing import Any
@ -66,3 +68,10 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
)
return masked_creds
def make_short_id() -> str:
"""Fast way to generate a random 8 character id ... useful for tagging data
to trace it through a flow. This is definitely not guaranteed to be unique and is
targeted at the stated use case."""
return base64.b32encode(os.urandom(5)).decode("utf-8")[:8] # 5 bytes → 8 chars

View File

@ -26,6 +26,13 @@ doc_permission_sync_ctx: contextvars.ContextVar[
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
class LoggerContextVars:
@staticmethod
def reset() -> None:
pruning_ctx.set(dict())
doc_permission_sync_ctx.set(dict())
class TaskAttemptSingleton:
"""Used to tell if this process is an indexing job, and if so what is the
unique identifier for this indexing attempt. For things like the API server,
@ -70,27 +77,32 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
) -> tuple[str, MutableMapping[str, Any]]:
# If this is an indexing job, add the attempt ID to the log message
# This helps filter the logs for this specific indexing
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
while True:
pruning_ctx_dict = pruning_ctx.get()
if len(pruning_ctx_dict) > 0:
if "request_id" in pruning_ctx_dict:
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
pruning_ctx_dict = pruning_ctx.get()
if len(pruning_ctx_dict) > 0:
if "request_id" in pruning_ctx_dict:
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
if "cc_pair_id" in pruning_ctx_dict:
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
break
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
if len(doc_permission_sync_ctx_dict) > 0:
if "request_id" in doc_permission_sync_ctx_dict:
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
break
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
if "cc_pair_id" in pruning_ctx_dict:
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
elif len(doc_permission_sync_ctx_dict) > 0:
if "request_id" in doc_permission_sync_ctx_dict:
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
else:
if index_attempt_id is not None:
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
if cc_pair_id is not None:
msg = f"[CC Pair: {cc_pair_id}] {msg}"
break
# Add tenant information if it differs from default
# This will always be the case for authenticated API requests
if MULTI_TENANT:

View File

@ -81,6 +81,7 @@ hubspot-api-client==8.1.0
asana==5.0.8
dropbox==11.36.2
boto3-stubs[s3]==1.34.133
shapely==2.0.6
stripe==10.12.0
urllib3==2.2.3
mistune==0.8.4

View File

@ -176,3 +176,35 @@ def test_sharepoint_connector_other_library(
for expected in expected_documents:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_poll(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests"]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
start = datetime(2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc) # 12 seconds before
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
# Get documents within the time window
document_batches = list(connector._fetch_from_sharepoint(start=start, end=end))
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Should only find test1.docx
assert len(found_documents) == 1, "Should only find one document in the time window"
doc = found_documents[0]
assert doc.semantic_identifier == "test1.docx"
verify_document_metadata(doc)
verify_document_content(
doc, [d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0]
)

View File

@ -1,75 +0,0 @@
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: celery-worker-heavy-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: celery-worker-heavy
minReplicas: 1
maxReplicas: 5
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 60
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: celery-worker-light-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: celery-worker-light
minReplicas: 1
maxReplicas: 10
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: celery-worker-indexing-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: celery-worker-indexing
minReplicas: 1
maxReplicas: 10
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: celery-worker-monitoring-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: celery-worker-indexing
minReplicas: 1
maxReplicas: 4
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70

View File

@ -1,13 +0,0 @@
apiVersion: keda.sh/v1alpha1
kind: TriggerAuthentication
metadata:
name: celery-worker-auth
namespace: onyx
spec:
secretTargetRef:
- parameter: host
name: keda-redis-secret
key: host
- parameter: password
name: keda-redis-secret
key: password

View File

@ -1,53 +0,0 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: celery-worker-indexing-scaledobject
namespace: onyx
labels:
app: celery-worker-indexing
spec:
scaleTargetRef:
name: celery-worker-indexing
minReplicaCount: 1
maxReplicaCount: 30
triggers:
- type: redis
metadata:
sslEnabled: "true"
port: "6379"
enableTLS: "true"
listName: connector_indexing
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
sslEnabled: "true"
port: "6379"
enableTLS: "true"
listName: connector_indexing:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
sslEnabled: "true"
port: "6379"
enableTLS: "true"
listName: connector_indexing:3
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: cpu
metadata:
type: Utilization
value: "70"
- type: memory
metadata:
type: Utilization
value: "70"

View File

@ -1,58 +0,0 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: celery-worker-light-scaledobject
namespace: onyx
labels:
app: celery-worker-light
spec:
scaleTargetRef:
name: celery-worker-light
minReplicaCount: 5
maxReplicaCount: 20
triggers:
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync:3
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: connector_deletion
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: connector_deletion:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth

View File

@ -1,70 +0,0 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: celery-worker-primary-scaledobject
namespace: onyx
labels:
app: celery-worker-primary
spec:
scaleTargetRef:
name: celery-worker-primary
pollingInterval: 15 # Check every 15 seconds
cooldownPeriod: 30 # Wait 30 seconds before scaling down
minReplicaCount: 4
maxReplicaCount: 4
triggers:
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: celery
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: celery:1
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: celery:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: celery:3
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: periodic_tasks
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
port: "6379"
enableTLS: "true"
listName: periodic_tasks:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth

View File

@ -1,19 +0,0 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: indexing-model-server-scaledobject
namespace: onyx
labels:
app: indexing-model-server
spec:
scaleTargetRef:
name: indexing-model-server-deployment
pollingInterval: 15 # Check every 15 seconds
cooldownPeriod: 30 # Wait 30 seconds before scaling down
minReplicaCount: 10
maxReplicaCount: 10
triggers:
- type: cpu
metadata:
type: Utilization
value: "70"

View File

@ -1,9 +0,0 @@
apiVersion: v1
kind: Secret
metadata:
name: keda-redis-secret
namespace: onyx
type: Opaque
data:
host: { base64 encoded host here }
password: { base64 encoded password here }

View File

@ -1,44 +0,0 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: celery-beat
spec:
replicas: 1
selector:
matchLabels:
app: celery-beat
template:
metadata:
labels:
app: celery-beat
spec:
containers:
- name: celery-beat
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[
"celery",
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO",
]
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: onyx-secrets
key: redis_password
- name: ONYX_VERSION
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap
resources:
requests:
cpu: "250m"
memory: "512Mi"
limits:
cpu: "500m"
memory: "1Gi"

View File

@ -1,60 +0,0 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: celery-worker-heavy
spec:
replicas: 2
selector:
matchLabels:
app: celery-worker-heavy
template:
metadata:
labels:
app: celery-worker-heavy
spec:
containers:
- name: celery-worker-heavy
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[
"celery",
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
]
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: onyx-secrets
key: redis_password
- name: ONYX_VERSION
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap
volumeMounts:
- name: vespa-certificates
mountPath: "/app/certs"
readOnly: true
resources:
requests:
cpu: "1000m"
memory: "2Gi"
limits:
cpu: "2000m"
memory: "4Gi"
volumes:
- name: vespa-certificates
secret:
secretName: vespa-certificates
items:
- key: cert.pem
path: cert.pem
- key: key.pem
path: key.pem

View File

@ -1,62 +0,0 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: celery-worker-indexing
spec:
replicas: 1
selector:
matchLabels:
app: celery-worker-indexing
template:
metadata:
labels:
app: celery-worker-indexing
spec:
containers:
- name: celery-worker-indexing
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[
"celery",
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing",
"--prefetch-multiplier=1",
"--concurrency=10",
]
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: onyx-secrets
key: redis_password
- name: ONYX_VERSION
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap
volumeMounts:
- name: vespa-certificates
mountPath: "/app/certs"
readOnly: true
resources:
requests:
cpu: "500m"
memory: "4Gi"
limits:
cpu: "1000m"
memory: "8Gi"
volumes:
- name: vespa-certificates
secret:
secretName: vespa-certificates
items:
- key: cert.pem
path: cert.pem
- key: key.pem
path: key.pem

View File

@ -1,62 +0,0 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: celery-worker-light
spec:
replicas: 1
selector:
matchLabels:
app: celery-worker-light
template:
metadata:
labels:
app: celery-worker-light
spec:
containers:
- name: celery-worker-light
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[
"celery",
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
"--prefetch-multiplier=1",
"--concurrency=10",
]
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: onyx-secrets
key: redis_password
- name: ONYX_VERSION
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap
volumeMounts:
- name: vespa-certificates
mountPath: "/app/certs"
readOnly: true
resources:
requests:
cpu: "500m"
memory: "1Gi"
limits:
cpu: "1000m"
memory: "2Gi"
volumes:
- name: vespa-certificates
secret:
secretName: vespa-certificates
items:
- key: cert.pem
path: cert.pem
- key: key.pem
path: key.pem

View File

@ -1,62 +0,0 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: celery-worker-monitoring
spec:
replicas: 2
selector:
matchLabels:
app: celery-worker-monitoring
template:
metadata:
labels:
app: celery-worker-monitoring
spec:
containers:
- name: celery-worker-monitoring
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[
"celery",
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
"--prefetch-multiplier=8",
"--concurrency=8",
]
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: onyx-secrets
key: redis_password
- name: ONYX_VERSION
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap
volumeMounts:
- name: vespa-certificates
mountPath: "/app/certs"
readOnly: true
resources:
requests:
cpu: "1000m"
memory: "1Gi"
limits:
cpu: "1000m"
memory: "1Gi"
volumes:
- name: vespa-certificates
secret:
secretName: vespa-certificates
items:
- key: cert.pem
path: cert.pem
- key: key.pem
path: key.pem

View File

@ -1,62 +0,0 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: celery-worker-primary
spec:
replicas: 1
selector:
matchLabels:
app: celery-worker-primary
template:
metadata:
labels:
app: celery-worker-primary
spec:
containers:
- name: celery-worker-primary
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[
"celery",
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery,periodic_tasks",
"--prefetch-multiplier=1",
"--concurrency=10",
]
env:
- name: REDIS_PASSWORD
valueFrom:
secretKeyRef:
name: onyx-secrets
key: redis_password
- name: ONYX_VERSION
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap
volumeMounts:
- name: vespa-certificates
mountPath: "/app/certs"
readOnly: true
resources:
requests:
cpu: "500m"
memory: "1Gi"
limits:
cpu: "1000m"
memory: "2Gi"
volumes:
- name: vespa-certificates
secret:
secretName: vespa-certificates
items:
- key: cert.pem
path: cert.pem
- key: key.pem
path: key.pem

View File

@ -1,41 +1,16 @@
import { defineConfig, devices } from "@playwright/test";
export default defineConfig({
workers: 1, // temporary change to see if single threaded testing stabilizes the tests
testDir: "./tests/e2e", // Folder for test files
reporter: "list",
// Configure paths for screenshots
// expect: {
// toMatchSnapshot: {
// threshold: 0.2, // Adjust the threshold for visual diffs
// },
// },
// reporter: [["html", { outputFolder: "test-results/output/report" }]], // HTML report location
// outputDir: "test-results/output/screenshots", // Set output folder for test artifacts
globalSetup: require.resolve("./tests/e2e/global-setup"),
projects: [
{
// dependency for admin workflows
name: "admin_setup",
testMatch: /.*\admin_auth\.setup\.ts/,
},
{
// tests admin workflows
name: "chromium-admin",
grep: /@admin/,
name: "admin",
use: {
...devices["Desktop Chrome"],
// Use prepared auth state.
storageState: "admin_auth.json",
},
dependencies: ["admin_setup"],
},
{
// tests logged out / guest workflows
name: "chromium-guest",
grep: /@guest/,
use: {
...devices["Desktop Chrome"],
},
testIgnore: ["**/codeUtils.test.ts"],
},
],
});

View File

@ -232,11 +232,9 @@ export function AssistantEditor({
existingPersona?.llm_model_provider_override ?? null,
llm_model_version_override:
existingPersona?.llm_model_version_override ?? null,
starter_messages: existingPersona?.starter_messages ?? [
{
message: "",
},
],
starter_messages: existingPersona?.starter_messages?.length
? existingPersona.starter_messages
: [{ message: "" }],
enabled_tools_map: enabledToolsMap,
icon_color: existingPersona?.icon_color ?? defautIconColor,
icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
@ -1099,7 +1097,9 @@ export function AssistantEditor({
)}
</div>
</div>
<Separator />
<div className="w-full flex flex-col">
<div className="flex gap-x-2 items-center">
<div className="block font-medium text-sm">
@ -1110,6 +1110,7 @@ export function AssistantEditor({
<SubLabel>
Sample messages that help users understand what this
assistant can do and how to interact with it effectively.
New input fields will appear automatically as you type.
</SubLabel>
<div className="w-full">

View File

@ -64,19 +64,16 @@ export default function StarterMessagesList({
size="icon"
onClick={() => {
arrayHelpers.remove(index);
if (
index === values.length - 2 &&
!values[values.length - 1].message
) {
arrayHelpers.pop();
}
}}
className={`text-gray-400 hover:text-red-500 ${
index === values.length - 1 && !starterMessage.message
? "opacity-50 cursor-not-allowed"
: ""
}`}
disabled={index === values.length - 1 && !starterMessage.message}
disabled={
(index === values.length - 1 && !starterMessage.message) ||
(values.length === 1 && index === 0) // should never happen, but just in case
}
>
<FiTrash2 className="h-4 w-4" />
</Button>

View File

@ -111,6 +111,7 @@ import {
import AssistantModal from "../assistants/mine/AssistantModal";
import { getSourceMetadata } from "@/lib/sources";
import { UserSettingsModal } from "./modal/UserSettingsModal";
import { AlignStartVertical } from "lucide-react";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@ -189,7 +190,11 @@ export function ChatPage({
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
const { assistants: availableAssistants, finalAssistants } = useAssistants();
const {
assistants: availableAssistants,
finalAssistants,
pinnedAssistants,
} = useAssistants();
const [showApiKeyModal, setShowApiKeyModal] = useState(
!shouldShowWelcomeModal
@ -272,16 +277,6 @@ export function ChatPage({
SEARCH_PARAM_NAMES.TEMPERATURE
);
const defaultTemperature = search_param_temperature
? parseFloat(search_param_temperature)
: selectedAssistant?.tools.some(
(tool) =>
tool.in_code_tool_id === SEARCH_TOOL_ID ||
tool.in_code_tool_id === INTERNET_SEARCH_TOOL_ID
)
? 0
: 0.7;
const setSelectedAssistantFromId = (assistantId: number) => {
// NOTE: also intentionally look through available assistants here, so that
// even if the user has hidden an assistant they can still go back to it
@ -297,20 +292,21 @@ export function ChatPage({
const [presentingDocument, setPresentingDocument] =
useState<OnyxDocument | null>(null);
const { recentAssistants, refreshRecentAssistants } = useAssistants();
// Current assistant is decided based on this ordering
// 1. Alternative assistant (assistant selected explicitly by user)
// 2. Selected assistant (assistnat default in this chat session)
// 3. First pinned assistants (ordered list of pinned assistants)
// 4. Available assistants (ordered list of available assistants)
const liveAssistant: Persona | undefined = useMemo(
() =>
alternativeAssistant ||
selectedAssistant ||
recentAssistants[0] ||
finalAssistants[0] ||
pinnedAssistants[0] ||
availableAssistants[0],
[
alternativeAssistant,
selectedAssistant,
recentAssistants,
finalAssistants,
pinnedAssistants,
availableAssistants,
]
);
@ -816,7 +812,6 @@ export function ChatPage({
setMaxTokens(maxTokens);
}
}
refreshRecentAssistants(liveAssistant?.id);
fetchMaxTokens();
}, [liveAssistant]);

View File

@ -19,9 +19,7 @@ import {
import { useRouter, useSearchParams } from "next/navigation";
import { ChatSession } from "../interfaces";
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
import { Folder } from "../folders/interfaces";
import { usePopup } from "@/components/admin/connectors/Popup";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { DocumentIcon2, NewChatIcon } from "@/components/icons/icons";
@ -251,9 +249,11 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
const handleNewChat = () => {
reset();
console.log("currentChatSession", currentChatSession);
const newChatUrl =
`/${page}` +
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA && currentChatSession
(currentChatSession
? `?assistantId=${currentChatSession.persona_id}`
: "");
router.push(newChatUrl);
@ -275,7 +275,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
flex-col relative
h-screen
pt-2
transition-transform
`}
>
@ -294,8 +293,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
className="w-full px-2 py-1 rounded-md items-center hover:bg-hover cursor-pointer transition-all duration-150 flex gap-x-2"
href={
`/${page}` +
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&
currentChatSession?.persona_id
(currentChatSession
? `?assistantId=${currentChatSession?.persona_id}`
: "")
}
@ -320,14 +318,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
<Link
className="w-full px-2 py-1 rounded-md items-center hover:bg-hover cursor-pointer transition-all duration-150 flex gap-x-2"
href="/chat/input-prompts"
onClick={(e) => {
if (e.metaKey || e.ctrlKey) {
return;
}
if (handleNewChat) {
handleNewChat();
}
}}
>
<DocumentIcon2
size={20}

View File

@ -2,7 +2,6 @@
import { UserDropdown } from "../UserDropdown";
import { FiShare2 } from "react-icons/fi";
import { SetStateAction, useContext, useEffect } from "react";
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
import { ChatSession } from "@/app/chat/interfaces";
import Link from "next/link";
import { pageType } from "@/app/chat/sessionSidebar/types";
@ -42,8 +41,7 @@ export default function FunctionalHeader({
event.preventDefault();
window.open(
`/${page}` +
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&
currentChatSession
(currentChatSession
? `?assistantId=${currentChatSession.persona_id}`
: ""),
"_self"
@ -63,7 +61,7 @@ export default function FunctionalHeader({
reset();
const newChatUrl =
`/${page}` +
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA && currentChatSession
(currentChatSession
? `?assistantId=${currentChatSession.persona_id}`
: "");
router.push(newChatUrl);
@ -128,25 +126,6 @@ export default function FunctionalHeader({
</div>
)}
{/* <div
className={`absolute
${
documentSidebarToggled && !sidebarToggled
? "left-[calc(50%-125px)]"
: !documentSidebarToggled && sidebarToggled
? "left-[calc(50%+125px)]"
: "left-1/2"
}
${
documentSidebarToggled || sidebarToggled
? "mobile:w-[40vw] max-w-[50vw]"
: "mobile:w-[50vw] max-w-[60vw]"
}
top-1/2 transform -translate-x-1/2 -translate-y-1/2 transition-all duration-300`}
>
<ChatBanner />
</div> */}
<div className="invisible">
<LogoWithText
page={page}
@ -156,8 +135,6 @@ export default function FunctionalHeader({
/>
</div>
{/* className="fixed cursor-pointer flex z-40 left-4 bg-black top-3 h-8" */}
<div className="absolute right-2 mobile:top-1 desktop:top-1 h-8 flex">
{setSharingModalVisible && !hideUserDropdown && (
<div
@ -179,8 +156,7 @@ export default function FunctionalHeader({
className="desktop:hidden ml-2 my-auto"
href={
`/${page}` +
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&
currentChatSession
(currentChatSession
? `?assistantId=${currentChatSession.persona_id}`
: "")
}

View File

@ -25,8 +25,6 @@ interface AssistantsContextProps {
ownedButHiddenAssistants: Persona[];
refreshAssistants: () => Promise<void>;
isImageGenerationAvailable: boolean;
recentAssistants: Persona[];
refreshRecentAssistants: (currentAssistant: number) => Promise<void>;
// Admin only
editablePersonas: Persona[];
allAssistants: Persona[];
@ -56,35 +54,28 @@ export const AssistantsProvider: React.FC<{
const [editablePersonas, setEditablePersonas] = useState<Persona[]>([]);
const [allAssistants, setAllAssistants] = useState<Persona[]>([]);
const [pinnedAssistants, setPinnedAssistants] = useState<Persona[]>(
user?.preferences.pinned_assistants
? assistants.filter((assistant) =>
user?.preferences?.pinned_assistants?.includes(assistant.id)
)
: assistants.filter((a) => a.builtin_persona)
);
const [pinnedAssistants, setPinnedAssistants] = useState<Persona[]>(() => {
if (user?.preferences.pinned_assistants) {
return user.preferences.pinned_assistants
.map((id) => assistants.find((assistant) => assistant.id === id))
.filter((assistant): assistant is Persona => assistant !== undefined);
} else {
return assistants.filter((a) => a.builtin_persona);
}
});
useEffect(() => {
setPinnedAssistants(
user?.preferences.pinned_assistants
? assistants.filter((assistant) =>
user?.preferences?.pinned_assistants?.includes(assistant.id)
)
: assistants.filter((a) => a.builtin_persona)
);
setPinnedAssistants(() => {
if (user?.preferences.pinned_assistants) {
return user.preferences.pinned_assistants
.map((id) => assistants.find((assistant) => assistant.id === id))
.filter((assistant): assistant is Persona => assistant !== undefined);
} else {
return assistants.filter((a) => a.builtin_persona);
}
});
}, [user?.preferences?.pinned_assistants, assistants]);
const [recentAssistants, setRecentAssistants] = useState<Persona[]>(
user?.preferences.recent_assistants
?.filter((assistantId) =>
assistants.find((assistant) => assistant.id === assistantId)
)
.map(
(assistantId) =>
assistants.find((assistant) => assistant.id === assistantId)!
) || []
);
const [isImageGenerationAvailable, setIsImageGenerationAvailable] =
useState<boolean>(false);
@ -135,28 +126,6 @@ export const AssistantsProvider: React.FC<{
fetchPersonas();
}, [isAdmin, isCurator]);
const refreshRecentAssistants = async (currentAssistant: number) => {
const response = await fetch("/api/user/recent-assistants", {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
current_assistant: currentAssistant,
}),
});
if (!response.ok) {
return;
}
setRecentAssistants((recentAssistants) => [
assistants.find((assistant) => assistant.id === currentAssistant)!,
...recentAssistants.filter(
(assistant) => assistant.id !== currentAssistant
),
]);
};
const refreshAssistants = async () => {
try {
const response = await fetch("/api/persona", {
@ -181,13 +150,6 @@ export const AssistantsProvider: React.FC<{
} catch (error) {
console.error("Error refreshing assistants:", error);
}
setRecentAssistants(
assistants.filter(
(assistant) =>
user?.preferences.recent_assistants?.includes(assistant.id) || false
)
);
};
const {
@ -230,8 +192,6 @@ export const AssistantsProvider: React.FC<{
editablePersonas,
allAssistants,
isImageGenerationAvailable,
recentAssistants,
refreshRecentAssistants,
setPinnedAssistants,
pinnedAssistants,
}}

View File

@ -2,7 +2,6 @@
import { useContext } from "react";
import { FiSidebar } from "react-icons/fi";
import { SettingsContext } from "../settings/SettingsProvider";
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
import { LeftToLineIcon, NewChatIcon, RightToLineIcon } from "../icons/icons";
import {
Tooltip,
@ -90,9 +89,7 @@ export default function LogoWithText({
className="my-auto mobile:hidden"
href={
`/${page}` +
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA && assistantId
? `?assistantId=${assistantId}`
: "")
(assistantId ? `?assistantId=${assistantId}` : "")
}
onClick={(e) => {
if (e.metaKey || e.ctrlKey) {

View File

@ -18,10 +18,6 @@ export const NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED =
process.env.NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED?.toLowerCase() ===
"true";
export const NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA =
process.env.NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA?.toLowerCase() ===
"true";
export const GMAIL_AUTH_IS_ADMIN_COOKIE_NAME = "gmail_auth_is_admin";
export const GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME =

View File

@ -1,24 +1,9 @@
// dependency for all admin user tests
import { test as setup } from "@playwright/test";
import { test as setup, expect } from "@playwright/test";
import { TEST_CREDENTIALS } from "./constants";
setup("authenticate", async ({ page }) => {
const { email, password } = TEST_CREDENTIALS;
setup("authenticate as admin", async ({ browser }) => {
const context = await browser.newContext({ storageState: "admin_auth.json" });
const page = await context.newPage();
await page.goto("http://localhost:3000/chat");
await page.waitForURL("http://localhost:3000/auth/login?next=%2Fchat");
await expect(page).toHaveTitle("Onyx");
await page.fill("#email", email);
await page.fill("#password", password);
// Click the login button
await page.click('button[type="submit"]');
await page.waitForURL("http://localhost:3000/chat");
await page.context().storageState({ path: "admin_auth.json" });
});

View File

@ -1,65 +1,43 @@
import { test, expect } from "@chromatic-com/playwright";
import { test, expect } from "@playwright/test";
test(
"Admin - OAuth Redirect - Missing Code",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
await page.goto(
"http://localhost:3000/admin/connectors/slack/oauth/callback?state=xyz"
);
test.use({ storageState: "admin_auth.json" });
await expect(page.locator("p.text-text-500")).toHaveText(
"Missing authorization code."
);
}
);
test("Admin - OAuth Redirect - Missing Code", async ({ page }) => {
await page.goto(
"http://localhost:3000/admin/connectors/slack/oauth/callback?state=xyz"
);
test(
"Admin - OAuth Redirect - Missing State",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
await page.goto(
"http://localhost:3000/admin/connectors/slack/oauth/callback?code=123"
);
await expect(page.locator("p.text-text-500")).toHaveText(
"Missing authorization code."
);
});
await expect(page.locator("p.text-text-500")).toHaveText(
"Missing state parameter."
);
}
);
test("Admin - OAuth Redirect - Missing State", async ({ page }) => {
await page.goto(
"http://localhost:3000/admin/connectors/slack/oauth/callback?code=123"
);
test(
"Admin - OAuth Redirect - Invalid Connector",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
await page.goto(
"http://localhost:3000/admin/connectors/invalid-connector/oauth/callback?code=123&state=xyz"
);
await expect(page.locator("p.text-text-500")).toHaveText(
"Missing state parameter."
);
});
await expect(page.locator("p.text-text-500")).toHaveText(
"invalid_connector is not a valid source type."
);
}
);
test("Admin - OAuth Redirect - Invalid Connector", async ({ page }) => {
await page.goto(
"http://localhost:3000/admin/connectors/invalid-connector/oauth/callback?code=123&state=xyz"
);
test(
"Admin - OAuth Redirect - No Session",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
await page.goto(
"http://localhost:3000/admin/connectors/slack/oauth/callback?code=123&state=xyz"
);
await expect(page.locator("p.text-text-500")).toHaveText(
"invalid_connector is not a valid source type."
);
});
await expect(page.locator("p.text-text-500")).toHaveText(
"An error occurred during the OAuth process. Please try again."
);
}
);
test("Admin - OAuth Redirect - No Session", async ({ page }) => {
await page.goto(
"http://localhost:3000/admin/connectors/slack/oauth/callback?code=123&state=xyz"
);
await expect(page.locator("p.text-text-500")).toHaveText(
"An error occurred during the OAuth process. Please try again."
);
});

View File

@ -2,6 +2,8 @@ import { test, expect } from "@playwright/test";
import chromaticSnpashots from "./chromaticSnpashots.json";
import type { Page } from "@playwright/test";
test.use({ storageState: "admin_auth.json" });
async function verifyAdminPageNavigation(
page: Page,
path: string,
@ -13,7 +15,10 @@ async function verifyAdminPageNavigation(
}
) {
await page.goto(`http://localhost:3000/admin/${path}`);
await expect(page.locator("h1.text-3xl")).toHaveText(pageTitle);
await expect(page.locator("h1.text-3xl")).toHaveText(pageTitle, {
timeout: 2000,
});
if (options?.paragraphText) {
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
@ -35,18 +40,12 @@ async function verifyAdminPageNavigation(
}
for (const chromaticSnapshot of chromaticSnpashots) {
test(
`Admin - ${chromaticSnapshot.name}`,
{
tag: "@admin",
},
async ({ page }) => {
await verifyAdminPageNavigation(
page,
chromaticSnapshot.path,
chromaticSnapshot.pageTitle,
chromaticSnapshot.options
);
}
);
test(`Admin - ${chromaticSnapshot.name}`, async ({ page }) => {
await verifyAdminPageNavigation(
page,
chromaticSnapshot.path,
chromaticSnapshot.pageTitle,
chromaticSnapshot.options
);
});
}

View File

@ -0,0 +1,54 @@
import { test, expect } from "@playwright/test";
// Use pre-signed in "admin" storage state
test.use({
storageState: "admin_auth.json",
});
test("Chat workflow", async ({ page }) => {
// Initial setup
await page.goto("http://localhost:3000/chat", { timeout: 3000 });
// Interact with Art assistant
await page.locator("button").filter({ hasText: "Art" }).click();
await page.getByPlaceholder("Message Art assistant...").fill("Hi");
await page.keyboard.press("Enter");
await page.waitForTimeout(3000);
// Start a new chat
await page.getByRole("link", { name: "Start New Chat" }).click();
await page.waitForNavigation({ waitUntil: "networkidle" });
// Check for expected text
await expect(page.getByText("Assistant for generating")).toBeVisible();
// Interact with General assistant
await page.locator("button").filter({ hasText: "General" }).click();
// Check URL after clicking General assistant
await expect(page).toHaveURL("http://localhost:3000/chat?assistantId=-1", {
timeout: 5000,
});
// Create a new assistant
await page.getByRole("button", { name: "Explore Assistants" }).click();
await page.getByRole("button", { name: "Create" }).click();
await page.getByTestId("name").click();
await page.getByTestId("name").fill("Test Assistant");
await page.getByTestId("description").click();
await page.getByTestId("description").fill("Test Assistant Description");
await page.getByTestId("system_prompt").click();
await page.getByTestId("system_prompt").fill("Test Assistant Instructions");
await page.getByRole("button", { name: "Create" }).click();
// Verify new assistant creation
await expect(page.getByText("Test Assistant Description")).toBeVisible({
timeout: 5000,
});
// Start another new chat
await page.getByRole("link", { name: "Start New Chat" }).click();
await expect(page.getByText("Assistant with access to")).toBeVisible({
timeout: 5000,
});
});

View File

@ -1,4 +1,9 @@
export const TEST_CREDENTIALS = {
export const TEST_USER_CREDENTIALS = {
email: "user1@test.com",
password: "User1Password123!",
};
export const TEST_ADMIN_CREDENTIALS = {
email: "admin_user@test.com",
password: "TestPassword123!",
};

View File

@ -0,0 +1,22 @@
import { chromium, FullConfig } from "@playwright/test";
import { loginAs } from "./utils/auth";
async function globalSetup(config: FullConfig) {
const browser = await chromium.launch();
const adminContext = await browser.newContext();
const adminPage = await adminContext.newPage();
await loginAs(adminPage, "admin");
await adminContext.storageState({ path: "admin_auth.json" });
await adminContext.close();
const userContext = await browser.newContext();
const userPage = await userContext.newPage();
await loginAs(userPage, "user");
await userContext.storageState({ path: "user_auth.json" });
await userContext.close();
await browser.close();
}
export default globalSetup;

View File

@ -1,35 +0,0 @@
// Add this line
import { test, expect, takeSnapshot } from "@chromatic-com/playwright";
import { TEST_CREDENTIALS } from "./constants";
// Then use as normal 👇
test(
"Homepage",
{
tag: "@guest",
},
async ({ page }, testInfo) => {
// Test redirect to login, and redirect to search after login
const { email, password } = TEST_CREDENTIALS;
await page.goto("http://localhost:3000/chat");
await page.waitForURL("http://localhost:3000/auth/login?next=%2Fchat");
await expect(page).toHaveTitle("Onyx");
await takeSnapshot(page, "Before login", testInfo);
await page.fill("#email", email);
await page.fill("#password", password);
// Click the login button
await page.click('button[type="submit"]');
await page.waitForURL("http://localhost:3000/chat");
await page.getByPlaceholder("Send a message or try using @ or /");
await expect(page.locator("body")).not.toContainText("Initializing Onyx");
}
);

View File

@ -0,0 +1,37 @@
import { Page } from "@playwright/test";
import { TEST_ADMIN_CREDENTIALS, TEST_USER_CREDENTIALS } from "../constants";
// Basic function which logs in a user (either admin or regular user) to the application
// It handles both successful login attempts and potential timeouts, with a retry mechanism
export async function loginAs(page: Page, userType: "admin" | "user") {
const { email, password } =
userType === "admin" ? TEST_ADMIN_CREDENTIALS : TEST_USER_CREDENTIALS;
await page.goto("http://localhost:3000/auth/login", { timeout: 1000 });
await page.fill("#email", email);
await page.fill("#password", password);
// Click the login button
await page.click('button[type="submit"]');
try {
await page.waitForURL("http://localhost:3000/chat", { timeout: 4000 });
} catch (error) {
console.log(`Timeout occurred. Current URL: ${page.url()}`);
// If redirect to /chat doesn't happen, go to /auth/login
await page.goto("http://localhost:3000/auth/signup", { timeout: 1000 });
await page.fill("#email", email);
await page.fill("#password", password);
// Click the login button
await page.click('button[type="submit"]');
try {
await page.waitForURL("http://localhost:3000/chat", { timeout: 4000 });
} catch (error) {
console.log(`Timeout occurred again. Current URL: ${page.url()}`);
}
}
}

15
web/user_auth.json Normal file
View File

@ -0,0 +1,15 @@
{
"cookies": [
{
"name": "fastapiusersauth",
"value": "n_EMYYKHn4tQbuPTEbtN1gJ6dQTGek9omJPhO2GhHoA",
"domain": "localhost",
"path": "/",
"expires": 1738801376.508558,
"httpOnly": true,
"secure": false,
"sameSite": "Lax"
}
],
"origins": []
}