Connector checkpointing (#3876)

* wip checkpointing/continue on failure

more stuff for checkpointing

Basic implementation

FE stuff

More checkpointing/failure handling

rebase

rebase

initial scaffolding for IT

IT to test checkpointing

Cleanup

cleanup

Fix it

Rebase

Add todo

Fix actions IT

Test more

Pagination + fixes + cleanup

Fix IT networking

fix it

* rebase

* Address misc comments

* Address comments

* Remove unused router

* rebase

* Fix mypy

* Fixes

* fix it

* Fix tests

* Add drop index

* Add retries

* reset lock timeout

* Try hard drop of schema

* Add timeout/retries to downgrade

* rebase

* test

* test

* test

* Close all connections

* test closing idle only

* Fix it

* fix

* try using null pool

* Test

* fix

* rebase

* log

* Fix

* apply null pool

* Fix other test

* Fix quality checks

* Test not using the fixture

* Fix ordering

* fix test

* Change pooling behavior
This commit is contained in:
Chris Weaver
2025-02-15 18:34:39 -08:00
committed by GitHub
parent bc087fc20e
commit f1fc8ac19b
68 changed files with 3333 additions and 1102 deletions

View File

@@ -1,14 +1,15 @@
import requests
from sqlalchemy.orm import Session
from onyx.db.engine import get_session_context_manager
from onyx.db.models import User
def test_create_chat_session_and_send_messages(db_session: Session) -> None:
def test_create_chat_session_and_send_messages() -> None:
# Create a test user
test_user = User(email="test@example.com", hashed_password="dummy_hash")
db_session.add(test_user)
db_session.commit()
with get_session_context_manager() as db_session:
test_user = User(email="test@example.com", hashed_password="dummy_hash")
db_session.add(test_user)
db_session.commit()
base_url = "http://localhost:8080" # Adjust this to your API's base URL
headers = {"Authorization": f"Bearer {test_user.id}"}

View File

@@ -1,5 +1,7 @@
import os
ADMIN_USER_NAME = "admin_user"
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
@@ -9,3 +11,6 @@ MAX_DELAY = 45
GENERAL_HEADERS = {"Content-Type": "application/json"}
NUM_DOCS = 5
MOCK_CONNECTOR_SERVER_HOST = os.getenv("MOCK_CONNECTOR_SERVER_HOST") or "localhost"
MOCK_CONNECTOR_SERVER_PORT = os.getenv("MOCK_CONNECTOR_SERVER_PORT") or 8001

View File

@@ -223,12 +223,13 @@ class CCPairManager:
@staticmethod
def run_once(
cc_pair: DATestCCPair,
from_beginning: bool,
user_performing_action: DATestUser | None = None,
) -> None:
body = {
"connector_id": cc_pair.connector_id,
"credential_ids": [cc_pair.credential_id],
"from_beginning": True,
"from_beginning": from_beginning,
}
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",

View File

@@ -1,9 +1,14 @@
from uuid import uuid4
import requests
from sqlalchemy import and_
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.db.enums import AccessType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentByConnectorCredentialPair
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
@@ -186,3 +191,39 @@ class DocumentManager:
group_names,
doc_creating_user,
)
@staticmethod
def fetch_documents_for_cc_pair(
cc_pair_id: int,
db_session: Session,
vespa_client: vespa_fixture,
) -> list[SimpleTestDocument]:
stmt = (
select(DocumentByConnectorCredentialPair)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.where(ConnectorCredentialPair.id == cc_pair_id)
)
documents = db_session.execute(stmt).scalars().all()
if not documents:
return []
doc_ids = [document.id for document in documents]
retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"]
final_docs: list[SimpleTestDocument] = []
# NOTE: they are really chunks, but we're assuming that for these tests
# we only have one chunk per document for now
for doc_dict in retrieved_docs_dict:
doc_id = doc_dict["fields"]["document_id"]
doc_content = doc_dict["fields"]["content"]
final_docs.append(SimpleTestDocument(id=doc_id, content=doc_content))
return final_docs

View File

