rkuo-danswer 5f5cc9a724
Feature/redis connector refactor (#2992)
* refactor RedisConnectorDeletion into RedisConnector

* refactor redis stop and deletion

* port pruning

* nest pruning

* port deletion

* port indexing

* refactor into individual files

* refactor redis connector index  to take search settings at init

* move back to debug level log

* refactor doc set and user group (mostly)

* mypy fixes
2024-11-02 19:53:04 +00:00

419 lines
14 KiB
Python

import time
from datetime import datetime
from typing import Any
from uuid import uuid4
import requests
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import TaskStatus
from danswer.server.documents.models import CeleryTaskStatus
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorIndexingStatus
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
def _cc_pair_creator(
connector_id: int,
credential_id: int,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"
request = {
"name": name,
"access_type": access_type,
"groups": groups or [],
}
response = requests.put(
url=f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}",
json=request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return DATestCCPair(
id=response.json()["data"],
name=name,
connector_id=connector_id,
credential_id=credential_id,
access_type=access_type,
groups=groups or [],
)
class CCPairManager:
@staticmethod
def create_from_scratch(
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
source: DocumentSource = DocumentSource.FILE,
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
credential_json: dict[str, Any] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
connector = ConnectorManager.create(
name=name,
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config,
is_public=(access_type == AccessType.PUBLIC),
groups=groups,
user_performing_action=user_performing_action,
)
credential = CredentialManager.create(
credential_json=credential_json,
name=name,
source=source,
curator_public=(access_type == AccessType.PUBLIC),
groups=groups,
user_performing_action=user_performing_action,
)
cc_pair = _cc_pair_creator(
connector_id=connector.id,
credential_id=credential.id,
name=name,
access_type=access_type,
groups=groups,
user_performing_action=user_performing_action,
)
return cc_pair
@staticmethod
def create(
connector_id: int,
credential_id: int,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
cc_pair = _cc_pair_creator(
connector_id=connector_id,
credential_id=credential_id,
name=name,
access_type=access_type,
groups=groups,
user_performing_action=user_performing_action,
)
return cc_pair
@staticmethod
def pause_cc_pair(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
json={"status": "PAUSED"},
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
result.raise_for_status()
@staticmethod
def delete(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/deletion-attempt",
json=cc_pair_identifier.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
result.raise_for_status()
@staticmethod
def get_one(
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
) -> ConnectorIndexingStatus | None:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
for cc_pair_json in response.json():
cc_pair = ConnectorIndexingStatus(**cc_pair_json)
if cc_pair.cc_pair_id == cc_pair_id:
return cc_pair
return None
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[ConnectorIndexingStatus]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [ConnectorIndexingStatus(**cc_pair) for cc_pair in response.json()]
@staticmethod
def verify(
cc_pair: DATestCCPair,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_cc_pairs = CCPairManager.get_all(user_performing_action)
for retrieved_cc_pair in all_cc_pairs:
if retrieved_cc_pair.cc_pair_id == cc_pair.id:
if verify_deleted:
# We assume that this check will be performed after the deletion is
# already waited for
raise ValueError(
f"CC pair {cc_pair.id} found but should be deleted"
)
if (
retrieved_cc_pair.name == cc_pair.name
and retrieved_cc_pair.connector.id == cc_pair.connector_id
and retrieved_cc_pair.credential.id == cc_pair.credential_id
and retrieved_cc_pair.access_type == cc_pair.access_type
and set(retrieved_cc_pair.groups) == set(cc_pair.groups)
):
return
if not verify_deleted:
raise ValueError(f"CC pair {cc_pair.id} not found")
@staticmethod
def run_once(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
body = {
"connector_id": cc_pair.connector_id,
"credential_ids": [cc_pair.credential_id],
"from_beginning": True,
}
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
json=body,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
result.raise_for_status()
@staticmethod
def wait_for_indexing(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: Wait for an indexing success time after this time"""
start = time.monotonic()
while True:
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
for fetched_cc_pair in fetched_cc_pairs:
if fetched_cc_pair.cc_pair_id != cc_pair.id:
continue
if fetched_cc_pair.in_progress:
continue
if (
fetched_cc_pair.last_success
and fetched_cc_pair.last_success > after
):
print(f"Indexing complete: cc_pair={cc_pair.id}")
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"Indexing wait timed out: cc_pair={cc_pair.id} timeout={timeout}s"
)
print(
f"Indexing wait for completion: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
)
time.sleep(5)
@staticmethod
def prune(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
result.raise_for_status()
@staticmethod
def last_pruned(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> datetime | None:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/last_pruned",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
response_str = response.json()
# If the response itself is a datetime string, parse it
if not isinstance(response_str, str):
return None
try:
return datetime.fromisoformat(response_str)
except ValueError:
return None
@staticmethod
def wait_for_prune(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: The task register time must be after this time."""
start = time.monotonic()
while True:
last_pruned = CCPairManager.last_pruned(cc_pair, user_performing_action)
if last_pruned and last_pruned > after:
print(f"Pruning complete: cc_pair={cc_pair.id}")
break
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"CC pair pruning was not completed within {timeout} seconds"
)
print(
f"Waiting for CC pruning to complete. elapsed={elapsed:.2f} timeout={timeout}"
)
time.sleep(5)
@staticmethod
def sync(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
result.raise_for_status()
@staticmethod
def get_sync_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> CeleryTaskStatus:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return CeleryTaskStatus(**response.json())
@staticmethod
def wait_for_sync(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: The task register time must be after this time."""
start = time.monotonic()
while True:
task = CCPairManager.get_sync_task(cc_pair, user_performing_action)
if not task:
raise ValueError("Sync task not found.")
if not task.register_time or task.register_time < after:
raise ValueError("Sync task register time is too early.")
if task.status == TaskStatus.SUCCESS:
# Sync succeeded
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"CC pair syncing was not completed within {timeout} seconds"
)
print(
f"Waiting for CC syncing to complete. elapsed={elapsed:.2f} timeout={timeout}"
)
time.sleep(5)
@staticmethod
def wait_for_deletion_completion(
cc_pair_id: int | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
"""if cc_pair_id is not specified, just waits until no connectors are in the deleting state.
if cc_pair_id is specified, checks to ensure the specific cc_pair_id is gone.
We had a bug where the connector was paused in the middle of deleting, so specifying the
cc_pair_id is good to do."""
start = time.monotonic()
while True:
cc_pairs = CCPairManager.get_all(user_performing_action)
if cc_pair_id:
found = False
for cc_pair in cc_pairs:
if cc_pair.cc_pair_id == cc_pair_id:
found = True
break
if not found:
return
else:
if all(
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
for cc_pair in cc_pairs
):
return
if time.monotonic() - start > MAX_DELAY:
raise TimeoutError(
f"CC pairs deletion was not completed within the {MAX_DELAY} seconds"
)
else:
print("Some CC pairs are still being deleted, waiting...")
time.sleep(2)