Chris Weaver f1fc8ac19b
Connector checkpointing (#3876)
* wip checkpointing/continue on failure

more stuff for checkpointing

Basic implementation

FE stuff

More checkpointing/failure handling

rebase

rebase

initial scaffolding for IT

IT to test checkpointing

Cleanup

cleanup

Fix it

Rebase

Add todo

Fix actions IT

Test more

Pagination + fixes + cleanup

Fix IT networking

fix it

* rebase

* Address misc comments

* Address comments

* Remove unused router

* rebase

* Fix mypy

* Fixes

* fix it

* Fix tests

* Add drop index

* Add retries

* reset lock timeout

* Try hard drop of schema

* Add timeout/retries to downgrade

* rebase

* test

* test

* test

* Close all connections

* test closing idle only

* Fix it

* fix

* try using null pool

* Test

* fix

* rebase

* log

* Fix

* apply null pool

* Fix other test

* Fix quality checks

* Test not using the fixture

* Fix ordering

* fix test

* Change pooling behavior
2025-02-16 02:34:39 +00:00

235 lines
8.5 KiB
Python

from datetime import datetime
from datetime import timedelta
from urllib.parse import urlencode
import requests
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import IndexModelStatus
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.search_settings import get_current_search_settings
from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestIndexAttempt
from tests.integration.common_utils.test_models import DATestUser
class IndexAttemptManager:
@staticmethod
def create_test_index_attempts(
num_attempts: int,
cc_pair_id: int,
from_beginning: bool = False,
status: IndexingStatus = IndexingStatus.SUCCESS,
new_docs_indexed: int = 10,
total_docs_indexed: int = 10,
docs_removed_from_index: int = 0,
error_msg: str | None = None,
base_time: datetime | None = None,
) -> list[DATestIndexAttempt]:
if base_time is None:
base_time = datetime.now()
attempts = []
with get_session_context_manager() as db_session:
# Get the current search settings
search_settings = get_current_search_settings(db_session)
if (
not search_settings
or search_settings.status != IndexModelStatus.PRESENT
):
raise ValueError("No current search settings found with PRESENT status")
for i in range(num_attempts):
time_created = base_time - timedelta(hours=i)
index_attempt = IndexAttempt(
connector_credential_pair_id=cc_pair_id,
from_beginning=from_beginning,
status=status,
new_docs_indexed=new_docs_indexed,
total_docs_indexed=total_docs_indexed,
docs_removed_from_index=docs_removed_from_index,
error_msg=error_msg,
time_created=time_created,
time_started=time_created,
time_updated=time_created,
search_settings_id=search_settings.id,
)
db_session.add(index_attempt)
db_session.flush() # To get the ID
attempts.append(
DATestIndexAttempt(
id=index_attempt.id,
status=index_attempt.status,
new_docs_indexed=index_attempt.new_docs_indexed,
total_docs_indexed=index_attempt.total_docs_indexed,
docs_removed_from_index=index_attempt.docs_removed_from_index,
error_msg=index_attempt.error_msg,
time_started=index_attempt.time_started,
time_updated=index_attempt.time_updated,
)
)
db_session.commit()
return attempts
@staticmethod
def get_index_attempt_page(
cc_pair_id: int,
page: int = 0,
page_size: int = 10,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[IndexAttemptSnapshot]:
query_params: dict[str, str | int] = {
"page_num": page,
"page_size": page_size,
}
url = (
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts"
f"?{urlencode(query_params, doseq=True)}"
)
response = requests.get(
url=url,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
data = response.json()
return PaginatedReturn(
items=[IndexAttemptSnapshot(**item) for item in data["items"]],
total_items=data["total_items"],
)
@staticmethod
def get_latest_index_attempt_for_cc_pair(
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot | None:
"""Get an IndexAttempt by ID"""
index_attempts = IndexAttemptManager.get_index_attempt_page(
cc_pair_id, user_performing_action=user_performing_action
).items
if not index_attempts:
return None
index_attempts = sorted(
index_attempts, key=lambda x: x.time_started or "0", reverse=True
)
return index_attempts[0]
@staticmethod
def wait_for_index_attempt_start(
cc_pair_id: int,
index_attempts_to_ignore: list[int] | None = None,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
"""Wait for an IndexAttempt to start"""
start = datetime.now()
index_attempts_to_ignore = index_attempts_to_ignore or []
while True:
index_attempt = IndexAttemptManager.get_latest_index_attempt_for_cc_pair(
cc_pair_id=cc_pair_id,
user_performing_action=user_performing_action,
)
if (
index_attempt
and index_attempt.time_started
and index_attempt.id not in index_attempts_to_ignore
):
return index_attempt
elapsed = (datetime.now() - start).total_seconds()
if elapsed > timeout:
raise TimeoutError(
f"IndexAttempt for CC Pair {cc_pair_id} did not start within {timeout} seconds"
)
@staticmethod
def get_index_attempt_by_id(
index_attempt_id: int,
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
page_num = 0
page_size = 10
while True:
page = IndexAttemptManager.get_index_attempt_page(
cc_pair_id=cc_pair_id,
page=page_num,
page_size=page_size,
user_performing_action=user_performing_action,
)
for attempt in page.items:
if attempt.id == index_attempt_id:
return attempt
if len(page.items) < page_size:
break
page_num += 1
raise ValueError(f"IndexAttempt {index_attempt_id} not found")
@staticmethod
def wait_for_index_attempt_completion(
index_attempt_id: int,
cc_pair_id: int,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""Wait for an IndexAttempt to complete"""
start = datetime.now()
while True:
index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair_id,
user_performing_action=user_performing_action,
)
if index_attempt.status and index_attempt.status.is_terminal():
print(f"IndexAttempt {index_attempt_id} completed")
return
elapsed = (datetime.now() - start).total_seconds()
if elapsed > timeout:
raise TimeoutError(
f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds"
)
print(
f"Waiting for IndexAttempt {index_attempt_id} to complete. "
f"elapsed={elapsed:.2f} timeout={timeout}"
)
@staticmethod
def get_index_attempt_errors_for_cc_pair(
cc_pair_id: int,
include_resolved: bool = True,
user_performing_action: DATestUser | None = None,
) -> list[IndexAttemptErrorPydantic]:
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
if include_resolved:
url += "&include_resolved=true"
response = requests.get(
url=url,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
data = response.json()
return [IndexAttemptErrorPydantic(**item) for item in data["items"]]