@@ -4,6 +4,7 @@ from urllib.parse import urlencode
import requests
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import IndexModelStatus
from onyx.db.models import IndexAttempt
@@ -13,6 +14,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestIndexAttempt
from tests.integration.common_utils.test_models import DATestUser
@@ -92,8 +94,12 @@ class IndexAttemptManager:
"page_size": page_size,
}
url = (
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts"
f"?{urlencode(query_params, doseq=True)}"
)
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts?{urlencode(query_params, doseq=True)}",
url=url,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
@@ -104,3 +110,125 @@ class IndexAttemptManager:
items=[IndexAttemptSnapshot(**item) for item in data["items"]],
total_items=data["total_items"],
)
@staticmethod
def get_latest_index_attempt_for_cc_pair(
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot | None:
"""Get an IndexAttempt by ID"""
index_attempts = IndexAttemptManager.get_index_attempt_page(
cc_pair_id, user_performing_action=user_performing_action
).items
if not index_attempts:
return None
index_attempts = sorted(
index_attempts, key=lambda x: x.time_started or "0", reverse=True
)
return index_attempts[0]
@staticmethod
def wait_for_index_attempt_start(
cc_pair_id: int,
index_attempts_to_ignore: list[int] | None = None,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
"""Wait for an IndexAttempt to start"""
start = datetime.now()
index_attempts_to_ignore = index_attempts_to_ignore or []
while True:
index_attempt = IndexAttemptManager.get_latest_index_attempt_for_cc_pair(
cc_pair_id=cc_pair_id,
user_performing_action=user_performing_action,
)
if (
index_attempt
and index_attempt.time_started
and index_attempt.id not in index_attempts_to_ignore
):
return index_attempt
elapsed = (datetime.now() - start).total_seconds()
if elapsed > timeout:
raise TimeoutError(
f"IndexAttempt for CC Pair {cc_pair_id} did not start within {timeout} seconds"
)
@staticmethod
def get_index_attempt_by_id(
index_attempt_id: int,
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
page_num = 0
page_size = 10
while True:
page = IndexAttemptManager.get_index_attempt_page(
cc_pair_id=cc_pair_id,
page=page_num,
page_size=page_size,
user_performing_action=user_performing_action,
)
for attempt in page.items:
if attempt.id == index_attempt_id:
return attempt
if len(page.items) < page_size:
break
page_num += 1
raise ValueError(f"IndexAttempt {index_attempt_id} not found")
@staticmethod
def wait_for_index_attempt_completion(
index_attempt_id: int,
cc_pair_id: int,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""Wait for an IndexAttempt to complete"""
start = datetime.now()
while True:
index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair_id,
user_performing_action=user_performing_action,
)
if index_attempt.status and index_attempt.status.is_terminal():
print(f"IndexAttempt {index_attempt_id} completed")
return
elapsed = (datetime.now() - start).total_seconds()
if elapsed > timeout:
raise TimeoutError(
f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds"
)
print(
f"Waiting for IndexAttempt {index_attempt_id} to complete. "
f"elapsed={elapsed:.2f} timeout={timeout}"
)
@staticmethod
def get_index_attempt_errors_for_cc_pair(
cc_pair_id: int,
include_resolved: bool = True,
user_performing_action: DATestUser | None = None,
) -> list[IndexAttemptErrorPydantic]:
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
if include_resolved:
url += "&include_resolved=true"
response = requests.get(
url=url,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
data = response.json()
return [IndexAttemptErrorPydantic(**item) for item in data["items"]]

View File

@@ -25,6 +25,7 @@ from onyx.indexing.models import IndexingSetting
from onyx.setup import setup_postgres
from onyx.setup import setup_vespa
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.timeout import run_with_timeout
logger = setup_logger()
@@ -66,6 +67,7 @@ def _run_migrations(
def downgrade_postgres(
database: str = "postgres",
schema: str = "public",
config_name: str = "alembic",
revision: str = "base",
clear_data: bool = False,
@@ -73,8 +75,8 @@ def downgrade_postgres(
"""Downgrade Postgres database to base state."""
if clear_data:
if revision != "base":
logger.warning("Clearing data without rolling back to base state")
# Delete all rows to allow migrations to be rolled back
raise ValueError("Clearing data without rolling back to base state")
conn = psycopg2.connect(
dbname=database,
user=POSTGRES_USER,
@@ -82,38 +84,33 @@ def downgrade_postgres(
host=POSTGRES_HOST,
port=POSTGRES_PORT,
)
conn.autocommit = True # Need autocommit for dropping schema
cur = conn.cursor()
# Disable triggers to prevent foreign key constraints from being checked
cur.execute("SET session_replication_role = 'replica';")
# Fetch all table names in the current database
# Close any existing connections to the schema before dropping
cur.execute(
"""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
f"""
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND pg_stat_activity.state = 'idle in transaction'
AND pid <> pg_backend_pid();
"""
)
tables = cur.fetchall()
# Drop and recreate the public schema - this removes ALL objects
cur.execute(f"DROP SCHEMA {schema} CASCADE;")
cur.execute(f"CREATE SCHEMA {schema};")
for table in tables:
table_name = table[0]
# Restore default privileges
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO postgres;")
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO public;")
# Don't touch migration history or Kombu
if table_name in ("alembic_version", "kombu_message", "kombu_queue"):
continue
cur.execute(f'DELETE FROM "{table_name}"')
# Re-enable triggers
cur.execute("SET session_replication_role = 'origin';")
conn.commit()
cur.close()
conn.close()
return
# Downgrade to base
conn_str = build_connection_string(
db=database,
@@ -157,11 +154,37 @@ def reset_postgres(
setup_onyx: bool = True,
) -> None:
"""Reset the Postgres database."""
downgrade_postgres(
database=database, config_name=config_name, revision="base", clear_data=True
)
# this seems to hang due to locking issues, so run with a timeout with a few retries
NUM_TRIES = 10
TIMEOUT = 10
success = False
for _ in range(NUM_TRIES):
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
try:
run_with_timeout(
downgrade_postgres,
TIMEOUT,
kwargs={
"database": database,
"config_name": config_name,
"revision": "base",
"clear_data": True,
},
)
success = True
break
except TimeoutError:
logger.warning(
f"Postgres downgrade timed out, retrying... ({_ + 1}/{NUM_TRIES})"
)
if not success:
raise RuntimeError("Postgres downgrade failed after 10 timeouts.")
logger.info("Upgrading Postgres...")
upgrade_postgres(database=database, config_name=config_name, revision="head")
if setup_onyx:
logger.info("Setting up Postgres...")
with get_session_context_manager() as db_session:
setup_postgres(db_session)

View File

@@ -0,0 +1,57 @@
import uuid
from datetime import datetime
from datetime import timezone
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import Section
def create_test_document(
doc_id: str | None = None,
text: str = "Test content",
link: str = "http://test.com",
source: DocumentSource = DocumentSource.MOCK_CONNECTOR,
metadata: dict | None = None,
) -> Document:
"""Create a test document with the given parameters.
Args:
doc_id: Optional document ID. If not provided, a random UUID will be generated.
text: The text content of the document. Defaults to "Test content".
link: The link for the document section. Defaults to "http://test.com".
source: The document source. Defaults to MOCK_CONNECTOR.
metadata: Optional metadata dictionary. Defaults to empty dict.
"""
doc_id = doc_id or f"test-doc-{uuid.uuid4()}"
return Document(
id=doc_id,
sections=[Section(text=text, link=link)],
source=source,
semantic_identifier=doc_id,
doc_updated_at=datetime.now(timezone.utc),
metadata=metadata or {},
)
def create_test_document_failure(
doc_id: str,
failure_message: str = "Simulated failure",
document_link: str | None = None,
) -> ConnectorFailure:
"""Create a test document failure with the given parameters.
Args:
doc_id: The ID of the document that failed.
failure_message: The failure message. Defaults to "Simulated failure".
document_link: Optional link to the failed document.
"""
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=document_link,
),
failure_message=failure_message,
)

View File

@@ -0,0 +1,18 @@
import multiprocessing
from collections.abc import Callable
from typing import Any
from typing import TypeVar
T = TypeVar("T")
def run_with_timeout(task: Callable[..., T], timeout: int, kwargs: dict[str, Any]) -> T:
# Use multiprocessing to prevent a thread from blocking the main thread
with multiprocessing.Pool(processes=1) as pool:
async_result = pool.apply_async(task, kwds=kwargs)
try:
# Wait at most timeout seconds for the function to complete
result = async_result.get(timeout=timeout)
return result
except multiprocessing.TimeoutError:
raise TimeoutError(f"Function timed out after {timeout} seconds")

View File

@@ -1,12 +1,11 @@
import os
from collections.abc import Generator
import pytest
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
@@ -36,16 +35,24 @@ def load_env_vars(env_file: str = ".env") -> None:
load_env_vars()
@pytest.fixture
def db_session() -> Generator[Session, None, None]:
with get_session_context_manager() as session:
yield session
"""NOTE: for some reason using this seems to lead to misc
`sqlalchemy.exc.OperationalError: (psycopg2.OperationalError) server closed the connection unexpectedly`
errors.
Commenting out till we can get to the bottom of it. For now, just using
instantiate the session directly within the test.
"""
# @pytest.fixture
# def db_session() -> Generator[Session, None, None]:
# with get_session_context_manager() as session:
# yield session
@pytest.fixture
def vespa_client(db_session: Session) -> vespa_fixture:
search_settings = get_current_search_settings(db_session)
return vespa_fixture(index_name=search_settings.index_name)
def vespa_client() -> vespa_fixture:
with get_session_context_manager() as db_session:
search_settings = get_current_search_settings(db_session)
return vespa_fixture(index_name=search_settings.index_name)
@pytest.fixture
@@ -56,20 +63,27 @@ def reset() -> None:
@pytest.fixture
def new_admin_user(reset: None) -> DATestUser | None:
try:
return UserManager.create(name="admin_user")
return UserManager.create(name=ADMIN_USER_NAME)
except Exception:
return None
@pytest.fixture
def admin_user() -> DATestUser | None:
def admin_user() -> DATestUser:
try:
return UserManager.create(name="admin_user")
except Exception:
pass
user = UserManager.create(name=ADMIN_USER_NAME, is_first_user=True)
# if there are other users for some reason, reset and try again
if not UserManager.is_role(user, UserRole.ADMIN):
print("Trying to reset")
reset_all()
user = UserManager.create(name=ADMIN_USER_NAME)
return user
except Exception as e:
print(f"Failed to create admin user: {e}")
try:
return UserManager.login_as_user(
user = UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
@@ -79,10 +93,16 @@ def admin_user() -> DATestUser | None:
is_active=True,
)
)
except Exception:
pass
if not UserManager.is_role(user, UserRole.ADMIN):
reset_all()
user = UserManager.create(name=ADMIN_USER_NAME)
return user
return None
return user
except Exception as e:
print(f"Failed to create or login as admin user: {e}")
raise RuntimeError("Failed to create or login as admin user")
@pytest.fixture

View File

@@ -138,7 +138,9 @@ def test_google_permission_sync(
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_1, doc_text_1)
# run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair, after=before, user_performing_action=admin_user
)
@@ -184,7 +186,9 @@ def test_google_permission_sync(
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_2, doc_text_2)
# Run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,

