mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-03 13:43:18 +02:00
Connector checkpointing (#3876)
* wip checkpointing/continue on failure more stuff for checkpointing Basic implementation FE stuff More checkpointing/failure handling rebase rebase initial scaffolding for IT IT to test checkpointing Cleanup cleanup Fix it Rebase Add todo Fix actions IT Test more Pagination + fixes + cleanup Fix IT networking fix it * rebase * Address misc comments * Address comments * Remove unused router * rebase * Fix mypy * Fixes * fix it * Fix tests * Add drop index * Add retries * reset lock timeout * Try hard drop of schema * Add timeout/retries to downgrade * rebase * test * test * test * Close all connections * test closing idle only * Fix it * fix * try using null pool * Test * fix * rebase * log * Fix * apply null pool * Fix other test * Fix quality checks * Test not using the fixture * Fix ordering * fix test * Change pooling behavior
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def test_create_chat_session_and_send_messages(db_session: Session) -> None:
|
||||
def test_create_chat_session_and_send_messages() -> None:
|
||||
# Create a test user
|
||||
test_user = User(email="test@example.com", hashed_password="dummy_hash")
|
||||
db_session.add(test_user)
|
||||
db_session.commit()
|
||||
with get_session_context_manager() as db_session:
|
||||
test_user = User(email="test@example.com", hashed_password="dummy_hash")
|
||||
db_session.add(test_user)
|
||||
db_session.commit()
|
||||
|
||||
base_url = "http://localhost:8080" # Adjust this to your API's base URL
|
||||
headers = {"Authorization": f"Bearer {test_user.id}"}
|
||||
|
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
|
||||
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
|
||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
||||
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
||||
@@ -9,3 +11,6 @@ MAX_DELAY = 45
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
NUM_DOCS = 5
|
||||
|
||||
MOCK_CONNECTOR_SERVER_HOST = os.getenv("MOCK_CONNECTOR_SERVER_HOST") or "localhost"
|
||||
MOCK_CONNECTOR_SERVER_PORT = os.getenv("MOCK_CONNECTOR_SERVER_PORT") or 8001
|
||||
|
@@ -223,12 +223,13 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def run_once(
|
||||
cc_pair: DATestCCPair,
|
||||
from_beginning: bool,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
body = {
|
||||
"connector_id": cc_pair.connector_id,
|
||||
"credential_ids": [cc_pair.credential_id],
|
||||
"from_beginning": True,
|
||||
"from_beginning": from_beginning,
|
||||
}
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
|
||||
|
@@ -1,9 +1,14 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
@@ -186,3 +191,39 @@ class DocumentManager:
|
||||
group_names,
|
||||
doc_creating_user,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fetch_documents_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
vespa_client: vespa_fixture,
|
||||
) -> list[SimpleTestDocument]:
|
||||
stmt = (
|
||||
select(DocumentByConnectorCredentialPair)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
)
|
||||
documents = db_session.execute(stmt).scalars().all()
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
doc_ids = [document.id for document in documents]
|
||||
retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"]
|
||||
|
||||
final_docs: list[SimpleTestDocument] = []
|
||||
# NOTE: they are really chunks, but we're assuming that for these tests
|
||||
# we only have one chunk per document for now
|
||||
for doc_dict in retrieved_docs_dict:
|
||||
doc_id = doc_dict["fields"]["document_id"]
|
||||
doc_content = doc_dict["fields"]["content"]
|
||||
final_docs.append(SimpleTestDocument(id=doc_id, content=doc_content))
|
||||
|
||||
return final_docs
|
||||
|
@@ -4,6 +4,7 @@ from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.models import IndexAttempt
|
||||
@@ -13,6 +14,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestIndexAttempt
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -92,8 +94,12 @@ class IndexAttemptManager:
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
url = (
|
||||
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts"
|
||||
f"?{urlencode(query_params, doseq=True)}"
|
||||
)
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts?{urlencode(query_params, doseq=True)}",
|
||||
url=url,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
@@ -104,3 +110,125 @@ class IndexAttemptManager:
|
||||
items=[IndexAttemptSnapshot(**item) for item in data["items"]],
|
||||
total_items=data["total_items"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_latest_index_attempt_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot | None:
|
||||
"""Get an IndexAttempt by ID"""
|
||||
index_attempts = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id, user_performing_action=user_performing_action
|
||||
).items
|
||||
if not index_attempts:
|
||||
return None
|
||||
|
||||
index_attempts = sorted(
|
||||
index_attempts, key=lambda x: x.time_started or "0", reverse=True
|
||||
)
|
||||
return index_attempts[0]
|
||||
|
||||
@staticmethod
|
||||
def wait_for_index_attempt_start(
|
||||
cc_pair_id: int,
|
||||
index_attempts_to_ignore: list[int] | None = None,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot:
|
||||
"""Wait for an IndexAttempt to start"""
|
||||
start = datetime.now()
|
||||
index_attempts_to_ignore = index_attempts_to_ignore or []
|
||||
|
||||
while True:
|
||||
index_attempt = IndexAttemptManager.get_latest_index_attempt_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
if (
|
||||
index_attempt
|
||||
and index_attempt.time_started
|
||||
and index_attempt.id not in index_attempts_to_ignore
|
||||
):
|
||||
return index_attempt
|
||||
|
||||
elapsed = (datetime.now() - start).total_seconds()
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"IndexAttempt for CC Pair {cc_pair_id} did not start within {timeout} seconds"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_index_attempt_by_id(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot:
|
||||
page_num = 0
|
||||
page_size = 10
|
||||
while True:
|
||||
page = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair_id,
|
||||
page=page_num,
|
||||
page_size=page_size,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
for attempt in page.items:
|
||||
if attempt.id == index_attempt_id:
|
||||
return attempt
|
||||
|
||||
if len(page.items) < page_size:
|
||||
break
|
||||
|
||||
page_num += 1
|
||||
|
||||
raise ValueError(f"IndexAttempt {index_attempt_id} not found")
|
||||
|
||||
@staticmethod
|
||||
def wait_for_index_attempt_completion(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""Wait for an IndexAttempt to complete"""
|
||||
start = datetime.now()
|
||||
while True:
|
||||
index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
if index_attempt.status and index_attempt.status.is_terminal():
|
||||
print(f"IndexAttempt {index_attempt_id} completed")
|
||||
return
|
||||
|
||||
elapsed = (datetime.now() - start).total_seconds()
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for IndexAttempt {index_attempt_id} to complete. "
|
||||
f"elapsed={elapsed:.2f} timeout={timeout}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
include_resolved: bool = True,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[IndexAttemptErrorPydantic]:
|
||||
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
|
||||
if include_resolved:
|
||||
url += "&include_resolved=true"
|
||||
response = requests.get(
|
||||
url=url,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [IndexAttemptErrorPydantic(**item) for item in data["items"]]
|
||||
|
@@ -25,6 +25,7 @@ from onyx.indexing.models import IndexingSetting
|
||||
from onyx.setup import setup_postgres
|
||||
from onyx.setup import setup_vespa
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.integration.common_utils.timeout import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -66,6 +67,7 @@ def _run_migrations(
|
||||
|
||||
def downgrade_postgres(
|
||||
database: str = "postgres",
|
||||
schema: str = "public",
|
||||
config_name: str = "alembic",
|
||||
revision: str = "base",
|
||||
clear_data: bool = False,
|
||||
@@ -73,8 +75,8 @@ def downgrade_postgres(
|
||||
"""Downgrade Postgres database to base state."""
|
||||
if clear_data:
|
||||
if revision != "base":
|
||||
logger.warning("Clearing data without rolling back to base state")
|
||||
# Delete all rows to allow migrations to be rolled back
|
||||
raise ValueError("Clearing data without rolling back to base state")
|
||||
|
||||
conn = psycopg2.connect(
|
||||
dbname=database,
|
||||
user=POSTGRES_USER,
|
||||
@@ -82,38 +84,33 @@ def downgrade_postgres(
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
)
|
||||
conn.autocommit = True # Need autocommit for dropping schema
|
||||
cur = conn.cursor()
|
||||
|
||||
# Disable triggers to prevent foreign key constraints from being checked
|
||||
cur.execute("SET session_replication_role = 'replica';")
|
||||
|
||||
# Fetch all table names in the current database
|
||||
# Close any existing connections to the schema before dropping
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT tablename
|
||||
FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
f"""
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = '{database}'
|
||||
AND pg_stat_activity.state = 'idle in transaction'
|
||||
AND pid <> pg_backend_pid();
|
||||
"""
|
||||
)
|
||||
|
||||
tables = cur.fetchall()
|
||||
# Drop and recreate the public schema - this removes ALL objects
|
||||
cur.execute(f"DROP SCHEMA {schema} CASCADE;")
|
||||
cur.execute(f"CREATE SCHEMA {schema};")
|
||||
|
||||
for table in tables:
|
||||
table_name = table[0]
|
||||
# Restore default privileges
|
||||
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO postgres;")
|
||||
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO public;")
|
||||
|
||||
# Don't touch migration history or Kombu
|
||||
if table_name in ("alembic_version", "kombu_message", "kombu_queue"):
|
||||
continue
|
||||
|
||||
cur.execute(f'DELETE FROM "{table_name}"')
|
||||
|
||||
# Re-enable triggers
|
||||
cur.execute("SET session_replication_role = 'origin';")
|
||||
|
||||
conn.commit()
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
return
|
||||
|
||||
# Downgrade to base
|
||||
conn_str = build_connection_string(
|
||||
db=database,
|
||||
@@ -157,11 +154,37 @@ def reset_postgres(
|
||||
setup_onyx: bool = True,
|
||||
) -> None:
|
||||
"""Reset the Postgres database."""
|
||||
downgrade_postgres(
|
||||
database=database, config_name=config_name, revision="base", clear_data=True
|
||||
)
|
||||
# this seems to hang due to locking issues, so run with a timeout with a few retries
|
||||
NUM_TRIES = 10
|
||||
TIMEOUT = 10
|
||||
success = False
|
||||
for _ in range(NUM_TRIES):
|
||||
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
|
||||
try:
|
||||
run_with_timeout(
|
||||
downgrade_postgres,
|
||||
TIMEOUT,
|
||||
kwargs={
|
||||
"database": database,
|
||||
"config_name": config_name,
|
||||
"revision": "base",
|
||||
"clear_data": True,
|
||||
},
|
||||
)
|
||||
success = True
|
||||
break
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"Postgres downgrade timed out, retrying... ({_ + 1}/{NUM_TRIES})"
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError("Postgres downgrade failed after 10 timeouts.")
|
||||
|
||||
logger.info("Upgrading Postgres...")
|
||||
upgrade_postgres(database=database, config_name=config_name, revision="head")
|
||||
if setup_onyx:
|
||||
logger.info("Setting up Postgres...")
|
||||
with get_session_context_manager() as db_session:
|
||||
setup_postgres(db_session)
|
||||
|
||||
|
@@ -0,0 +1,57 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import Section
|
||||
|
||||
|
||||
def create_test_document(
|
||||
doc_id: str | None = None,
|
||||
text: str = "Test content",
|
||||
link: str = "http://test.com",
|
||||
source: DocumentSource = DocumentSource.MOCK_CONNECTOR,
|
||||
metadata: dict | None = None,
|
||||
) -> Document:
|
||||
"""Create a test document with the given parameters.
|
||||
|
||||
Args:
|
||||
doc_id: Optional document ID. If not provided, a random UUID will be generated.
|
||||
text: The text content of the document. Defaults to "Test content".
|
||||
link: The link for the document section. Defaults to "http://test.com".
|
||||
source: The document source. Defaults to MOCK_CONNECTOR.
|
||||
metadata: Optional metadata dictionary. Defaults to empty dict.
|
||||
"""
|
||||
doc_id = doc_id or f"test-doc-{uuid.uuid4()}"
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=[Section(text=text, link=link)],
|
||||
source=source,
|
||||
semantic_identifier=doc_id,
|
||||
doc_updated_at=datetime.now(timezone.utc),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
|
||||
def create_test_document_failure(
|
||||
doc_id: str,
|
||||
failure_message: str = "Simulated failure",
|
||||
document_link: str | None = None,
|
||||
) -> ConnectorFailure:
|
||||
"""Create a test document failure with the given parameters.
|
||||
|
||||
Args:
|
||||
doc_id: The ID of the document that failed.
|
||||
failure_message: The failure message. Defaults to "Simulated failure".
|
||||
document_link: Optional link to the failed document.
|
||||
"""
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=document_link,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
)
|
18
backend/tests/integration/common_utils/timeout.py
Normal file
18
backend/tests/integration/common_utils/timeout.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import multiprocessing
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def run_with_timeout(task: Callable[..., T], timeout: int, kwargs: dict[str, Any]) -> T:
|
||||
# Use multiprocessing to prevent a thread from blocking the main thread
|
||||
with multiprocessing.Pool(processes=1) as pool:
|
||||
async_result = pool.apply_async(task, kwds=kwargs)
|
||||
try:
|
||||
# Wait at most timeout seconds for the function to complete
|
||||
result = async_result.get(timeout=timeout)
|
||||
return result
|
||||
except multiprocessing.TimeoutError:
|
||||
raise TimeoutError(f"Function timed out after {timeout} seconds")
|
@@ -1,12 +1,11 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
@@ -36,16 +35,24 @@ def load_env_vars(env_file: str = ".env") -> None:
|
||||
load_env_vars()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session() -> Generator[Session, None, None]:
|
||||
with get_session_context_manager() as session:
|
||||
yield session
|
||||
"""NOTE: for some reason using this seems to lead to misc
|
||||
`sqlalchemy.exc.OperationalError: (psycopg2.OperationalError) server closed the connection unexpectedly`
|
||||
errors.
|
||||
|
||||
Commenting out till we can get to the bottom of it. For now, just using
|
||||
instantiate the session directly within the test.
|
||||
"""
|
||||
# @pytest.fixture
|
||||
# def db_session() -> Generator[Session, None, None]:
|
||||
# with get_session_context_manager() as session:
|
||||
# yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vespa_client(db_session: Session) -> vespa_fixture:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return vespa_fixture(index_name=search_settings.index_name)
|
||||
def vespa_client() -> vespa_fixture:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return vespa_fixture(index_name=search_settings.index_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -56,20 +63,27 @@ def reset() -> None:
|
||||
@pytest.fixture
|
||||
def new_admin_user(reset: None) -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create(name="admin_user")
|
||||
return UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
def admin_user() -> DATestUser:
|
||||
try:
|
||||
return UserManager.create(name="admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
user = UserManager.create(name=ADMIN_USER_NAME, is_first_user=True)
|
||||
|
||||
# if there are other users for some reason, reset and try again
|
||||
if not UserManager.is_role(user, UserRole.ADMIN):
|
||||
print("Trying to reset")
|
||||
reset_all()
|
||||
user = UserManager.create(name=ADMIN_USER_NAME)
|
||||
return user
|
||||
except Exception as e:
|
||||
print(f"Failed to create admin user: {e}")
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
user = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
@@ -79,10 +93,16 @@ def admin_user() -> DATestUser | None:
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if not UserManager.is_role(user, UserRole.ADMIN):
|
||||
reset_all()
|
||||
user = UserManager.create(name=ADMIN_USER_NAME)
|
||||
return user
|
||||
|
||||
return None
|
||||
return user
|
||||
except Exception as e:
|
||||
print(f"Failed to create or login as admin user: {e}")
|
||||
|
||||
raise RuntimeError("Failed to create or login as admin user")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@@ -138,7 +138,9 @@ def test_google_permission_sync(
|
||||
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_1, doc_text_1)
|
||||
|
||||
# run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair, after=before, user_performing_action=admin_user
|
||||
)
|
||||
@@ -184,7 +186,9 @@ def test_google_permission_sync(
|
||||
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_2, doc_text_2)
|
||||
|
||||
# Run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
|
@@ -113,7 +113,9 @@ def test_slack_permission_sync(
|
||||
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
@@ -305,7 +307,9 @@ def test_slack_group_permission_sync(
|
||||
)
|
||||
|
||||
# Run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
|
@@ -111,7 +111,9 @@ def test_slack_prune(
|
||||
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
|
@@ -0,0 +1,20 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
mock_connector_server:
|
||||
build:
|
||||
context: ./mock_connector_server
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8001:8001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8001/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
networks:
|
||||
- onyx-stack_default
|
||||
networks:
|
||||
onyx-stack_default:
|
||||
name: onyx-stack_default
|
||||
external: true
|
@@ -0,0 +1,9 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN pip install fastapi uvicorn
|
||||
|
||||
COPY ./main.py /app/main.py
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"]
|
@@ -0,0 +1,76 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
# We would like to import these, but it makes building this so much harder/slower
|
||||
# from onyx.connectors.mock_connector.connector import SingleConnectorYield
|
||||
# from onyx.connectors.models import ConnectorCheckpoint
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Global state to store connector behavior configuration
|
||||
class ConnectorBehavior(BaseModel):
|
||||
connector_yields: list[dict] = Field(
|
||||
default_factory=list
|
||||
) # really list[SingleConnectorYield]
|
||||
called_with_checkpoints: list[dict] = Field(
|
||||
default_factory=list
|
||||
) # really list[ConnectorCheckpoint]
|
||||
|
||||
|
||||
current_behavior: ConnectorBehavior = ConnectorBehavior()
|
||||
|
||||
|
||||
@app.post("/set-behavior")
|
||||
async def set_behavior(behavior: list[dict]) -> None:
|
||||
"""Set the behavior for the next connector run"""
|
||||
global current_behavior
|
||||
current_behavior = ConnectorBehavior(connector_yields=behavior)
|
||||
|
||||
|
||||
@app.get("/get-documents")
|
||||
async def get_documents() -> list[dict]:
|
||||
"""Get the next batch of documents and update the checkpoint"""
|
||||
global current_behavior
|
||||
|
||||
if not current_behavior.connector_yields:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No documents or failures configured"
|
||||
)
|
||||
|
||||
connector_yields = current_behavior.connector_yields
|
||||
|
||||
# Clear the current behavior after returning it
|
||||
current_behavior = ConnectorBehavior()
|
||||
|
||||
return connector_yields
|
||||
|
||||
|
||||
@app.post("/add-checkpoint")
|
||||
async def add_checkpoint(checkpoint: dict) -> None:
|
||||
"""Add a checkpoint to the list of checkpoints. Called by the MockConnector."""
|
||||
global current_behavior
|
||||
current_behavior.called_with_checkpoints.append(checkpoint)
|
||||
|
||||
|
||||
@app.get("/get-checkpoints")
|
||||
async def get_checkpoints() -> list[dict]:
|
||||
"""Get the list of checkpoints. Used by the test to verify the
|
||||
proper checkpoint ordering."""
|
||||
global current_behavior
|
||||
return current_behavior.called_with_checkpoints
|
||||
|
||||
|
||||
@app.post("/reset")
|
||||
async def reset() -> None:
|
||||
"""Reset the connector behavior to default"""
|
||||
global current_behavior
|
||||
current_behavior = ConnectorBehavior()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy"}
|
@@ -9,6 +9,8 @@ from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import create_index_attempt
|
||||
@@ -101,10 +103,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
|
||||
create_index_attempt_error(
|
||||
index_attempt_id=new_attempt.id,
|
||||
batch=1,
|
||||
docs=[],
|
||||
exception_msg="",
|
||||
exception_traceback="",
|
||||
connector_credential_pair_id=cc_pair_1.id,
|
||||
failure=ConnectorFailure(
|
||||
failure_message="Test error",
|
||||
failed_document=DocumentFailure(
|
||||
document_id=cc_pair_1.documents[0].id,
|
||||
document_link=None,
|
||||
),
|
||||
failed_entity=None,
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -127,10 +134,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
)
|
||||
create_index_attempt_error(
|
||||
index_attempt_id=attempt_id,
|
||||
batch=1,
|
||||
docs=[],
|
||||
exception_msg="",
|
||||
exception_traceback="",
|
||||
connector_credential_pair_id=cc_pair_1.id,
|
||||
failure=ConnectorFailure(
|
||||
failure_message="Test error",
|
||||
failed_document=DocumentFailure(
|
||||
document_id=cc_pair_1.documents[0].id,
|
||||
document_link=None,
|
||||
),
|
||||
failed_entity=None,
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
518
backend/tests/integration/tests/indexing/test_checkpointing.py
Normal file
518
backend/tests/integration/tests/indexing/test_checkpointing.py
Normal file
@@ -0,0 +1,518 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.test_document_utils import create_test_document
|
||||
from tests.integration.common_utils.test_document_utils import (
|
||||
create_test_document_failure,
|
||||
)
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_client() -> httpx.Client:
|
||||
print(
|
||||
f"Initializing mock server client with host: "
|
||||
f"{MOCK_CONNECTOR_SERVER_HOST} and port: "
|
||||
f"{MOCK_CONNECTOR_SERVER_PORT}"
|
||||
)
|
||||
return httpx.Client(
|
||||
base_url=f"http://{MOCK_CONNECTOR_SERVER_HOST}:{MOCK_CONNECTOR_SERVER_PORT}",
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
|
||||
def test_mock_connector_basic_flow(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that the mock connector can successfully process documents and failures"""
|
||||
# Set up mock server behavior
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# create CC Pair + index attempt
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# wait for index attempt to start
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# wait for index attempt to finish
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify results
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == test_doc.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
def test_mock_connector_with_failures(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that the mock connector processes both successes and failures properly."""
|
||||
doc1 = create_test_document()
|
||||
doc2 = create_test_document()
|
||||
doc2_failure = create_test_document_failure(doc_id=doc2.id)
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc1.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [doc2_failure.model_dump(mode="json")],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create a CC Pair for the mock connector
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-failure-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for the index attempt to start and then complete
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
|
||||
# Verify results: doc1 should be indexed and doc2 should have an error entry
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == doc1.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 1
|
||||
error = errors[0]
|
||||
assert error.failure_message == doc2_failure.failure_message
|
||||
assert error.document_id == doc2.id
|
||||
|
||||
|
||||
def test_mock_connector_failure_recovery(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that a failed document can be successfully indexed in a subsequent attempt
|
||||
while maintaining previously successful documents."""
|
||||
# Create test documents and failure
|
||||
doc1 = create_test_document()
|
||||
doc2 = create_test_document()
|
||||
doc2_failure = create_test_document_failure(doc_id=doc2.id)
|
||||
entity_id = "test-entity-id"
|
||||
entity_failure_msg = "Simulated unhandled error"
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc1.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [
|
||||
doc2_failure.model_dump(mode="json"),
|
||||
ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=entity_id,
|
||||
missed_time_range=(
|
||||
datetime.now(timezone.utc) - timedelta(days=1),
|
||||
datetime.now(timezone.utc),
|
||||
),
|
||||
),
|
||||
failure_message=entity_failure_msg,
|
||||
).model_dump(mode="json"),
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create CC Pair and run initial indexing attempt
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for first index attempt to complete
|
||||
initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
|
||||
# Verify initial state: doc1 indexed, doc2 failed
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == doc1.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 2
|
||||
error_doc2 = next(error for error in errors if error.document_id == doc2.id)
|
||||
assert error_doc2.failure_message == doc2_failure.failure_message
|
||||
assert not error_doc2.is_resolved
|
||||
|
||||
error_entity = next(error for error in errors if error.entity_id == entity_id)
|
||||
assert error_entity.failure_message == entity_failure_msg
|
||||
assert not error_entity.is_resolved
|
||||
|
||||
# Update mock server to return success for both documents
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [
|
||||
doc1.model_dump(mode="json"),
|
||||
doc2.model_dump(mode="json"),
|
||||
],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Trigger another indexing attempt
|
||||
# NOTE: must be from beginning to handle the entity failure
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
finished_second_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_second_index_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify both documents are now indexed
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 2
|
||||
document_ids = {doc.id for doc in documents}
|
||||
assert doc2.id in document_ids
|
||||
assert doc1.id in document_ids
|
||||
|
||||
# Verify original failures were marked as resolved
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 2
|
||||
error_doc2 = next(error for error in errors if error.document_id == doc2.id)
|
||||
error_entity = next(error for error in errors if error.entity_id == entity_id)
|
||||
|
||||
assert error_doc2.is_resolved
|
||||
assert error_entity.is_resolved
|
||||
|
||||
|
||||
def test_mock_connector_checkpoint_recovery(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that checkpointing works correctly when an unhandled exception occurs
|
||||
and that subsequent runs pick up from the last successful checkpoint."""
|
||||
# Create test documents
|
||||
# Create 100 docs for first batch, this is needed to get past the
|
||||
# `_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT` logic in `get_latest_valid_checkpoint`.
|
||||
docs_batch_1 = [create_test_document() for _ in range(100)]
|
||||
doc2 = create_test_document()
|
||||
doc3 = create_test_document()
|
||||
|
||||
# Set up mock server behavior for initial run:
|
||||
# - First yield: 100 docs with checkpoint1
|
||||
# - Second yield: doc2 with checkpoint2
|
||||
# - Third yield: unhandled exception
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc.model_dump(mode="json") for doc in docs_batch_1],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=True
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
},
|
||||
{
|
||||
"documents": [doc2.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=True
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
},
|
||||
{
|
||||
"documents": [],
|
||||
# should never hit this, unhandled exception happens first
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
"unhandled_exception": "Simulated unhandled error",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create CC Pair and run initial indexing attempt
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-checkpoint-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for first index attempt to complete
|
||||
initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.FAILED
|
||||
|
||||
# Verify initial state: both docs should be indexed
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 101 # 100 docs from first batch + doc2
|
||||
document_ids = {doc.id for doc in documents}
|
||||
assert doc2.id in document_ids
|
||||
assert all(doc.id in document_ids for doc in docs_batch_1)
|
||||
|
||||
# Get the checkpoints that were sent to the mock server
|
||||
response = mock_server_client.get("/get-checkpoints")
|
||||
assert response.status_code == 200
|
||||
initial_checkpoints = response.json()
|
||||
|
||||
# Verify we got the expected checkpoints in order
|
||||
assert len(initial_checkpoints) > 0
|
||||
assert (
|
||||
initial_checkpoints[0]["checkpoint_content"] == {}
|
||||
) # Initial empty checkpoint
|
||||
assert initial_checkpoints[1]["checkpoint_content"] == {}
|
||||
assert initial_checkpoints[2]["checkpoint_content"] == {}
|
||||
|
||||
# Reset the mock server for the next run
|
||||
response = mock_server_client.post("/reset")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Set up mock server behavior for recovery run - should succeed fully this time
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc3.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Trigger another indexing attempt
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=False, user_performing_action=admin_user
|
||||
)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_recovery_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_recovery_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify results
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 102 # 100 docs from first batch + doc2 + doc3
|
||||
document_ids = {doc.id for doc in documents}
|
||||
assert doc3.id in document_ids
|
||||
assert doc2.id in document_ids
|
||||
assert all(doc.id in document_ids for doc in docs_batch_1)
|
||||
|
||||
# Get the checkpoints from the recovery run
|
||||
response = mock_server_client.get("/get-checkpoints")
|
||||
assert response.status_code == 200
|
||||
recovery_checkpoints = response.json()
|
||||
|
||||
# Verify the recovery run started from the last successful checkpoint
|
||||
assert len(recovery_checkpoints) == 1
|
||||
assert recovery_checkpoints[0]["checkpoint_content"] == {}
|
Reference in New Issue
Block a user