mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 20:08:36 +02:00
Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/no_scan_iter
# Conflicts: # backend/onyx/background/celery/tasks/vespa/tasks.py # backend/onyx/redis/redis_connector_doc_perm_sync.py
This commit is contained in:
commit
5232aeacad
@ -0,0 +1,80 @@
|
||||
"""foreign key input prompts
|
||||
|
||||
Revision ID: 33ea50e88f24
|
||||
Revises: a6df6b88ef81
|
||||
Create Date: 2025-01-29 10:54:22.141765
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33ea50e88f24"
|
||||
down_revision = "a6df6b88ef81"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Safely drop constraints if exists
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate with ON DELETE CASCADE
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the new FKs with ondelete
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate them without cascading
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
)
|
@ -0,0 +1,29 @@
|
||||
"""remove recent assistants
|
||||
|
||||
Revision ID: a6df6b88ef81
|
||||
Revises: 4d58345da04a
|
||||
Create Date: 2025-01-29 10:25:52.790407
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a6df6b88ef81"
|
||||
down_revision = "4d58345da04a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
@ -13,6 +13,7 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -257,6 +258,7 @@ def _fetch_all_page_restrictions(
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
@ -265,6 +267,12 @@ def _fetch_all_page_restrictions(
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
@ -334,7 +342,7 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@ -359,6 +367,12 @@ def confluence_doc_sync(
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync", 1)
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
@ -367,4 +381,5 @@ def confluence_doc_sync(
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
callback=callback,
|
||||
)
|
||||
|
@ -14,6 +14,8 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
@ -33,10 +35,17 @@ def _build_group_member_email_map(
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not group_member_emails:
|
||||
logger.warning(f"No groups found for user with email: {email}")
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
@ -6,6 +6,7 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -28,7 +29,7 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@ -44,6 +45,12 @@ def gmail_doc_sync(
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gmail_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
@ -10,6 +10,7 @@ from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -128,7 +129,7 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@ -146,6 +147,12 @@ def gdrive_doc_sync(
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
|
@ -7,6 +7,7 @@ from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@ -14,7 +15,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
@ -24,6 +25,14 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
@ -114,7 +123,7 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@ -127,7 +136,7 @@ def slack_doc_sync(
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair,
|
||||
cc_pair=cc_pair, callback=callback
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
|
@ -15,11 +15,13 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
@ -111,6 +111,7 @@ async def login_as_anonymous_user(
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie("fastapiusersauth")
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
|
@ -198,7 +198,8 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
|
@ -91,6 +91,28 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
return False
|
||||
|
||||
|
||||
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""This is a redis specific way to build a list of tasks in a queue.
|
||||
|
||||
This helps us read the queue once and then efficiently look for missing tasks
|
||||
in the queue.
|
||||
"""
|
||||
|
||||
task_set: set[str] = set()
|
||||
|
||||
for priority in range(len(OnyxCeleryPriority)):
|
||||
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
|
||||
|
||||
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
|
||||
for task in tasks:
|
||||
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
|
||||
task_id = task_dict.get("headers", {}).get("id")
|
||||
if task_id:
|
||||
task_set.add(task_id)
|
||||
|
||||
return task_set
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
@ -3,13 +3,16 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -22,6 +25,10 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
)
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
@ -32,6 +39,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
@ -44,14 +52,19 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@ -105,7 +118,12 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# TODO(rkuo): merge into check function after lookup table for fences is added
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
@ -126,14 +144,32 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if _is_external_doc_permissions_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
if not payload_id:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
|
||||
task_logger.info(
|
||||
f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}"
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES):
|
||||
# clear any permission fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_permission_sync_fences(tenant_id, r, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating permission sync fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=60)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@ -152,13 +188,15 @@ def try_creating_permissions_sync_task(
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
) -> str | None:
|
||||
"""Returns a randomized payload id on success.
|
||||
Returns None if no syncing is required."""
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
payload_id: str | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
@ -193,7 +231,13 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncPayload(started=None, celery_task_id=None)
|
||||
redis_connector.permissions.set_active()
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
id=make_short_id(),
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
result = app.send_task(
|
||||
@ -208,8 +252,11 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
redis_connector.permissions.set_active()
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
payload_id = payload.celery_task_id
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
@ -217,7 +264,7 @@ def try_creating_permissions_sync_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
return payload_id
|
||||
|
||||
|
||||
@shared_task(
|
||||
@ -238,6 +285,8 @@ def connector_permission_sync_generator_task(
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||
@ -325,12 +374,17 @@ def connector_permission_sync_generator_task(
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
new_payload = RedisConnectorPermissionSyncPayload(
|
||||
id=payload.id,
|
||||
submitted=payload.submitted,
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
|
||||
cc_pair, callback
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
@ -380,6 +434,8 @@ def update_external_document_permissions_task(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
@ -409,16 +465,268 @@ def update_external_document_permissions_task(
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"connector_id={connector_id} "
|
||||
f"doc={doc_id} "
|
||||
f"action=update_permissions "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
|
||||
task_logger.exception(
|
||||
f"Exception in update_external_document_permissions_task: "
|
||||
f"connector_id={connector_id} "
|
||||
f"doc_id={doc_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
reserved_generator_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
validate_permission_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
queued_upsert_tasks,
|
||||
reserved_generator_tasks,
|
||||
r,
|
||||
r_celery,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_permission_sync_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
reserved_tasks: set[str],
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_permission_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.permissions.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if not payload.celery_task_id:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# either the generator task must be in flight or its subtasks must be
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
r_celery,
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within a worker
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own permissions taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.permissions.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
if member_str in reserved_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_permission_sync_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
if tasks_not_in_celery == 0:
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.permissions.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
return
|
||||
|
||||
|
||||
class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.redis_client = redis_client
|
||||
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = "PermissionSyncCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
try:
|
||||
self.redis_connector.permissions.set_active()
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_tag = tag
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"PermissionSyncCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
|
||||
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
|
||||
|
||||
@ -444,20 +752,36 @@ def monitor_ccpair_permissions_taskset(
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
try:
|
||||
payload = redis_connector.permissions.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"Permissions sync payload failed to validate. "
|
||||
"Schema may have been updated."
|
||||
)
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
f"Permissions sync progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"id={payload.id} "
|
||||
f"remaining={remaining} "
|
||||
f"initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
redis_connector.permissions.payload
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started)
|
||||
task_logger.info(
|
||||
f"Permissions sync finished: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"id={payload.id} "
|
||||
f"num_synced={initial}"
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@ -9,6 +10,7 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
@ -20,9 +22,12 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@ -39,10 +44,12 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -102,6 +109,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
@ -136,6 +147,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
@ -144,6 +156,23 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
continue
|
||||
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
# lock_beat.reacquire()
|
||||
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# validate_external_group_sync_fences(
|
||||
# tenant_id, self.app, r, r_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception(
|
||||
# "Exception while validating external group sync fences"
|
||||
# )
|
||||
|
||||
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@ -186,6 +215,12 @@ def try_creating_external_group_sync_task(
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
@ -199,11 +234,6 @@ def try_creating_external_group_sync_task(
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@ -213,8 +243,8 @@ def try_creating_external_group_sync_task(
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
@ -241,7 +271,7 @@ def connector_external_group_sync_generator_task(
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles external group syncing for a given connector credential pair
|
||||
External group sync task for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
@ -249,19 +279,59 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.external_group_sync.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_external_group_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.external_group_sync.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_external_group_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@ -330,3 +400,135 @@ def connector_external_group_sync_generator_task(
|
||||
redis_connector.external_group_sync.set_fence(None)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_sync_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_external_group_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_sync_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_external_group_sync_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.external_group_sync.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
"validate_external_group_sync_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
# redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
# redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
@ -39,6 +39,7 @@ from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import pruning_ctx
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -251,6 +252,8 @@ def connector_pruning_generator_task(
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
pruning_ctx_dict["request_id"] = self.request.id
|
||||
@ -399,7 +402,7 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
||||
f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}"
|
||||
)
|
||||
|
||||
update_sync_record_status(
|
||||
|
@ -75,6 +75,8 @@ def document_by_cc_pair_cleanup_task(
|
||||
"""
|
||||
task_logger.debug(f"Task start: doc={document_id}")
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
action = "skip"
|
||||
@ -154,11 +156,13 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected}"
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
|
@ -932,6 +932,9 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
task_logger.exception("monitor_vespa_sync exceptioned.")
|
||||
return False
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
@ -1021,6 +1024,7 @@ def vespa_metadata_sync_task(
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
|
@ -478,6 +478,12 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
|
||||
# 0 disables this behavior and is the default.
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
|
||||
|
||||
# Enable multi-threaded embedding model calls for parallel processing
|
||||
# Note: only applies for API-based embedding models
|
||||
INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
|
||||
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1
|
||||
)
|
||||
|
||||
# During an indexing attempt, specifies the number of batches which are allowed to
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
|
||||
|
@ -300,6 +300,8 @@ class OnyxRedisLocks:
|
||||
|
||||
class OnyxRedisSignals:
|
||||
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
|
||||
VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences"
|
||||
VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences"
|
||||
|
||||
|
||||
class OnyxRedisConstants:
|
||||
|
@ -1,3 +1,5 @@
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
@ -274,6 +276,11 @@ class AirtableConnector(LoadConnector):
|
||||
field_val = fields.get(field_name)
|
||||
field_type = field_schema.type
|
||||
|
||||
logger.debug(
|
||||
f"Processing field '{field_name}' of type '{field_type}' "
|
||||
f"for record '{record_id}'."
|
||||
)
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_id=field_schema.id,
|
||||
field_name=field_name,
|
||||
@ -327,19 +334,45 @@ class AirtableConnector(LoadConnector):
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
record_documents: list[Document] = []
|
||||
for record in records:
|
||||
document = self._process_record(
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
if document:
|
||||
record_documents.append(document)
|
||||
logger.info(f"Starting to process Airtable records for {table.name}.")
|
||||
|
||||
# Process records in parallel batches using ThreadPoolExecutor
|
||||
PARALLEL_BATCH_SIZE = 16
|
||||
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
|
||||
|
||||
# Process records in batches
|
||||
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
|
||||
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
|
||||
record_documents: list[Document] = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit batch tasks
|
||||
future_to_record = {
|
||||
executor.submit(
|
||||
self._process_record,
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
): record
|
||||
for record in batch_records
|
||||
}
|
||||
|
||||
# Wait for all tasks in this batch to complete
|
||||
for future in as_completed(future_to_record):
|
||||
record = future_to_record[future]
|
||||
try:
|
||||
document = future.result()
|
||||
if document:
|
||||
record_documents.append(document)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to process record {record['id']}")
|
||||
raise e
|
||||
|
||||
# After batch is complete, yield if we've hit the batch size
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
# Yield any remaining records
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
|
@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
@ -45,7 +46,17 @@ class ConnectorRunner:
|
||||
def run(self) -> GenerateDocumentsOutput:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
try:
|
||||
yield from self.doc_batch_generator
|
||||
start = time.monotonic()
|
||||
for batch in self.doc_batch_generator:
|
||||
# to know how long connector is taking
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to build a batch."
|
||||
)
|
||||
|
||||
yield batch
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
except Exception:
|
||||
exc_type, _, exc_traceback = sys.exc_info()
|
||||
|
||||
|
@ -150,6 +150,16 @@ class Document(DocumentBase):
|
||||
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
|
||||
source: DocumentSource
|
||||
|
||||
def get_total_char_length(self) -> int:
|
||||
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
|
||||
section_length = sum(len(section.text) for section in self.sections)
|
||||
identifier_length = len(self.semantic_identifier) + len(self.title or "")
|
||||
metadata_length = sum(
|
||||
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
|
||||
for k, v in self.metadata.items()
|
||||
)
|
||||
return section_length + identifier_length + metadata_length
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a document"""
|
||||
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
|
||||
|
@ -127,13 +127,6 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> list[tuple[DriveItem, str]]:
|
||||
filter_str = ""
|
||||
if start is not None and end is not None:
|
||||
filter_str = (
|
||||
f"last_modified_datetime ge {start.isoformat()} and "
|
||||
f"last_modified_datetime le {end.isoformat()}"
|
||||
)
|
||||
|
||||
final_driveitems: list[tuple[DriveItem, str]] = []
|
||||
try:
|
||||
site = self.graph_client.sites.get_by_url(site_descriptor.url)
|
||||
@ -167,9 +160,10 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
root_folder = root_folder.get_by_path(folder_part)
|
||||
|
||||
# Get all items recursively
|
||||
query = root_folder.get_files(True, 1000)
|
||||
if filter_str:
|
||||
query = query.filter(filter_str)
|
||||
query = root_folder.get_files(
|
||||
recursive=True,
|
||||
page_size=1000,
|
||||
)
|
||||
driveitems = query.execute_query()
|
||||
logger.debug(
|
||||
f"Found {len(driveitems)} items in drive '{drive.name}'"
|
||||
@ -180,11 +174,12 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
"Shared Documents" if drive.name == "Documents" else drive.name
|
||||
)
|
||||
|
||||
# Filter items based on folder path if specified
|
||||
if site_descriptor.folder_path:
|
||||
# Filter items to ensure they're in the specified folder or its subfolders
|
||||
# The path will be in format: /drives/{drive_id}/root:/folder/path
|
||||
filtered_driveitems = [
|
||||
(item, drive_name)
|
||||
driveitems = [
|
||||
item
|
||||
for item in driveitems
|
||||
if any(
|
||||
path_part == site_descriptor.folder_path
|
||||
@ -196,7 +191,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
)[1].split("/")
|
||||
)
|
||||
]
|
||||
if len(filtered_driveitems) == 0:
|
||||
if len(driveitems) == 0:
|
||||
all_paths = [
|
||||
item.parent_reference.path for item in driveitems
|
||||
]
|
||||
@ -204,11 +199,23 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
f"Nothing found for folder '{site_descriptor.folder_path}' "
|
||||
f"in; any of valid paths: {all_paths}"
|
||||
)
|
||||
final_driveitems.extend(filtered_driveitems)
|
||||
else:
|
||||
final_driveitems.extend(
|
||||
[(item, drive_name) for item in driveitems]
|
||||
|
||||
# Filter items based on time window if specified
|
||||
if start is not None and end is not None:
|
||||
driveitems = [
|
||||
item
|
||||
for item in driveitems
|
||||
if start
|
||||
<= item.last_modified_datetime.replace(tzinfo=timezone.utc)
|
||||
<= end
|
||||
]
|
||||
logger.debug(
|
||||
f"Found {len(driveitems)} items within time window in drive '{drive.name}'"
|
||||
)
|
||||
|
||||
for item in driveitems:
|
||||
final_driveitems.append((item, drive_name))
|
||||
|
||||
except Exception as e:
|
||||
# Some drives might not be accessible
|
||||
logger.warning(f"Failed to process drive: {str(e)}")
|
||||
|
@ -161,9 +161,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
hidden_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[]
|
||||
)
|
||||
recent_assistants: Mapped[list[dict]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
|
||||
)
|
||||
|
||||
pinned_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@ -291,8 +291,9 @@ def get_personas_for_user(
|
||||
include_deleted: bool = False,
|
||||
joinedload_all: bool = False,
|
||||
) -> Sequence[Persona]:
|
||||
stmt = select(Persona).distinct()
|
||||
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
|
||||
stmt = select(Persona)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if not include_default:
|
||||
stmt = stmt.where(Persona.builtin_persona.is_(False))
|
||||
if not include_slack_bot_personas:
|
||||
@ -302,14 +303,16 @@ def get_personas_for_user(
|
||||
|
||||
if joinedload_all:
|
||||
stmt = stmt.options(
|
||||
joinedload(Persona.prompts),
|
||||
joinedload(Persona.tools),
|
||||
joinedload(Persona.document_sets),
|
||||
joinedload(Persona.groups),
|
||||
joinedload(Persona.users),
|
||||
selectinload(Persona.prompts),
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.groups),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.labels),
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).unique().scalars().all()
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return results
|
||||
|
||||
|
||||
def get_personas(db_session: Session) -> Sequence[Persona]:
|
||||
|
@ -380,6 +380,15 @@ def index_doc_batch(
|
||||
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
|
||||
)
|
||||
|
||||
doc_descriptors = [
|
||||
{
|
||||
"doc_id": doc.id,
|
||||
"doc_length": doc.get_total_char_length(),
|
||||
}
|
||||
for doc in ctx.updatable_docs
|
||||
]
|
||||
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
@ -11,6 +13,7 @@ from requests import RequestException
|
||||
from requests import Response
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
|
||||
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
from onyx.configs.app_configs import SKIP_WARM_UP
|
||||
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
@ -155,6 +158,7 @@ class EmbeddingModel:
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
|
||||
) -> list[Embedding]:
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
|
||||
@ -163,12 +167,15 @@ class EmbeddingModel:
|
||||
)
|
||||
|
||||
embeddings: list[Embedding] = []
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
|
||||
def process_batch(
|
||||
batch_idx: int, text_batch: list[str]
|
||||
) -> tuple[int, list[Embedding]]:
|
||||
if self.callback:
|
||||
if self.callback.should_stop():
|
||||
raise RuntimeError("_batch_encode_texts detected stop signal")
|
||||
|
||||
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
||||
logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}")
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
@ -185,10 +192,43 @@ class EmbeddingModel:
|
||||
)
|
||||
|
||||
response = self._make_model_server_request(embed_request)
|
||||
embeddings.extend(response.embeddings)
|
||||
return batch_idx, response.embeddings
|
||||
|
||||
# only multi thread if:
|
||||
# 1. num_threads is greater than 1
|
||||
# 2. we are using an API-based embedding model (provider_type is not None)
|
||||
# 3. there are more than 1 batch (no point in threading if only 1)
|
||||
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
future_to_batch = {
|
||||
executor.submit(process_batch, idx, batch): idx
|
||||
for idx, batch in enumerate(text_batches, start=1)
|
||||
}
|
||||
|
||||
# Collect results in order
|
||||
batch_results: list[tuple[int, list[Embedding]]] = []
|
||||
for future in as_completed(future_to_batch):
|
||||
try:
|
||||
result = future.result()
|
||||
batch_results.append(result)
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
except Exception as e:
|
||||
logger.exception("Embedding model failed to process batch")
|
||||
raise e
|
||||
|
||||
# Sort by batch index and extend embeddings
|
||||
batch_results.sort(key=lambda x: x[0])
|
||||
for _, batch_embeddings in batch_results:
|
||||
embeddings.extend(batch_embeddings)
|
||||
else:
|
||||
# Original sequential processing
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
_, batch_embeddings = process_batch(idx, text_batch)
|
||||
embeddings.extend(batch_embeddings)
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
return embeddings
|
||||
|
||||
def encode(
|
||||
|
@ -18,6 +18,8 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||
id: str
|
||||
submitted: datetime
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
@ -42,6 +44,12 @@ class RedisConnectorPermissionSync:
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
@ -55,6 +63,7 @@ class RedisConnectorPermissionSync:
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
@ -110,6 +119,20 @@ class RedisConnectorPermissionSync:
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
"""the fence payload is an int representing the starting number of
|
||||
@ -177,6 +200,7 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
@ -191,6 +215,9 @@ class RedisConnectorPermissionSync:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorPermissionSync.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
@ -11,6 +11,7 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSyncPayload(BaseModel):
|
||||
submitted: datetime
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
@ -135,6 +136,12 @@ class RedisConnectorExternalGroupSync:
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"
|
||||
|
@ -35,8 +35,8 @@ class RedisConnectorIndex:
|
||||
TERMINATE_TTL = 600
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# there are gaps in time between states where we need some slack
|
||||
# to correctly transition
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
|
@ -123,7 +123,7 @@ class TenantRedis(redis.Redis):
|
||||
"ttl",
|
||||
] # Regular methods that need simple prefixing
|
||||
|
||||
if item == "scan_iter":
|
||||
if item == "scan_iter" or item == "sscan_iter":
|
||||
return self._prefix_scan_iter(original_attr)
|
||||
elif item in methods_to_wrap and callable(original_attr):
|
||||
return self._prefix_method(original_attr)
|
||||
|
@ -422,27 +422,29 @@ def sync_cc_pair(
|
||||
if redis_connector.permissions.fenced:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Doc permissions sync task already in progress.",
|
||||
detail="Permissions sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Doc permissions sync cc_pair={cc_pair_id} "
|
||||
f"Permissions sync cc_pair={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not tasks_created:
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Doc permissions sync task creation failed.",
|
||||
detail="Permissions sync task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the doc permissions sync task.",
|
||||
message="Successfully created the permissions sync task.",
|
||||
)
|
||||
|
||||
|
||||
|
@ -44,7 +44,6 @@ class UserPreferences(BaseModel):
|
||||
chosen_assistants: list[int] | None = None
|
||||
hidden_assistants: list[int] = []
|
||||
visible_assistants: list[int] = []
|
||||
recent_assistants: list[int] | None = None
|
||||
default_model: str | None = None
|
||||
auto_scroll: bool | None = None
|
||||
pinned_assistants: list[int] | None = None
|
||||
|
@ -572,59 +572,6 @@ class ChosenDefaultModelRequest(BaseModel):
|
||||
default_model: str | None = None
|
||||
|
||||
|
||||
class RecentAssistantsRequest(BaseModel):
|
||||
current_assistant: int
|
||||
|
||||
|
||||
def update_recent_assistants(
|
||||
recent_assistants: list[int] | None, current_assistant: int
|
||||
) -> list[int]:
|
||||
if recent_assistants is None:
|
||||
recent_assistants = []
|
||||
else:
|
||||
recent_assistants = [x for x in recent_assistants if x != current_assistant]
|
||||
|
||||
# Add current assistant to start of list
|
||||
recent_assistants.insert(0, current_assistant)
|
||||
|
||||
# Keep only the 5 most recent assistants
|
||||
recent_assistants = recent_assistants[:5]
|
||||
return recent_assistants
|
||||
|
||||
|
||||
@router.patch("/user/recent-assistants")
|
||||
def update_user_recent_assistants(
|
||||
request: RecentAssistantsRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
preferences = no_auth_user.preferences
|
||||
recent_assistants = preferences.recent_assistants
|
||||
updated_preferences = update_recent_assistants(
|
||||
recent_assistants, request.current_assistant
|
||||
)
|
||||
preferences.recent_assistants = updated_preferences
|
||||
set_no_auth_user_preferences(store, preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
recent_assistants = UserInfo.from_model(user).preferences.recent_assistants
|
||||
updated_recent_assistants = update_recent_assistants(
|
||||
recent_assistants, request.current_assistant
|
||||
)
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(recent_assistants=updated_recent_assistants)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.patch("/shortcut-enabled")
|
||||
def update_user_shortcut_enabled(
|
||||
shortcut_enabled: bool,
|
||||
@ -731,30 +678,6 @@ class ChosenAssistantsRequest(BaseModel):
|
||||
chosen_assistants: list[int]
|
||||
|
||||
|
||||
@router.patch("/user/assistant-list")
|
||||
def update_user_assistant_list(
|
||||
request: ChosenAssistantsRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
no_auth_user.preferences.chosen_assistants = request.chosen_assistants
|
||||
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(chosen_assistants=request.chosen_assistants)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_assistant_visibility(
|
||||
preferences: UserPreferences, assistant_id: int, show: bool
|
||||
) -> UserPreferences:
|
||||
|
@ -1,4 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
@ -66,3 +68,10 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
||||
)
|
||||
|
||||
return masked_creds
|
||||
|
||||
|
||||
def make_short_id() -> str:
|
||||
"""Fast way to generate a random 8 character id ... useful for tagging data
|
||||
to trace it through a flow. This is definitely not guaranteed to be unique and is
|
||||
targeted at the stated use case."""
|
||||
return base64.b32encode(os.urandom(5)).decode("utf-8")[:8] # 5 bytes → 8 chars
|
||||
|
@ -26,6 +26,13 @@ doc_permission_sync_ctx: contextvars.ContextVar[
|
||||
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
|
||||
|
||||
|
||||
class LoggerContextVars:
|
||||
@staticmethod
|
||||
def reset() -> None:
|
||||
pruning_ctx.set(dict())
|
||||
doc_permission_sync_ctx.set(dict())
|
||||
|
||||
|
||||
class TaskAttemptSingleton:
|
||||
"""Used to tell if this process is an indexing job, and if so what is the
|
||||
unique identifier for this indexing attempt. For things like the API server,
|
||||
@ -70,27 +77,32 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
|
||||
) -> tuple[str, MutableMapping[str, Any]]:
|
||||
# If this is an indexing job, add the attempt ID to the log message
|
||||
# This helps filter the logs for this specific indexing
|
||||
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
|
||||
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
|
||||
while True:
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
if len(pruning_ctx_dict) > 0:
|
||||
if "request_id" in pruning_ctx_dict:
|
||||
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
if len(pruning_ctx_dict) > 0:
|
||||
if "request_id" in pruning_ctx_dict:
|
||||
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
|
||||
if "cc_pair_id" in pruning_ctx_dict:
|
||||
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
|
||||
break
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
if len(doc_permission_sync_ctx_dict) > 0:
|
||||
if "request_id" in doc_permission_sync_ctx_dict:
|
||||
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
|
||||
break
|
||||
|
||||
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
|
||||
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
|
||||
|
||||
if "cc_pair_id" in pruning_ctx_dict:
|
||||
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
|
||||
elif len(doc_permission_sync_ctx_dict) > 0:
|
||||
if "request_id" in doc_permission_sync_ctx_dict:
|
||||
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
|
||||
else:
|
||||
if index_attempt_id is not None:
|
||||
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
|
||||
|
||||
if cc_pair_id is not None:
|
||||
msg = f"[CC Pair: {cc_pair_id}] {msg}"
|
||||
|
||||
break
|
||||
# Add tenant information if it differs from default
|
||||
# This will always be the case for authenticated API requests
|
||||
if MULTI_TENANT:
|
||||
|
@ -81,6 +81,7 @@ hubspot-api-client==8.1.0
|
||||
asana==5.0.8
|
||||
dropbox==11.36.2
|
||||
boto3-stubs[s3]==1.34.133
|
||||
shapely==2.0.6
|
||||
stripe==10.12.0
|
||||
urllib3==2.2.3
|
||||
mistune==0.8.4
|
||||
|
@ -176,3 +176,35 @@ def test_sharepoint_connector_other_library(
|
||||
for expected in expected_documents:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_poll(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests"]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
|
||||
start = datetime(2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc) # 12 seconds before
|
||||
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
|
||||
|
||||
# Get documents within the time window
|
||||
document_batches = list(connector._fetch_from_sharepoint(start=start, end=end))
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
|
||||
# Should only find test1.docx
|
||||
assert len(found_documents) == 1, "Should only find one document in the time window"
|
||||
doc = found_documents[0]
|
||||
assert doc.semantic_identifier == "test1.docx"
|
||||
verify_document_metadata(doc)
|
||||
verify_document_content(
|
||||
doc, [d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0]
|
||||
)
|
||||
|
@ -1,75 +0,0 @@
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: celery-worker-heavy-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: celery-worker-heavy
|
||||
minReplicas: 1
|
||||
maxReplicas: 5
|
||||
metrics:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: 60
|
||||
---
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: celery-worker-light-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: celery-worker-light
|
||||
minReplicas: 1
|
||||
maxReplicas: 10
|
||||
metrics:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: 70
|
||||
---
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: celery-worker-indexing-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: celery-worker-indexing
|
||||
minReplicas: 1
|
||||
maxReplicas: 10
|
||||
metrics:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: 70
|
||||
---
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: celery-worker-monitoring-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: celery-worker-indexing
|
||||
minReplicas: 1
|
||||
maxReplicas: 4
|
||||
metrics:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: 70
|
@ -1,13 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: TriggerAuthentication
|
||||
metadata:
|
||||
name: celery-worker-auth
|
||||
namespace: onyx
|
||||
spec:
|
||||
secretTargetRef:
|
||||
- parameter: host
|
||||
name: keda-redis-secret
|
||||
key: host
|
||||
- parameter: password
|
||||
name: keda-redis-secret
|
||||
key: password
|
@ -1,53 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: celery-worker-indexing-scaledobject
|
||||
namespace: onyx
|
||||
labels:
|
||||
app: celery-worker-indexing
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-indexing
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 30
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing:3
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: cpu
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
||||
|
||||
- type: memory
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
@ -1,58 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: celery-worker-light-scaledobject
|
||||
namespace: onyx
|
||||
labels:
|
||||
app: celery-worker-light
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-light
|
||||
minReplicaCount: 5
|
||||
maxReplicaCount: 20
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync:3
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_deletion
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_deletion:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
@ -1,70 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: celery-worker-primary-scaledobject
|
||||
namespace: onyx
|
||||
labels:
|
||||
app: celery-worker-primary
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-primary
|
||||
pollingInterval: 15 # Check every 15 seconds
|
||||
cooldownPeriod: 30 # Wait 30 seconds before scaling down
|
||||
minReplicaCount: 4
|
||||
maxReplicaCount: 4
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:1
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:3
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: periodic_tasks
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: periodic_tasks:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
@ -1,19 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: indexing-model-server-scaledobject
|
||||
namespace: onyx
|
||||
labels:
|
||||
app: indexing-model-server
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: indexing-model-server-deployment
|
||||
pollingInterval: 15 # Check every 15 seconds
|
||||
cooldownPeriod: 30 # Wait 30 seconds before scaling down
|
||||
minReplicaCount: 10
|
||||
maxReplicaCount: 10
|
||||
triggers:
|
||||
- type: cpu
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
@ -1,9 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: keda-redis-secret
|
||||
namespace: onyx
|
||||
type: Opaque
|
||||
data:
|
||||
host: { base64 encoded host here }
|
||||
password: { base64 encoded password here }
|
@ -1,44 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-beat
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-beat
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-beat
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-beat
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
resources:
|
||||
requests:
|
||||
cpu: "250m"
|
||||
memory: "512Mi"
|
||||
limits:
|
||||
cpu: "500m"
|
||||
memory: "1Gi"
|
@ -1,60 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-worker-heavy
|
||||
spec:
|
||||
replicas: 2
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-worker-heavy
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-worker-heavy
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-heavy
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: vespa-certificates
|
||||
mountPath: "/app/certs"
|
||||
readOnly: true
|
||||
resources:
|
||||
requests:
|
||||
cpu: "1000m"
|
||||
memory: "2Gi"
|
||||
limits:
|
||||
cpu: "2000m"
|
||||
memory: "4Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
secretName: vespa-certificates
|
||||
items:
|
||||
- key: cert.pem
|
||||
path: cert.pem
|
||||
- key: key.pem
|
||||
path: key.pem
|
@ -1,62 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-worker-indexing
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-worker-indexing
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-worker-indexing
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-indexing
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
"--prefetch-multiplier=1",
|
||||
"--concurrency=10",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: vespa-certificates
|
||||
mountPath: "/app/certs"
|
||||
readOnly: true
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "4Gi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "8Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
secretName: vespa-certificates
|
||||
items:
|
||||
- key: cert.pem
|
||||
path: cert.pem
|
||||
- key: key.pem
|
||||
path: key.pem
|
@ -1,62 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-worker-light
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-worker-light
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-worker-light
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-light
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
"--prefetch-multiplier=1",
|
||||
"--concurrency=10",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: vespa-certificates
|
||||
mountPath: "/app/certs"
|
||||
readOnly: true
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "1Gi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "2Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
secretName: vespa-certificates
|
||||
items:
|
||||
- key: cert.pem
|
||||
path: cert.pem
|
||||
- key: key.pem
|
||||
path: key.pem
|
@ -1,62 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-worker-monitoring
|
||||
spec:
|
||||
replicas: 2
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-worker-monitoring
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-worker-monitoring
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-monitoring
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
"--prefetch-multiplier=8",
|
||||
"--concurrency=8",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: vespa-certificates
|
||||
mountPath: "/app/certs"
|
||||
readOnly: true
|
||||
resources:
|
||||
requests:
|
||||
cpu: "1000m"
|
||||
memory: "1Gi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "1Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
secretName: vespa-certificates
|
||||
items:
|
||||
- key: cert.pem
|
||||
path: cert.pem
|
||||
- key: key.pem
|
||||
path: key.pem
|
@ -1,62 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-worker-primary
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-worker-primary
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-worker-primary
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-primary
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery,periodic_tasks",
|
||||
"--prefetch-multiplier=1",
|
||||
"--concurrency=10",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: vespa-certificates
|
||||
mountPath: "/app/certs"
|
||||
readOnly: true
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "1Gi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "2Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
secretName: vespa-certificates
|
||||
items:
|
||||
- key: cert.pem
|
||||
path: cert.pem
|
||||
- key: key.pem
|
||||
path: key.pem
|
@ -1,41 +1,16 @@
|
||||
import { defineConfig, devices } from "@playwright/test";
|
||||
|
||||
export default defineConfig({
|
||||
workers: 1, // temporary change to see if single threaded testing stabilizes the tests
|
||||
testDir: "./tests/e2e", // Folder for test files
|
||||
reporter: "list",
|
||||
// Configure paths for screenshots
|
||||
// expect: {
|
||||
// toMatchSnapshot: {
|
||||
// threshold: 0.2, // Adjust the threshold for visual diffs
|
||||
// },
|
||||
// },
|
||||
// reporter: [["html", { outputFolder: "test-results/output/report" }]], // HTML report location
|
||||
// outputDir: "test-results/output/screenshots", // Set output folder for test artifacts
|
||||
globalSetup: require.resolve("./tests/e2e/global-setup"),
|
||||
|
||||
projects: [
|
||||
{
|
||||
// dependency for admin workflows
|
||||
name: "admin_setup",
|
||||
testMatch: /.*\admin_auth\.setup\.ts/,
|
||||
},
|
||||
{
|
||||
// tests admin workflows
|
||||
name: "chromium-admin",
|
||||
grep: /@admin/,
|
||||
name: "admin",
|
||||
use: {
|
||||
...devices["Desktop Chrome"],
|
||||
// Use prepared auth state.
|
||||
storageState: "admin_auth.json",
|
||||
},
|
||||
dependencies: ["admin_setup"],
|
||||
},
|
||||
{
|
||||
// tests logged out / guest workflows
|
||||
name: "chromium-guest",
|
||||
grep: /@guest/,
|
||||
use: {
|
||||
...devices["Desktop Chrome"],
|
||||
},
|
||||
testIgnore: ["**/codeUtils.test.ts"],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
@ -232,11 +232,9 @@ export function AssistantEditor({
|
||||
existingPersona?.llm_model_provider_override ?? null,
|
||||
llm_model_version_override:
|
||||
existingPersona?.llm_model_version_override ?? null,
|
||||
starter_messages: existingPersona?.starter_messages ?? [
|
||||
{
|
||||
message: "",
|
||||
},
|
||||
],
|
||||
starter_messages: existingPersona?.starter_messages?.length
|
||||
? existingPersona.starter_messages
|
||||
: [{ message: "" }],
|
||||
enabled_tools_map: enabledToolsMap,
|
||||
icon_color: existingPersona?.icon_color ?? defautIconColor,
|
||||
icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
|
||||
@ -1099,7 +1097,9 @@ export function AssistantEditor({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="w-full flex flex-col">
|
||||
<div className="flex gap-x-2 items-center">
|
||||
<div className="block font-medium text-sm">
|
||||
@ -1110,6 +1110,7 @@ export function AssistantEditor({
|
||||
<SubLabel>
|
||||
Sample messages that help users understand what this
|
||||
assistant can do and how to interact with it effectively.
|
||||
New input fields will appear automatically as you type.
|
||||
</SubLabel>
|
||||
|
||||
<div className="w-full">
|
||||
|
@ -64,19 +64,16 @@ export default function StarterMessagesList({
|
||||
size="icon"
|
||||
onClick={() => {
|
||||
arrayHelpers.remove(index);
|
||||
if (
|
||||
index === values.length - 2 &&
|
||||
!values[values.length - 1].message
|
||||
) {
|
||||
arrayHelpers.pop();
|
||||
}
|
||||
}}
|
||||
className={`text-gray-400 hover:text-red-500 ${
|
||||
index === values.length - 1 && !starterMessage.message
|
||||
? "opacity-50 cursor-not-allowed"
|
||||
: ""
|
||||
}`}
|
||||
disabled={index === values.length - 1 && !starterMessage.message}
|
||||
disabled={
|
||||
(index === values.length - 1 && !starterMessage.message) ||
|
||||
(values.length === 1 && index === 0) // should never happen, but just in case
|
||||
}
|
||||
>
|
||||
<FiTrash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
|
@ -111,6 +111,7 @@ import {
|
||||
import AssistantModal from "../assistants/mine/AssistantModal";
|
||||
import { getSourceMetadata } from "@/lib/sources";
|
||||
import { UserSettingsModal } from "./modal/UserSettingsModal";
|
||||
import { AlignStartVertical } from "lucide-react";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@ -189,7 +190,11 @@ export function ChatPage({
|
||||
|
||||
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
|
||||
|
||||
const { assistants: availableAssistants, finalAssistants } = useAssistants();
|
||||
const {
|
||||
assistants: availableAssistants,
|
||||
finalAssistants,
|
||||
pinnedAssistants,
|
||||
} = useAssistants();
|
||||
|
||||
const [showApiKeyModal, setShowApiKeyModal] = useState(
|
||||
!shouldShowWelcomeModal
|
||||
@ -272,16 +277,6 @@ export function ChatPage({
|
||||
SEARCH_PARAM_NAMES.TEMPERATURE
|
||||
);
|
||||
|
||||
const defaultTemperature = search_param_temperature
|
||||
? parseFloat(search_param_temperature)
|
||||
: selectedAssistant?.tools.some(
|
||||
(tool) =>
|
||||
tool.in_code_tool_id === SEARCH_TOOL_ID ||
|
||||
tool.in_code_tool_id === INTERNET_SEARCH_TOOL_ID
|
||||
)
|
||||
? 0
|
||||
: 0.7;
|
||||
|
||||
const setSelectedAssistantFromId = (assistantId: number) => {
|
||||
// NOTE: also intentionally look through available assistants here, so that
|
||||
// even if the user has hidden an assistant they can still go back to it
|
||||
@ -297,20 +292,21 @@ export function ChatPage({
|
||||
const [presentingDocument, setPresentingDocument] =
|
||||
useState<OnyxDocument | null>(null);
|
||||
|
||||
const { recentAssistants, refreshRecentAssistants } = useAssistants();
|
||||
|
||||
// Current assistant is decided based on this ordering
|
||||
// 1. Alternative assistant (assistant selected explicitly by user)
|
||||
// 2. Selected assistant (assistnat default in this chat session)
|
||||
// 3. First pinned assistants (ordered list of pinned assistants)
|
||||
// 4. Available assistants (ordered list of available assistants)
|
||||
const liveAssistant: Persona | undefined = useMemo(
|
||||
() =>
|
||||
alternativeAssistant ||
|
||||
selectedAssistant ||
|
||||
recentAssistants[0] ||
|
||||
finalAssistants[0] ||
|
||||
pinnedAssistants[0] ||
|
||||
availableAssistants[0],
|
||||
[
|
||||
alternativeAssistant,
|
||||
selectedAssistant,
|
||||
recentAssistants,
|
||||
finalAssistants,
|
||||
pinnedAssistants,
|
||||
availableAssistants,
|
||||
]
|
||||
);
|
||||
@ -816,7 +812,6 @@ export function ChatPage({
|
||||
setMaxTokens(maxTokens);
|
||||
}
|
||||
}
|
||||
refreshRecentAssistants(liveAssistant?.id);
|
||||
fetchMaxTokens();
|
||||
}, [liveAssistant]);
|
||||
|
||||
|
@ -19,9 +19,7 @@ import {
|
||||
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { ChatSession } from "../interfaces";
|
||||
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
|
||||
import { Folder } from "../folders/interfaces";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
|
||||
import { DocumentIcon2, NewChatIcon } from "@/components/icons/icons";
|
||||
@ -251,9 +249,11 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
|
||||
const handleNewChat = () => {
|
||||
reset();
|
||||
console.log("currentChatSession", currentChatSession);
|
||||
|
||||
const newChatUrl =
|
||||
`/${page}` +
|
||||
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA && currentChatSession
|
||||
(currentChatSession
|
||||
? `?assistantId=${currentChatSession.persona_id}`
|
||||
: "");
|
||||
router.push(newChatUrl);
|
||||
@ -275,7 +275,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
flex-col relative
|
||||
h-screen
|
||||
pt-2
|
||||
|
||||
transition-transform
|
||||
`}
|
||||
>
|
||||
@ -294,8 +293,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
className="w-full px-2 py-1 rounded-md items-center hover:bg-hover cursor-pointer transition-all duration-150 flex gap-x-2"
|
||||
href={
|
||||
`/${page}` +
|
||||
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&
|
||||
currentChatSession?.persona_id
|
||||
(currentChatSession
|
||||
? `?assistantId=${currentChatSession?.persona_id}`
|
||||
: "")
|
||||
}
|
||||
@ -320,14 +318,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
<Link
|
||||
className="w-full px-2 py-1 rounded-md items-center hover:bg-hover cursor-pointer transition-all duration-150 flex gap-x-2"
|
||||
href="/chat/input-prompts"
|
||||
onClick={(e) => {
|
||||
if (e.metaKey || e.ctrlKey) {
|
||||
return;
|
||||
}
|
||||
if (handleNewChat) {
|
||||
handleNewChat();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<DocumentIcon2
|
||||
size={20}
|
||||
|
@ -2,7 +2,6 @@
|
||||
import { UserDropdown } from "../UserDropdown";
|
||||
import { FiShare2 } from "react-icons/fi";
|
||||
import { SetStateAction, useContext, useEffect } from "react";
|
||||
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
|
||||
import { ChatSession } from "@/app/chat/interfaces";
|
||||
import Link from "next/link";
|
||||
import { pageType } from "@/app/chat/sessionSidebar/types";
|
||||
@ -42,8 +41,7 @@ export default function FunctionalHeader({
|
||||
event.preventDefault();
|
||||
window.open(
|
||||
`/${page}` +
|
||||
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&
|
||||
currentChatSession
|
||||
(currentChatSession
|
||||
? `?assistantId=${currentChatSession.persona_id}`
|
||||
: ""),
|
||||
"_self"
|
||||
@ -63,7 +61,7 @@ export default function FunctionalHeader({
|
||||
reset();
|
||||
const newChatUrl =
|
||||
`/${page}` +
|
||||
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA && currentChatSession
|
||||
(currentChatSession
|
||||
? `?assistantId=${currentChatSession.persona_id}`
|
||||
: "");
|
||||
router.push(newChatUrl);
|
||||
@ -128,25 +126,6 @@ export default function FunctionalHeader({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* <div
|
||||
className={`absolute
|
||||
${
|
||||
documentSidebarToggled && !sidebarToggled
|
||||
? "left-[calc(50%-125px)]"
|
||||
: !documentSidebarToggled && sidebarToggled
|
||||
? "left-[calc(50%+125px)]"
|
||||
: "left-1/2"
|
||||
}
|
||||
${
|
||||
documentSidebarToggled || sidebarToggled
|
||||
? "mobile:w-[40vw] max-w-[50vw]"
|
||||
: "mobile:w-[50vw] max-w-[60vw]"
|
||||
}
|
||||
top-1/2 transform -translate-x-1/2 -translate-y-1/2 transition-all duration-300`}
|
||||
>
|
||||
<ChatBanner />
|
||||
</div> */}
|
||||
|
||||
<div className="invisible">
|
||||
<LogoWithText
|
||||
page={page}
|
||||
@ -156,8 +135,6 @@ export default function FunctionalHeader({
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* className="fixed cursor-pointer flex z-40 left-4 bg-black top-3 h-8" */}
|
||||
|
||||
<div className="absolute right-2 mobile:top-1 desktop:top-1 h-8 flex">
|
||||
{setSharingModalVisible && !hideUserDropdown && (
|
||||
<div
|
||||
@ -179,8 +156,7 @@ export default function FunctionalHeader({
|
||||
className="desktop:hidden ml-2 my-auto"
|
||||
href={
|
||||
`/${page}` +
|
||||
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&
|
||||
currentChatSession
|
||||
(currentChatSession
|
||||
? `?assistantId=${currentChatSession.persona_id}`
|
||||
: "")
|
||||
}
|
||||
|
@ -25,8 +25,6 @@ interface AssistantsContextProps {
|
||||
ownedButHiddenAssistants: Persona[];
|
||||
refreshAssistants: () => Promise<void>;
|
||||
isImageGenerationAvailable: boolean;
|
||||
recentAssistants: Persona[];
|
||||
refreshRecentAssistants: (currentAssistant: number) => Promise<void>;
|
||||
// Admin only
|
||||
editablePersonas: Persona[];
|
||||
allAssistants: Persona[];
|
||||
@ -56,35 +54,28 @@ export const AssistantsProvider: React.FC<{
|
||||
const [editablePersonas, setEditablePersonas] = useState<Persona[]>([]);
|
||||
const [allAssistants, setAllAssistants] = useState<Persona[]>([]);
|
||||
|
||||
const [pinnedAssistants, setPinnedAssistants] = useState<Persona[]>(
|
||||
user?.preferences.pinned_assistants
|
||||
? assistants.filter((assistant) =>
|
||||
user?.preferences?.pinned_assistants?.includes(assistant.id)
|
||||
)
|
||||
: assistants.filter((a) => a.builtin_persona)
|
||||
);
|
||||
const [pinnedAssistants, setPinnedAssistants] = useState<Persona[]>(() => {
|
||||
if (user?.preferences.pinned_assistants) {
|
||||
return user.preferences.pinned_assistants
|
||||
.map((id) => assistants.find((assistant) => assistant.id === id))
|
||||
.filter((assistant): assistant is Persona => assistant !== undefined);
|
||||
} else {
|
||||
return assistants.filter((a) => a.builtin_persona);
|
||||
}
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
setPinnedAssistants(
|
||||
user?.preferences.pinned_assistants
|
||||
? assistants.filter((assistant) =>
|
||||
user?.preferences?.pinned_assistants?.includes(assistant.id)
|
||||
)
|
||||
: assistants.filter((a) => a.builtin_persona)
|
||||
);
|
||||
setPinnedAssistants(() => {
|
||||
if (user?.preferences.pinned_assistants) {
|
||||
return user.preferences.pinned_assistants
|
||||
.map((id) => assistants.find((assistant) => assistant.id === id))
|
||||
.filter((assistant): assistant is Persona => assistant !== undefined);
|
||||
} else {
|
||||
return assistants.filter((a) => a.builtin_persona);
|
||||
}
|
||||
});
|
||||
}, [user?.preferences?.pinned_assistants, assistants]);
|
||||
|
||||
const [recentAssistants, setRecentAssistants] = useState<Persona[]>(
|
||||
user?.preferences.recent_assistants
|
||||
?.filter((assistantId) =>
|
||||
assistants.find((assistant) => assistant.id === assistantId)
|
||||
)
|
||||
.map(
|
||||
(assistantId) =>
|
||||
assistants.find((assistant) => assistant.id === assistantId)!
|
||||
) || []
|
||||
);
|
||||
|
||||
const [isImageGenerationAvailable, setIsImageGenerationAvailable] =
|
||||
useState<boolean>(false);
|
||||
|
||||
@ -135,28 +126,6 @@ export const AssistantsProvider: React.FC<{
|
||||
fetchPersonas();
|
||||
}, [isAdmin, isCurator]);
|
||||
|
||||
const refreshRecentAssistants = async (currentAssistant: number) => {
|
||||
const response = await fetch("/api/user/recent-assistants", {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
current_assistant: currentAssistant,
|
||||
}),
|
||||
});
|
||||
if (!response.ok) {
|
||||
return;
|
||||
}
|
||||
setRecentAssistants((recentAssistants) => [
|
||||
assistants.find((assistant) => assistant.id === currentAssistant)!,
|
||||
|
||||
...recentAssistants.filter(
|
||||
(assistant) => assistant.id !== currentAssistant
|
||||
),
|
||||
]);
|
||||
};
|
||||
|
||||
const refreshAssistants = async () => {
|
||||
try {
|
||||
const response = await fetch("/api/persona", {
|
||||
@ -181,13 +150,6 @@ export const AssistantsProvider: React.FC<{
|
||||
} catch (error) {
|
||||
console.error("Error refreshing assistants:", error);
|
||||
}
|
||||
|
||||
setRecentAssistants(
|
||||
assistants.filter(
|
||||
(assistant) =>
|
||||
user?.preferences.recent_assistants?.includes(assistant.id) || false
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
const {
|
||||
@ -230,8 +192,6 @@ export const AssistantsProvider: React.FC<{
|
||||
editablePersonas,
|
||||
allAssistants,
|
||||
isImageGenerationAvailable,
|
||||
recentAssistants,
|
||||
refreshRecentAssistants,
|
||||
setPinnedAssistants,
|
||||
pinnedAssistants,
|
||||
}}
|
||||
|
@ -2,7 +2,6 @@
|
||||
import { useContext } from "react";
|
||||
import { FiSidebar } from "react-icons/fi";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
|
||||
import { LeftToLineIcon, NewChatIcon, RightToLineIcon } from "../icons/icons";
|
||||
import {
|
||||
Tooltip,
|
||||
@ -90,9 +89,7 @@ export default function LogoWithText({
|
||||
className="my-auto mobile:hidden"
|
||||
href={
|
||||
`/${page}` +
|
||||
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA && assistantId
|
||||
? `?assistantId=${assistantId}`
|
||||
: "")
|
||||
(assistantId ? `?assistantId=${assistantId}` : "")
|
||||
}
|
||||
onClick={(e) => {
|
||||
if (e.metaKey || e.ctrlKey) {
|
||||
|
@ -18,10 +18,6 @@ export const NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED =
|
||||
process.env.NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED?.toLowerCase() ===
|
||||
"true";
|
||||
|
||||
export const NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA =
|
||||
process.env.NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA?.toLowerCase() ===
|
||||
"true";
|
||||
|
||||
export const GMAIL_AUTH_IS_ADMIN_COOKIE_NAME = "gmail_auth_is_admin";
|
||||
|
||||
export const GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME =
|
||||
|
@ -1,24 +1,9 @@
|
||||
// dependency for all admin user tests
|
||||
import { test as setup } from "@playwright/test";
|
||||
|
||||
import { test as setup, expect } from "@playwright/test";
|
||||
import { TEST_CREDENTIALS } from "./constants";
|
||||
|
||||
setup("authenticate", async ({ page }) => {
|
||||
const { email, password } = TEST_CREDENTIALS;
|
||||
|
||||
setup("authenticate as admin", async ({ browser }) => {
|
||||
const context = await browser.newContext({ storageState: "admin_auth.json" });
|
||||
const page = await context.newPage();
|
||||
await page.goto("http://localhost:3000/chat");
|
||||
|
||||
await page.waitForURL("http://localhost:3000/auth/login?next=%2Fchat");
|
||||
|
||||
await expect(page).toHaveTitle("Onyx");
|
||||
|
||||
await page.fill("#email", email);
|
||||
await page.fill("#password", password);
|
||||
|
||||
// Click the login button
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
await page.waitForURL("http://localhost:3000/chat");
|
||||
|
||||
await page.context().storageState({ path: "admin_auth.json" });
|
||||
});
|
||||
|
@ -1,65 +1,43 @@
|
||||
import { test, expect } from "@chromatic-com/playwright";
|
||||
import { test, expect } from "@playwright/test";
|
||||
|
||||
test(
|
||||
"Admin - OAuth Redirect - Missing Code",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
await page.goto(
|
||||
"http://localhost:3000/admin/connectors/slack/oauth/callback?state=xyz"
|
||||
);
|
||||
test.use({ storageState: "admin_auth.json" });
|
||||
|
||||
await expect(page.locator("p.text-text-500")).toHaveText(
|
||||
"Missing authorization code."
|
||||
);
|
||||
}
|
||||
);
|
||||
test("Admin - OAuth Redirect - Missing Code", async ({ page }) => {
|
||||
await page.goto(
|
||||
"http://localhost:3000/admin/connectors/slack/oauth/callback?state=xyz"
|
||||
);
|
||||
|
||||
test(
|
||||
"Admin - OAuth Redirect - Missing State",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
await page.goto(
|
||||
"http://localhost:3000/admin/connectors/slack/oauth/callback?code=123"
|
||||
);
|
||||
await expect(page.locator("p.text-text-500")).toHaveText(
|
||||
"Missing authorization code."
|
||||
);
|
||||
});
|
||||
|
||||
await expect(page.locator("p.text-text-500")).toHaveText(
|
||||
"Missing state parameter."
|
||||
);
|
||||
}
|
||||
);
|
||||
test("Admin - OAuth Redirect - Missing State", async ({ page }) => {
|
||||
await page.goto(
|
||||
"http://localhost:3000/admin/connectors/slack/oauth/callback?code=123"
|
||||
);
|
||||
|
||||
test(
|
||||
"Admin - OAuth Redirect - Invalid Connector",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
await page.goto(
|
||||
"http://localhost:3000/admin/connectors/invalid-connector/oauth/callback?code=123&state=xyz"
|
||||
);
|
||||
await expect(page.locator("p.text-text-500")).toHaveText(
|
||||
"Missing state parameter."
|
||||
);
|
||||
});
|
||||
|
||||
await expect(page.locator("p.text-text-500")).toHaveText(
|
||||
"invalid_connector is not a valid source type."
|
||||
);
|
||||
}
|
||||
);
|
||||
test("Admin - OAuth Redirect - Invalid Connector", async ({ page }) => {
|
||||
await page.goto(
|
||||
"http://localhost:3000/admin/connectors/invalid-connector/oauth/callback?code=123&state=xyz"
|
||||
);
|
||||
|
||||
test(
|
||||
"Admin - OAuth Redirect - No Session",
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
await page.goto(
|
||||
"http://localhost:3000/admin/connectors/slack/oauth/callback?code=123&state=xyz"
|
||||
);
|
||||
await expect(page.locator("p.text-text-500")).toHaveText(
|
||||
"invalid_connector is not a valid source type."
|
||||
);
|
||||
});
|
||||
|
||||
await expect(page.locator("p.text-text-500")).toHaveText(
|
||||
"An error occurred during the OAuth process. Please try again."
|
||||
);
|
||||
}
|
||||
);
|
||||
test("Admin - OAuth Redirect - No Session", async ({ page }) => {
|
||||
await page.goto(
|
||||
"http://localhost:3000/admin/connectors/slack/oauth/callback?code=123&state=xyz"
|
||||
);
|
||||
|
||||
await expect(page.locator("p.text-text-500")).toHaveText(
|
||||
"An error occurred during the OAuth process. Please try again."
|
||||
);
|
||||
});
|
||||
|
@ -2,6 +2,8 @@ import { test, expect } from "@playwright/test";
|
||||
import chromaticSnpashots from "./chromaticSnpashots.json";
|
||||
import type { Page } from "@playwright/test";
|
||||
|
||||
test.use({ storageState: "admin_auth.json" });
|
||||
|
||||
async function verifyAdminPageNavigation(
|
||||
page: Page,
|
||||
path: string,
|
||||
@ -13,7 +15,10 @@ async function verifyAdminPageNavigation(
|
||||
}
|
||||
) {
|
||||
await page.goto(`http://localhost:3000/admin/${path}`);
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText(pageTitle);
|
||||
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText(pageTitle, {
|
||||
timeout: 2000,
|
||||
});
|
||||
|
||||
if (options?.paragraphText) {
|
||||
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
|
||||
@ -35,18 +40,12 @@ async function verifyAdminPageNavigation(
|
||||
}
|
||||
|
||||
for (const chromaticSnapshot of chromaticSnpashots) {
|
||||
test(
|
||||
`Admin - ${chromaticSnapshot.name}`,
|
||||
{
|
||||
tag: "@admin",
|
||||
},
|
||||
async ({ page }) => {
|
||||
await verifyAdminPageNavigation(
|
||||
page,
|
||||
chromaticSnapshot.path,
|
||||
chromaticSnapshot.pageTitle,
|
||||
chromaticSnapshot.options
|
||||
);
|
||||
}
|
||||
);
|
||||
test(`Admin - ${chromaticSnapshot.name}`, async ({ page }) => {
|
||||
await verifyAdminPageNavigation(
|
||||
page,
|
||||
chromaticSnapshot.path,
|
||||
chromaticSnapshot.pageTitle,
|
||||
chromaticSnapshot.options
|
||||
);
|
||||
});
|
||||
}
|
||||
|
54
web/tests/e2e/assisant_ordering.spec.ts
Normal file
54
web/tests/e2e/assisant_ordering.spec.ts
Normal file
@ -0,0 +1,54 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
|
||||
// Use pre-signed in "admin" storage state
|
||||
test.use({
|
||||
storageState: "admin_auth.json",
|
||||
});
|
||||
|
||||
test("Chat workflow", async ({ page }) => {
|
||||
// Initial setup
|
||||
await page.goto("http://localhost:3000/chat", { timeout: 3000 });
|
||||
|
||||
// Interact with Art assistant
|
||||
await page.locator("button").filter({ hasText: "Art" }).click();
|
||||
await page.getByPlaceholder("Message Art assistant...").fill("Hi");
|
||||
await page.keyboard.press("Enter");
|
||||
await page.waitForTimeout(3000);
|
||||
|
||||
// Start a new chat
|
||||
await page.getByRole("link", { name: "Start New Chat" }).click();
|
||||
await page.waitForNavigation({ waitUntil: "networkidle" });
|
||||
|
||||
// Check for expected text
|
||||
await expect(page.getByText("Assistant for generating")).toBeVisible();
|
||||
|
||||
// Interact with General assistant
|
||||
await page.locator("button").filter({ hasText: "General" }).click();
|
||||
|
||||
// Check URL after clicking General assistant
|
||||
await expect(page).toHaveURL("http://localhost:3000/chat?assistantId=-1", {
|
||||
timeout: 5000,
|
||||
});
|
||||
|
||||
// Create a new assistant
|
||||
await page.getByRole("button", { name: "Explore Assistants" }).click();
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
await page.getByTestId("name").click();
|
||||
await page.getByTestId("name").fill("Test Assistant");
|
||||
await page.getByTestId("description").click();
|
||||
await page.getByTestId("description").fill("Test Assistant Description");
|
||||
await page.getByTestId("system_prompt").click();
|
||||
await page.getByTestId("system_prompt").fill("Test Assistant Instructions");
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Verify new assistant creation
|
||||
await expect(page.getByText("Test Assistant Description")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
|
||||
// Start another new chat
|
||||
await page.getByRole("link", { name: "Start New Chat" }).click();
|
||||
await expect(page.getByText("Assistant with access to")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
});
|
@ -1,4 +1,9 @@
|
||||
export const TEST_CREDENTIALS = {
|
||||
export const TEST_USER_CREDENTIALS = {
|
||||
email: "user1@test.com",
|
||||
password: "User1Password123!",
|
||||
};
|
||||
|
||||
export const TEST_ADMIN_CREDENTIALS = {
|
||||
email: "admin_user@test.com",
|
||||
password: "TestPassword123!",
|
||||
};
|
||||
|
22
web/tests/e2e/global-setup.ts
Normal file
22
web/tests/e2e/global-setup.ts
Normal file
@ -0,0 +1,22 @@
|
||||
import { chromium, FullConfig } from "@playwright/test";
|
||||
import { loginAs } from "./utils/auth";
|
||||
|
||||
async function globalSetup(config: FullConfig) {
|
||||
const browser = await chromium.launch();
|
||||
|
||||
const adminContext = await browser.newContext();
|
||||
const adminPage = await adminContext.newPage();
|
||||
await loginAs(adminPage, "admin");
|
||||
await adminContext.storageState({ path: "admin_auth.json" });
|
||||
await adminContext.close();
|
||||
|
||||
const userContext = await browser.newContext();
|
||||
const userPage = await userContext.newPage();
|
||||
await loginAs(userPage, "user");
|
||||
await userContext.storageState({ path: "user_auth.json" });
|
||||
await userContext.close();
|
||||
|
||||
await browser.close();
|
||||
}
|
||||
|
||||
export default globalSetup;
|
@ -1,35 +0,0 @@
|
||||
// ➕ Add this line
|
||||
import { test, expect, takeSnapshot } from "@chromatic-com/playwright";
|
||||
import { TEST_CREDENTIALS } from "./constants";
|
||||
|
||||
// Then use as normal 👇
|
||||
test(
|
||||
"Homepage",
|
||||
{
|
||||
tag: "@guest",
|
||||
},
|
||||
async ({ page }, testInfo) => {
|
||||
// Test redirect to login, and redirect to search after login
|
||||
const { email, password } = TEST_CREDENTIALS;
|
||||
|
||||
await page.goto("http://localhost:3000/chat");
|
||||
|
||||
await page.waitForURL("http://localhost:3000/auth/login?next=%2Fchat");
|
||||
|
||||
await expect(page).toHaveTitle("Onyx");
|
||||
|
||||
await takeSnapshot(page, "Before login", testInfo);
|
||||
|
||||
await page.fill("#email", email);
|
||||
await page.fill("#password", password);
|
||||
|
||||
// Click the login button
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
await page.waitForURL("http://localhost:3000/chat");
|
||||
|
||||
await page.getByPlaceholder("Send a message or try using @ or /");
|
||||
|
||||
await expect(page.locator("body")).not.toContainText("Initializing Onyx");
|
||||
}
|
||||
);
|
37
web/tests/e2e/utils/auth.ts
Normal file
37
web/tests/e2e/utils/auth.ts
Normal file
@ -0,0 +1,37 @@
|
||||
import { Page } from "@playwright/test";
|
||||
import { TEST_ADMIN_CREDENTIALS, TEST_USER_CREDENTIALS } from "../constants";
|
||||
|
||||
// Basic function which logs in a user (either admin or regular user) to the application
|
||||
// It handles both successful login attempts and potential timeouts, with a retry mechanism
|
||||
export async function loginAs(page: Page, userType: "admin" | "user") {
|
||||
const { email, password } =
|
||||
userType === "admin" ? TEST_ADMIN_CREDENTIALS : TEST_USER_CREDENTIALS;
|
||||
await page.goto("http://localhost:3000/auth/login", { timeout: 1000 });
|
||||
|
||||
await page.fill("#email", email);
|
||||
await page.fill("#password", password);
|
||||
|
||||
// Click the login button
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
try {
|
||||
await page.waitForURL("http://localhost:3000/chat", { timeout: 4000 });
|
||||
} catch (error) {
|
||||
console.log(`Timeout occurred. Current URL: ${page.url()}`);
|
||||
|
||||
// If redirect to /chat doesn't happen, go to /auth/login
|
||||
await page.goto("http://localhost:3000/auth/signup", { timeout: 1000 });
|
||||
|
||||
await page.fill("#email", email);
|
||||
await page.fill("#password", password);
|
||||
|
||||
// Click the login button
|
||||
await page.click('button[type="submit"]');
|
||||
|
||||
try {
|
||||
await page.waitForURL("http://localhost:3000/chat", { timeout: 4000 });
|
||||
} catch (error) {
|
||||
console.log(`Timeout occurred again. Current URL: ${page.url()}`);
|
||||
}
|
||||
}
|
||||
}
|
15
web/user_auth.json
Normal file
15
web/user_auth.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"cookies": [
|
||||
{
|
||||
"name": "fastapiusersauth",
|
||||
"value": "n_EMYYKHn4tQbuPTEbtN1gJ6dQTGek9omJPhO2GhHoA",
|
||||
"domain": "localhost",
|
||||
"path": "/",
|
||||
"expires": 1738801376.508558,
|
||||
"httpOnly": true,
|
||||
"secure": false,
|
||||
"sameSite": "Lax"
|
||||
}
|
||||
],
|
||||
"origins": []
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user