View File

@@ -113,7 +113,9 @@ def test_slack_permission_sync(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
@@ -305,7 +307,9 @@ def test_slack_group_permission_sync(
)
# Run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,

View File

@@ -111,7 +111,9 @@ def test_slack_prune(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,

View File

@@ -0,0 +1,20 @@
version: '3.8'
services:
mock_connector_server:
build:
context: ./mock_connector_server
dockerfile: Dockerfile
ports:
- "8001:8001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8001/health"]
interval: 10s
timeout: 5s
retries: 5
networks:
- onyx-stack_default
networks:
onyx-stack_default:
name: onyx-stack_default
external: true

View File

@@ -0,0 +1,9 @@
FROM python:3.11.7-slim-bookworm
WORKDIR /app
RUN pip install fastapi uvicorn
COPY ./main.py /app/main.py
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"]

View File

@@ -0,0 +1,76 @@
from fastapi import FastAPI
from fastapi import HTTPException
from pydantic import BaseModel
from pydantic import Field
# We would like to import these, but it makes building this so much harder/slower
# from onyx.connectors.mock_connector.connector import SingleConnectorYield
# from onyx.connectors.models import ConnectorCheckpoint
app = FastAPI()
# Global state to store connector behavior configuration
class ConnectorBehavior(BaseModel):
connector_yields: list[dict] = Field(
default_factory=list
) # really list[SingleConnectorYield]
called_with_checkpoints: list[dict] = Field(
default_factory=list
) # really list[ConnectorCheckpoint]
current_behavior: ConnectorBehavior = ConnectorBehavior()
@app.post("/set-behavior")
async def set_behavior(behavior: list[dict]) -> None:
"""Set the behavior for the next connector run"""
global current_behavior
current_behavior = ConnectorBehavior(connector_yields=behavior)
@app.get("/get-documents")
async def get_documents() -> list[dict]:
"""Get the next batch of documents and update the checkpoint"""
global current_behavior
if not current_behavior.connector_yields:
raise HTTPException(
status_code=400, detail="No documents or failures configured"
)
connector_yields = current_behavior.connector_yields
# Clear the current behavior after returning it
current_behavior = ConnectorBehavior()
return connector_yields
@app.post("/add-checkpoint")
async def add_checkpoint(checkpoint: dict) -> None:
"""Add a checkpoint to the list of checkpoints. Called by the MockConnector."""
global current_behavior
current_behavior.called_with_checkpoints.append(checkpoint)
@app.get("/get-checkpoints")
async def get_checkpoints() -> list[dict]:
"""Get the list of checkpoints. Used by the test to verify the
proper checkpoint ordering."""
global current_behavior
return current_behavior.called_with_checkpoints
@app.post("/reset")
async def reset() -> None:
"""Reset the connector behavior to default"""
global current_behavior
current_behavior = ConnectorBehavior()
@app.get("/health")
async def health_check() -> dict[str, str]:
"""Health check endpoint"""
return {"status": "healthy"}

View File

@@ -9,6 +9,8 @@ from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import DocumentFailure
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import create_index_attempt
@@ -101,10 +103,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
create_index_attempt_error(
index_attempt_id=new_attempt.id,
batch=1,
docs=[],
exception_msg="",
exception_traceback="",
connector_credential_pair_id=cc_pair_1.id,
failure=ConnectorFailure(
failure_message="Test error",
failed_document=DocumentFailure(
document_id=cc_pair_1.documents[0].id,
document_link=None,
),
failed_entity=None,
),
db_session=db_session,
)
@@ -127,10 +134,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
)
create_index_attempt_error(
index_attempt_id=attempt_id,
batch=1,
docs=[],
exception_msg="",
exception_traceback="",
connector_credential_pair_id=cc_pair_1.id,
failure=ConnectorFailure(
failure_message="Test error",
failed_document=DocumentFailure(
document_id=cc_pair_1.documents[0].id,
document_link=None,
),
failed_entity=None,
),
db_session=db_session,
)

