mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Implement source testing framework + Slack (#2650)
* Added permission sync tests for Slack * moved folders * prune test + mypy * added wait for indexing to cc_pair creation * commented out check * should fix other tests * added slack channel pool * fixed everything and mypy * reduced flake
This commit is contained in:
parent
b3c367d09c
commit
c2088602e1
2
.github/workflows/pr-Integration-tests.yml
vendored
2
.github/workflows/pr-Integration-tests.yml
vendored
@ -12,6 +12,7 @@ on:
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
@ -142,6 +143,7 @@ jobs:
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test
|
||||
continue-on-error: true
|
||||
|
@ -205,12 +205,17 @@ _DISALLOWED_MSG_SUBTYPES = {
|
||||
"group_leave",
|
||||
"group_archive",
|
||||
"group_unarchive",
|
||||
"channel_leave",
|
||||
"channel_name",
|
||||
"channel_join",
|
||||
}
|
||||
|
||||
|
||||
def _default_msg_filter(message: MessageType) -> bool:
|
||||
def default_msg_filter(message: MessageType) -> bool:
|
||||
# Don't keep messages from bots
|
||||
if message.get("bot_id") or message.get("app_id"):
|
||||
if message.get("bot_profile", {}).get("name") == "DanswerConnector":
|
||||
return False
|
||||
return True
|
||||
|
||||
# Uninformative
|
||||
@ -261,7 +266,7 @@ def _get_all_docs(
|
||||
channel_name_regex_enabled: bool = False,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Get all documents in the workspace, channel by channel"""
|
||||
slack_cleaner = SlackTextCleaner(client=client)
|
||||
@ -320,7 +325,7 @@ def _get_all_doc_ids(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
|
@ -29,15 +29,19 @@ from danswer.db.index_attempt import count_index_attempts_for_connector
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from danswer.db.models import User
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import CCPairPruningTask
|
||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||
from danswer.server.documents.models import CeleryTaskStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
||||
from danswer.server.documents.models import PaginatedIndexAttempts
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
logger = setup_logger()
|
||||
@ -199,7 +203,7 @@ def get_cc_pair_latest_prune(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CCPairPruningTask:
|
||||
) -> CeleryTaskStatus:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
@ -223,7 +227,7 @@ def get_cc_pair_latest_prune(
|
||||
detail="No pruning task found.",
|
||||
)
|
||||
|
||||
return CCPairPruningTask(
|
||||
return CeleryTaskStatus(
|
||||
id=last_pruning_task.task_id,
|
||||
name=last_pruning_task.task_name,
|
||||
status=last_pruning_task.status,
|
||||
@ -280,6 +284,95 @@ def prune_cc_pair(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync")
|
||||
def get_cc_pair_latest_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CeleryTaskStatus:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
# look up the last sync task for this connector (if it exists)
|
||||
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
|
||||
last_sync_task = get_latest_task(sync_task_name, db_session)
|
||||
if not last_sync_task:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="No sync task found.",
|
||||
)
|
||||
|
||||
return CeleryTaskStatus(
|
||||
id=last_sync_task.task_id,
|
||||
name=last_sync_task.task_name,
|
||||
status=last_sync_task.status,
|
||||
start_time=last_sync_task.start_time,
|
||||
register_time=last_sync_task.register_time,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync")
|
||||
def sync_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[list[int]]:
|
||||
# avoiding circular refs
|
||||
from ee.danswer.background.celery.celery_app import (
|
||||
sync_external_doc_permissions_task,
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
|
||||
last_sync_task = get_latest_task(sync_task_name, db_session)
|
||||
|
||||
if last_sync_task and check_task_is_live_and_not_timed_out(
|
||||
last_sync_task, db_session
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Sync task already in progress.",
|
||||
)
|
||||
if skip_cc_pair_pruning_by_task(
|
||||
last_sync_task,
|
||||
db_session=db_session,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair_id),
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the sync task.",
|
||||
)
|
||||
|
||||
|
||||
@router.put("/connector/{connector_id}/credential/{credential_id}")
|
||||
def associate_credential_to_connector(
|
||||
connector_id: int,
|
||||
|
@ -781,6 +781,7 @@ def connector_run_once(
|
||||
detail="Connector has no valid credentials, cannot create index attempts.",
|
||||
)
|
||||
|
||||
# Prevents index attempts for cc pairs that already have an index attempt currently running
|
||||
skipped_credentials = [
|
||||
credential_id
|
||||
for credential_id in credential_ids
|
||||
@ -790,15 +791,15 @@ def connector_run_once(
|
||||
credential_id=credential_id,
|
||||
),
|
||||
only_current=True,
|
||||
disinclude_finished=True,
|
||||
db_session=db_session,
|
||||
disinclude_finished=True,
|
||||
)
|
||||
]
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
connector_credential_pairs = [
|
||||
get_connector_credential_pair(run_info.connector_id, credential_id, db_session)
|
||||
get_connector_credential_pair(connector_id, credential_id, db_session)
|
||||
for credential_id in credential_ids
|
||||
if credential_id not in skipped_credentials
|
||||
]
|
||||
|
@ -268,7 +268,7 @@ class CCPairFullInfo(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class CCPairPruningTask(BaseModel):
|
||||
class CeleryTaskStatus(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
status: TaskStatus
|
||||
|
@ -1,3 +1,6 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.enums import AccessType
|
||||
@ -12,10 +15,32 @@ from ee.danswer.background.task_name_builders import (
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
if cc_pair.last_time_perm_sync is None:
|
||||
return True
|
||||
|
||||
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
if (current_time - last_sync).total_seconds() > source_sync_period:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def should_perform_chat_ttl_check(
|
||||
retention_limit_days: int | None, db_session: Session
|
||||
) -> bool:
|
||||
@ -28,7 +53,7 @@ def should_perform_chat_ttl_check(
|
||||
if not latest_task:
|
||||
return True
|
||||
|
||||
if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
if check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
return True
|
||||
@ -50,6 +75,9 @@ def should_perform_external_doc_permissions_check(
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ -69,4 +97,7 @@ def should_perform_external_group_permissions_check(
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
@ -6,38 +6,15 @@ from sqlalchemy.orm import Session
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.document_index.factory import get_current_primary_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
if cc_pair.last_time_perm_sync is None:
|
||||
return True
|
||||
|
||||
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
if (current_time - last_sync).total_seconds() > source_sync_period:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def run_external_group_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
@ -53,9 +30,6 @@ def run_external_group_permission_sync(
|
||||
# Not all sync connectors support group permissions so this is fine
|
||||
return
|
||||
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
@ -91,9 +65,6 @@ def run_external_doc_permission_sync(
|
||||
f"No permission sync function found for source type: {source_type}"
|
||||
)
|
||||
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> document mapping
|
||||
|
@ -4,7 +4,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
|
||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
||||
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
||||
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
|
||||
MAX_DELAY = 30
|
||||
MAX_DELAY = 45
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
@ -9,7 +9,7 @@ from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.server.documents.models import CCPairPruningTask
|
||||
from danswer.server.documents.models import CeleryTaskStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorIndexingStatus
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
@ -85,7 +85,7 @@ class CCPairManager:
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
return _cc_pair_creator(
|
||||
cc_pair = _cc_pair_creator(
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
name=name,
|
||||
@ -93,6 +93,7 @@ class CCPairManager:
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
return cc_pair
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@ -103,7 +104,7 @@ class CCPairManager:
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestCCPair:
|
||||
return _cc_pair_creator(
|
||||
cc_pair = _cc_pair_creator(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
name=name,
|
||||
@ -111,6 +112,7 @@ class CCPairManager:
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
return cc_pair
|
||||
|
||||
@staticmethod
|
||||
def pause_cc_pair(
|
||||
@ -203,9 +205,28 @@ class CCPairManager:
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"CC pair {cc_pair.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def run_once(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
body = {
|
||||
"connector_id": cc_pair.connector_id,
|
||||
"credential_ids": [cc_pair.credential_id],
|
||||
"from_beginning": True,
|
||||
}
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
|
||||
json=body,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def wait_for_indexing(
|
||||
cc_pair_test: DATestCCPair,
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
@ -213,14 +234,20 @@ class CCPairManager:
|
||||
"""after: Wait for an indexing success time after this time"""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
for cc_pair in cc_pairs:
|
||||
if cc_pair.cc_pair_id != cc_pair_test.id:
|
||||
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
for fetched_cc_pair in fetched_cc_pairs:
|
||||
if fetched_cc_pair.cc_pair_id != cc_pair.id:
|
||||
continue
|
||||
|
||||
if cc_pair.last_success and cc_pair.last_success > after:
|
||||
print(f"cc_pair {cc_pair_test.id} indexing complete.")
|
||||
if (
|
||||
fetched_cc_pair.last_success
|
||||
and fetched_cc_pair.last_success > after
|
||||
):
|
||||
print(f"cc_pair {cc_pair.id} indexing complete.")
|
||||
return
|
||||
else:
|
||||
print("cc_pair found but not finished:")
|
||||
# print(fetched_cc_pair.__dict__)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
@ -250,7 +277,7 @@ class CCPairManager:
|
||||
def get_prune_task(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> CCPairPruningTask:
|
||||
) -> CeleryTaskStatus:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
|
||||
headers=user_performing_action.headers
|
||||
@ -258,11 +285,11 @@ class CCPairManager:
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return CCPairPruningTask(**response.json())
|
||||
return CeleryTaskStatus(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def wait_for_prune(
|
||||
cc_pair_test: DATestCCPair,
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
@ -270,7 +297,7 @@ class CCPairManager:
|
||||
"""after: The task register time must be after this time."""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
task = CCPairManager.get_prune_task(cc_pair_test, user_performing_action)
|
||||
task = CCPairManager.get_prune_task(cc_pair, user_performing_action)
|
||||
if not task:
|
||||
raise ValueError("Prune task not found.")
|
||||
|
||||
@ -292,16 +319,75 @@ class CCPairManager:
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
@staticmethod
|
||||
def sync(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_sync_task(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> CeleryTaskStatus:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return CeleryTaskStatus(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""after: The task register time must be after this time."""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
task = CCPairManager.get_sync_task(cc_pair, user_performing_action)
|
||||
if not task:
|
||||
raise ValueError("Sync task not found.")
|
||||
|
||||
if not task.register_time or task.register_time < after:
|
||||
raise ValueError("Sync task register time is too early.")
|
||||
|
||||
if task.status == TaskStatus.SUCCESS:
|
||||
# Sync succeeded
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"CC pair syncing was not completed within {timeout} seconds"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for CC syncing to complete. elapsed={elapsed:.2f} timeout={timeout}"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_deletion_completion(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
if all(
|
||||
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
|
||||
for cc_pair in cc_pairs
|
||||
for cc_pair in fetched_cc_pairs
|
||||
):
|
||||
return
|
||||
|
||||
|
@ -50,9 +50,7 @@ class LLMProviderManager:
|
||||
)
|
||||
llm_response.raise_for_status()
|
||||
response_data = llm_response.json()
|
||||
import json
|
||||
|
||||
print(json.dumps(response_data, indent=4))
|
||||
result_llm = DATestLLMProvider(
|
||||
id=response_data["id"],
|
||||
name=response_data["name"],
|
||||
|
@ -17,11 +17,14 @@ class UserManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
email: str | None = None,
|
||||
) -> DATestUser:
|
||||
if name is None:
|
||||
name = f"test{str(uuid4())}"
|
||||
|
||||
email = f"{name}@test.com"
|
||||
if email is None:
|
||||
email = f"{name}@test.com"
|
||||
|
||||
password = "test"
|
||||
|
||||
body = {
|
||||
@ -44,12 +47,10 @@ class UserManager:
|
||||
)
|
||||
print(f"Created user {test_user.email}")
|
||||
|
||||
test_user.headers["Cookie"] = UserManager.login_as_user(test_user)
|
||||
|
||||
return test_user
|
||||
return UserManager.login_as_user(test_user)
|
||||
|
||||
@staticmethod
|
||||
def login_as_user(test_user: DATestUser) -> str:
|
||||
def login_as_user(test_user: DATestUser) -> DATestUser:
|
||||
data = urlencode(
|
||||
{
|
||||
"username": test_user.email,
|
||||
@ -71,7 +72,9 @@ class UserManager:
|
||||
raise Exception("Failed to login")
|
||||
|
||||
print(f"Logged in as {test_user.email}")
|
||||
return f"{result_cookie.name}={result_cookie.value}"
|
||||
cookie = f"{result_cookie.name}={result_cookie.value}"
|
||||
test_user.headers["Cookie"] = cookie
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
def verify_role(user_to_verify: DATestUser, target_role: UserRole) -> bool:
|
||||
|
@ -0,0 +1,28 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_test_setup() -> Generator[tuple[dict[str, Any], dict[str, Any]], None, None]:
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
admin_user_id = SlackManager.build_slack_user_email_id_map(slack_client)[
|
||||
"admin@onyx-test.com"
|
||||
]
|
||||
|
||||
(
|
||||
public_channel,
|
||||
private_channel,
|
||||
run_id,
|
||||
) = SlackManager.get_and_provision_available_slack_channels(
|
||||
slack_client=slack_client, admin_user_id=admin_user_id
|
||||
)
|
||||
|
||||
yield public_channel, private_channel
|
||||
|
||||
# This part will always run after the test, even if it fails
|
||||
SlackManager.cleanup_after_test(slack_client=slack_client, test_id=run_id)
|
@ -0,0 +1,311 @@
|
||||
"""
|
||||
Assumptions:
|
||||
- The test users have already been created
|
||||
- General is empty of messages
|
||||
- In addition to the normal slack oauth permissions, the following scopes are needed:
|
||||
- channels:manage
|
||||
- groups:write
|
||||
- chat:write
|
||||
- chat:write.public
|
||||
"""
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from danswer.connectors.slack.connector import default_msg_filter
|
||||
from danswer.connectors.slack.connector import get_channel_messages
|
||||
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
|
||||
|
||||
def _get_slack_channel_id(channel: dict[str, Any]) -> str:
|
||||
if not (channel_id := channel.get("id")):
|
||||
raise ValueError("Channel ID is missing")
|
||||
return channel_id
|
||||
|
||||
|
||||
def _get_non_general_channels(
|
||||
slack_client: WebClient,
|
||||
get_private: bool,
|
||||
get_public: bool,
|
||||
only_get_done: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
channel_types = []
|
||||
if get_private:
|
||||
channel_types.append("private_channel")
|
||||
if get_public:
|
||||
channel_types.append("public_channel")
|
||||
|
||||
conversations: list[dict[str, Any]] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
slack_client.conversations_list,
|
||||
exclude_archived=False,
|
||||
types=channel_types,
|
||||
):
|
||||
conversations.extend(result["channels"])
|
||||
|
||||
filtered_conversations = []
|
||||
for conversation in conversations:
|
||||
if conversation.get("is_general", False):
|
||||
continue
|
||||
if only_get_done and "done" not in conversation.get("name", ""):
|
||||
continue
|
||||
filtered_conversations.append(conversation)
|
||||
return filtered_conversations
|
||||
|
||||
|
||||
def _clear_slack_conversation_members(
|
||||
slack_client: WebClient,
|
||||
admin_user_id: str,
|
||||
channel: dict[str, Any],
|
||||
) -> None:
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
member_ids: list[str] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
slack_client.conversations_members,
|
||||
channel=channel_id,
|
||||
):
|
||||
member_ids.extend(result["members"])
|
||||
|
||||
for member_id in member_ids:
|
||||
if member_id == admin_user_id:
|
||||
continue
|
||||
try:
|
||||
make_slack_api_call_w_retries(
|
||||
slack_client.conversations_kick, channel=channel_id, user=member_id
|
||||
)
|
||||
print(f"Kicked member: {member_id}")
|
||||
except Exception as e:
|
||||
if "cant_kick_self" in str(e):
|
||||
continue
|
||||
print(f"Error kicking member: {e}")
|
||||
print(member_id)
|
||||
try:
|
||||
make_slack_api_call_w_retries(
|
||||
slack_client.conversations_unarchive, channel=channel_id
|
||||
)
|
||||
channel["is_archived"] = False
|
||||
except Exception:
|
||||
# Channel is already unarchived
|
||||
pass
|
||||
|
||||
|
||||
def _add_slack_conversation_members(
|
||||
slack_client: WebClient, channel: dict[str, Any], member_ids: list[str]
|
||||
) -> None:
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
for user_id in member_ids:
|
||||
try:
|
||||
make_slack_api_call_w_retries(
|
||||
slack_client.conversations_invite,
|
||||
channel=channel_id,
|
||||
users=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
if "already_in_channel" in str(e):
|
||||
continue
|
||||
print(f"Error inviting member: {e}")
|
||||
print(user_id)
|
||||
|
||||
|
||||
def _delete_slack_conversation_messages(
|
||||
slack_client: WebClient,
|
||||
channel: dict[str, Any],
|
||||
message_to_delete: str | None = None,
|
||||
) -> None:
|
||||
"""deletes all messages from a channel if message_to_delete is None"""
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
for message_batch in get_channel_messages(slack_client, channel):
|
||||
for message in message_batch:
|
||||
if default_msg_filter(message):
|
||||
continue
|
||||
|
||||
if message_to_delete and message.get("text") != message_to_delete:
|
||||
continue
|
||||
print(" removing message: ", message.get("text"))
|
||||
|
||||
try:
|
||||
if not (ts := message.get("ts")):
|
||||
raise ValueError("Message timestamp is missing")
|
||||
make_slack_api_call_w_retries(
|
||||
slack_client.chat_delete,
|
||||
channel=channel_id,
|
||||
ts=ts,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error deleting message: {e}")
|
||||
print(message)
|
||||
|
||||
|
||||
def _build_slack_channel_from_name(
|
||||
slack_client: WebClient,
|
||||
admin_user_id: str,
|
||||
suffix: str,
|
||||
is_private: bool,
|
||||
channel: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
base = "public_channel" if not is_private else "private_channel"
|
||||
channel_name = f"{base}-{suffix}"
|
||||
if channel:
|
||||
# If channel is provided, we rename it
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
channel_response = make_slack_api_call_w_retries(
|
||||
slack_client.conversations_rename,
|
||||
channel=channel_id,
|
||||
name=channel_name,
|
||||
)
|
||||
else:
|
||||
# Otherwise, we create a new channel
|
||||
channel_response = make_slack_api_call_w_retries(
|
||||
slack_client.conversations_create,
|
||||
name=channel_name,
|
||||
is_private=is_private,
|
||||
)
|
||||
|
||||
try:
|
||||
channel_response = make_slack_api_call_w_retries(
|
||||
slack_client.conversations_unarchive,
|
||||
channel=channel_response["channel"]["id"],
|
||||
)
|
||||
except Exception:
|
||||
# Channel is already unarchived
|
||||
pass
|
||||
try:
|
||||
channel_response = make_slack_api_call_w_retries(
|
||||
slack_client.conversations_invite,
|
||||
channel=channel_response["channel"]["id"],
|
||||
users=[admin_user_id],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
final_channel = channel_response["channel"] if channel_response else {}
|
||||
return final_channel
|
||||
|
||||
|
||||
class SlackManager:
|
||||
@staticmethod
|
||||
def get_slack_client(token: str) -> WebClient:
|
||||
return WebClient(token=token)
|
||||
|
||||
@staticmethod
|
||||
def get_and_provision_available_slack_channels(
|
||||
slack_client: WebClient, admin_user_id: str
|
||||
) -> tuple[dict[str, Any], dict[str, Any], str]:
|
||||
run_id = str(uuid4())
|
||||
public_channels = _get_non_general_channels(
|
||||
slack_client, get_private=False, get_public=True, only_get_done=True
|
||||
)
|
||||
|
||||
first_available_channel = (
|
||||
None if len(public_channels) < 1 else public_channels[0]
|
||||
)
|
||||
public_channel = _build_slack_channel_from_name(
|
||||
slack_client=slack_client,
|
||||
admin_user_id=admin_user_id,
|
||||
suffix=run_id,
|
||||
is_private=False,
|
||||
channel=first_available_channel,
|
||||
)
|
||||
_delete_slack_conversation_messages(
|
||||
slack_client=slack_client, channel=public_channel
|
||||
)
|
||||
|
||||
private_channels = _get_non_general_channels(
|
||||
slack_client, get_private=True, get_public=False, only_get_done=True
|
||||
)
|
||||
second_available_channel = (
|
||||
None if len(private_channels) < 1 else private_channels[0]
|
||||
)
|
||||
private_channel = _build_slack_channel_from_name(
|
||||
slack_client=slack_client,
|
||||
admin_user_id=admin_user_id,
|
||||
suffix=run_id,
|
||||
is_private=True,
|
||||
channel=second_available_channel,
|
||||
)
|
||||
_delete_slack_conversation_messages(
|
||||
slack_client=slack_client, channel=private_channel
|
||||
)
|
||||
|
||||
return public_channel, private_channel, run_id
|
||||
|
||||
@staticmethod
|
||||
def build_slack_user_email_id_map(slack_client: WebClient) -> dict[str, str]:
|
||||
users_results = make_slack_api_call_w_retries(
|
||||
slack_client.users_list,
|
||||
)
|
||||
users: list[dict[str, Any]] = users_results.get("members", [])
|
||||
user_email_id_map = {}
|
||||
for user in users:
|
||||
if not (email := user.get("profile", {}).get("email")):
|
||||
continue
|
||||
if not (user_id := user.get("id")):
|
||||
raise ValueError("User ID is missing")
|
||||
user_email_id_map[email] = user_id
|
||||
return user_email_id_map
|
||||
|
||||
@staticmethod
|
||||
def set_channel_members(
|
||||
slack_client: WebClient,
|
||||
admin_user_id: str,
|
||||
channel: dict[str, Any],
|
||||
user_ids: list[str],
|
||||
) -> None:
|
||||
_clear_slack_conversation_members(
|
||||
slack_client=slack_client,
|
||||
channel=channel,
|
||||
admin_user_id=admin_user_id,
|
||||
)
|
||||
_add_slack_conversation_members(
|
||||
slack_client=slack_client, channel=channel, member_ids=user_ids
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_message_to_channel(
|
||||
slack_client: WebClient, channel: dict[str, Any], message: str
|
||||
) -> None:
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
make_slack_api_call_w_retries(
|
||||
slack_client.chat_postMessage,
|
||||
channel=channel_id,
|
||||
text=message,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def remove_message_from_channel(
|
||||
slack_client: WebClient, channel: dict[str, Any], message: str
|
||||
) -> None:
|
||||
_delete_slack_conversation_messages(
|
||||
slack_client=slack_client, channel=channel, message_to_delete=message
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def cleanup_after_test(
|
||||
slack_client: WebClient,
|
||||
test_id: str,
|
||||
) -> None:
|
||||
channel_types = ["private_channel", "public_channel"]
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
slack_client.conversations_list,
|
||||
exclude_archived=False,
|
||||
types=channel_types,
|
||||
):
|
||||
channels.extend(result["channels"])
|
||||
|
||||
for channel in channels:
|
||||
if test_id not in channel.get("name", ""):
|
||||
continue
|
||||
# "done" in the channel name indicates that this channel is free to be used for a new test
|
||||
new_name = f"done_{str(uuid4())}"
|
||||
try:
|
||||
make_slack_api_call_w_retries(
|
||||
slack_client.conversations_rename,
|
||||
channel=channel["id"],
|
||||
name=new_name,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
print(f"Error renaming channel {channel['id']}: {e}")
|
@ -0,0 +1,251 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
|
||||
def test_slack_permission_sync(
|
||||
reset: None,
|
||||
vespa_client: vespa_fixture,
|
||||
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
|
||||
) -> None:
|
||||
public_channel, private_channel = slack_test_setup
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(
|
||||
email="admin@onyx-test.com",
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_1: DATestUser = UserManager.create(
|
||||
email="test_user_1@onyx-test.com",
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_2: DATestUser = UserManager.create(
|
||||
email="test_user_2@onyx-test.com",
|
||||
)
|
||||
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
admin_user_id = email_id_map[admin_user.email]
|
||||
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
before = datetime.now(timezone.utc)
|
||||
credential: DATestCredential = CredentialManager.create(
|
||||
source=DocumentSource.SLACK,
|
||||
credential_json={
|
||||
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
connector: DATestConnector = ConnectorManager.create(
|
||||
name="Slack",
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"workspace": "onyx-test-workspace",
|
||||
"channels": [public_channel["name"], private_channel["name"]],
|
||||
},
|
||||
is_public=True,
|
||||
groups=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair: DATestCCPair = CCPairManager.create(
|
||||
credential_id=credential.id,
|
||||
connector_id=connector.id,
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Add test_user_1 and admin_user to the private channel
|
||||
desired_channel_members = [admin_user, test_user_1]
|
||||
SlackManager.set_channel_members(
|
||||
slack_client=slack_client,
|
||||
admin_user_id=admin_user_id,
|
||||
channel=private_channel,
|
||||
user_ids=[email_id_map[user.email] for user in desired_channel_members],
|
||||
)
|
||||
|
||||
public_message = "Steve's favorite number is 809752"
|
||||
private_message = "Sara's favorite number is 346794"
|
||||
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=public_channel,
|
||||
message=public_message,
|
||||
)
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=private_channel,
|
||||
message=private_message,
|
||||
)
|
||||
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Run permission sync
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Search as admin with access to both channels
|
||||
search_request = DocumentSearchRequest(
|
||||
message="favorite number",
|
||||
search_type=SearchType.KEYWORD,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
)
|
||||
search_request_body = search_request.model_dump()
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/query/document-search",
|
||||
json=search_request_body,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
found_docs = result.json()["top_documents"]
|
||||
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||
|
||||
# Ensure admin user can see messages from both channels
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message in danswer_doc_message_strings
|
||||
|
||||
# Search as test_user_2 with access to only the public channel
|
||||
search_request = DocumentSearchRequest(
|
||||
message="favorite number",
|
||||
search_type=SearchType.KEYWORD,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
)
|
||||
search_request_body = search_request.model_dump()
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/query/document-search",
|
||||
json=search_request_body,
|
||||
headers=test_user_2.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
found_docs = result.json()["top_documents"]
|
||||
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||
print(
|
||||
"\ntop_documents content before removing from private channel for test_user_2: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
# Ensure test_user_2 can only see messages from the public channel
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message not in danswer_doc_message_strings
|
||||
|
||||
# Search as test_user_1 with access to both channels
|
||||
search_request = DocumentSearchRequest(
|
||||
message="favorite number",
|
||||
search_type=SearchType.KEYWORD,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
)
|
||||
search_request_body = search_request.model_dump()
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/query/document-search",
|
||||
json=search_request_body,
|
||||
headers=test_user_1.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
found_docs = result.json()["top_documents"]
|
||||
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||
print(
|
||||
"\ntop_documents content before removing from private channel for test_user_1: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
# Ensure test_user_1 can see messages from both channels
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message in danswer_doc_message_strings
|
||||
|
||||
# ----------------------MAKE THE CHANGES--------------------------
|
||||
print("\nRemoving test_user_1 from the private channel")
|
||||
# Remove test_user_1 from the private channel
|
||||
desired_channel_members = [admin_user]
|
||||
SlackManager.set_channel_members(
|
||||
slack_client=slack_client,
|
||||
admin_user_id=admin_user_id,
|
||||
channel=private_channel,
|
||||
user_ids=[email_id_map[user.email] for user in desired_channel_members],
|
||||
)
|
||||
|
||||
# Run permission sync
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# ----------------------------VERIFY THE CHANGES---------------------------
|
||||
# Ensure test_user_1 can no longer see messages from the private channel
|
||||
# Search as test_user_1 with access to only the public channel
|
||||
search_request = DocumentSearchRequest(
|
||||
message="favorite number",
|
||||
search_type=SearchType.KEYWORD,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
)
|
||||
search_request_body = search_request.model_dump()
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/query/document-search",
|
||||
json=search_request_body,
|
||||
headers=test_user_1.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
found_docs = result.json()["top_documents"]
|
||||
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||
print(
|
||||
"\ntop_documents content after removing from private channel for test_user_1: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
# Ensure test_user_1 can only see messages from the public channel
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message not in danswer_doc_message_strings
|
@ -0,0 +1,255 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
|
||||
def test_slack_prune(
|
||||
reset: None,
|
||||
vespa_client: vespa_fixture,
|
||||
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
|
||||
) -> None:
|
||||
public_channel, private_channel = slack_test_setup
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(
|
||||
email="admin@onyx-test.com",
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_1: DATestUser = UserManager.create(
|
||||
email="test_user_1@onyx-test.com",
|
||||
)
|
||||
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
admin_user_id = email_id_map[admin_user.email]
|
||||
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
before = datetime.now(timezone.utc)
|
||||
credential: DATestCredential = CredentialManager.create(
|
||||
source=DocumentSource.SLACK,
|
||||
credential_json={
|
||||
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
connector: DATestConnector = ConnectorManager.create(
|
||||
name="Slack",
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"workspace": "onyx-test-workspace",
|
||||
"channels": [public_channel["name"], private_channel["name"]],
|
||||
},
|
||||
is_public=True,
|
||||
groups=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair: DATestCCPair = CCPairManager.create(
|
||||
credential_id=credential.id,
|
||||
connector_id=connector.id,
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# ----------------------SETUP INITIAL SLACK STATE--------------------------
|
||||
# Add test_user_1 and admin_user to the private channel
|
||||
desired_channel_members = [admin_user, test_user_1]
|
||||
SlackManager.set_channel_members(
|
||||
slack_client=slack_client,
|
||||
admin_user_id=admin_user_id,
|
||||
channel=private_channel,
|
||||
user_ids=[email_id_map[user.email] for user in desired_channel_members],
|
||||
)
|
||||
|
||||
public_message = "Steve's favorite number is 809752"
|
||||
private_message = "Sara's favorite number is 346794"
|
||||
message_to_delete = "Rebecca's favorite number is 753468"
|
||||
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=public_channel,
|
||||
message=public_message,
|
||||
)
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=private_channel,
|
||||
message=private_message,
|
||||
)
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=private_channel,
|
||||
message=message_to_delete,
|
||||
)
|
||||
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Run permission sync
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# ----------------------TEST THE SETUP--------------------------
|
||||
# Search as admin with access to both channels
|
||||
search_request = DocumentSearchRequest(
|
||||
message="favorite number",
|
||||
search_type=SearchType.KEYWORD,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
)
|
||||
search_request_body = search_request.model_dump()
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/query/document-search",
|
||||
json=search_request_body,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
found_docs = result.json()["top_documents"]
|
||||
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||
print(
|
||||
"\ntop_documents content before deleting for admin: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
# Ensure admin user can see all messages
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message in danswer_doc_message_strings
|
||||
assert message_to_delete in danswer_doc_message_strings
|
||||
|
||||
# Search as test_user_1 with access to both channels
|
||||
search_request = DocumentSearchRequest(
|
||||
message="favorite number",
|
||||
search_type=SearchType.KEYWORD,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
)
|
||||
search_request_body = search_request.model_dump()
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/query/document-search",
|
||||
json=search_request_body,
|
||||
headers=test_user_1.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
found_docs = result.json()["top_documents"]
|
||||
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||
print(
|
||||
"\ntop_documents content before deleting for test_user_1: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
# Ensure test_user_1 can see all messages
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message in danswer_doc_message_strings
|
||||
assert message_to_delete in danswer_doc_message_strings
|
||||
|
||||
# ----------------------MAKE THE CHANGES--------------------------
|
||||
# Delete messages
|
||||
print("\nDeleting message: ", message_to_delete)
|
||||
SlackManager.remove_message_from_channel(
|
||||
slack_client=slack_client,
|
||||
channel=private_channel,
|
||||
message=message_to_delete,
|
||||
)
|
||||
|
||||
# Prune the cc_pair
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.prune(cc_pair, user_performing_action=admin_user)
|
||||
CCPairManager.wait_for_prune(cc_pair, before, user_performing_action=admin_user)
|
||||
|
||||
# ----------------------------VERIFY THE CHANGES---------------------------
|
||||
# Ensure admin user can't see deleted messages
|
||||
# Search as admin user with access to only the public channel
|
||||
search_request = DocumentSearchRequest(
|
||||
message="favorite number",
|
||||
search_type=SearchType.KEYWORD,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
)
|
||||
search_request_body = search_request.model_dump()
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/query/document-search",
|
||||
json=search_request_body,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
found_docs = result.json()["top_documents"]
|
||||
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||
print(
|
||||
"\ntop_documents content after deleting for admin: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
# Ensure admin can't see deleted messages
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message in danswer_doc_message_strings
|
||||
assert message_to_delete not in danswer_doc_message_strings
|
||||
|
||||
# Ensure test_user_1 can't see deleted messages
|
||||
# Search as test_user_1 with access to only the public channel
|
||||
search_request = DocumentSearchRequest(
|
||||
message="favorite number",
|
||||
search_type=SearchType.KEYWORD,
|
||||
retrieval_options=RetrievalDetails(),
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
)
|
||||
search_request_body = search_request.model_dump()
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/query/document-search",
|
||||
json=search_request_body,
|
||||
headers=test_user_1.headers,
|
||||
)
|
||||
result.raise_for_status()
|
||||
found_docs = result.json()["top_documents"]
|
||||
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||
print(
|
||||
"\ntop_documents content after prune for test_user_1: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
# Ensure test_user_1 can't see deleted messages
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message in danswer_doc_message_strings
|
||||
assert message_to_delete not in danswer_doc_message_strings
|
Loading…
x
Reference in New Issue
Block a user