mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-19 12:30:55 +02:00
Connector checkpointing (#3876)
* wip checkpointing/continue on failure more stuff for checkpointing Basic implementation FE stuff More checkpointing/failure handling rebase rebase initial scaffolding for IT IT to test checkpointing Cleanup cleanup Fix it Rebase Add todo Fix actions IT Test more Pagination + fixes + cleanup Fix IT networking fix it * rebase * Address misc comments * Address comments * Remove unused router * rebase * Fix mypy * Fixes * fix it * Fix tests * Add drop index * Add retries * reset lock timeout * Try hard drop of schema * Add timeout/retries to downgrade * rebase * test * test * test * Close all connections * test closing idle only * Fix it * fix * try using null pool * Test * fix * rebase * log * Fix * apply null pool * Fix other test * Fix quality checks * Test not using the fixture * Fix ordering * fix test * Change pooling behavior
This commit is contained in:
parent
bc087fc20e
commit
f1fc8ac19b
33
.github/workflows/pr-integration-tests.yml
vendored
33
.github/workflows/pr-integration-tests.yml
vendored
@ -99,7 +99,7 @@ jobs:
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
@ -108,12 +108,13 @@ jobs:
|
||||
echo "Waiting for 3 minutes to ensure API server is ready..."
|
||||
sleep 180
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
@ -143,24 +144,27 @@ jobs:
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
@ -190,15 +194,24 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
@ -208,6 +221,8 @@ jobs:
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
@ -229,13 +244,13 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
@ -249,4 +264,4 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
|
2
.vscode/launch.template.jsonc
vendored
2
.vscode/launch.template.jsonc
vendored
@ -205,7 +205,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
|
@ -0,0 +1,124 @@
|
||||
"""Add checkpointing/failure handling
|
||||
|
||||
Revision ID: b7a7eee5aa15
|
||||
Revises: f39c5794c10a
|
||||
Create Date: 2025-01-24 15:17:36.763172
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7a7eee5aa15"
|
||||
down_revision = "f39c5794c10a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("checkpoint_pointer", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_index_attempt_cc_pair_settings_poll",
|
||||
"index_attempt",
|
||||
[
|
||||
"connector_credential_pair_id",
|
||||
"search_settings_id",
|
||||
"status",
|
||||
sa.text("time_updated DESC"),
|
||||
],
|
||||
)
|
||||
|
||||
# Drop the old IndexAttemptError table
|
||||
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
|
||||
op.drop_table("index_attempt_errors")
|
||||
|
||||
# Create the new version of the table
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
sa.Column("document_link", sa.String(), nullable=True),
|
||||
sa.Column("entity_id", sa.String(), nullable=True),
|
||||
sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("failure_message", sa.Text(), nullable=False),
|
||||
sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("SET lock_timeout = '5s'")
|
||||
|
||||
# try a few times to drop the table, this has been observed to fail due to other locks
|
||||
# blocking the drop
|
||||
NUM_TRIES = 10
|
||||
for i in range(NUM_TRIES):
|
||||
try:
|
||||
op.drop_table("index_attempt_errors")
|
||||
break
|
||||
except Exception as e:
|
||||
if i == NUM_TRIES - 1:
|
||||
raise e
|
||||
print(f"Error dropping table: {e}. Retrying...")
|
||||
|
||||
op.execute("SET lock_timeout = DEFAULT")
|
||||
|
||||
# Recreate the old IndexAttemptError table
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
|
||||
sa.Column("batch", sa.Integer(), nullable=True),
|
||||
sa.Column("doc_summaries", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("traceback", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"index_attempt_id",
|
||||
"index_attempt_errors",
|
||||
["time_created"],
|
||||
)
|
||||
|
||||
op.drop_index("ix_index_attempt_cc_pair_settings_poll")
|
||||
op.drop_column("index_attempt", "checkpoint_pointer")
|
||||
op.drop_column("index_attempt", "poll_range_start")
|
||||
op.drop_column("index_attempt", "poll_range_end")
|
@ -5,7 +5,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@ -17,7 +17,7 @@ logger = setup_logger()
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
@ -36,6 +36,15 @@ beat_task_templates.extend(
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-checkpoint-cleanup",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
|
@ -28,6 +28,10 @@ from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attem
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
|
||||
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
|
||||
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
|
||||
from onyx.background.indexing.checkpointing_utils import (
|
||||
get_index_attempts_with_old_checkpoints,
|
||||
)
|
||||
from onyx.background.indexing.job_client import SimpleJob
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.job_client import SimpleJobException
|
||||
@ -38,6 +42,7 @@ from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
@ -1069,3 +1074,62 @@ def connector_indexing_proxy_task(
|
||||
|
||||
redis_connector_index.set_watchdog(False)
|
||||
return
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
"""Clean up old checkpoints that are older than 7 days."""
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
locked = True
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
|
||||
for attempt in old_attempts:
|
||||
task_logger.info(
|
||||
f"Cleaning up checkpoint for index attempt {attempt.id}"
|
||||
)
|
||||
cleanup_checkpoint_task.apply_async(
|
||||
kwargs={
|
||||
"index_attempt_id": attempt.id,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.CHECKPOINT_CLEANUP,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during checkpoint cleanup")
|
||||
return None
|
||||
finally:
|
||||
if locked:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_checkpoint_cleanup - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
|
||||
bind=True,
|
||||
)
|
||||
def cleanup_checkpoint_task(
|
||||
self: Task, *, index_attempt_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
"""Clean up a checkpoint for a given index attempt"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
cleanup_checkpoint(db_session, index_attempt_id)
|
||||
|
@ -1,80 +0,0 @@
|
||||
"""Experimental functionality related to splitting up indexing
|
||||
into a series of checkpoints to better handle intermittent failures
|
||||
/ jobs being killed by cloud providers."""
|
||||
import datetime
|
||||
|
||||
from onyx.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
|
||||
|
||||
|
||||
def _2010_dt() -> datetime.datetime:
|
||||
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def _2020_dt() -> datetime.datetime:
|
||||
return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def _default_end_time(
|
||||
last_successful_run: datetime.datetime | None,
|
||||
) -> datetime.datetime:
|
||||
"""If year is before 2010, go to the beginning of 2010.
|
||||
If year is 2010-2020, go in 5 year increments.
|
||||
If year > 2020, then go in 180 day increments.
|
||||
|
||||
For connectors that don't support a `filter_by` and instead rely on `sort_by`
|
||||
for polling, then this will cause a massive duplication of fetches. For these
|
||||
connectors, you may want to override this function to return a more reasonable
|
||||
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
|
||||
last_successful_run = (
|
||||
datetime_to_utc(last_successful_run) if last_successful_run else None
|
||||
)
|
||||
if last_successful_run is None or last_successful_run < _2010_dt():
|
||||
return _2010_dt()
|
||||
|
||||
if last_successful_run < _2020_dt():
|
||||
return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt())
|
||||
|
||||
return last_successful_run + datetime.timedelta(days=180)
|
||||
|
||||
|
||||
def find_end_time_for_indexing_attempt(
|
||||
last_successful_run: datetime.datetime | None,
|
||||
# source_type can be used to override the default for certain connectors, currently unused
|
||||
source_type: DocumentSource,
|
||||
) -> datetime.datetime | None:
|
||||
"""Is the current time unless the connector is run over a large period, in which case it is
|
||||
split up into large time segments that become smaller as it approaches the present
|
||||
"""
|
||||
# NOTE: source_type can be used to override the default for certain connectors
|
||||
end_of_window = _default_end_time(last_successful_run)
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
if end_of_window < now:
|
||||
return end_of_window
|
||||
|
||||
# None signals that we should index up to current time
|
||||
return None
|
||||
|
||||
|
||||
def get_time_windows_for_index_attempt(
|
||||
last_successful_run: datetime.datetime, source_type: DocumentSource
|
||||
) -> list[tuple[datetime.datetime, datetime.datetime]]:
|
||||
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
|
||||
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
|
||||
|
||||
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
|
||||
start_of_window: datetime.datetime | None = last_successful_run
|
||||
while start_of_window:
|
||||
end_of_window = find_end_time_for_indexing_attempt(
|
||||
last_successful_run=start_of_window, source_type=source_type
|
||||
)
|
||||
time_windows.append(
|
||||
(
|
||||
start_of_window,
|
||||
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
|
||||
)
|
||||
)
|
||||
start_of_window = end_of_window
|
||||
|
||||
return time_windows
|
200
backend/onyx/background/indexing/checkpointing_utils.py
Normal file
200
backend/onyx/background/indexing/checkpointing_utils.py
Normal file
@ -0,0 +1,200 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from io import BytesIO
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.object_size_check import deep_getsizeof
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
|
||||
_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT = 100
|
||||
|
||||
|
||||
def _build_checkpoint_pointer(index_attempt_id: int) -> str:
|
||||
return f"checkpoint_{index_attempt_id}.json"
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
db_session: Session, index_attempt_id: int, checkpoint: ConnectorCheckpoint
|
||||
) -> str:
|
||||
"""Save a checkpoint for a given index attempt to the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
file_name=checkpoint_pointer,
|
||||
content=BytesIO(checkpoint.model_dump_json().encode()),
|
||||
display_name=checkpoint_pointer,
|
||||
file_origin=FileOrigin.INDEXING_CHECKPOINT,
|
||||
file_type="application/json",
|
||||
)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
index_attempt.checkpoint_pointer = checkpoint_pointer
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
return checkpoint_pointer
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> ConnectorCheckpoint | None:
|
||||
"""Load a checkpoint for a given index attempt from the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
try:
|
||||
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
|
||||
checkpoint_data = checkpoint_io.read().decode("utf-8")
|
||||
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
def get_latest_valid_checkpoint(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> ConnectorCheckpoint:
|
||||
"""Get the latest valid checkpoint for a given connector credential pair"""
|
||||
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
db_session=db_session,
|
||||
limit=_NUM_RECENT_ATTEMPTS_TO_CONSIDER,
|
||||
)
|
||||
checkpoint_candidates = [
|
||||
candidate
|
||||
for candidate in checkpoint_candidates
|
||||
if (
|
||||
candidate.poll_range_start == window_start
|
||||
and candidate.poll_range_end == window_end
|
||||
and candidate.status == IndexingStatus.FAILED
|
||||
and candidate.checkpoint_pointer is not None
|
||||
# we want to make sure that the checkpoint is actually useful
|
||||
# if it's only gone through a few docs, it's probably not worth
|
||||
# using. This also avoids weird cases where a connector is basically
|
||||
# non-functional but still "makes progress" by slowly moving the
|
||||
# checkpoint forward run after run
|
||||
and candidate.total_docs_indexed
|
||||
and candidate.total_docs_indexed > _NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT
|
||||
)
|
||||
]
|
||||
|
||||
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
|
||||
# for now, capped at 10
|
||||
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
|
||||
logger.warning(
|
||||
f"{_NUM_RECENT_ATTEMPTS_TO_CONSIDER} consecutive failed attempts found "
|
||||
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
|
||||
"from scratch."
|
||||
)
|
||||
return ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
|
||||
# assumes latest checkpoint is the furthest along. This only isn't true
|
||||
# if something else has gone wrong.
|
||||
latest_valid_checkpoint_candidate = (
|
||||
checkpoint_candidates[0] if checkpoint_candidates else None
|
||||
)
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
if latest_valid_checkpoint_candidate:
|
||||
try:
|
||||
previous_checkpoint = load_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to load checkpoint from previous failed attempt with ID "
|
||||
f"{latest_valid_checkpoint_candidate.id}."
|
||||
)
|
||||
previous_checkpoint = None
|
||||
|
||||
if previous_checkpoint is not None:
|
||||
logger.info(
|
||||
f"Using checkpoint from previous failed attempt with ID "
|
||||
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
|
||||
f"{previous_checkpoint}"
|
||||
)
|
||||
save_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
checkpoint=previous_checkpoint,
|
||||
)
|
||||
checkpoint = previous_checkpoint
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def get_index_attempts_with_old_checkpoints(
|
||||
db_session: Session, days_to_keep: int = 7
|
||||
) -> list[IndexAttempt]:
|
||||
"""Get all index attempts with checkpoints older than the specified number of days.
|
||||
|
||||
Args:
|
||||
db_session: The database session
|
||||
days_to_keep: Number of days to keep checkpoints for (default: 7)
|
||||
|
||||
Returns:
|
||||
Number of checkpoints deleted
|
||||
"""
|
||||
cutoff_date = get_db_current_time(db_session) - timedelta(days=days_to_keep)
|
||||
|
||||
# Find all index attempts with checkpoints older than cutoff_date
|
||||
old_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
and_(
|
||||
IndexAttempt.checkpoint_pointer.isnot(None),
|
||||
IndexAttempt.time_created < cutoff_date,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return old_attempts
|
||||
|
||||
|
||||
def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
|
||||
"""Clean up a checkpoint for a given index attempt"""
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
if not index_attempt.checkpoint_pointer:
|
||||
return None
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.delete_file(index_attempt.checkpoint_pointer)
|
||||
|
||||
index_attempt.checkpoint_pointer = None
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
|
||||
"""Check if the checkpoint content size exceeds the limit (200MB)"""
|
||||
content_size = deep_getsizeof(checkpoint.checkpoint_content)
|
||||
if content_size > 200_000_000: # 200MB in bytes
|
||||
raise ValueError(
|
||||
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"
|
||||
)
|
87
backend/onyx/background/indexing/memory_tracer.py
Normal file
87
backend/onyx/background/indexing/memory_tracer.py
Normal file
@ -0,0 +1,87 @@
|
||||
import tracemalloc
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DANSWER_TRACEMALLOC_FRAMES = 10
|
||||
|
||||
|
||||
class MemoryTracer:
|
||||
def __init__(self, interval: int = 0, num_print_entries: int = 5):
|
||||
self.interval = interval
|
||||
self.num_print_entries = num_print_entries
|
||||
self.snapshot_first: tracemalloc.Snapshot | None = None
|
||||
self.snapshot_prev: tracemalloc.Snapshot | None = None
|
||||
self.snapshot: tracemalloc.Snapshot | None = None
|
||||
self.counter = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the memory tracer if interval is greater than 0."""
|
||||
if self.interval > 0:
|
||||
logger.debug(f"Memory tracer starting: interval={self.interval}")
|
||||
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
|
||||
self._take_snapshot()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the memory tracer if it's running."""
|
||||
if self.interval > 0:
|
||||
self.log_final_diff()
|
||||
tracemalloc.stop()
|
||||
logger.debug("Memory tracer stopped.")
|
||||
|
||||
def _take_snapshot(self) -> None:
|
||||
"""Take a snapshot and update internal snapshot states."""
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
# Filter out irrelevant frames
|
||||
snapshot = snapshot.filter_traces(
|
||||
(
|
||||
tracemalloc.Filter(False, tracemalloc.__file__),
|
||||
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
|
||||
tracemalloc.Filter(False, "<frozen importlib._bootstrap_external>"),
|
||||
)
|
||||
)
|
||||
|
||||
if not self.snapshot_first:
|
||||
self.snapshot_first = snapshot
|
||||
|
||||
if self.snapshot:
|
||||
self.snapshot_prev = self.snapshot
|
||||
|
||||
self.snapshot = snapshot
|
||||
|
||||
def _log_diff(
|
||||
self, current: tracemalloc.Snapshot, previous: tracemalloc.Snapshot
|
||||
) -> None:
|
||||
"""Log the memory difference between two snapshots."""
|
||||
stats = current.compare_to(previous, "traceback")
|
||||
for s in stats[: self.num_print_entries]:
|
||||
logger.debug(f"Tracer diff: {s}")
|
||||
for line in s.traceback.format():
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
def increment_and_maybe_trace(self) -> None:
|
||||
"""Increment counter and perform trace if interval is hit."""
|
||||
if self.interval <= 0:
|
||||
return
|
||||
|
||||
self.counter += 1
|
||||
if self.counter % self.interval == 0:
|
||||
logger.debug(
|
||||
f"Running trace comparison for batch {self.counter}. interval={self.interval}"
|
||||
)
|
||||
self._take_snapshot()
|
||||
if self.snapshot and self.snapshot_prev:
|
||||
self._log_diff(self.snapshot, self.snapshot_prev)
|
||||
|
||||
def log_final_diff(self) -> None:
|
||||
"""Log the final memory diff between start and end of indexing."""
|
||||
if self.interval <= 0:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Running trace comparison between start and end of indexing. {self.counter} batches processed."
|
||||
)
|
||||
self._take_snapshot()
|
||||
if self.snapshot and self.snapshot_first:
|
||||
self._log_diff(self.snapshot, self.snapshot_first)
|
40
backend/onyx/background/indexing/models.py
Normal file
40
backend/onyx/background/indexing/models.py
Normal file
@ -0,0 +1,40 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.models import IndexAttemptError
|
||||
|
||||
|
||||
class IndexAttemptErrorPydantic(BaseModel):
|
||||
id: int
|
||||
connector_credential_pair_id: int
|
||||
|
||||
document_id: str | None
|
||||
document_link: str | None
|
||||
|
||||
entity_id: str | None
|
||||
failed_time_range_start: datetime | None
|
||||
failed_time_range_end: datetime | None
|
||||
|
||||
failure_message: str
|
||||
is_resolved: bool = False
|
||||
|
||||
time_created: datetime
|
||||
|
||||
index_attempt_id: int
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
|
||||
return cls(
|
||||
id=model.id,
|
||||
connector_credential_pair_id=model.connector_credential_pair_id,
|
||||
document_id=model.document_id,
|
||||
document_link=model.document_link,
|
||||
entity_id=model.entity_id,
|
||||
failed_time_range_start=model.failed_time_range_start,
|
||||
failed_time_range_end=model.failed_time_range_end,
|
||||
failure_message=model.failure_message,
|
||||
is_resolved=model.is_resolved,
|
||||
time_created=model.time_created,
|
||||
index_attempt_id=model.index_attempt_id,
|
||||
)
|
@ -1,5 +1,6 @@
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@ -7,8 +8,11 @@ from datetime import timezone
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from onyx.background.indexing.tracer import OnyxTracer
|
||||
from onyx.background.indexing.checkpointing_utils import check_checkpoint_size
|
||||
from onyx.background.indexing.checkpointing_utils import get_latest_valid_checkpoint
|
||||
from onyx.background.indexing.checkpointing_utils import save_checkpoint
|
||||
from onyx.background.indexing.memory_tracer import MemoryTracer
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
@ -17,6 +21,8 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@ -24,15 +30,18 @@ from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
@ -53,6 +62,7 @@ INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
def _get_connector_runner(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
batch_size: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
@ -100,7 +110,9 @@ def _get_connector_runner(
|
||||
raise e
|
||||
|
||||
return ConnectorRunner(
|
||||
connector=runnable_connector, time_range=(start_time, end_time)
|
||||
connector=runnable_connector,
|
||||
batch_size=batch_size,
|
||||
time_range=(start_time, end_time),
|
||||
)
|
||||
|
||||
|
||||
@ -159,6 +171,66 @@ class RunIndexingContext(BaseModel):
|
||||
search_settings_status: IndexModelStatus
|
||||
|
||||
|
||||
def _check_connector_and_attempt_status(
|
||||
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
|
||||
) -> None:
|
||||
"""
|
||||
Checks the status of the connector credential pair and index attempt.
|
||||
Raises a RuntimeError if any conditions are not met.
|
||||
"""
|
||||
cc_pair_loop = get_connector_credential_pair_from_id(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
)
|
||||
if not cc_pair_loop:
|
||||
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
|
||||
|
||||
if (
|
||||
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and ctx.search_settings_status != IndexModelStatus.FUTURE
|
||||
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt_loop:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
|
||||
)
|
||||
|
||||
|
||||
def _check_failure_threshold(
|
||||
total_failures: int,
|
||||
document_count: int,
|
||||
batch_num: int,
|
||||
last_failure: ConnectorFailure | None,
|
||||
) -> None:
|
||||
"""Check if we've hit the failure threshold and raise an appropriate exception if so.
|
||||
|
||||
We consider the threshold hit if:
|
||||
1. We have more than 3 failures AND
|
||||
2. Failures account for more than 10% of processed documents
|
||||
"""
|
||||
failure_ratio = total_failures / (document_count or 1)
|
||||
|
||||
FAILURE_THRESHOLD = 3
|
||||
FAILURE_RATIO_THRESHOLD = 0.1
|
||||
if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD:
|
||||
logger.error(
|
||||
f"Connector run failed with '{total_failures}' errors "
|
||||
f"after '{batch_num}' batches."
|
||||
)
|
||||
if last_failure and last_failure.exception:
|
||||
raise last_failure.exception from last_failure.exception
|
||||
|
||||
raise RuntimeError(
|
||||
f"Connector run encountered too many errors, aborting. "
|
||||
f"Last error: {last_failure}"
|
||||
)
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
@ -169,11 +241,8 @@ def _run_indexing(
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
|
||||
TODO: do not change index attempt statuses here ... instead, set signals in redis
|
||||
and allow the monitor function to clean them up
|
||||
"""
|
||||
start_time = time.time()
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
@ -221,6 +290,46 @@ def _run_indexing(
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
)
|
||||
if last_successful_index_time > POLL_CONNECTOR_OFFSET:
|
||||
window_start = datetime.fromtimestamp(
|
||||
last_successful_index_time, tz=timezone.utc
|
||||
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
|
||||
else:
|
||||
# don't go into "negative" time if we've never indexed before
|
||||
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
most_recent_attempt = next(
|
||||
iter(
|
||||
get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt_start.search_settings_id,
|
||||
db_session=db_session_temp,
|
||||
limit=1,
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# if the last attempt failed, try and use the same window. This is necessary
|
||||
# to ensure correctness with checkpointing. If we don't do this, things like
|
||||
# new slack channels could be missed (since existing slack channels are
|
||||
# cached as part of the checkpoint).
|
||||
if (
|
||||
most_recent_attempt
|
||||
and most_recent_attempt.poll_range_end
|
||||
and (
|
||||
most_recent_attempt.status == IndexingStatus.FAILED
|
||||
or most_recent_attempt.status == IndexingStatus.CANCELED
|
||||
)
|
||||
):
|
||||
window_end = most_recent_attempt.poll_range_end
|
||||
else:
|
||||
window_end = datetime.now(tz=timezone.utc)
|
||||
|
||||
# add start/end now that they have been set
|
||||
index_attempt_start.poll_range_start = window_start
|
||||
index_attempt_start.poll_range_end = window_end
|
||||
db_session_temp.add(index_attempt_start)
|
||||
db_session_temp.commit()
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
@ -234,7 +343,6 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt_id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
@ -246,63 +354,73 @@ def _run_indexing(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
tracer: OnyxTracer
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
|
||||
tracer = OnyxTracer()
|
||||
tracer.start()
|
||||
tracer.snap()
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
memory_tracer.start()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
|
||||
total_failures = 0
|
||||
batch_num = 0
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
run_end_dt = None
|
||||
tracer_counter: int
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
for ind, (window_start, window_end) in enumerate(
|
||||
get_time_windows_for_index_attempt(
|
||||
last_successful_run=datetime.fromtimestamp(
|
||||
last_successful_index_time, tz=timezone.utc
|
||||
),
|
||||
source_type=db_connector.source,
|
||||
)
|
||||
):
|
||||
cc_pair_loop: ConnectorCredentialPair | None = None
|
||||
index_attempt_loop: IndexAttempt | None = None
|
||||
tracer_counter = 0
|
||||
|
||||
try:
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session_temp,
|
||||
attempt=index_attempt,
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_loop_start = get_index_attempt(
|
||||
db_session_temp, index_attempt_id
|
||||
)
|
||||
if not index_attempt_loop_start:
|
||||
raise RuntimeError(
|
||||
f"Index attempt {index_attempt_id} not found in DB."
|
||||
)
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling.
|
||||
if index_attempt.from_beginning:
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint = get_latest_valid_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
attempt=index_attempt_loop_start,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.snap()
|
||||
for doc_batch in connector_runner.run():
|
||||
unresolved_errors = get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
unresolved_only=True,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
doc_id_to_unresolved_errors: dict[
|
||||
str, list[IndexAttemptError]
|
||||
] = defaultdict(list)
|
||||
for error in unresolved_errors:
|
||||
if error.document_id:
|
||||
doc_id_to_unresolved_errors[error.document_id].append(error)
|
||||
|
||||
entity_based_unresolved_errors = [
|
||||
error for error in unresolved_errors if error.entity_id
|
||||
]
|
||||
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
@ -313,41 +431,37 @@ def _run_indexing(
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
cc_pair_loop = get_connector_credential_pair_from_id(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp, ctx, index_attempt_id
|
||||
)
|
||||
if not cc_pair_loop:
|
||||
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
|
||||
|
||||
if (
|
||||
(
|
||||
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and ctx.search_settings_status != IndexModelStatus.FUTURE
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
)
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
index_attempt_loop = get_index_attempt(
|
||||
db_session_temp, index_attempt_id
|
||||
_check_failure_threshold(
|
||||
total_failures, document_count, batch_num, failure
|
||||
)
|
||||
if not index_attempt_loop:
|
||||
raise RuntimeError(
|
||||
f"Index attempt {index_attempt_id} not found in DB."
|
||||
)
|
||||
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
|
||||
)
|
||||
# save the new checkpoint (if one is provided)
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# below is all document processing logic, so if no batch we can just continue
|
||||
if document_batch is None:
|
||||
continue
|
||||
|
||||
batch_description = []
|
||||
|
||||
doc_batch_cleaned = strip_null_characters(doc_batch)
|
||||
doc_batch_cleaned = strip_null_characters(document_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
@ -377,15 +491,51 @@ def _run_indexing(
|
||||
chunk_count += index_pipeline_result.total_chunks
|
||||
document_count += index_pipeline_result.total_docs
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
# of the transactions when computing `NOW()`, so if we have
|
||||
# a long running transaction, the `time_updated` field will
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
# resolve errors for documents that were successfully indexed
|
||||
failed_document_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in index_pipeline_result.failures
|
||||
if failure.failed_document
|
||||
]
|
||||
successful_document_ids = [
|
||||
document.id
|
||||
for document in document_batch
|
||||
if document.id not in failed_document_ids
|
||||
]
|
||||
for document_id in successful_document_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
if document_id in doc_id_to_unresolved_errors:
|
||||
logger.info(
|
||||
f"Resolving IndexAttemptError for document '{document_id}'"
|
||||
)
|
||||
for error in doc_id_to_unresolved_errors[document_id]:
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
# add brand new failures
|
||||
if index_pipeline_result.failures:
|
||||
total_failures += len(index_pipeline_result.failures)
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
for failure in index_pipeline_result.failures:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures,
|
||||
document_count,
|
||||
batch_num,
|
||||
index_pipeline_result.failures[-1],
|
||||
)
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
|
||||
# so we need either to commit() or to use a new session
|
||||
update_docs_indexed(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
@ -397,126 +547,77 @@ def _run_indexing(
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
tracer_counter += 1
|
||||
if (
|
||||
INDEXING_TRACER_INTERVAL > 0
|
||||
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
|
||||
):
|
||||
logger.debug(
|
||||
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
run_end_dt = window_end
|
||||
if ctx.is_primary:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# `make sure the checkpoints aren't getting too large`at some regular interval
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
|
||||
check_checkpoint_size(checkpoint)
|
||||
|
||||
# save latest checkpoint
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
f"{time.monotonic() - start_time} seconds"
|
||||
)
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or (
|
||||
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
|
||||
)
|
||||
or (
|
||||
index_attempt_loop is not None
|
||||
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
|
||||
)
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
break
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(
|
||||
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
tracer.stop()
|
||||
logger.debug("Memory tracer stopped.")
|
||||
|
||||
if (
|
||||
index_attempt_md.num_exceptions > 0
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
raise Exception(
|
||||
f"Connector failed - All batches exceptioned: batches={batch_num}"
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
memory_tracer.stop()
|
||||
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
# resolve entity-based errors
|
||||
for error in entity_based_unresolved_errors:
|
||||
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
if total_failures == 0:
|
||||
mark_attempt_succeeded(index_attempt_id, db_session_temp)
|
||||
|
||||
create_milestone_and_report(
|
||||
@ -535,7 +636,7 @@ def _run_indexing(
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"exceptions={index_attempt_md.num_exceptions} "
|
||||
f"failures={total_failures} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
@ -547,7 +648,7 @@ def _run_indexing(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
run_dt=run_end_dt,
|
||||
run_dt=window_end,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,77 +0,0 @@
|
||||
import tracemalloc
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DANSWER_TRACEMALLOC_FRAMES = 10
|
||||
|
||||
|
||||
class OnyxTracer:
|
||||
def __init__(self) -> None:
|
||||
self.snapshot_first: tracemalloc.Snapshot | None = None
|
||||
self.snapshot_prev: tracemalloc.Snapshot | None = None
|
||||
self.snapshot: tracemalloc.Snapshot | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
|
||||
|
||||
def stop(self) -> None:
|
||||
tracemalloc.stop()
|
||||
|
||||
def snap(self) -> None:
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
|
||||
snapshot = snapshot.filter_traces(
|
||||
(
|
||||
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
|
||||
tracemalloc.Filter(
|
||||
False, "<frozen importlib._bootstrap>"
|
||||
), # Exclude importlib
|
||||
tracemalloc.Filter(
|
||||
False, "<frozen importlib._bootstrap_external>"
|
||||
), # Exclude external importlib
|
||||
)
|
||||
)
|
||||
|
||||
if not self.snapshot_first:
|
||||
self.snapshot_first = snapshot
|
||||
|
||||
if self.snapshot:
|
||||
self.snapshot_prev = self.snapshot
|
||||
|
||||
self.snapshot = snapshot
|
||||
|
||||
def log_snapshot(self, numEntries: int) -> None:
|
||||
if not self.snapshot:
|
||||
return
|
||||
|
||||
stats = self.snapshot.statistics("traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.debug(f"Tracer snap: {s}")
|
||||
for line in s.traceback:
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
@staticmethod
|
||||
def log_diff(
|
||||
snap_current: tracemalloc.Snapshot,
|
||||
snap_previous: tracemalloc.Snapshot,
|
||||
numEntries: int,
|
||||
) -> None:
|
||||
stats = snap_current.compare_to(snap_previous, "traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.debug(f"Tracer diff: {s}")
|
||||
for line in s.traceback.format():
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
def log_previous_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_prev:
|
||||
return
|
||||
|
||||
OnyxTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
|
||||
|
||||
def log_first_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_first:
|
||||
return
|
||||
|
||||
OnyxTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)
|
@ -169,6 +169,11 @@ POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
|
||||
)
|
||||
|
||||
# defaults to False
|
||||
# generally should only be used for
|
||||
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
|
||||
|
||||
# defaults to False
|
||||
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
|
||||
|
||||
|
@ -165,6 +165,9 @@ class DocumentSource(str, Enum):
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
|
||||
# Special case just for integration tests
|
||||
MOCK_CONNECTOR = "mock_connector"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
@ -243,6 +246,7 @@ class FileOrigin(str, Enum):
|
||||
CHAT_IMAGE_GEN = "chat_image_gen"
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
INDEXING_CHECKPOINT = "indexing_checkpoint"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@ -274,6 +278,7 @@ class OnyxCeleryQueues:
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
LLM_MODEL_UPDATE = "llm_model_update"
|
||||
CHECKPOINT_CLEANUP = "checkpoint_cleanup"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
@ -293,6 +298,7 @@ class OnyxRedisLocks:
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
|
||||
CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat"
|
||||
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_doc_permissions_sync_beat"
|
||||
)
|
||||
@ -368,6 +374,10 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
|
||||
# Connector checkpoint cleanup
|
||||
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
|
||||
CLEANUP_CHECKPOINT = "cleanup_checkpoint"
|
||||
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
|
||||
|
@ -1,11 +1,16 @@
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@ -15,48 +20,139 @@ logger = setup_logger()
|
||||
TimeRange = tuple[datetime, datetime]
|
||||
|
||||
|
||||
class CheckpointOutputWrapper:
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
||||
formats.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_checkpoint: ConnectorCheckpoint | None = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput,
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
# grabs the final return value and stores it in the `next_checkpoint` variable
|
||||
def _inner_wrapper(
|
||||
checkpoint_connector_generator: CheckpointOutput,
|
||||
) -> CheckpointOutput:
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(document_or_failure, Document):
|
||||
yield document_or_failure, None, None
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
yield None, document_or_failure, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid document_or_failure type: {type(document_or_failure)}"
|
||||
)
|
||||
|
||||
if self.next_checkpoint is None:
|
||||
raise RuntimeError(
|
||||
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
|
||||
)
|
||||
|
||||
yield None, None, self.next_checkpoint
|
||||
|
||||
|
||||
class ConnectorRunner:
|
||||
"""
|
||||
Handles:
|
||||
- Batching
|
||||
- Additional exception logging
|
||||
- Combining different connector types to a single interface
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
batch_size: int,
|
||||
time_range: TimeRange | None = None,
|
||||
fail_loudly: bool = False,
|
||||
):
|
||||
self.connector = connector
|
||||
self.time_range = time_range
|
||||
self.batch_size = batch_size
|
||||
|
||||
if isinstance(self.connector, PollConnector):
|
||||
if time_range is None:
|
||||
raise ValueError("time_range is required for PollConnector")
|
||||
self.doc_batch: list[Document] = []
|
||||
|
||||
self.doc_batch_generator = self.connector.poll_source(
|
||||
time_range[0].timestamp(), time_range[1].timestamp()
|
||||
)
|
||||
|
||||
elif isinstance(self.connector, LoadConnector):
|
||||
if time_range and fail_loudly:
|
||||
raise ValueError(
|
||||
"time_range specified, but passed in connector is not a PollConnector"
|
||||
)
|
||||
|
||||
self.doc_batch_generator = self.connector.load_from_state()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
|
||||
|
||||
def run(self) -> GenerateDocumentsOutput:
|
||||
def run(
|
||||
self, checkpoint: ConnectorCheckpoint
|
||||
) -> Generator[
|
||||
tuple[
|
||||
list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
try:
|
||||
start = time.monotonic()
|
||||
for batch in self.doc_batch_generator:
|
||||
# to know how long connector is taking
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to build a batch."
|
||||
)
|
||||
|
||||
yield batch
|
||||
if isinstance(self.connector, CheckpointConnector):
|
||||
if self.time_range is None:
|
||||
raise ValueError("time_range is required for CheckpointConnector")
|
||||
|
||||
start = time.monotonic()
|
||||
checkpoint_connector_generator = self.connector.load_from_checkpoint(
|
||||
start=self.time_range[0].timestamp(),
|
||||
end=self.time_range[1].timestamp(),
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
next_checkpoint: ConnectorCheckpoint | None = None
|
||||
# this is guaranteed to always run at least once with next_checkpoint being non-None
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None:
|
||||
self.doc_batch.append(document)
|
||||
|
||||
if failure is not None:
|
||||
yield None, failure, None
|
||||
|
||||
if len(self.doc_batch) >= self.batch_size:
|
||||
yield self.doc_batch, None, None
|
||||
self.doc_batch = []
|
||||
|
||||
# yield remaining documents
|
||||
if len(self.doc_batch) > 0:
|
||||
yield self.doc_batch, None, None
|
||||
self.doc_batch = []
|
||||
|
||||
yield None, None, next_checkpoint
|
||||
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
|
||||
)
|
||||
|
||||
else:
|
||||
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
finished_checkpoint.has_more = False
|
||||
|
||||
if isinstance(self.connector, PollConnector):
|
||||
if self.time_range is None:
|
||||
raise ValueError("time_range is required for PollConnector")
|
||||
|
||||
for document_batch in self.connector.poll_source(
|
||||
start=self.time_range[0].timestamp(),
|
||||
end=self.time_range[1].timestamp(),
|
||||
):
|
||||
yield document_batch, None, None
|
||||
|
||||
yield None, None, finished_checkpoint
|
||||
elif isinstance(self.connector, LoadConnector):
|
||||
for document_batch in self.connector.load_from_state():
|
||||
yield document_batch, None, None
|
||||
|
||||
yield None, None, finished_checkpoint
|
||||
else:
|
||||
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
|
||||
except Exception:
|
||||
exc_type, _, exc_traceback = sys.exc_info()
|
||||
|
||||
|
@ -30,12 +30,14 @@ from onyx.connectors.google_site.connector import GoogleSitesConnector
|
||||
from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.linear.connector import LinearConnector
|
||||
from onyx.connectors.loopio.connector import LoopioConnector
|
||||
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
|
||||
from onyx.connectors.mock_connector.connector import MockConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.notion.connector import NotionConnector
|
||||
from onyx.connectors.onyx_jira.connector import JiraConnector
|
||||
@ -43,7 +45,7 @@ from onyx.connectors.productboard.connector import ProductboardConnector
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.slab.connector import SlabConnector
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.connectors.web.connector import WebConnector
|
||||
from onyx.connectors.wikipedia.connector import WikipediaConnector
|
||||
@ -66,8 +68,8 @@ def identify_connector_class(
|
||||
DocumentSource.WEB: WebConnector,
|
||||
DocumentSource.FILE: LocalFileConnector,
|
||||
DocumentSource.SLACK: {
|
||||
InputType.POLL: SlackPollConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackPollConnector,
|
||||
InputType.POLL: SlackConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
@ -109,6 +111,8 @@ def identify_connector_class(
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
@ -125,10 +129,23 @@ def identify_connector_class(
|
||||
|
||||
if any(
|
||||
[
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not issubclass(connector, LoadConnector),
|
||||
input_type == InputType.POLL and not issubclass(connector, PollConnector),
|
||||
input_type == InputType.EVENT and not issubclass(connector, EventConnector),
|
||||
(
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not issubclass(connector, LoadConnector)
|
||||
),
|
||||
(
|
||||
input_type == InputType.POLL
|
||||
# either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointConnector)
|
||||
)
|
||||
),
|
||||
(
|
||||
input_type == InputType.EVENT
|
||||
and not issubclass(connector, EventConnector)
|
||||
),
|
||||
]
|
||||
):
|
||||
raise ConnectorMissingException(
|
||||
|
@ -1,10 +1,13 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@ -14,6 +17,7 @@ SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
|
||||
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC):
|
||||
@ -105,3 +109,33 @@ class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def handle_event(self, event: Any) -> GenerateDocumentsOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CheckpointConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Yields back documents or failures. Final return is the new checkpoint.
|
||||
|
||||
Final return can be access via either:
|
||||
|
||||
```
|
||||
try:
|
||||
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
|
||||
print(document_or_failure)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value # Extracting the return value
|
||||
print(checkpoint)
|
||||
```
|
||||
|
||||
OR
|
||||
|
||||
```
|
||||
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
86
backend/onyx/connectors/mock_connector/connector.py
Normal file
86
backend/onyx/connectors/mock_connector/connector.py
Normal file
@ -0,0 +1,86 @@
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SingleConnectorYield(BaseModel):
|
||||
documents: list[Document]
|
||||
checkpoint: ConnectorCheckpoint
|
||||
failures: list[ConnectorFailure]
|
||||
unhandled_exception: str | None = None
|
||||
|
||||
|
||||
class MockConnector(CheckpointConnector):
|
||||
def __init__(
|
||||
self,
|
||||
mock_server_host: str,
|
||||
mock_server_port: int,
|
||||
) -> None:
|
||||
self.mock_server_host = mock_server_host
|
||||
self.mock_server_port = mock_server_port
|
||||
self.client = httpx.Client(timeout=30.0)
|
||||
|
||||
self.connector_yields: list[SingleConnectorYield] | None = None
|
||||
self.current_yield_index: int = 0
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
response = self.client.get(self._get_mock_server_url("get-documents"))
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
self.connector_yields = [
|
||||
SingleConnectorYield(**yield_data) for yield_data in data
|
||||
]
|
||||
return None
|
||||
|
||||
def _get_mock_server_url(self, endpoint: str) -> str:
|
||||
return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}"
|
||||
|
||||
def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None:
|
||||
response = self.client.post(
|
||||
self._get_mock_server_url("add-checkpoint"),
|
||||
json=checkpoint.model_dump(mode="json"),
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
if self.connector_yields is None:
|
||||
raise ValueError("No connector yields configured")
|
||||
|
||||
# Save the checkpoint to the mock server
|
||||
self._save_checkpoint(checkpoint)
|
||||
|
||||
yield_index = self.current_yield_index
|
||||
self.current_yield_index += 1
|
||||
current_yield = self.connector_yields[yield_index]
|
||||
|
||||
# If the current yield has an unhandled exception, raise it
|
||||
# This is used to simulate an unhandled failure in the connector.
|
||||
if current_yield.unhandled_exception:
|
||||
raise RuntimeError(current_yield.unhandled_exception)
|
||||
|
||||
# yield all documents
|
||||
for document in current_yield.documents:
|
||||
yield document
|
||||
|
||||
for failure in current_yield.failures:
|
||||
yield failure
|
||||
|
||||
return current_yield.checkpoint
|
@ -3,6 +3,7 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
@ -187,36 +188,48 @@ class SlimDocument(BaseModel):
|
||||
perm_sync_data: Any | None = None
|
||||
|
||||
|
||||
class DocumentErrorSummary(BaseModel):
|
||||
id: str
|
||||
semantic_id: str
|
||||
section_link: str | None
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, doc: Document) -> "DocumentErrorSummary":
|
||||
section_link = doc.sections[0].link if len(doc.sections) > 0 else None
|
||||
return cls(
|
||||
id=doc.id, semantic_id=doc.semantic_identifier, section_link=section_link
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DocumentErrorSummary":
|
||||
return cls(
|
||||
id=str(data.get("id")),
|
||||
semantic_id=str(data.get("semantic_id")),
|
||||
section_link=str(data.get("section_link")),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"semantic_id": self.semantic_id,
|
||||
"section_link": self.section_link,
|
||||
}
|
||||
|
||||
|
||||
class IndexAttemptMetadata(BaseModel):
|
||||
batch_num: int | None = None
|
||||
num_exceptions: int = 0
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
|
||||
|
||||
class ConnectorCheckpoint(BaseModel):
|
||||
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
|
||||
checkpoint_content: dict
|
||||
has_more: bool
|
||||
|
||||
@classmethod
|
||||
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
|
||||
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
|
||||
|
||||
|
||||
class DocumentFailure(BaseModel):
|
||||
document_id: str
|
||||
document_link: str | None = None
|
||||
|
||||
|
||||
class EntityFailure(BaseModel):
|
||||
entity_id: str
|
||||
missed_time_range: tuple[datetime, datetime] | None = None
|
||||
|
||||
|
||||
class ConnectorFailure(BaseModel):
|
||||
failed_document: DocumentFailure | None = None
|
||||
failed_entity: EntityFailure | None = None
|
||||
failure_message: str
|
||||
exception: Exception | None = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_failed_fields(cls, values: dict) -> dict:
|
||||
failed_document = values.get("failed_document")
|
||||
failed_entity = values.get("failed_entity")
|
||||
if (failed_document is None and failed_entity is None) or (
|
||||
failed_document is not None and failed_entity is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"Exactly one of 'failed_document' or 'failed_entity' must be specified."
|
||||
)
|
||||
return values
|
||||
|
@ -1,10 +1,16 @@
|
||||
import contextvars
|
||||
import copy
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypedDict
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
@ -12,14 +18,18 @@ from slack_sdk.errors import SlackApiError
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
@ -33,6 +43,8 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
|
||||
ChannelType = dict[str, Any]
|
||||
MessageType = dict[str, Any]
|
||||
@ -40,6 +52,13 @@ MessageType = dict[str, Any]
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
|
||||
class SlackCheckpointContent(TypedDict):
|
||||
channel_ids: list[str]
|
||||
channel_completion_map: dict[str, str]
|
||||
current_channel: ChannelType | None
|
||||
seen_thread_ts: list[str]
|
||||
|
||||
|
||||
def _collect_paginated_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool,
|
||||
@ -140,6 +159,10 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
|
||||
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
|
||||
|
||||
|
||||
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
|
||||
return f"{channel_id}__{thread_ts}"
|
||||
|
||||
|
||||
def thread_to_doc(
|
||||
channel: ChannelType,
|
||||
thread: ThreadType,
|
||||
@ -182,7 +205,7 @@ def thread_to_doc(
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"{channel_id}__{thread[0]['ts']}",
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(event=m, client=client, channel_id=channel_id),
|
||||
@ -267,64 +290,97 @@ def filter_channels(
|
||||
]
|
||||
|
||||
|
||||
def _get_all_docs(
|
||||
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
|
||||
"""Get a channel by its ID.
|
||||
|
||||
Args:
|
||||
client: The Slack WebClient instance
|
||||
channel_id: The ID of the channel to fetch
|
||||
|
||||
Returns:
|
||||
The channel information
|
||||
|
||||
Raises:
|
||||
SlackApiError: If the channel cannot be fetched
|
||||
"""
|
||||
response = make_slack_api_call_w_retries(
|
||||
client.conversations_info,
|
||||
channel=channel_id,
|
||||
)
|
||||
return cast(ChannelType, response["channel"])
|
||||
|
||||
|
||||
def _get_messages(
|
||||
channel: ChannelType,
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Get all documents in the workspace, channel by channel"""
|
||||
slack_cleaner = SlackTextCleaner(client=client)
|
||||
) -> tuple[list[MessageType], bool]:
|
||||
"""Slack goes from newest to oldest."""
|
||||
|
||||
# Cache to prevent refetching via API since users
|
||||
user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
# have to be in the channel in order to read messages
|
||||
if not channel["is_member"]:
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
response = make_slack_api_call_w_retries(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
limit=_SLACK_LIMIT,
|
||||
)
|
||||
response.validate()
|
||||
|
||||
for channel in filtered_channels:
|
||||
channel_docs = 0
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client, channel=channel, oldest=oldest, latest=latest
|
||||
messages = cast(list[MessageType], response.get("messages", []))
|
||||
|
||||
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
|
||||
"next_cursor", ""
|
||||
)
|
||||
has_more = bool(cursor)
|
||||
return messages, has_more
|
||||
|
||||
|
||||
def _message_to_doc(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> Document | None:
|
||||
filtered_thread: ThreadType | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
# skip threads we've already seen, since we've already processed all
|
||||
# messages in that thread
|
||||
if thread_ts in seen_thread_ts:
|
||||
return None
|
||||
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||
)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
elif not msg_filter_func(message):
|
||||
filtered_thread = [message]
|
||||
|
||||
if filtered_thread:
|
||||
return thread_to_doc(
|
||||
channel=channel,
|
||||
thread=filtered_thread,
|
||||
slack_cleaner=slack_cleaner,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
)
|
||||
|
||||
seen_thread_ts: set[str] = set()
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
filtered_thread: ThreadType | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
# skip threads we've already seen, since we've already processed all
|
||||
# messages in that thread
|
||||
if thread_ts in seen_thread_ts:
|
||||
continue
|
||||
seen_thread_ts.add(thread_ts)
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||
)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
elif not msg_filter_func(message):
|
||||
filtered_thread = [message]
|
||||
|
||||
if filtered_thread:
|
||||
channel_docs += 1
|
||||
yield thread_to_doc(
|
||||
channel=channel,
|
||||
thread=filtered_thread,
|
||||
slack_cleaner=slack_cleaner,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_all_doc_ids(
|
||||
@ -368,7 +424,7 @@ def _get_all_doc_ids(
|
||||
for message_ts in message_ts_set:
|
||||
channel_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=f"{channel_id}__{message_ts}",
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
|
||||
perm_sync_data={"channel_id": channel_id},
|
||||
)
|
||||
)
|
||||
@ -376,7 +432,51 @@ def _get_all_doc_ids(
|
||||
yield channel_metadata_list
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector, SlimConnector):
|
||||
def _process_message(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
|
||||
thread_ts = message.get("thread_ts")
|
||||
try:
|
||||
# causes random failures for testing checkpointing / continue on failure
|
||||
# import random
|
||||
# if random.random() > 0.95:
|
||||
# raise RuntimeError("Random failure :P")
|
||||
|
||||
doc = _message_to_doc(
|
||||
message=message,
|
||||
client=client,
|
||||
channel=channel,
|
||||
slack_cleaner=slack_cleaner,
|
||||
user_cache=user_cache,
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
msg_filter_func=msg_filter_func,
|
||||
)
|
||||
return (doc, thread_ts, None)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing message {message['ts']}")
|
||||
return (
|
||||
None,
|
||||
thread_ts,
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=_build_doc_id(
|
||||
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
|
||||
),
|
||||
document_link=get_message_link(message, client, channel["id"]),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
def __init__(
|
||||
self,
|
||||
channels: list[str] | None = None,
|
||||
@ -390,9 +490,14 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
self.batch_size = batch_size
|
||||
self.client: WebClient | None = None
|
||||
|
||||
# just used for efficiency
|
||||
self.text_cleaner: SlackTextCleaner | None = None
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
self.client = WebClient(token=bot_token)
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
return None
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
@ -411,30 +516,155 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.client is None:
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Rough outline:
|
||||
|
||||
Step 1: Get all channels, yield back Checkpoint.
|
||||
Step 2: Loop through each channel. For each channel:
|
||||
Step 2.1: Get messages within the time range.
|
||||
Step 2.2: Process messages in parallel, yield back docs.
|
||||
Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel.
|
||||
Slack returns messages from newest to oldest, so we need to keep track of
|
||||
the latest message we've seen in each channel.
|
||||
Step 2.4: If there are no more messages in the channel, switch the current
|
||||
channel to the next channel.
|
||||
"""
|
||||
if self.client is None or self.text_cleaner is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
documents: list[Document] = []
|
||||
for document in _get_all_docs(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
|
||||
# throw an error if we use 0.0 on an account without infinite data
|
||||
# retention
|
||||
oldest=str(start) if start else None,
|
||||
latest=str(end),
|
||||
):
|
||||
documents.append(document)
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
documents = []
|
||||
checkpoint_content = cast(
|
||||
SlackCheckpointContent,
|
||||
(
|
||||
copy.deepcopy(checkpoint.checkpoint_content)
|
||||
or {
|
||||
"channel_ids": None,
|
||||
"channel_completion_map": {},
|
||||
"current_channel": None,
|
||||
"seen_thread_ts": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
# if this is the very first time we've called this, need to
|
||||
# get all relevant channels and save them into the checkpoint
|
||||
if checkpoint_content["channel_ids"] is None:
|
||||
raw_channels = get_channels(self.client)
|
||||
filtered_channels = filter_channels(
|
||||
raw_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
if len(filtered_channels) == 0:
|
||||
return checkpoint
|
||||
|
||||
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
|
||||
checkpoint_content["current_channel"] = filtered_channels[0]
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=True,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
final_channel_ids = checkpoint_content["channel_ids"]
|
||||
channel = checkpoint_content["current_channel"]
|
||||
if channel is None:
|
||||
raise ValueError("current_channel key not found in checkpoint")
|
||||
|
||||
channel_id = channel["id"]
|
||||
if channel_id not in final_channel_ids:
|
||||
raise ValueError(f"Channel {channel_id} not found in checkpoint")
|
||||
|
||||
oldest = str(start) if start else None
|
||||
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
|
||||
try:
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
)
|
||||
message_batch, has_more_in_channel = _get_messages(
|
||||
channel, self.client, oldest, latest
|
||||
)
|
||||
new_latest = message_batch[-1]["ts"] if message_batch else latest
|
||||
|
||||
# Process messages in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures: list[Future] = []
|
||||
for message in message_batch:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
futures.append(
|
||||
executor.submit(
|
||||
current_context.run,
|
||||
_process_message,
|
||||
message=message,
|
||||
client=self.client,
|
||||
channel=channel,
|
||||
slack_cleaner=self.text_cleaner,
|
||||
user_cache=self.user_cache,
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
)
|
||||
)
|
||||
|
||||
for future in as_completed(futures):
|
||||
doc, thread_ts, failures = future.result()
|
||||
if doc:
|
||||
# handle race conditions here since this is single
|
||||
# threaded. Multi-threaded _process_message reads from this
|
||||
# but since this is single threaded, we won't run into simul
|
||||
# writes. At worst, we can duplicate a thread, which will be
|
||||
# deduped later on.
|
||||
if thread_ts not in seen_thread_ts:
|
||||
yield doc
|
||||
|
||||
if thread_ts:
|
||||
seen_thread_ts.add(thread_ts)
|
||||
elif failures:
|
||||
for failure in failures:
|
||||
yield failure
|
||||
|
||||
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
|
||||
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
|
||||
if has_more_in_channel:
|
||||
checkpoint_content["current_channel"] = channel
|
||||
else:
|
||||
new_channel_id = next(
|
||||
(
|
||||
channel_id
|
||||
for channel_id in final_channel_ids
|
||||
if channel_id
|
||||
not in checkpoint_content["channel_completion_map"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if new_channel_id:
|
||||
new_channel = _get_channel_by_id(self.client, new_channel_id)
|
||||
checkpoint_content["current_channel"] = new_channel
|
||||
else:
|
||||
checkpoint_content["current_channel"] = None
|
||||
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=checkpoint_content["current_channel"] is not None,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing channel {channel['name']}")
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=channel["id"],
|
||||
missed_time_range=(
|
||||
datetime.fromtimestamp(start, tz=timezone.utc),
|
||||
datetime.fromtimestamp(end, tz=timezone.utc),
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -442,7 +672,7 @@ if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
slack_channel = os.environ.get("SLACK_CHANNEL")
|
||||
connector = SlackPollConnector(
|
||||
connector = SlackConnector(
|
||||
channels=[slack_channel] if slack_channel else None,
|
||||
)
|
||||
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})
|
||||
@ -450,6 +680,17 @@ if __name__ == "__main__":
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
|
||||
print(next(document_batches))
|
||||
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
|
||||
try:
|
||||
for document_or_failure in gen:
|
||||
if isinstance(document_or_failure, Document):
|
||||
print(document_or_failure)
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
print(document_or_failure)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
print("Next checkpoint:", checkpoint)
|
||||
|
||||
print("Next checkpoint:", checkpoint)
|
||||
|
@ -34,9 +34,14 @@ def get_message_link(
|
||||
) -> str:
|
||||
channel_id = channel_id or event["channel"]
|
||||
message_ts = event["ts"]
|
||||
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
|
||||
permalink = response["permalink"]
|
||||
return permalink
|
||||
message_ts_without_dot = message_ts.replace(".", "")
|
||||
thread_ts = event.get("thread_ts")
|
||||
base_url = get_base_url(client.token)
|
||||
|
||||
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
|
||||
f"?thread_ts={thread_ts}" if thread_ts else ""
|
||||
)
|
||||
return link
|
||||
|
||||
|
||||
def _make_slack_api_call_paginated(
|
||||
|
@ -18,6 +18,7 @@ import boto3
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
@ -39,6 +40,7 @@ from onyx.configs.app_configs import POSTGRES_PASSWORD
|
||||
from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
|
||||
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
@ -187,20 +189,38 @@ class SqlEngine:
|
||||
_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
DEFAULT_ENGINE_KWARGS = {
|
||||
"pool_size": 20,
|
||||
"max_overflow": 5,
|
||||
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||
)
|
||||
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
||||
engine = create_engine(connection_string, **merged_kwargs)
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
final_engine_kwargs: dict[str, Any] = {}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
# if null pool is specified, then we need to make sure that
|
||||
# we remove any passed in kwargs related to pool size that would
|
||||
# cause the initialization to fail
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
final_engine_kwargs["poolclass"] = pool.NullPool
|
||||
if "pool_size" in final_engine_kwargs:
|
||||
del final_engine_kwargs["pool_size"]
|
||||
if "max_overflow" in final_engine_kwargs:
|
||||
del final_engine_kwargs["max_overflow"]
|
||||
else:
|
||||
final_engine_kwargs["pool_size"] = 20
|
||||
final_engine_kwargs["max_overflow"] = 5
|
||||
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||
|
||||
# any passed in kwargs override the defaults
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
@ -299,13 +319,21 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
|
||||
connect_args["ssl"] = ssl_context
|
||||
|
||||
engine_kwargs = {
|
||||
"connect_args": connect_args,
|
||||
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
engine_kwargs["poolclass"] = pool.NullPool
|
||||
else:
|
||||
engine_kwargs["pool_size"] = POSTGRES_API_SERVER_POOL_SIZE
|
||||
engine_kwargs["max_overflow"] = POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string,
|
||||
connect_args=connect_args,
|
||||
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
||||
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
**engine_kwargs,
|
||||
)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
@ -11,8 +11,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentErrorSummary
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
@ -41,6 +40,27 @@ def get_last_attempt_for_cc_pair(
|
||||
)
|
||||
|
||||
|
||||
def get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
limit: int,
|
||||
db_session: Session,
|
||||
) -> list[IndexAttempt]:
|
||||
return (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
IndexAttempt.status.notin_(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_index_attempt(
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> IndexAttempt | None:
|
||||
@ -615,23 +635,32 @@ def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
|
||||
def create_index_attempt_error(
|
||||
index_attempt_id: int | None,
|
||||
batch: int | None,
|
||||
docs: list[Document],
|
||||
exception_msg: str,
|
||||
exception_traceback: str,
|
||||
connector_credential_pair_id: int,
|
||||
failure: ConnectorFailure,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
doc_summaries = []
|
||||
for doc in docs:
|
||||
doc_summary = DocumentErrorSummary.from_document(doc)
|
||||
doc_summaries.append(doc_summary.to_dict())
|
||||
|
||||
new_error = IndexAttemptError(
|
||||
index_attempt_id=index_attempt_id,
|
||||
batch=batch,
|
||||
doc_summaries=doc_summaries,
|
||||
error_msg=exception_msg,
|
||||
traceback=exception_traceback,
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
document_id=(
|
||||
failure.failed_document.document_id if failure.failed_document else None
|
||||
),
|
||||
document_link=(
|
||||
failure.failed_document.document_link if failure.failed_document else None
|
||||
),
|
||||
entity_id=(failure.failed_entity.entity_id if failure.failed_entity else None),
|
||||
failed_time_range_start=(
|
||||
failure.failed_entity.missed_time_range[0]
|
||||
if failure.failed_entity and failure.failed_entity.missed_time_range
|
||||
else None
|
||||
),
|
||||
failed_time_range_end=(
|
||||
failure.failed_entity.missed_time_range[1]
|
||||
if failure.failed_entity and failure.failed_entity.missed_time_range
|
||||
else None
|
||||
),
|
||||
failure_message=failure.failure_message,
|
||||
is_resolved=False,
|
||||
)
|
||||
db_session.add(new_error)
|
||||
db_session.commit()
|
||||
@ -649,3 +678,42 @@ def get_index_attempt_errors(
|
||||
|
||||
errors = db_session.scalars(stmt)
|
||||
return list(errors.all())
|
||||
|
||||
|
||||
def count_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
unresolved_only: bool,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(IndexAttemptError)
|
||||
.where(IndexAttemptError.connector_credential_pair_id == cc_pair_id)
|
||||
)
|
||||
if unresolved_only:
|
||||
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
|
||||
|
||||
result = db_session.scalar(stmt)
|
||||
return 0 if result is None else result
|
||||
|
||||
|
||||
def get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
unresolved_only: bool,
|
||||
db_session: Session,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[IndexAttemptError]:
|
||||
stmt = select(IndexAttemptError).where(
|
||||
IndexAttemptError.connector_credential_pair_id == cc_pair_id
|
||||
)
|
||||
if unresolved_only:
|
||||
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
|
||||
|
||||
# Order by most recent first
|
||||
stmt = stmt.order_by(desc(IndexAttemptError.time_created))
|
||||
|
||||
if page is not None and page_size is not None:
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
@ -827,6 +827,19 @@ class IndexAttempt(Base):
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# for polling connectors, the start and end time of the poll window
|
||||
# will be set when the index attempt starts
|
||||
poll_range_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
poll_range_end: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
|
||||
# Points to the last checkpoint that was saved for this run. The pointer here
|
||||
# can be taken to the FileStore to grab the actual checkpoint value
|
||||
checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@ -870,6 +883,13 @@ class IndexAttempt(Base):
|
||||
desc("time_updated"),
|
||||
unique=False,
|
||||
),
|
||||
Index(
|
||||
"ix_index_attempt_cc_pair_settings_poll",
|
||||
"connector_credential_pair_id",
|
||||
"search_settings_id",
|
||||
"status",
|
||||
desc("time_updated"),
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -886,25 +906,33 @@ class IndexAttempt(Base):
|
||||
|
||||
|
||||
class IndexAttemptError(Base):
|
||||
"""
|
||||
Represents an error that was encountered during an IndexAttempt.
|
||||
"""
|
||||
|
||||
__tablename__ = "index_attempt_errors"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
index_attempt_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("index_attempt.id"),
|
||||
nullable=True,
|
||||
nullable=False,
|
||||
)
|
||||
connector_credential_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# The index of the batch where the error occurred (if looping thru batches)
|
||||
# Just informational.
|
||||
batch: Mapped[int | None] = mapped_column(Integer, default=None)
|
||||
doc_summaries: Mapped[list[Any]] = mapped_column(postgresql.JSONB())
|
||||
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
traceback: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
document_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
document_link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
entity_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
failed_time_range_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
failed_time_range_end: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
failure_message: Mapped[str] = mapped_column(Text)
|
||||
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@ -913,21 +941,6 @@ class IndexAttemptError(Base):
|
||||
# This is the reverse side of the relationship
|
||||
index_attempt = relationship("IndexAttempt", back_populates="error_rows")
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"index_attempt_id",
|
||||
"time_created",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<IndexAttempt(id={self.id!r}, "
|
||||
f"index_attempt_id={self.index_attempt_id!r}, "
|
||||
f"error_msg={self.error_msg!r})>"
|
||||
f"time_created={self.time_created!r}, "
|
||||
)
|
||||
|
||||
|
||||
class SyncRecord(Base):
|
||||
"""
|
||||
|
@ -1,6 +1,10 @@
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
@ -217,3 +221,49 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
deployment_name=search_settings.deployment_name,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
def embed_chunks_with_failure_handling(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
"""Tries to embed all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
"""
|
||||
|
||||
# First try to embed all chunks in one batch
|
||||
try:
|
||||
return embedder.embed_chunks(chunks=chunks), []
|
||||
except Exception:
|
||||
logger.exception("Failed to embed chunk batch. Trying individual docs.")
|
||||
# wait a couple seconds to let any rate limits or temporary issues resolve
|
||||
time.sleep(2)
|
||||
|
||||
# Try embedding each document's chunks individually
|
||||
chunks_by_doc: dict[str, list[DocAwareChunk]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
chunks_by_doc[chunk.source_document.id].append(chunk)
|
||||
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
|
||||
for doc_id, chunks_for_doc in chunks_by_doc.items():
|
||||
try:
|
||||
doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc)
|
||||
embedded_chunks.extend(doc_embedded_chunks)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=(
|
||||
chunks_for_doc[0].get_link() if chunks_for_doc else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
|
||||
return embedded_chunks, failures
|
||||
|
@ -1,23 +1,21 @@
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import Protocol
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import get_access_for_documents
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import INDEXING_EXCEPTION_LIMIT
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.document import fetch_chunk_counts_for_documents
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
@ -29,7 +27,6 @@ from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.document import upsert_documents
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.tag import create_or_add_document_tag
|
||||
@ -41,10 +38,12 @@ from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
@ -67,6 +66,8 @@ class IndexingPipelineResult(BaseModel):
|
||||
# number of chunks that were inserted into Vespa
|
||||
total_chunks: int
|
||||
|
||||
failures: list[ConnectorFailure]
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
@ -156,14 +157,10 @@ def index_doc_batch_with_handler(
|
||||
document_index: DocumentIndex,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
attempt_id: int | None,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> IndexingPipelineResult:
|
||||
index_pipeline_result = IndexingPipelineResult(
|
||||
new_docs=0, total_docs=len(document_batch), total_chunks=0
|
||||
)
|
||||
try:
|
||||
index_pipeline_result = index_doc_batch(
|
||||
chunker=chunker,
|
||||
@ -176,47 +173,25 @@ def index_doc_batch_with_handler(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
|
||||
logger.error(
|
||||
"NOTE: HTTP Status 507 Insufficient Storage indicates "
|
||||
"you need to allocate more memory or disk space to the "
|
||||
"Vespa/index container."
|
||||
logger.exception(f"Failed to index document batch: {document_batch}")
|
||||
index_pipeline_result = IndexingPipelineResult(
|
||||
new_docs=0,
|
||||
total_docs=len(document_batch),
|
||||
total_chunks=0,
|
||||
failures=[
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=document.id,
|
||||
document_link=(
|
||||
document.sections[0].link if document.sections else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
|
||||
if INDEXING_EXCEPTION_LIMIT == 0:
|
||||
raise
|
||||
|
||||
trace = traceback.format_exc()
|
||||
create_index_attempt_error(
|
||||
attempt_id,
|
||||
batch=index_attempt_metadata.batch_num,
|
||||
docs=document_batch,
|
||||
exception_msg=str(e),
|
||||
exception_traceback=trace,
|
||||
db_session=db_session,
|
||||
for document in document_batch
|
||||
],
|
||||
)
|
||||
logger.exception(
|
||||
f"Indexing batch {index_attempt_metadata.batch_num} failed. msg='{e}' trace='{trace}'"
|
||||
)
|
||||
|
||||
index_attempt_metadata.num_exceptions += 1
|
||||
if index_attempt_metadata.num_exceptions == INDEXING_EXCEPTION_LIMIT:
|
||||
logger.warning(
|
||||
f"Maximum number of exceptions for this index attempt "
|
||||
f"({INDEXING_EXCEPTION_LIMIT}) has been reached. "
|
||||
f"The next exception will abort the indexing attempt."
|
||||
)
|
||||
elif index_attempt_metadata.num_exceptions > INDEXING_EXCEPTION_LIMIT:
|
||||
logger.warning(
|
||||
f"Maximum number of exceptions for this index attempt "
|
||||
f"({INDEXING_EXCEPTION_LIMIT}) has been exceeded."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Maximum exception limit of {INDEXING_EXCEPTION_LIMIT} exceeded."
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
return index_pipeline_result
|
||||
|
||||
@ -376,8 +351,12 @@ def index_doc_batch(
|
||||
document_ids=[doc.id for doc in filtered_documents],
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
return IndexingPipelineResult(
|
||||
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
|
||||
new_docs=0,
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
)
|
||||
|
||||
doc_descriptors = [
|
||||
@ -390,10 +369,19 @@ def index_doc_batch(
|
||||
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
# NOTE: no special handling for failures here, since the chunker is not
|
||||
# a common source of failure for the indexing pipeline
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
|
||||
chunks_with_embeddings, embedding_failures = (
|
||||
embed_chunks_with_failure_handling(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
)
|
||||
if chunks
|
||||
else ([], [])
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
|
||||
@ -459,7 +447,11 @@ def index_doc_batch(
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
insertion_records = document_index.index(
|
||||
(
|
||||
insertion_records,
|
||||
vector_db_write_failures,
|
||||
) = write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
chunks=access_aware_chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
|
||||
@ -519,6 +511,7 @@ def index_doc_batch(
|
||||
new_docs=len([r for r in insertion_records if r.already_existed is False]),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(access_aware_chunks),
|
||||
failures=vector_db_write_failures + embedding_failures,
|
||||
)
|
||||
|
||||
return result
|
||||
@ -531,7 +524,6 @@ def build_indexing_pipeline(
|
||||
db_session: Session,
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
attempt_id: int | None = None,
|
||||
tenant_id: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
@ -553,7 +545,6 @@ def build_indexing_pipeline(
|
||||
embedder=embedder,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
attempt_id=attempt_id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
@ -57,6 +57,13 @@ class DocAwareChunk(BaseChunk):
|
||||
"""Used when logging the identity of a chunk"""
|
||||
return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}"
|
||||
|
||||
def get_link(self) -> str | None:
|
||||
return (
|
||||
self.source_document.sections[0].link
|
||||
if self.source_document.sections
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
class IndexChunk(DocAwareChunk):
|
||||
embeddings: ChunkEmbedding
|
||||
|
99
backend/onyx/indexing/vector_db_insertion.py
Normal file
99
backend/onyx/indexing/vector_db_insertion.py
Normal file
@ -0,0 +1,99 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _log_insufficient_storage_error(e: Exception) -> None:
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
|
||||
logger.error(
|
||||
"NOTE: HTTP Status 507 Insufficient Storage indicates "
|
||||
"you need to allocate more memory or disk space to the "
|
||||
"Vespa/index container."
|
||||
)
|
||||
|
||||
|
||||
def write_chunks_to_vector_db_with_backoff(
|
||||
document_index: DocumentIndex,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
|
||||
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
|
||||
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
|
||||
vector DB interface assumes that all chunks for a single document are present.
|
||||
"""
|
||||
|
||||
# first try to write the chunks to the vector db
|
||||
try:
|
||||
return (
|
||||
list(
|
||||
document_index.index(
|
||||
chunks=chunks,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
),
|
||||
[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to write chunk batch to vector db. Trying individual docs."
|
||||
)
|
||||
|
||||
# give some specific logging on this common failure case.
|
||||
_log_insufficient_storage_error(e)
|
||||
|
||||
# wait a couple seconds just to give the vector db a chance to recover
|
||||
time.sleep(2)
|
||||
|
||||
# try writing each doc one by one
|
||||
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
chunks_for_docs[chunk.source_document.id].append(chunk)
|
||||
|
||||
insertion_records: list[DocumentInsertionRecord] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
for doc_id, chunks_for_doc in chunks_for_docs.items():
|
||||
try:
|
||||
insertion_records.extend(
|
||||
document_index.index(
|
||||
chunks=chunks_for_doc,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to write document chunks for '{doc_id}' to vector db"
|
||||
)
|
||||
|
||||
# give some specific logging on this common failure case.
|
||||
_log_insufficient_storage_error(e)
|
||||
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=(
|
||||
chunks_for_doc[0].get_link() if chunks_for_doc else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
|
||||
return insertion_records, failures
|
@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.indexing import router as indexing_router
|
||||
from onyx.server.documents.standard_oauth import router as oauth_router
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.folder.api import router as folder_router
|
||||
@ -317,7 +316,6 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, indexing_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, get_full_openai_assistants_api_router()
|
||||
)
|
||||
|
@ -22,6 +22,7 @@ from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
@ -39,7 +40,9 @@ from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import count_index_attempts_for_connector
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from onyx.db.models import SearchSettings
|
||||
@ -546,6 +549,47 @@ def get_docs_sync_status(
|
||||
return [DocumentSyncStatus.from_model(doc) for doc in all_docs_for_cc_pair]
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/errors")
|
||||
def get_cc_pair_indexing_errors(
|
||||
cc_pair_id: int,
|
||||
include_resolved: bool = Query(False),
|
||||
page: int = Query(0, ge=0),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[IndexAttemptErrorPydantic]:
|
||||
"""Gives back all errors for a given CC Pair. Allows pagination based on page and page_size params.
|
||||
|
||||
Args:
|
||||
cc_pair_id: ID of the connector-credential pair to get errors for
|
||||
include_resolved: Whether to include resolved errors in the results
|
||||
page: Page number for pagination, starting at 0
|
||||
page_size: Number of errors to return per page
|
||||
_: Current user, must be curator or admin
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Paginated list of indexing errors for the CC pair.
|
||||
"""
|
||||
total_count = count_index_attempt_errors_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
unresolved_only=not include_resolved,
|
||||
)
|
||||
|
||||
index_attempt_errors = get_index_attempt_errors_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
unresolved_only=not include_resolved,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return PaginatedReturn(
|
||||
items=[IndexAttemptErrorPydantic.from_model(e) for e in index_attempt_errors],
|
||||
total_items=total_count,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/connector/{connector_id}/credential/{credential_id}")
|
||||
def associate_credential_to_connector(
|
||||
connector_id: int,
|
||||
|
@ -1,23 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.index_attempt import (
|
||||
get_index_attempt_errors,
|
||||
)
|
||||
from onyx.db.models import User
|
||||
from onyx.server.documents.models import IndexAttemptError
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@router.get("/admin/indexing-errors/{index_attempt_id}")
|
||||
def get_indexing_errors(
|
||||
index_attempt_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[IndexAttemptError]:
|
||||
indexing_errors = get_index_attempt_errors(index_attempt_id, db_session)
|
||||
return [IndexAttemptError.from_db_model(e) for e in indexing_errors]
|
@ -8,9 +8,9 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import DocumentErrorSummary
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@ -19,7 +19,6 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError as DbIndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import TaskStatus
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
@ -150,6 +149,7 @@ class CredentialSnapshot(CredentialBase):
|
||||
class IndexAttemptSnapshot(BaseModel):
|
||||
id: int
|
||||
status: IndexingStatus | None
|
||||
from_beginning: bool
|
||||
new_docs_indexed: int # only includes completely new docs
|
||||
total_docs_indexed: int # includes docs that are updated
|
||||
docs_removed_from_index: int
|
||||
@ -166,6 +166,7 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
return IndexAttemptSnapshot(
|
||||
id=index_attempt.id,
|
||||
status=index_attempt.status,
|
||||
from_beginning=index_attempt.from_beginning,
|
||||
new_docs_indexed=index_attempt.new_docs_indexed or 0,
|
||||
total_docs_indexed=index_attempt.total_docs_indexed or 0,
|
||||
docs_removed_from_index=index_attempt.docs_removed_from_index or 0,
|
||||
@ -181,31 +182,6 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class IndexAttemptError(BaseModel):
|
||||
id: int
|
||||
index_attempt_id: int | None
|
||||
batch_number: int | None
|
||||
doc_summaries: list[DocumentErrorSummary]
|
||||
error_msg: str | None
|
||||
traceback: str | None
|
||||
time_created: str
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, error: DbIndexAttemptError) -> "IndexAttemptError":
|
||||
doc_summaries = [
|
||||
DocumentErrorSummary.from_dict(summary) for summary in error.doc_summaries
|
||||
]
|
||||
return IndexAttemptError(
|
||||
id=error.id,
|
||||
index_attempt_id=error.index_attempt_id,
|
||||
batch_number=error.batch,
|
||||
doc_summaries=doc_summaries,
|
||||
error_msg=error.error_msg,
|
||||
traceback=error.traceback,
|
||||
time_created=error.time_created.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
# These are the types currently supported by the pagination hook
|
||||
# More api endpoints can be refactored and be added here for use with the pagination hook
|
||||
PaginatedType = TypeVar(
|
||||
@ -214,6 +190,7 @@ PaginatedType = TypeVar(
|
||||
FullUserSnapshot,
|
||||
InvitedUserSnapshot,
|
||||
ChatSessionMinimal,
|
||||
IndexAttemptErrorPydantic,
|
||||
)
|
||||
|
||||
|
||||
|
26
backend/onyx/utils/object_size_check.py
Normal file
26
backend/onyx/utils/object_size_check.py
Normal file
@ -0,0 +1,26 @@
|
||||
import sys
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T", dict, list, tuple, set, frozenset)
|
||||
|
||||
|
||||
def deep_getsizeof(obj: T, seen: set[int] | None = None) -> int:
|
||||
"""Recursively sum size of objects, handling circular references."""
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in seen:
|
||||
return 0 # Prevent infinite recursion for circular references
|
||||
|
||||
seen.add(obj_id)
|
||||
size = sys.getsizeof(obj)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
size += sum(
|
||||
deep_getsizeof(k, seen) + deep_getsizeof(v, seen) for k, v in obj.items()
|
||||
)
|
||||
elif isinstance(obj, (list, tuple, set, frozenset)):
|
||||
size += sum(deep_getsizeof(i, seen) for i in obj)
|
||||
|
||||
return size
|
@ -42,7 +42,7 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
|
||||
]
|
||||
|
||||
cmd_worker_heavy = [
|
||||
|
@ -33,7 +33,7 @@ stopasgroup=true
|
||||
command=celery -A onyx.background.celery.versioned_apps.light worker
|
||||
--loglevel=INFO
|
||||
--hostname=light@%%n
|
||||
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert
|
||||
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup
|
||||
stdout_logfile=/var/log/celery_worker_light.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
|
@ -1,14 +1,15 @@
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def test_create_chat_session_and_send_messages(db_session: Session) -> None:
|
||||
def test_create_chat_session_and_send_messages() -> None:
|
||||
# Create a test user
|
||||
test_user = User(email="test@example.com", hashed_password="dummy_hash")
|
||||
db_session.add(test_user)
|
||||
db_session.commit()
|
||||
with get_session_context_manager() as db_session:
|
||||
test_user = User(email="test@example.com", hashed_password="dummy_hash")
|
||||
db_session.add(test_user)
|
||||
db_session.commit()
|
||||
|
||||
base_url = "http://localhost:8080" # Adjust this to your API's base URL
|
||||
headers = {"Authorization": f"Bearer {test_user.id}"}
|
||||
|
@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
|
||||
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
|
||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
||||
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
||||
@ -9,3 +11,6 @@ MAX_DELAY = 45
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
NUM_DOCS = 5
|
||||
|
||||
MOCK_CONNECTOR_SERVER_HOST = os.getenv("MOCK_CONNECTOR_SERVER_HOST") or "localhost"
|
||||
MOCK_CONNECTOR_SERVER_PORT = os.getenv("MOCK_CONNECTOR_SERVER_PORT") or 8001
|
||||
|
@ -223,12 +223,13 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def run_once(
|
||||
cc_pair: DATestCCPair,
|
||||
from_beginning: bool,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
body = {
|
||||
"connector_id": cc_pair.connector_id,
|
||||
"credential_ids": [cc_pair.credential_id],
|
||||
"from_beginning": True,
|
||||
"from_beginning": from_beginning,
|
||||
}
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
|
||||
|
@ -1,9 +1,14 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
@ -186,3 +191,39 @@ class DocumentManager:
|
||||
group_names,
|
||||
doc_creating_user,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fetch_documents_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
vespa_client: vespa_fixture,
|
||||
) -> list[SimpleTestDocument]:
|
||||
stmt = (
|
||||
select(DocumentByConnectorCredentialPair)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
)
|
||||
documents = db_session.execute(stmt).scalars().all()
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
doc_ids = [document.id for document in documents]
|
||||
retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"]
|
||||
|
||||
final_docs: list[SimpleTestDocument] = []
|
||||
# NOTE: they are really chunks, but we're assuming that for these tests
|
||||
# we only have one chunk per document for now
|
||||
for doc_dict in retrieved_docs_dict:
|
||||
doc_id = doc_dict["fields"]["document_id"]
|
||||
doc_content = doc_dict["fields"]["content"]
|
||||
final_docs.append(SimpleTestDocument(id=doc_id, content=doc_content))
|
||||
|
||||
return final_docs
|
||||
|
@ -4,6 +4,7 @@ from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.models import IndexAttempt
|
||||
@ -13,6 +14,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestIndexAttempt
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@ -92,8 +94,12 @@ class IndexAttemptManager:
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
url = (
|
||||
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts"
|
||||
f"?{urlencode(query_params, doseq=True)}"
|
||||
)
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts?{urlencode(query_params, doseq=True)}",
|
||||
url=url,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
@ -104,3 +110,125 @@ class IndexAttemptManager:
|
||||
items=[IndexAttemptSnapshot(**item) for item in data["items"]],
|
||||
total_items=data["total_items"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_latest_index_attempt_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot | None:
|
||||
"""Get an IndexAttempt by ID"""
|
||||
index_attempts = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id, user_performing_action=user_performing_action
|
||||
).items
|
||||
if not index_attempts:
|
||||
return None
|
||||
|
||||
index_attempts = sorted(
|
||||
index_attempts, key=lambda x: x.time_started or "0", reverse=True
|
||||
)
|
||||
return index_attempts[0]
|
||||
|
||||
@staticmethod
|
||||
def wait_for_index_attempt_start(
|
||||
cc_pair_id: int,
|
||||
index_attempts_to_ignore: list[int] | None = None,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot:
|
||||
"""Wait for an IndexAttempt to start"""
|
||||
start = datetime.now()
|
||||
index_attempts_to_ignore = index_attempts_to_ignore or []
|
||||
|
||||
while True:
|
||||
index_attempt = IndexAttemptManager.get_latest_index_attempt_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
if (
|
||||
index_attempt
|
||||
and index_attempt.time_started
|
||||
and index_attempt.id not in index_attempts_to_ignore
|
||||
):
|
||||
return index_attempt
|
||||
|
||||
elapsed = (datetime.now() - start).total_seconds()
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"IndexAttempt for CC Pair {cc_pair_id} did not start within {timeout} seconds"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_index_attempt_by_id(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot:
|
||||
page_num = 0
|
||||
page_size = 10
|
||||
while True:
|
||||
page = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair_id,
|
||||
page=page_num,
|
||||
page_size=page_size,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
for attempt in page.items:
|
||||
if attempt.id == index_attempt_id:
|
||||
return attempt
|
||||
|
||||
if len(page.items) < page_size:
|
||||
break
|
||||
|
||||
page_num += 1
|
||||
|
||||
raise ValueError(f"IndexAttempt {index_attempt_id} not found")
|
||||
|
||||
@staticmethod
|
||||
def wait_for_index_attempt_completion(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""Wait for an IndexAttempt to complete"""
|
||||
start = datetime.now()
|
||||
while True:
|
||||
index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
if index_attempt.status and index_attempt.status.is_terminal():
|
||||
print(f"IndexAttempt {index_attempt_id} completed")
|
||||
return
|
||||
|
||||
elapsed = (datetime.now() - start).total_seconds()
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for IndexAttempt {index_attempt_id} to complete. "
|
||||
f"elapsed={elapsed:.2f} timeout={timeout}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
include_resolved: bool = True,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[IndexAttemptErrorPydantic]:
|
||||
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
|
||||
if include_resolved:
|
||||
url += "&include_resolved=true"
|
||||
response = requests.get(
|
||||
url=url,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [IndexAttemptErrorPydantic(**item) for item in data["items"]]
|
||||
|
@ -25,6 +25,7 @@ from onyx.indexing.models import IndexingSetting
|
||||
from onyx.setup import setup_postgres
|
||||
from onyx.setup import setup_vespa
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.integration.common_utils.timeout import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -66,6 +67,7 @@ def _run_migrations(
|
||||
|
||||
def downgrade_postgres(
|
||||
database: str = "postgres",
|
||||
schema: str = "public",
|
||||
config_name: str = "alembic",
|
||||
revision: str = "base",
|
||||
clear_data: bool = False,
|
||||
@ -73,8 +75,8 @@ def downgrade_postgres(
|
||||
"""Downgrade Postgres database to base state."""
|
||||
if clear_data:
|
||||
if revision != "base":
|
||||
logger.warning("Clearing data without rolling back to base state")
|
||||
# Delete all rows to allow migrations to be rolled back
|
||||
raise ValueError("Clearing data without rolling back to base state")
|
||||
|
||||
conn = psycopg2.connect(
|
||||
dbname=database,
|
||||
user=POSTGRES_USER,
|
||||
@ -82,38 +84,33 @@ def downgrade_postgres(
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
)
|
||||
conn.autocommit = True # Need autocommit for dropping schema
|
||||
cur = conn.cursor()
|
||||
|
||||
# Disable triggers to prevent foreign key constraints from being checked
|
||||
cur.execute("SET session_replication_role = 'replica';")
|
||||
|
||||
# Fetch all table names in the current database
|
||||
# Close any existing connections to the schema before dropping
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT tablename
|
||||
FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
f"""
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = '{database}'
|
||||
AND pg_stat_activity.state = 'idle in transaction'
|
||||
AND pid <> pg_backend_pid();
|
||||
"""
|
||||
)
|
||||
|
||||
tables = cur.fetchall()
|
||||
# Drop and recreate the public schema - this removes ALL objects
|
||||
cur.execute(f"DROP SCHEMA {schema} CASCADE;")
|
||||
cur.execute(f"CREATE SCHEMA {schema};")
|
||||
|
||||
for table in tables:
|
||||
table_name = table[0]
|
||||
# Restore default privileges
|
||||
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO postgres;")
|
||||
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO public;")
|
||||
|
||||
# Don't touch migration history or Kombu
|
||||
if table_name in ("alembic_version", "kombu_message", "kombu_queue"):
|
||||
continue
|
||||
|
||||
cur.execute(f'DELETE FROM "{table_name}"')
|
||||
|
||||
# Re-enable triggers
|
||||
cur.execute("SET session_replication_role = 'origin';")
|
||||
|
||||
conn.commit()
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
return
|
||||
|
||||
# Downgrade to base
|
||||
conn_str = build_connection_string(
|
||||
db=database,
|
||||
@ -157,11 +154,37 @@ def reset_postgres(
|
||||
setup_onyx: bool = True,
|
||||
) -> None:
|
||||
"""Reset the Postgres database."""
|
||||
downgrade_postgres(
|
||||
database=database, config_name=config_name, revision="base", clear_data=True
|
||||
)
|
||||
# this seems to hang due to locking issues, so run with a timeout with a few retries
|
||||
NUM_TRIES = 10
|
||||
TIMEOUT = 10
|
||||
success = False
|
||||
for _ in range(NUM_TRIES):
|
||||
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
|
||||
try:
|
||||
run_with_timeout(
|
||||
downgrade_postgres,
|
||||
TIMEOUT,
|
||||
kwargs={
|
||||
"database": database,
|
||||
"config_name": config_name,
|
||||
"revision": "base",
|
||||
"clear_data": True,
|
||||
},
|
||||
)
|
||||
success = True
|
||||
break
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"Postgres downgrade timed out, retrying... ({_ + 1}/{NUM_TRIES})"
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError("Postgres downgrade failed after 10 timeouts.")
|
||||
|
||||
logger.info("Upgrading Postgres...")
|
||||
upgrade_postgres(database=database, config_name=config_name, revision="head")
|
||||
if setup_onyx:
|
||||
logger.info("Setting up Postgres...")
|
||||
with get_session_context_manager() as db_session:
|
||||
setup_postgres(db_session)
|
||||
|
||||
|
@ -0,0 +1,57 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import Section
|
||||
|
||||
|
||||
def create_test_document(
|
||||
doc_id: str | None = None,
|
||||
text: str = "Test content",
|
||||
link: str = "http://test.com",
|
||||
source: DocumentSource = DocumentSource.MOCK_CONNECTOR,
|
||||
metadata: dict | None = None,
|
||||
) -> Document:
|
||||
"""Create a test document with the given parameters.
|
||||
|
||||
Args:
|
||||
doc_id: Optional document ID. If not provided, a random UUID will be generated.
|
||||
text: The text content of the document. Defaults to "Test content".
|
||||
link: The link for the document section. Defaults to "http://test.com".
|
||||
source: The document source. Defaults to MOCK_CONNECTOR.
|
||||
metadata: Optional metadata dictionary. Defaults to empty dict.
|
||||
"""
|
||||
doc_id = doc_id or f"test-doc-{uuid.uuid4()}"
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=[Section(text=text, link=link)],
|
||||
source=source,
|
||||
semantic_identifier=doc_id,
|
||||
doc_updated_at=datetime.now(timezone.utc),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
|
||||
def create_test_document_failure(
|
||||
doc_id: str,
|
||||
failure_message: str = "Simulated failure",
|
||||
document_link: str | None = None,
|
||||
) -> ConnectorFailure:
|
||||
"""Create a test document failure with the given parameters.
|
||||
|
||||
Args:
|
||||
doc_id: The ID of the document that failed.
|
||||
failure_message: The failure message. Defaults to "Simulated failure".
|
||||
document_link: Optional link to the failed document.
|
||||
"""
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=document_link,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
)
|
18
backend/tests/integration/common_utils/timeout.py
Normal file
18
backend/tests/integration/common_utils/timeout.py
Normal file
@ -0,0 +1,18 @@
|
||||
import multiprocessing
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def run_with_timeout(task: Callable[..., T], timeout: int, kwargs: dict[str, Any]) -> T:
|
||||
# Use multiprocessing to prevent a thread from blocking the main thread
|
||||
with multiprocessing.Pool(processes=1) as pool:
|
||||
async_result = pool.apply_async(task, kwds=kwargs)
|
||||
try:
|
||||
# Wait at most timeout seconds for the function to complete
|
||||
result = async_result.get(timeout=timeout)
|
||||
return result
|
||||
except multiprocessing.TimeoutError:
|
||||
raise TimeoutError(f"Function timed out after {timeout} seconds")
|
@ -1,12 +1,11 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
@ -36,16 +35,24 @@ def load_env_vars(env_file: str = ".env") -> None:
|
||||
load_env_vars()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session() -> Generator[Session, None, None]:
|
||||
with get_session_context_manager() as session:
|
||||
yield session
|
||||
"""NOTE: for some reason using this seems to lead to misc
|
||||
`sqlalchemy.exc.OperationalError: (psycopg2.OperationalError) server closed the connection unexpectedly`
|
||||
errors.
|
||||
|
||||
Commenting out till we can get to the bottom of it. For now, just using
|
||||
instantiate the session directly within the test.
|
||||
"""
|
||||
# @pytest.fixture
|
||||
# def db_session() -> Generator[Session, None, None]:
|
||||
# with get_session_context_manager() as session:
|
||||
# yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vespa_client(db_session: Session) -> vespa_fixture:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return vespa_fixture(index_name=search_settings.index_name)
|
||||
def vespa_client() -> vespa_fixture:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return vespa_fixture(index_name=search_settings.index_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -56,20 +63,27 @@ def reset() -> None:
|
||||
@pytest.fixture
|
||||
def new_admin_user(reset: None) -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create(name="admin_user")
|
||||
return UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
def admin_user() -> DATestUser:
|
||||
try:
|
||||
return UserManager.create(name="admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
user = UserManager.create(name=ADMIN_USER_NAME, is_first_user=True)
|
||||
|
||||
# if there are other users for some reason, reset and try again
|
||||
if not UserManager.is_role(user, UserRole.ADMIN):
|
||||
print("Trying to reset")
|
||||
reset_all()
|
||||
user = UserManager.create(name=ADMIN_USER_NAME)
|
||||
return user
|
||||
except Exception as e:
|
||||
print(f"Failed to create admin user: {e}")
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
user = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
@ -79,10 +93,16 @@ def admin_user() -> DATestUser | None:
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if not UserManager.is_role(user, UserRole.ADMIN):
|
||||
reset_all()
|
||||
user = UserManager.create(name=ADMIN_USER_NAME)
|
||||
return user
|
||||
|
||||
return None
|
||||
return user
|
||||
except Exception as e:
|
||||
print(f"Failed to create or login as admin user: {e}")
|
||||
|
||||
raise RuntimeError("Failed to create or login as admin user")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -138,7 +138,9 @@ def test_google_permission_sync(
|
||||
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_1, doc_text_1)
|
||||
|
||||
# run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair, after=before, user_performing_action=admin_user
|
||||
)
|
||||
@ -184,7 +186,9 @@ def test_google_permission_sync(
|
||||
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_2, doc_text_2)
|
||||
|
||||
# Run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
|
@ -113,7 +113,9 @@ def test_slack_permission_sync(
|
||||
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
@ -305,7 +307,9 @@ def test_slack_group_permission_sync(
|
||||
)
|
||||
|
||||
# Run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
|
@ -111,7 +111,9 @@ def test_slack_prune(
|
||||
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
|
@ -0,0 +1,20 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
mock_connector_server:
|
||||
build:
|
||||
context: ./mock_connector_server
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8001:8001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8001/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
networks:
|
||||
- onyx-stack_default
|
||||
networks:
|
||||
onyx-stack_default:
|
||||
name: onyx-stack_default
|
||||
external: true
|
@ -0,0 +1,9 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN pip install fastapi uvicorn
|
||||
|
||||
COPY ./main.py /app/main.py
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"]
|
@ -0,0 +1,76 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
# We would like to import these, but it makes building this so much harder/slower
|
||||
# from onyx.connectors.mock_connector.connector import SingleConnectorYield
|
||||
# from onyx.connectors.models import ConnectorCheckpoint
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Global state to store connector behavior configuration
|
||||
class ConnectorBehavior(BaseModel):
|
||||
connector_yields: list[dict] = Field(
|
||||
default_factory=list
|
||||
) # really list[SingleConnectorYield]
|
||||
called_with_checkpoints: list[dict] = Field(
|
||||
default_factory=list
|
||||
) # really list[ConnectorCheckpoint]
|
||||
|
||||
|
||||
current_behavior: ConnectorBehavior = ConnectorBehavior()
|
||||
|
||||
|
||||
@app.post("/set-behavior")
|
||||
async def set_behavior(behavior: list[dict]) -> None:
|
||||
"""Set the behavior for the next connector run"""
|
||||
global current_behavior
|
||||
current_behavior = ConnectorBehavior(connector_yields=behavior)
|
||||
|
||||
|
||||
@app.get("/get-documents")
|
||||
async def get_documents() -> list[dict]:
|
||||
"""Get the next batch of documents and update the checkpoint"""
|
||||
global current_behavior
|
||||
|
||||
if not current_behavior.connector_yields:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No documents or failures configured"
|
||||
)
|
||||
|
||||
connector_yields = current_behavior.connector_yields
|
||||
|
||||
# Clear the current behavior after returning it
|
||||
current_behavior = ConnectorBehavior()
|
||||
|
||||
return connector_yields
|
||||
|
||||
|
||||
@app.post("/add-checkpoint")
|
||||
async def add_checkpoint(checkpoint: dict) -> None:
|
||||
"""Add a checkpoint to the list of checkpoints. Called by the MockConnector."""
|
||||
global current_behavior
|
||||
current_behavior.called_with_checkpoints.append(checkpoint)
|
||||
|
||||
|
||||
@app.get("/get-checkpoints")
|
||||
async def get_checkpoints() -> list[dict]:
|
||||
"""Get the list of checkpoints. Used by the test to verify the
|
||||
proper checkpoint ordering."""
|
||||
global current_behavior
|
||||
return current_behavior.called_with_checkpoints
|
||||
|
||||
|
||||
@app.post("/reset")
|
||||
async def reset() -> None:
|
||||
"""Reset the connector behavior to default"""
|
||||
global current_behavior
|
||||
current_behavior = ConnectorBehavior()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy"}
|
@ -9,6 +9,8 @@ from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import create_index_attempt
|
||||
@ -101,10 +103,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
|
||||
create_index_attempt_error(
|
||||
index_attempt_id=new_attempt.id,
|
||||
batch=1,
|
||||
docs=[],
|
||||
exception_msg="",
|
||||
exception_traceback="",
|
||||
connector_credential_pair_id=cc_pair_1.id,
|
||||
failure=ConnectorFailure(
|
||||
failure_message="Test error",
|
||||
failed_document=DocumentFailure(
|
||||
document_id=cc_pair_1.documents[0].id,
|
||||
document_link=None,
|
||||
),
|
||||
failed_entity=None,
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@ -127,10 +134,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
)
|
||||
create_index_attempt_error(
|
||||
index_attempt_id=attempt_id,
|
||||
batch=1,
|
||||
docs=[],
|
||||
exception_msg="",
|
||||
exception_traceback="",
|
||||
connector_credential_pair_id=cc_pair_1.id,
|
||||
failure=ConnectorFailure(
|
||||
failure_message="Test error",
|
||||
failed_document=DocumentFailure(
|
||||
document_id=cc_pair_1.documents[0].id,
|
||||
document_link=None,
|
||||
),
|
||||
failed_entity=None,
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
518
backend/tests/integration/tests/indexing/test_checkpointing.py
Normal file
518
backend/tests/integration/tests/indexing/test_checkpointing.py
Normal file
@ -0,0 +1,518 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.test_document_utils import create_test_document
|
||||
from tests.integration.common_utils.test_document_utils import (
|
||||
create_test_document_failure,
|
||||
)
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_client() -> httpx.Client:
|
||||
print(
|
||||
f"Initializing mock server client with host: "
|
||||
f"{MOCK_CONNECTOR_SERVER_HOST} and port: "
|
||||
f"{MOCK_CONNECTOR_SERVER_PORT}"
|
||||
)
|
||||
return httpx.Client(
|
||||
base_url=f"http://{MOCK_CONNECTOR_SERVER_HOST}:{MOCK_CONNECTOR_SERVER_PORT}",
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
|
||||
def test_mock_connector_basic_flow(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that the mock connector can successfully process documents and failures"""
|
||||
# Set up mock server behavior
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# create CC Pair + index attempt
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# wait for index attempt to start
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# wait for index attempt to finish
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify results
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == test_doc.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
def test_mock_connector_with_failures(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that the mock connector processes both successes and failures properly."""
|
||||
doc1 = create_test_document()
|
||||
doc2 = create_test_document()
|
||||
doc2_failure = create_test_document_failure(doc_id=doc2.id)
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc1.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [doc2_failure.model_dump(mode="json")],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create a CC Pair for the mock connector
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-failure-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for the index attempt to start and then complete
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
|
||||
# Verify results: doc1 should be indexed and doc2 should have an error entry
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == doc1.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 1
|
||||
error = errors[0]
|
||||
assert error.failure_message == doc2_failure.failure_message
|
||||
assert error.document_id == doc2.id
|
||||
|
||||
|
||||
def test_mock_connector_failure_recovery(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that a failed document can be successfully indexed in a subsequent attempt
|
||||
while maintaining previously successful documents."""
|
||||
# Create test documents and failure
|
||||
doc1 = create_test_document()
|
||||
doc2 = create_test_document()
|
||||
doc2_failure = create_test_document_failure(doc_id=doc2.id)
|
||||
entity_id = "test-entity-id"
|
||||
entity_failure_msg = "Simulated unhandled error"
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc1.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [
|
||||
doc2_failure.model_dump(mode="json"),
|
||||
ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=entity_id,
|
||||
missed_time_range=(
|
||||
datetime.now(timezone.utc) - timedelta(days=1),
|
||||
datetime.now(timezone.utc),
|
||||
),
|
||||
),
|
||||
failure_message=entity_failure_msg,
|
||||
).model_dump(mode="json"),
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create CC Pair and run initial indexing attempt
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for first index attempt to complete
|
||||
initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
|
||||
# Verify initial state: doc1 indexed, doc2 failed
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == doc1.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 2
|
||||
error_doc2 = next(error for error in errors if error.document_id == doc2.id)
|
||||
assert error_doc2.failure_message == doc2_failure.failure_message
|
||||
assert not error_doc2.is_resolved
|
||||
|
||||
error_entity = next(error for error in errors if error.entity_id == entity_id)
|
||||
assert error_entity.failure_message == entity_failure_msg
|
||||
assert not error_entity.is_resolved
|
||||
|
||||
# Update mock server to return success for both documents
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [
|
||||
doc1.model_dump(mode="json"),
|
||||
doc2.model_dump(mode="json"),
|
||||
],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Trigger another indexing attempt
|
||||
# NOTE: must be from beginning to handle the entity failure
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
finished_second_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_second_index_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify both documents are now indexed
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 2
|
||||
document_ids = {doc.id for doc in documents}
|
||||
assert doc2.id in document_ids
|
||||
assert doc1.id in document_ids
|
||||
|
||||
# Verify original failures were marked as resolved
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 2
|
||||
error_doc2 = next(error for error in errors if error.document_id == doc2.id)
|
||||
error_entity = next(error for error in errors if error.entity_id == entity_id)
|
||||
|
||||
assert error_doc2.is_resolved
|
||||
assert error_entity.is_resolved
|
||||
|
||||
|
||||
def test_mock_connector_checkpoint_recovery(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that checkpointing works correctly when an unhandled exception occurs
|
||||
and that subsequent runs pick up from the last successful checkpoint."""
|
||||
# Create test documents
|
||||
# Create 100 docs for first batch, this is needed to get past the
|
||||
# `_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT` logic in `get_latest_valid_checkpoint`.
|
||||
docs_batch_1 = [create_test_document() for _ in range(100)]
|
||||
doc2 = create_test_document()
|
||||
doc3 = create_test_document()
|
||||
|
||||
# Set up mock server behavior for initial run:
|
||||
# - First yield: 100 docs with checkpoint1
|
||||
# - Second yield: doc2 with checkpoint2
|
||||
# - Third yield: unhandled exception
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc.model_dump(mode="json") for doc in docs_batch_1],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=True
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
},
|
||||
{
|
||||
"documents": [doc2.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=True
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
},
|
||||
{
|
||||
"documents": [],
|
||||
# should never hit this, unhandled exception happens first
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
"unhandled_exception": "Simulated unhandled error",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create CC Pair and run initial indexing attempt
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-checkpoint-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for first index attempt to complete
|
||||
initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.FAILED
|
||||
|
||||
# Verify initial state: both docs should be indexed
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 101 # 100 docs from first batch + doc2
|
||||
document_ids = {doc.id for doc in documents}
|
||||
assert doc2.id in document_ids
|
||||
assert all(doc.id in document_ids for doc in docs_batch_1)
|
||||
|
||||
# Get the checkpoints that were sent to the mock server
|
||||
response = mock_server_client.get("/get-checkpoints")
|
||||
assert response.status_code == 200
|
||||
initial_checkpoints = response.json()
|
||||
|
||||
# Verify we got the expected checkpoints in order
|
||||
assert len(initial_checkpoints) > 0
|
||||
assert (
|
||||
initial_checkpoints[0]["checkpoint_content"] == {}
|
||||
) # Initial empty checkpoint
|
||||
assert initial_checkpoints[1]["checkpoint_content"] == {}
|
||||
assert initial_checkpoints[2]["checkpoint_content"] == {}
|
||||
|
||||
# Reset the mock server for the next run
|
||||
response = mock_server_client.post("/reset")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Set up mock server behavior for recovery run - should succeed fully this time
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc3.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Trigger another indexing attempt
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=False, user_performing_action=admin_user
|
||||
)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_recovery_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_recovery_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify results
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 102 # 100 docs from first batch + doc2 + doc3
|
||||
document_ids = {doc.id for doc in documents}
|
||||
assert doc3.id in document_ids
|
||||
assert doc2.id in document_ids
|
||||
assert all(doc.id in document_ids for doc in docs_batch_1)
|
||||
|
||||
# Get the checkpoints from the recovery run
|
||||
response = mock_server_client.get("/get-checkpoints")
|
||||
assert response.status_code == 200
|
||||
recovery_checkpoints = response.json()
|
||||
|
||||
# Verify the recovery run started from the last successful checkpoint
|
||||
assert len(recovery_checkpoints) == 1
|
||||
assert recovery_checkpoints[0]["checkpoint_content"] == {}
|
@ -61,6 +61,7 @@ services:
|
||||
# Other services
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
|
||||
@ -174,6 +175,7 @@ services:
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors
|
||||
|
@ -0,0 +1,141 @@
|
||||
import { Modal } from "@/components/Modal";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { IndexAttemptError } from "./types";
|
||||
import { localizeAndPrettify } from "@/lib/time";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useState } from "react";
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
|
||||
interface IndexAttemptErrorsModalProps {
|
||||
errors: {
|
||||
items: IndexAttemptError[];
|
||||
total_items: number;
|
||||
};
|
||||
onClose: () => void;
|
||||
onResolveAll: () => void;
|
||||
isResolvingErrors?: boolean;
|
||||
onPageChange: (page: number) => void;
|
||||
currentPage: number;
|
||||
pageSize?: number;
|
||||
}
|
||||
|
||||
const DEFAULT_PAGE_SIZE = 10;
|
||||
|
||||
export default function IndexAttemptErrorsModal({
|
||||
errors,
|
||||
onClose,
|
||||
onResolveAll,
|
||||
isResolvingErrors = false,
|
||||
onPageChange,
|
||||
currentPage,
|
||||
pageSize = DEFAULT_PAGE_SIZE,
|
||||
}: IndexAttemptErrorsModalProps) {
|
||||
const totalPages = Math.ceil(errors.total_items / pageSize);
|
||||
const hasUnresolvedErrors = errors.items.some((error) => !error.is_resolved);
|
||||
|
||||
return (
|
||||
<Modal title="Indexing Errors" onOutsideClick={onClose} width="max-w-6xl">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-2">
|
||||
{isResolvingErrors ? (
|
||||
<div className="text-sm text-text-default">
|
||||
Currently attempting to resolve all errors by performing a full
|
||||
re-index. This may take some time to complete.
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="text-sm text-text-default">
|
||||
Below are the errors encountered during indexing. Each row
|
||||
represents a failed document or entity.
|
||||
</div>
|
||||
<div className="text-sm text-text-default">
|
||||
Click the button below to kick off a full re-index to try and
|
||||
resolve these errors. This full re-index may take much longer
|
||||
than a normal update.
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Time</TableHead>
|
||||
<TableHead>Document ID</TableHead>
|
||||
<TableHead className="w-1/2">Error Message</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{errors.items.map((error) => (
|
||||
<TableRow key={error.id}>
|
||||
<TableCell>{localizeAndPrettify(error.time_created)}</TableCell>
|
||||
<TableCell>
|
||||
{error.document_link ? (
|
||||
<a
|
||||
href={error.document_link}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-link hover:underline"
|
||||
>
|
||||
{error.document_id || error.entity_id || "Unknown"}
|
||||
</a>
|
||||
) : (
|
||||
error.document_id || error.entity_id || "Unknown"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="whitespace-normal">
|
||||
{error.failure_message}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span
|
||||
className={`px-2 py-1 rounded text-xs ${
|
||||
error.is_resolved
|
||||
? "bg-green-100 text-green-800"
|
||||
: "bg-red-100 text-red-800"
|
||||
}`}
|
||||
>
|
||||
{error.is_resolved ? "Resolved" : "Unresolved"}
|
||||
</span>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<div className="mt-4">
|
||||
{totalPages > 1 && (
|
||||
<div className="flex-1 flex justify-center mb-2">
|
||||
<PageSelector
|
||||
totalPages={totalPages}
|
||||
currentPage={currentPage + 1}
|
||||
onPageChange={(page) => onPageChange(page - 1)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex w-full">
|
||||
<div className="flex gap-2 ml-auto">
|
||||
{hasUnresolvedErrors && !isResolvingErrors && (
|
||||
<Button
|
||||
onClick={onResolveAll}
|
||||
variant="default"
|
||||
className="ml-4 whitespace-nowrap"
|
||||
>
|
||||
Resolve All
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
@ -34,38 +34,26 @@ import usePaginatedFetch from "@/hooks/usePaginatedFetch";
|
||||
const ITEMS_PER_PAGE = 8;
|
||||
const PAGES_PER_BATCH = 8;
|
||||
|
||||
export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
|
||||
export interface IndexingAttemptsTableProps {
|
||||
ccPair: CCPairFullInfo;
|
||||
indexAttempts: IndexAttemptSnapshot[];
|
||||
currentPage: number;
|
||||
totalPages: number;
|
||||
onPageChange: (page: number) => void;
|
||||
}
|
||||
|
||||
export function IndexingAttemptsTable({
|
||||
ccPair,
|
||||
indexAttempts,
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
}: IndexingAttemptsTableProps) {
|
||||
const [indexAttemptTracePopupId, setIndexAttemptTracePopupId] = useState<
|
||||
number | null
|
||||
>(null);
|
||||
|
||||
const {
|
||||
currentPageData: pageOfIndexAttempts,
|
||||
isLoading,
|
||||
error,
|
||||
currentPage,
|
||||
totalPages,
|
||||
goToPage,
|
||||
} = usePaginatedFetch<IndexAttemptSnapshot>({
|
||||
itemsPerPage: ITEMS_PER_PAGE,
|
||||
pagesPerBatch: PAGES_PER_BATCH,
|
||||
endpoint: `${buildCCPairInfoUrl(ccPair.id)}/index-attempts`,
|
||||
});
|
||||
|
||||
if (isLoading || !pageOfIndexAttempts) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle={`Failed to fetch info on Connector with ID ${ccPair.id}`}
|
||||
errorMsg={error?.toString() || "Unknown error"}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (!pageOfIndexAttempts?.length) {
|
||||
if (!indexAttempts?.length) {
|
||||
return (
|
||||
<Callout
|
||||
className="mt-4"
|
||||
@ -78,7 +66,7 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
|
||||
);
|
||||
}
|
||||
|
||||
const indexAttemptToDisplayTraceFor = pageOfIndexAttempts?.find(
|
||||
const indexAttemptToDisplayTraceFor = indexAttempts?.find(
|
||||
(indexAttempt) => indexAttempt.id === indexAttemptTracePopupId
|
||||
);
|
||||
|
||||
@ -119,7 +107,7 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{pageOfIndexAttempts.map((indexAttempt) => {
|
||||
{indexAttempts.map((indexAttempt) => {
|
||||
const docsPerMinute =
|
||||
getDocsProcessedPerMinute(indexAttempt)?.toFixed(2);
|
||||
return (
|
||||
@ -161,18 +149,6 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
|
||||
<TableCell>{indexAttempt.total_docs_indexed}</TableCell>
|
||||
<TableCell>
|
||||
<div>
|
||||
{indexAttempt.error_count > 0 && (
|
||||
<Link
|
||||
className="cursor-pointer my-auto"
|
||||
href={`/admin/indexing/${indexAttempt.id}`}
|
||||
>
|
||||
<Text className="flex flex-wrap text-link whitespace-normal">
|
||||
<SearchIcon />
|
||||
View Errors
|
||||
</Text>
|
||||
</Link>
|
||||
)}
|
||||
|
||||
{indexAttempt.status === "success" && (
|
||||
<Text className="flex flex-wrap whitespace-normal">
|
||||
{"-"}
|
||||
@ -209,7 +185,7 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
|
||||
<PageSelector
|
||||
totalPages={totalPages}
|
||||
currentPage={currentPage}
|
||||
onPageChange={goToPage}
|
||||
onPageChange={onPageChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -1,11 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { runConnector } from "@/lib/connector";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import Text from "@/components/ui/text";
|
||||
import { mutate } from "swr";
|
||||
import { buildCCPairInfoUrl } from "./lib";
|
||||
import { triggerIndexing } from "./lib";
|
||||
import { useState } from "react";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
@ -23,26 +21,6 @@ function ReIndexPopup({
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
hide: () => void;
|
||||
}) {
|
||||
async function triggerIndexing(fromBeginning: boolean) {
|
||||
const errorMsg = await runConnector(
|
||||
connectorId,
|
||||
[credentialId],
|
||||
fromBeginning
|
||||
);
|
||||
if (errorMsg) {
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Triggered connector run",
|
||||
type: "success",
|
||||
});
|
||||
}
|
||||
mutate(buildCCPairInfoUrl(ccPairId));
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal title="Run Indexing" onOutsideClick={hide}>
|
||||
<div>
|
||||
@ -50,7 +28,13 @@ function ReIndexPopup({
|
||||
variant="submit"
|
||||
className="ml-auto"
|
||||
onClick={() => {
|
||||
triggerIndexing(false);
|
||||
triggerIndexing(
|
||||
false,
|
||||
connectorId,
|
||||
credentialId,
|
||||
ccPairId,
|
||||
setPopup
|
||||
);
|
||||
hide();
|
||||
}}
|
||||
>
|
||||
@ -68,7 +52,13 @@ function ReIndexPopup({
|
||||
variant="submit"
|
||||
className="ml-auto"
|
||||
onClick={() => {
|
||||
triggerIndexing(true);
|
||||
triggerIndexing(
|
||||
true,
|
||||
connectorId,
|
||||
credentialId,
|
||||
ccPairId,
|
||||
setPopup
|
||||
);
|
||||
hide();
|
||||
}}
|
||||
>
|
||||
|
@ -1,4 +1,7 @@
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { runConnector } from "@/lib/connector";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { mutate } from "swr";
|
||||
|
||||
export function buildCCPairInfoUrl(ccPairId: string | number) {
|
||||
return `/api/manage/admin/cc-pair/${ccPairId}`;
|
||||
@ -11,3 +14,29 @@ export function buildSimilarCredentialInfoURL(
|
||||
const base = `/api/manage/admin/similar-credentials/${source_type}`;
|
||||
return get_editable ? `${base}?get_editable=True` : base;
|
||||
}
|
||||
|
||||
export async function triggerIndexing(
|
||||
fromBeginning: boolean,
|
||||
connectorId: number,
|
||||
credentialId: number,
|
||||
ccPairId: number,
|
||||
setPopup: (popupSpec: PopupSpec | null) => void
|
||||
) {
|
||||
const errorMsg = await runConnector(
|
||||
connectorId,
|
||||
[credentialId],
|
||||
fromBeginning
|
||||
);
|
||||
if (errorMsg) {
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Triggered connector run",
|
||||
type: "success",
|
||||
});
|
||||
}
|
||||
mutate(buildCCPairInfoUrl(ccPairId));
|
||||
}
|
||||
|
@ -25,13 +25,24 @@ import DeletionErrorStatus from "./DeletionErrorStatus";
|
||||
import { IndexingAttemptsTable } from "./IndexingAttemptsTable";
|
||||
import { ModifyStatusButtonCluster } from "./ModifyStatusButtonCluster";
|
||||
import { ReIndexButton } from "./ReIndexButton";
|
||||
import { buildCCPairInfoUrl } from "./lib";
|
||||
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
|
||||
import { buildCCPairInfoUrl, triggerIndexing } from "./lib";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import {
|
||||
CCPairFullInfo,
|
||||
ConnectorCredentialPairStatus,
|
||||
IndexAttemptError,
|
||||
PaginatedIndexAttemptErrors,
|
||||
} from "./types";
|
||||
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import EditPropertyModal from "@/components/modals/EditPropertyModal";
|
||||
|
||||
import * as Yup from "yup";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
import IndexAttemptErrorsModal from "./IndexAttemptErrorsModal";
|
||||
import usePaginatedFetch from "@/hooks/usePaginatedFetch";
|
||||
import { IndexAttemptSnapshot } from "@/lib/types";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
|
||||
// synchronize these validations with the SQLAlchemy connector class until we have a
|
||||
// centralized schema for both frontend and backend
|
||||
@ -51,43 +62,99 @@ const PruneFrequencySchema = Yup.object().shape({
|
||||
.required("Property value is required"),
|
||||
});
|
||||
|
||||
const ITEMS_PER_PAGE = 8;
|
||||
const PAGES_PER_BATCH = 8;
|
||||
|
||||
function Main({ ccPairId }: { ccPairId: number }) {
|
||||
const router = useRouter(); // Initialize the router
|
||||
const router = useRouter();
|
||||
const {
|
||||
data: ccPair,
|
||||
isLoading,
|
||||
error,
|
||||
isLoading: isLoadingCCPair,
|
||||
error: ccPairError,
|
||||
} = useSWR<CCPairFullInfo>(
|
||||
buildCCPairInfoUrl(ccPairId),
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
|
||||
const {
|
||||
currentPageData: indexAttempts,
|
||||
isLoading: isLoadingIndexAttempts,
|
||||
currentPage,
|
||||
totalPages,
|
||||
goToPage,
|
||||
} = usePaginatedFetch<IndexAttemptSnapshot>({
|
||||
itemsPerPage: ITEMS_PER_PAGE,
|
||||
pagesPerBatch: PAGES_PER_BATCH,
|
||||
endpoint: `${buildCCPairInfoUrl(ccPairId)}/index-attempts`,
|
||||
});
|
||||
|
||||
const {
|
||||
currentPageData: indexAttemptErrorsPage,
|
||||
currentPage: errorsCurrentPage,
|
||||
totalPages: errorsTotalPages,
|
||||
goToPage: goToErrorsPage,
|
||||
} = usePaginatedFetch<IndexAttemptError>({
|
||||
itemsPerPage: 10,
|
||||
pagesPerBatch: 1,
|
||||
endpoint: `/api/manage/admin/cc-pair/${ccPairId}/errors`,
|
||||
});
|
||||
|
||||
const indexAttemptErrors = indexAttemptErrorsPage
|
||||
? {
|
||||
items: indexAttemptErrorsPage,
|
||||
total_items:
|
||||
errorsCurrentPage === errorsTotalPages &&
|
||||
indexAttemptErrorsPage.length === 0
|
||||
? 0
|
||||
: errorsTotalPages * 10,
|
||||
}
|
||||
: null;
|
||||
|
||||
const [hasLoadedOnce, setHasLoadedOnce] = useState(false);
|
||||
const [editingRefreshFrequency, setEditingRefreshFrequency] = useState(false);
|
||||
const [editingPruningFrequency, setEditingPruningFrequency] = useState(false);
|
||||
const [showIndexAttemptErrors, setShowIndexAttemptErrors] = useState(false);
|
||||
const [showIsResolvingKickoffLoader, setShowIsResolvingKickoffLoader] =
|
||||
useState(false);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const latestIndexAttempt = indexAttempts?.[0];
|
||||
const isResolvingErrors =
|
||||
(latestIndexAttempt?.status === "in_progress" ||
|
||||
latestIndexAttempt?.status === "not_started") &&
|
||||
latestIndexAttempt?.from_beginning &&
|
||||
// if there are errors in the latest index attempt, we don't want to show the loader
|
||||
!indexAttemptErrors?.items?.some(
|
||||
(error) => error.index_attempt_id === latestIndexAttempt?.id
|
||||
);
|
||||
|
||||
const finishConnectorDeletion = useCallback(() => {
|
||||
router.push("/admin/indexing/status?message=connector-deleted");
|
||||
}, [router]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoading) {
|
||||
if (isLoadingCCPair) {
|
||||
return;
|
||||
}
|
||||
if (ccPair && !error) {
|
||||
if (ccPair && !ccPairError) {
|
||||
setHasLoadedOnce(true);
|
||||
}
|
||||
|
||||
if (
|
||||
(hasLoadedOnce && (error || !ccPair)) ||
|
||||
(hasLoadedOnce && (ccPairError || !ccPair)) ||
|
||||
(ccPair?.status === ConnectorCredentialPairStatus.DELETING &&
|
||||
!ccPair.connector)
|
||||
) {
|
||||
finishConnectorDeletion();
|
||||
}
|
||||
}, [isLoading, ccPair, error, hasLoadedOnce, finishConnectorDeletion]);
|
||||
}, [
|
||||
isLoadingCCPair,
|
||||
ccPair,
|
||||
ccPairError,
|
||||
hasLoadedOnce,
|
||||
finishConnectorDeletion,
|
||||
]);
|
||||
|
||||
const handleUpdateName = async (newName: string) => {
|
||||
try {
|
||||
@ -191,15 +258,19 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
}
|
||||
};
|
||||
|
||||
if (isLoading) {
|
||||
if (isLoadingCCPair || isLoadingIndexAttempts) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (!ccPair || (!hasLoadedOnce && error)) {
|
||||
if (!ccPair || (!hasLoadedOnce && ccPairError)) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle={`Failed to fetch info on Connector with ID ${ccPairId}`}
|
||||
errorMsg={error?.info?.detail || error?.toString() || "Unknown error"}
|
||||
errorMsg={
|
||||
ccPairError?.info?.detail ||
|
||||
ccPairError?.toString() ||
|
||||
"Unknown error"
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -219,6 +290,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
{showIsResolvingKickoffLoader && !isResolvingErrors && <Spinner />}
|
||||
|
||||
{editingRefreshFrequency && (
|
||||
<EditPropertyModal
|
||||
@ -244,6 +316,32 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{showIndexAttemptErrors && indexAttemptErrors && (
|
||||
<IndexAttemptErrorsModal
|
||||
errors={indexAttemptErrors}
|
||||
onClose={() => setShowIndexAttemptErrors(false)}
|
||||
onResolveAll={async () => {
|
||||
setShowIndexAttemptErrors(false);
|
||||
setShowIsResolvingKickoffLoader(true);
|
||||
await triggerIndexing(
|
||||
true,
|
||||
ccPair.connector.id,
|
||||
ccPair.credential.id,
|
||||
ccPair.id,
|
||||
setPopup
|
||||
);
|
||||
|
||||
// show the loader for a max of 10 seconds
|
||||
setTimeout(() => {
|
||||
setShowIsResolvingKickoffLoader(false);
|
||||
}, 10000);
|
||||
}}
|
||||
isResolvingErrors={isResolvingErrors}
|
||||
onPageChange={goToErrorsPage}
|
||||
currentPage={errorsCurrentPage}
|
||||
/>
|
||||
)}
|
||||
|
||||
<BackButton
|
||||
behaviorOverride={() => router.push("/admin/indexing/status")}
|
||||
/>
|
||||
@ -342,13 +440,46 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* NOTE: no divider / title here for `ConfigDisplay` since it is optional and we need
|
||||
to render these conditionally.*/}
|
||||
<div className="mt-6">
|
||||
<div className="flex">
|
||||
<Title>Indexing Attempts</Title>
|
||||
</div>
|
||||
<IndexingAttemptsTable ccPair={ccPair} />
|
||||
{indexAttemptErrors && indexAttemptErrors.total_items > 0 && (
|
||||
<Alert className="border-alert bg-yellow-50 my-2">
|
||||
<AlertCircle className="h-4 w-4 text-yellow-700" />
|
||||
<AlertTitle className="text-yellow-950 font-semibold">
|
||||
Some documents failed to index
|
||||
</AlertTitle>
|
||||
<AlertDescription className="text-yellow-900">
|
||||
{isResolvingErrors ? (
|
||||
<span>
|
||||
<span className="text-sm text-yellow-700 animate-pulse">
|
||||
Resolving failures
|
||||
</span>
|
||||
</span>
|
||||
) : (
|
||||
<>
|
||||
We ran into some issues while processing some documents.{" "}
|
||||
<b
|
||||
className="text-link cursor-pointer"
|
||||
onClick={() => setShowIndexAttemptErrors(true)}
|
||||
>
|
||||
View details.
|
||||
</b>
|
||||
</>
|
||||
)}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
{indexAttempts && (
|
||||
<IndexingAttemptsTable
|
||||
ccPair={ccPair}
|
||||
indexAttempts={indexAttempts}
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={goToPage}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<Separator />
|
||||
<div className="flex mt-4">
|
||||
|
@ -37,3 +37,27 @@ export interface PaginatedIndexAttempts {
|
||||
page: number;
|
||||
total_pages: number;
|
||||
}
|
||||
|
||||
export interface IndexAttemptError {
|
||||
id: number;
|
||||
connector_credential_pair_id: number;
|
||||
|
||||
document_id: string | null;
|
||||
document_link: string | null;
|
||||
|
||||
entity_id: string | null;
|
||||
failed_time_range_start: string | null;
|
||||
failed_time_range_end: string | null;
|
||||
|
||||
failure_message: string;
|
||||
is_resolved: boolean;
|
||||
|
||||
time_created: string;
|
||||
|
||||
index_attempt_id: number;
|
||||
}
|
||||
|
||||
export interface PaginatedIndexAttemptErrors {
|
||||
items: IndexAttemptError[];
|
||||
total_items: number;
|
||||
}
|
||||
|
@ -1,189 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
import { CheckmarkIcon, CopyIcon } from "@/components/icons/icons";
|
||||
import { localizeAndPrettify } from "@/lib/time";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import Text from "@/components/ui/text";
|
||||
import { useState } from "react";
|
||||
import { IndexAttemptError } from "./types";
|
||||
import { TableHeader } from "@/components/ui/table";
|
||||
|
||||
const NUM_IN_PAGE = 8;
|
||||
|
||||
export function CustomModal({
|
||||
isVisible,
|
||||
onClose,
|
||||
title,
|
||||
content,
|
||||
showCopyButton = false,
|
||||
}: {
|
||||
isVisible: boolean;
|
||||
onClose: () => void;
|
||||
title: string;
|
||||
content: string;
|
||||
showCopyButton?: boolean;
|
||||
}) {
|
||||
const [copyClicked, setCopyClicked] = useState(false);
|
||||
|
||||
if (!isVisible) return null;
|
||||
|
||||
return (
|
||||
<Modal
|
||||
width="w-4/6"
|
||||
className="h-5/6 overflow-y-hidden flex flex-col"
|
||||
title={title}
|
||||
onOutsideClick={onClose}
|
||||
>
|
||||
<div className="overflow-y-auto mb-6">
|
||||
{showCopyButton && (
|
||||
<div className="mb-6">
|
||||
{!copyClicked ? (
|
||||
<div
|
||||
onClick={() => {
|
||||
navigator.clipboard.writeText(content);
|
||||
setCopyClicked(true);
|
||||
setTimeout(() => setCopyClicked(false), 2000);
|
||||
}}
|
||||
className="flex w-fit cursor-pointer hover:bg-accent-background p-2 border-border border rounded"
|
||||
>
|
||||
Copy full content
|
||||
<CopyIcon className="ml-2 my-auto" />
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex w-fit hover:bg-accent-background p-2 border-border border rounded cursor-default">
|
||||
Copied to clipboard
|
||||
<CheckmarkIcon
|
||||
className="my-auto ml-2 flex flex-shrink-0 text-success"
|
||||
size={16}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<div className="whitespace-pre-wrap">{content}</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
||||
export function IndexAttemptErrorsTable({
|
||||
indexAttemptErrors,
|
||||
}: {
|
||||
indexAttemptErrors: IndexAttemptError[];
|
||||
}) {
|
||||
const [page, setPage] = useState(1);
|
||||
const [modalData, setModalData] = useState<{
|
||||
id: number | null;
|
||||
title: string;
|
||||
content: string;
|
||||
} | null>(null);
|
||||
const closeModal = () => setModalData(null);
|
||||
|
||||
return (
|
||||
<>
|
||||
{modalData && (
|
||||
<CustomModal
|
||||
isVisible={!!modalData}
|
||||
onClose={closeModal}
|
||||
title={modalData.title}
|
||||
content={modalData.content}
|
||||
showCopyButton
|
||||
/>
|
||||
)}
|
||||
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Timestamp</TableHead>
|
||||
<TableHead>Batch Number</TableHead>
|
||||
<TableHead>Document Summaries</TableHead>
|
||||
<TableHead>Error Message</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{indexAttemptErrors
|
||||
.slice(NUM_IN_PAGE * (page - 1), NUM_IN_PAGE * page)
|
||||
.map((indexAttemptError) => {
|
||||
return (
|
||||
<TableRow key={indexAttemptError.id}>
|
||||
<TableCell>
|
||||
{indexAttemptError.time_created
|
||||
? localizeAndPrettify(indexAttemptError.time_created)
|
||||
: "-"}
|
||||
</TableCell>
|
||||
<TableCell>{indexAttemptError.batch_number}</TableCell>
|
||||
<TableCell>
|
||||
{indexAttemptError.doc_summaries && (
|
||||
<div
|
||||
onClick={() =>
|
||||
setModalData({
|
||||
id: indexAttemptError.id,
|
||||
title: "Document Summaries",
|
||||
content: JSON.stringify(
|
||||
indexAttemptError.doc_summaries,
|
||||
null,
|
||||
2
|
||||
),
|
||||
})
|
||||
}
|
||||
className="mt-2 text-link cursor-pointer select-none"
|
||||
>
|
||||
View Document Summaries
|
||||
</div>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div>
|
||||
<Text className="flex flex-wrap whitespace-normal">
|
||||
{indexAttemptError.error_msg || "-"}
|
||||
</Text>
|
||||
{indexAttemptError.traceback && (
|
||||
<div
|
||||
onClick={() =>
|
||||
setModalData({
|
||||
id: indexAttemptError.id,
|
||||
title: "Exception Traceback",
|
||||
content: indexAttemptError.traceback!,
|
||||
})
|
||||
}
|
||||
className="mt-2 text-link cursor-pointer select-none"
|
||||
>
|
||||
View Full Trace
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
{indexAttemptErrors.length > NUM_IN_PAGE && (
|
||||
<div className="mt-3 flex">
|
||||
<div className="mx-auto">
|
||||
<PageSelector
|
||||
totalPages={Math.ceil(indexAttemptErrors.length / NUM_IN_PAGE)}
|
||||
currentPage={page}
|
||||
onPageChange={(newPage) => {
|
||||
setPage(newPage);
|
||||
window.scrollTo({
|
||||
top: 0,
|
||||
left: 0,
|
||||
behavior: "smooth",
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
export function buildIndexingErrorsUrl(id: string | number) {
|
||||
return `/api/manage/admin/indexing-errors/${id}`;
|
||||
}
|
@ -1,59 +0,0 @@
|
||||
"use client";
|
||||
import { use } from "react";
|
||||
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import Title from "@/components/ui/title";
|
||||
import useSWR from "swr";
|
||||
import { IndexAttemptErrorsTable } from "./IndexAttemptErrorsTable";
|
||||
import { buildIndexingErrorsUrl } from "./lib";
|
||||
import { IndexAttemptError } from "./types";
|
||||
|
||||
function Main({ id }: { id: number }) {
|
||||
const {
|
||||
data: indexAttemptErrors,
|
||||
isLoading,
|
||||
error,
|
||||
} = useSWR<IndexAttemptError[]>(
|
||||
buildIndexingErrorsUrl(id),
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
if (isLoading) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (error || !indexAttemptErrors) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle={`Failed to fetch errors for attempt ID ${id}`}
|
||||
errorMsg={error?.info?.detail || error.toString()}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<BackButton />
|
||||
<div className="mt-6">
|
||||
<div className="flex">
|
||||
<Title>Indexing Errors for Attempt {id}</Title>
|
||||
</div>
|
||||
<IndexAttemptErrorsTable indexAttemptErrors={indexAttemptErrors} />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Page(props: { params: Promise<{ id: string }> }) {
|
||||
const params = use(props.params);
|
||||
const id = parseInt(params.id);
|
||||
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<Main id={id} />
|
||||
</div>
|
||||
);
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
export interface IndexAttemptError {
|
||||
id: number;
|
||||
index_attempt_id: number;
|
||||
batch_number: number;
|
||||
doc_summaries: DocumentErrorSummary[];
|
||||
error_msg: string;
|
||||
traceback: string;
|
||||
time_created: string;
|
||||
}
|
||||
|
||||
export interface DocumentErrorSummary {
|
||||
id: string;
|
||||
semantic_id: string;
|
||||
section_link: string;
|
||||
}
|
@ -41,25 +41,11 @@ export function IndexAttemptStatus({
|
||||
badge = icon;
|
||||
}
|
||||
} else if (status === "completed_with_errors") {
|
||||
const icon = (
|
||||
badge = (
|
||||
<Badge variant="secondary" icon={FiAlertTriangle}>
|
||||
Completed with errors
|
||||
</Badge>
|
||||
);
|
||||
badge = (
|
||||
<HoverPopup
|
||||
mainContent={<div className="cursor-pointer">{icon}</div>}
|
||||
popupContent={
|
||||
<div className="w-64 p-2 break-words overflow-hidden whitespace-normal">
|
||||
The indexing attempt completed, but some errors were encountered
|
||||
during the run.
|
||||
<br />
|
||||
<br />
|
||||
Click View Errors for more details.
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
);
|
||||
} else if (status === "success") {
|
||||
badge = (
|
||||
<Badge variant="success" icon={FiCheckCircle}>
|
||||
|
@ -7,12 +7,13 @@ import {
|
||||
} from "@/lib/types";
|
||||
import { ChatSessionMinimal } from "@/app/ee/admin/performance/usage/types";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { PaginatedIndexAttemptErrors } from "@/app/admin/connector/[ccPairId]/types";
|
||||
|
||||
type PaginatedType =
|
||||
| IndexAttemptSnapshot
|
||||
| AcceptedUserSnapshot
|
||||
| InvitedUserSnapshot
|
||||
| ChatSessionMinimal;
|
||||
// Any type that has an id property
|
||||
type PaginatedType = {
|
||||
id: number | string;
|
||||
[key: string]: any;
|
||||
};
|
||||
|
||||
interface PaginatedApiResponse<T extends PaginatedType> {
|
||||
items: T[];
|
||||
|
@ -1232,6 +1232,7 @@ export interface ConnectorBase<T> {
|
||||
indexing_start: Date | null;
|
||||
access_type: string;
|
||||
groups?: number[];
|
||||
from_beginning?: boolean;
|
||||
}
|
||||
|
||||
export interface Connector<T> extends ConnectorBase<T> {
|
||||
@ -1253,6 +1254,7 @@ export interface ConnectorSnapshot {
|
||||
indexing_start: number | null;
|
||||
time_created: string;
|
||||
time_updated: string;
|
||||
from_beginning?: boolean;
|
||||
}
|
||||
|
||||
export interface WebConfig {
|
||||
|
@ -335,6 +335,13 @@ export const SOURCE_METADATA_MAP: SourceMap = {
|
||||
displayName: "Not Applicable",
|
||||
category: SourceCategory.Other,
|
||||
},
|
||||
|
||||
// Just so integration tests don't crash the UI
|
||||
mock_connector: {
|
||||
icon: GlobeIcon,
|
||||
displayName: "Mock Connector",
|
||||
category: SourceCategory.Other,
|
||||
},
|
||||
} as SourceMap;
|
||||
|
||||
function fillSourceMetadata(
|
||||
|
@ -123,6 +123,7 @@ export interface FailedConnectorIndexingStatus {
|
||||
export interface IndexAttemptSnapshot {
|
||||
id: number;
|
||||
status: ValidStatuses | null;
|
||||
from_beginning: boolean;
|
||||
new_docs_indexed: number;
|
||||
docs_removed_from_index: number;
|
||||
total_docs_indexed: number;
|
||||
|
Loading…
x
Reference in New Issue
Block a user