View File

@@ -0,0 +1,518 @@
import uuid
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import httpx
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import InputType
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import IndexingStatus
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
from tests.integration.common_utils.test_document_utils import create_test_document
from tests.integration.common_utils.test_document_utils import (
create_test_document_failure,
)
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
@pytest.fixture
def mock_server_client() -> httpx.Client:
print(
f"Initializing mock server client with host: "
f"{MOCK_CONNECTOR_SERVER_HOST} and port: "
f"{MOCK_CONNECTOR_SERVER_PORT}"
)
return httpx.Client(
base_url=f"http://{MOCK_CONNECTOR_SERVER_HOST}:{MOCK_CONNECTOR_SERVER_PORT}",
timeout=5.0,
)
def test_mock_connector_basic_flow(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that the mock connector can successfully process documents and failures"""
# Set up mock server behavior
doc_uuid = uuid.uuid4()
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [],
}
],
)
assert response.status_code == 200
# create CC Pair + index attempt
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
user_performing_action=admin_user,
)
# wait for index attempt to start
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# wait for index attempt to finish
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# validate status
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_index_attempt.status == IndexingStatus.SUCCESS
# Verify results
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == test_doc.id
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 0
def test_mock_connector_with_failures(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that the mock connector processes both successes and failures properly."""
doc1 = create_test_document()
doc2 = create_test_document()
doc2_failure = create_test_document_failure(doc_id=doc2.id)
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [doc1.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [doc2_failure.model_dump(mode="json")],
}
],
)
assert response.status_code == 200
# Create a CC Pair for the mock connector
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-failure-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
user_performing_action=admin_user,
)
# Wait for the index attempt to start and then complete
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# validate status
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
# Verify results: doc1 should be indexed and doc2 should have an error entry
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == doc1.id
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 1
error = errors[0]
assert error.failure_message == doc2_failure.failure_message
assert error.document_id == doc2.id
def test_mock_connector_failure_recovery(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that a failed document can be successfully indexed in a subsequent attempt
while maintaining previously successful documents."""
# Create test documents and failure
doc1 = create_test_document()
doc2 = create_test_document()
doc2_failure = create_test_document_failure(doc_id=doc2.id)
entity_id = "test-entity-id"
entity_failure_msg = "Simulated unhandled error"
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [doc1.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [
doc2_failure.model_dump(mode="json"),
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=entity_id,
missed_time_range=(
datetime.now(timezone.utc) - timedelta(days=1),
datetime.now(timezone.utc),
),
),
failure_message=entity_failure_msg,
).model_dump(mode="json"),
],
}
],
)
assert response.status_code == 200
# Create CC Pair and run initial indexing attempt
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
user_performing_action=admin_user,
)
# Wait for first index attempt to complete
initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=initial_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# validate status
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=initial_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
# Verify initial state: doc1 indexed, doc2 failed
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == doc1.id
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 2
error_doc2 = next(error for error in errors if error.document_id == doc2.id)
assert error_doc2.failure_message == doc2_failure.failure_message
assert not error_doc2.is_resolved
error_entity = next(error for error in errors if error.entity_id == entity_id)
assert error_entity.failure_message == entity_failure_msg
assert not error_entity.is_resolved
# Update mock server to return success for both documents
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [
doc1.model_dump(mode="json"),
doc2.model_dump(mode="json"),
],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [],
}
],
)
assert response.status_code == 200
# Trigger another indexing attempt
# NOTE: must be from beginning to handle the entity failure
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
index_attempts_to_ignore=[initial_index_attempt.id],
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=recovery_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
finished_second_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=recovery_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_second_index_attempt.status == IndexingStatus.SUCCESS
# Verify both documents are now indexed
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 2
document_ids = {doc.id for doc in documents}
assert doc2.id in document_ids
assert doc1.id in document_ids
# Verify original failures were marked as resolved
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 2
error_doc2 = next(error for error in errors if error.document_id == doc2.id)
error_entity = next(error for error in errors if error.entity_id == entity_id)
assert error_doc2.is_resolved
assert error_entity.is_resolved
def test_mock_connector_checkpoint_recovery(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that checkpointing works correctly when an unhandled exception occurs
and that subsequent runs pick up from the last successful checkpoint."""
# Create test documents
# Create 100 docs for first batch, this is needed to get past the
# `_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT` logic in `get_latest_valid_checkpoint`.
docs_batch_1 = [create_test_document() for _ in range(100)]
doc2 = create_test_document()
doc3 = create_test_document()
# Set up mock server behavior for initial run:
# - First yield: 100 docs with checkpoint1
# - Second yield: doc2 with checkpoint2
# - Third yield: unhandled exception
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [doc.model_dump(mode="json") for doc in docs_batch_1],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=True
).model_dump(mode="json"),
"failures": [],
},
{
"documents": [doc2.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=True
).model_dump(mode="json"),
"failures": [],
},
{
"documents": [],
# should never hit this, unhandled exception happens first
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [],
"unhandled_exception": "Simulated unhandled error",
},
],
)
assert response.status_code == 200
# Create CC Pair and run initial indexing attempt
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-checkpoint-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
user_performing_action=admin_user,
)
# Wait for first index attempt to complete
initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=initial_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# validate status
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=initial_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_index_attempt.status == IndexingStatus.FAILED
# Verify initial state: both docs should be indexed
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 101 # 100 docs from first batch + doc2
document_ids = {doc.id for doc in documents}
assert doc2.id in document_ids
assert all(doc.id in document_ids for doc in docs_batch_1)
# Get the checkpoints that were sent to the mock server
response = mock_server_client.get("/get-checkpoints")
assert response.status_code == 200
initial_checkpoints = response.json()
# Verify we got the expected checkpoints in order
assert len(initial_checkpoints) > 0
assert (
initial_checkpoints[0]["checkpoint_content"] == {}
) # Initial empty checkpoint
assert initial_checkpoints[1]["checkpoint_content"] == {}
assert initial_checkpoints[2]["checkpoint_content"] == {}
# Reset the mock server for the next run
response = mock_server_client.post("/reset")
assert response.status_code == 200
# Set up mock server behavior for recovery run - should succeed fully this time
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [doc3.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [],
}
],
)
assert response.status_code == 200
# Trigger another indexing attempt
CCPairManager.run_once(
cc_pair, from_beginning=False, user_performing_action=admin_user
)
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
index_attempts_to_ignore=[initial_index_attempt.id],
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=recovery_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# validate status
finished_recovery_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=recovery_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_recovery_attempt.status == IndexingStatus.SUCCESS
# Verify results
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 102 # 100 docs from first batch + doc2 + doc3
document_ids = {doc.id for doc in documents}
assert doc3.id in document_ids
assert doc2.id in document_ids
assert all(doc.id in document_ids for doc in docs_batch_1)
# Get the checkpoints from the recovery run
response = mock_server_client.get("/get-checkpoints")
assert response.status_code == 200
recovery_checkpoints = response.json()
# Verify the recovery run started from the last successful checkpoint
assert len(recovery_checkpoints) == 1
assert recovery_checkpoints[0]["checkpoint_content"] == {}