mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-21 21:41:03 +02:00
Implement source testing framework + Slack (#2650)
* Added permission sync tests for Slack * moved folders * prune test + mypy * added wait for indexing to cc_pair creation * commented out check * should fix other tests * added slack channel pool * fixed everything and mypy * reduced flake
This commit is contained in:
parent
b3c367d09c
commit
c2088602e1
2
.github/workflows/pr-Integration-tests.yml
vendored
2
.github/workflows/pr-Integration-tests.yml
vendored
@ -12,6 +12,7 @@ on:
|
|||||||
|
|
||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
integration-tests:
|
integration-tests:
|
||||||
@ -142,6 +143,7 @@ jobs:
|
|||||||
-e REDIS_HOST=cache \
|
-e REDIS_HOST=cache \
|
||||||
-e API_SERVER_HOST=api_server \
|
-e API_SERVER_HOST=api_server \
|
||||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||||
|
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||||
-e TEST_WEB_HOSTNAME=test-runner \
|
-e TEST_WEB_HOSTNAME=test-runner \
|
||||||
danswer/danswer-integration:test
|
danswer/danswer-integration:test
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
|
@ -205,12 +205,17 @@ _DISALLOWED_MSG_SUBTYPES = {
|
|||||||
"group_leave",
|
"group_leave",
|
||||||
"group_archive",
|
"group_archive",
|
||||||
"group_unarchive",
|
"group_unarchive",
|
||||||
|
"channel_leave",
|
||||||
|
"channel_name",
|
||||||
|
"channel_join",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _default_msg_filter(message: MessageType) -> bool:
|
def default_msg_filter(message: MessageType) -> bool:
|
||||||
# Don't keep messages from bots
|
# Don't keep messages from bots
|
||||||
if message.get("bot_id") or message.get("app_id"):
|
if message.get("bot_id") or message.get("app_id"):
|
||||||
|
if message.get("bot_profile", {}).get("name") == "DanswerConnector":
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Uninformative
|
# Uninformative
|
||||||
@ -261,7 +266,7 @@ def _get_all_docs(
|
|||||||
channel_name_regex_enabled: bool = False,
|
channel_name_regex_enabled: bool = False,
|
||||||
oldest: str | None = None,
|
oldest: str | None = None,
|
||||||
latest: str | None = None,
|
latest: str | None = None,
|
||||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||||
) -> Generator[Document, None, None]:
|
) -> Generator[Document, None, None]:
|
||||||
"""Get all documents in the workspace, channel by channel"""
|
"""Get all documents in the workspace, channel by channel"""
|
||||||
slack_cleaner = SlackTextCleaner(client=client)
|
slack_cleaner = SlackTextCleaner(client=client)
|
||||||
@ -320,7 +325,7 @@ def _get_all_doc_ids(
|
|||||||
client: WebClient,
|
client: WebClient,
|
||||||
channels: list[str] | None = None,
|
channels: list[str] | None = None,
|
||||||
channel_name_regex_enabled: bool = False,
|
channel_name_regex_enabled: bool = False,
|
||||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||||
) -> set[str]:
|
) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Get all document ids in the workspace, channel by channel
|
Get all document ids in the workspace, channel by channel
|
||||||
|
@ -29,15 +29,19 @@ from danswer.db.index_attempt import count_index_attempts_for_connector
|
|||||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
|
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||||
from danswer.db.tasks import get_latest_task
|
from danswer.db.tasks import get_latest_task
|
||||||
from danswer.server.documents.models import CCPairFullInfo
|
from danswer.server.documents.models import CCPairFullInfo
|
||||||
from danswer.server.documents.models import CCPairPruningTask
|
|
||||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||||
|
from danswer.server.documents.models import CeleryTaskStatus
|
||||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||||
from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
||||||
from danswer.server.documents.models import PaginatedIndexAttempts
|
from danswer.server.documents.models import PaginatedIndexAttempts
|
||||||
from danswer.server.models import StatusResponse
|
from danswer.server.models import StatusResponse
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
from ee.danswer.background.task_name_builders import (
|
||||||
|
name_sync_external_doc_permissions_task,
|
||||||
|
)
|
||||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -199,7 +203,7 @@ def get_cc_pair_latest_prune(
|
|||||||
cc_pair_id: int,
|
cc_pair_id: int,
|
||||||
user: User = Depends(current_curator_or_admin_user),
|
user: User = Depends(current_curator_or_admin_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> CCPairPruningTask:
|
) -> CeleryTaskStatus:
|
||||||
cc_pair = get_connector_credential_pair_from_id(
|
cc_pair = get_connector_credential_pair_from_id(
|
||||||
cc_pair_id=cc_pair_id,
|
cc_pair_id=cc_pair_id,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@ -223,7 +227,7 @@ def get_cc_pair_latest_prune(
|
|||||||
detail="No pruning task found.",
|
detail="No pruning task found.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return CCPairPruningTask(
|
return CeleryTaskStatus(
|
||||||
id=last_pruning_task.task_id,
|
id=last_pruning_task.task_id,
|
||||||
name=last_pruning_task.task_name,
|
name=last_pruning_task.task_name,
|
||||||
status=last_pruning_task.status,
|
status=last_pruning_task.status,
|
||||||
@ -280,6 +284,95 @@ def prune_cc_pair(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/admin/cc-pair/{cc_pair_id}/sync")
|
||||||
|
def get_cc_pair_latest_sync(
|
||||||
|
cc_pair_id: int,
|
||||||
|
user: User = Depends(current_curator_or_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> CeleryTaskStatus:
|
||||||
|
cc_pair = get_connector_credential_pair_from_id(
|
||||||
|
cc_pair_id=cc_pair_id,
|
||||||
|
db_session=db_session,
|
||||||
|
user=user,
|
||||||
|
get_editable=False,
|
||||||
|
)
|
||||||
|
if not cc_pair:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Connection not found for current user's permissions",
|
||||||
|
)
|
||||||
|
|
||||||
|
# look up the last sync task for this connector (if it exists)
|
||||||
|
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
|
||||||
|
last_sync_task = get_latest_task(sync_task_name, db_session)
|
||||||
|
if not last_sync_task:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
|
detail="No sync task found.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return CeleryTaskStatus(
|
||||||
|
id=last_sync_task.task_id,
|
||||||
|
name=last_sync_task.task_name,
|
||||||
|
status=last_sync_task.status,
|
||||||
|
start_time=last_sync_task.start_time,
|
||||||
|
register_time=last_sync_task.register_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/admin/cc-pair/{cc_pair_id}/sync")
|
||||||
|
def sync_cc_pair(
|
||||||
|
cc_pair_id: int,
|
||||||
|
user: User = Depends(current_curator_or_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> StatusResponse[list[int]]:
|
||||||
|
# avoiding circular refs
|
||||||
|
from ee.danswer.background.celery.celery_app import (
|
||||||
|
sync_external_doc_permissions_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_pair = get_connector_credential_pair_from_id(
|
||||||
|
cc_pair_id=cc_pair_id,
|
||||||
|
db_session=db_session,
|
||||||
|
user=user,
|
||||||
|
get_editable=False,
|
||||||
|
)
|
||||||
|
if not cc_pair:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Connection not found for current user's permissions",
|
||||||
|
)
|
||||||
|
|
||||||
|
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
|
||||||
|
last_sync_task = get_latest_task(sync_task_name, db_session)
|
||||||
|
|
||||||
|
if last_sync_task and check_task_is_live_and_not_timed_out(
|
||||||
|
last_sync_task, db_session
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.CONFLICT,
|
||||||
|
detail="Sync task already in progress.",
|
||||||
|
)
|
||||||
|
if skip_cc_pair_pruning_by_task(
|
||||||
|
last_sync_task,
|
||||||
|
db_session=db_session,
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.CONFLICT,
|
||||||
|
detail="Sync task already in progress.",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
|
||||||
|
sync_external_doc_permissions_task.apply_async(
|
||||||
|
kwargs=dict(cc_pair_id=cc_pair_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
return StatusResponse(
|
||||||
|
success=True,
|
||||||
|
message="Successfully created the sync task.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/connector/{connector_id}/credential/{credential_id}")
|
@router.put("/connector/{connector_id}/credential/{credential_id}")
|
||||||
def associate_credential_to_connector(
|
def associate_credential_to_connector(
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
|
@ -781,6 +781,7 @@ def connector_run_once(
|
|||||||
detail="Connector has no valid credentials, cannot create index attempts.",
|
detail="Connector has no valid credentials, cannot create index attempts.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prevents index attempts for cc pairs that already have an index attempt currently running
|
||||||
skipped_credentials = [
|
skipped_credentials = [
|
||||||
credential_id
|
credential_id
|
||||||
for credential_id in credential_ids
|
for credential_id in credential_ids
|
||||||
@ -790,15 +791,15 @@ def connector_run_once(
|
|||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
),
|
),
|
||||||
only_current=True,
|
only_current=True,
|
||||||
disinclude_finished=True,
|
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
|
disinclude_finished=True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
search_settings = get_current_search_settings(db_session)
|
search_settings = get_current_search_settings(db_session)
|
||||||
|
|
||||||
connector_credential_pairs = [
|
connector_credential_pairs = [
|
||||||
get_connector_credential_pair(run_info.connector_id, credential_id, db_session)
|
get_connector_credential_pair(connector_id, credential_id, db_session)
|
||||||
for credential_id in credential_ids
|
for credential_id in credential_ids
|
||||||
if credential_id not in skipped_credentials
|
if credential_id not in skipped_credentials
|
||||||
]
|
]
|
||||||
|
@ -268,7 +268,7 @@ class CCPairFullInfo(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CCPairPruningTask(BaseModel):
|
class CeleryTaskStatus(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
status: TaskStatus
|
status: TaskStatus
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from datetime import timezone
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.db.enums import AccessType
|
from danswer.db.enums import AccessType
|
||||||
@ -12,10 +15,32 @@ from ee.danswer.background.task_name_builders import (
|
|||||||
from ee.danswer.background.task_name_builders import (
|
from ee.danswer.background.task_name_builders import (
|
||||||
name_sync_external_group_permissions_task,
|
name_sync_external_group_permissions_task,
|
||||||
)
|
)
|
||||||
|
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
|
||||||
|
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||||
|
|
||||||
|
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||||
|
if not source_sync_period:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If the last sync is None, it has never been run so we run the sync
|
||||||
|
if cc_pair.last_time_perm_sync is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
|
||||||
|
current_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# If the last sync is greater than the full fetch period, we run the sync
|
||||||
|
if (current_time - last_sync).total_seconds() > source_sync_period:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def should_perform_chat_ttl_check(
|
def should_perform_chat_ttl_check(
|
||||||
retention_limit_days: int | None, db_session: Session
|
retention_limit_days: int | None, db_session: Session
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@ -28,7 +53,7 @@ def should_perform_chat_ttl_check(
|
|||||||
if not latest_task:
|
if not latest_task:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session):
|
if check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@ -50,6 +75,9 @@ def should_perform_external_doc_permissions_check(
|
|||||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if not _is_time_to_run_sync(cc_pair):
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -69,4 +97,7 @@ def should_perform_external_group_permissions_check(
|
|||||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if not _is_time_to_run_sync(cc_pair):
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
@ -6,38 +6,15 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.access.access import get_access_for_documents
|
from danswer.access.access import get_access_for_documents
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||||
from danswer.db.models import ConnectorCredentialPair
|
|
||||||
from danswer.document_index.factory import get_current_primary_default_document_index
|
from danswer.document_index.factory import get_current_primary_default_document_index
|
||||||
from danswer.document_index.interfaces import UpdateRequest
|
from danswer.document_index.interfaces import UpdateRequest
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||||
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
|
|
||||||
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
|
||||||
|
|
||||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
|
||||||
if not source_sync_period:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If the last sync is None, it has never been run so we run the sync
|
|
||||||
if cc_pair.last_time_perm_sync is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
|
|
||||||
current_time = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# If the last sync is greater than the full fetch period, we run the sync
|
|
||||||
if (current_time - last_sync).total_seconds() > source_sync_period:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def run_external_group_permission_sync(
|
def run_external_group_permission_sync(
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
cc_pair_id: int,
|
cc_pair_id: int,
|
||||||
@ -53,9 +30,6 @@ def run_external_group_permission_sync(
|
|||||||
# Not all sync connectors support group permissions so this is fine
|
# Not all sync connectors support group permissions so this is fine
|
||||||
return
|
return
|
||||||
|
|
||||||
if not _is_time_to_run_sync(cc_pair):
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# This function updates:
|
# This function updates:
|
||||||
# - the user_email <-> external_user_group_id mapping
|
# - the user_email <-> external_user_group_id mapping
|
||||||
@ -91,9 +65,6 @@ def run_external_doc_permission_sync(
|
|||||||
f"No permission sync function found for source type: {source_type}"
|
f"No permission sync function found for source type: {source_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not _is_time_to_run_sync(cc_pair):
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# This function updates:
|
# This function updates:
|
||||||
# - the user_email <-> document mapping
|
# - the user_email <-> document mapping
|
||||||
|
@ -4,7 +4,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
|
|||||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
||||||
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
||||||
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
|
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
|
||||||
MAX_DELAY = 30
|
MAX_DELAY = 45
|
||||||
|
|
||||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from danswer.connectors.models import InputType
|
|||||||
from danswer.db.enums import AccessType
|
from danswer.db.enums import AccessType
|
||||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||||
from danswer.db.enums import TaskStatus
|
from danswer.db.enums import TaskStatus
|
||||||
from danswer.server.documents.models import CCPairPruningTask
|
from danswer.server.documents.models import CeleryTaskStatus
|
||||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||||
from danswer.server.documents.models import ConnectorIndexingStatus
|
from danswer.server.documents.models import ConnectorIndexingStatus
|
||||||
from danswer.server.documents.models import DocumentSource
|
from danswer.server.documents.models import DocumentSource
|
||||||
@ -85,7 +85,7 @@ class CCPairManager:
|
|||||||
groups=groups,
|
groups=groups,
|
||||||
user_performing_action=user_performing_action,
|
user_performing_action=user_performing_action,
|
||||||
)
|
)
|
||||||
return _cc_pair_creator(
|
cc_pair = _cc_pair_creator(
|
||||||
connector_id=connector.id,
|
connector_id=connector.id,
|
||||||
credential_id=credential.id,
|
credential_id=credential.id,
|
||||||
name=name,
|
name=name,
|
||||||
@ -93,6 +93,7 @@ class CCPairManager:
|
|||||||
groups=groups,
|
groups=groups,
|
||||||
user_performing_action=user_performing_action,
|
user_performing_action=user_performing_action,
|
||||||
)
|
)
|
||||||
|
return cc_pair
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
@ -103,7 +104,7 @@ class CCPairManager:
|
|||||||
groups: list[int] | None = None,
|
groups: list[int] | None = None,
|
||||||
user_performing_action: DATestUser | None = None,
|
user_performing_action: DATestUser | None = None,
|
||||||
) -> DATestCCPair:
|
) -> DATestCCPair:
|
||||||
return _cc_pair_creator(
|
cc_pair = _cc_pair_creator(
|
||||||
connector_id=connector_id,
|
connector_id=connector_id,
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
name=name,
|
name=name,
|
||||||
@ -111,6 +112,7 @@ class CCPairManager:
|
|||||||
groups=groups,
|
groups=groups,
|
||||||
user_performing_action=user_performing_action,
|
user_performing_action=user_performing_action,
|
||||||
)
|
)
|
||||||
|
return cc_pair
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def pause_cc_pair(
|
def pause_cc_pair(
|
||||||
@ -203,9 +205,28 @@ class CCPairManager:
|
|||||||
if not verify_deleted:
|
if not verify_deleted:
|
||||||
raise ValueError(f"CC pair {cc_pair.id} not found")
|
raise ValueError(f"CC pair {cc_pair.id} not found")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run_once(
|
||||||
|
cc_pair: DATestCCPair,
|
||||||
|
user_performing_action: DATestUser | None = None,
|
||||||
|
) -> None:
|
||||||
|
body = {
|
||||||
|
"connector_id": cc_pair.connector_id,
|
||||||
|
"credential_ids": [cc_pair.credential_id],
|
||||||
|
"from_beginning": True,
|
||||||
|
}
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
|
||||||
|
json=body,
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
result.raise_for_status()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait_for_indexing(
|
def wait_for_indexing(
|
||||||
cc_pair_test: DATestCCPair,
|
cc_pair: DATestCCPair,
|
||||||
after: datetime,
|
after: datetime,
|
||||||
timeout: float = MAX_DELAY,
|
timeout: float = MAX_DELAY,
|
||||||
user_performing_action: DATestUser | None = None,
|
user_performing_action: DATestUser | None = None,
|
||||||
@ -213,14 +234,20 @@ class CCPairManager:
|
|||||||
"""after: Wait for an indexing success time after this time"""
|
"""after: Wait for an indexing success time after this time"""
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
while True:
|
while True:
|
||||||
cc_pairs = CCPairManager.get_all(user_performing_action)
|
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||||
for cc_pair in cc_pairs:
|
for fetched_cc_pair in fetched_cc_pairs:
|
||||||
if cc_pair.cc_pair_id != cc_pair_test.id:
|
if fetched_cc_pair.cc_pair_id != cc_pair.id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if cc_pair.last_success and cc_pair.last_success > after:
|
if (
|
||||||
print(f"cc_pair {cc_pair_test.id} indexing complete.")
|
fetched_cc_pair.last_success
|
||||||
|
and fetched_cc_pair.last_success > after
|
||||||
|
):
|
||||||
|
print(f"cc_pair {cc_pair.id} indexing complete.")
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
print("cc_pair found but not finished:")
|
||||||
|
# print(fetched_cc_pair.__dict__)
|
||||||
|
|
||||||
elapsed = time.monotonic() - start
|
elapsed = time.monotonic() - start
|
||||||
if elapsed > timeout:
|
if elapsed > timeout:
|
||||||
@ -250,7 +277,7 @@ class CCPairManager:
|
|||||||
def get_prune_task(
|
def get_prune_task(
|
||||||
cc_pair: DATestCCPair,
|
cc_pair: DATestCCPair,
|
||||||
user_performing_action: DATestUser | None = None,
|
user_performing_action: DATestUser | None = None,
|
||||||
) -> CCPairPruningTask:
|
) -> CeleryTaskStatus:
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
|
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
|
||||||
headers=user_performing_action.headers
|
headers=user_performing_action.headers
|
||||||
@ -258,11 +285,11 @@ class CCPairManager:
|
|||||||
else GENERAL_HEADERS,
|
else GENERAL_HEADERS,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return CCPairPruningTask(**response.json())
|
return CeleryTaskStatus(**response.json())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait_for_prune(
|
def wait_for_prune(
|
||||||
cc_pair_test: DATestCCPair,
|
cc_pair: DATestCCPair,
|
||||||
after: datetime,
|
after: datetime,
|
||||||
timeout: float = MAX_DELAY,
|
timeout: float = MAX_DELAY,
|
||||||
user_performing_action: DATestUser | None = None,
|
user_performing_action: DATestUser | None = None,
|
||||||
@ -270,7 +297,7 @@ class CCPairManager:
|
|||||||
"""after: The task register time must be after this time."""
|
"""after: The task register time must be after this time."""
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
while True:
|
while True:
|
||||||
task = CCPairManager.get_prune_task(cc_pair_test, user_performing_action)
|
task = CCPairManager.get_prune_task(cc_pair, user_performing_action)
|
||||||
if not task:
|
if not task:
|
||||||
raise ValueError("Prune task not found.")
|
raise ValueError("Prune task not found.")
|
||||||
|
|
||||||
@ -292,16 +319,75 @@ class CCPairManager:
|
|||||||
)
|
)
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sync(
|
||||||
|
cc_pair: DATestCCPair,
|
||||||
|
user_performing_action: DATestUser | None = None,
|
||||||
|
) -> None:
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
result.raise_for_status()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_sync_task(
|
||||||
|
cc_pair: DATestCCPair,
|
||||||
|
user_performing_action: DATestUser | None = None,
|
||||||
|
) -> CeleryTaskStatus:
|
||||||
|
response = requests.get(
|
||||||
|
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return CeleryTaskStatus(**response.json())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def wait_for_sync(
|
||||||
|
cc_pair: DATestCCPair,
|
||||||
|
after: datetime,
|
||||||
|
timeout: float = MAX_DELAY,
|
||||||
|
user_performing_action: DATestUser | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""after: The task register time must be after this time."""
|
||||||
|
start = time.monotonic()
|
||||||
|
while True:
|
||||||
|
task = CCPairManager.get_sync_task(cc_pair, user_performing_action)
|
||||||
|
if not task:
|
||||||
|
raise ValueError("Sync task not found.")
|
||||||
|
|
||||||
|
if not task.register_time or task.register_time < after:
|
||||||
|
raise ValueError("Sync task register time is too early.")
|
||||||
|
|
||||||
|
if task.status == TaskStatus.SUCCESS:
|
||||||
|
# Sync succeeded
|
||||||
|
return
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
if elapsed > timeout:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"CC pair syncing was not completed within {timeout} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Waiting for CC syncing to complete. elapsed={elapsed:.2f} timeout={timeout}"
|
||||||
|
)
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait_for_deletion_completion(
|
def wait_for_deletion_completion(
|
||||||
user_performing_action: DATestUser | None = None,
|
user_performing_action: DATestUser | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
while True:
|
while True:
|
||||||
cc_pairs = CCPairManager.get_all(user_performing_action)
|
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||||
if all(
|
if all(
|
||||||
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
|
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
|
||||||
for cc_pair in cc_pairs
|
for cc_pair in fetched_cc_pairs
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -50,9 +50,7 @@ class LLMProviderManager:
|
|||||||
)
|
)
|
||||||
llm_response.raise_for_status()
|
llm_response.raise_for_status()
|
||||||
response_data = llm_response.json()
|
response_data = llm_response.json()
|
||||||
import json
|
|
||||||
|
|
||||||
print(json.dumps(response_data, indent=4))
|
|
||||||
result_llm = DATestLLMProvider(
|
result_llm = DATestLLMProvider(
|
||||||
id=response_data["id"],
|
id=response_data["id"],
|
||||||
name=response_data["name"],
|
name=response_data["name"],
|
||||||
|
@ -17,11 +17,14 @@ class UserManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
|
email: str | None = None,
|
||||||
) -> DATestUser:
|
) -> DATestUser:
|
||||||
if name is None:
|
if name is None:
|
||||||
name = f"test{str(uuid4())}"
|
name = f"test{str(uuid4())}"
|
||||||
|
|
||||||
|
if email is None:
|
||||||
email = f"{name}@test.com"
|
email = f"{name}@test.com"
|
||||||
|
|
||||||
password = "test"
|
password = "test"
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
@ -44,12 +47,10 @@ class UserManager:
|
|||||||
)
|
)
|
||||||
print(f"Created user {test_user.email}")
|
print(f"Created user {test_user.email}")
|
||||||
|
|
||||||
test_user.headers["Cookie"] = UserManager.login_as_user(test_user)
|
return UserManager.login_as_user(test_user)
|
||||||
|
|
||||||
return test_user
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def login_as_user(test_user: DATestUser) -> str:
|
def login_as_user(test_user: DATestUser) -> DATestUser:
|
||||||
data = urlencode(
|
data = urlencode(
|
||||||
{
|
{
|
||||||
"username": test_user.email,
|
"username": test_user.email,
|
||||||
@ -71,7 +72,9 @@ class UserManager:
|
|||||||
raise Exception("Failed to login")
|
raise Exception("Failed to login")
|
||||||
|
|
||||||
print(f"Logged in as {test_user.email}")
|
print(f"Logged in as {test_user.email}")
|
||||||
return f"{result_cookie.name}={result_cookie.value}"
|
cookie = f"{result_cookie.name}={result_cookie.value}"
|
||||||
|
test_user.headers["Cookie"] = cookie
|
||||||
|
return test_user
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify_role(user_to_verify: DATestUser, target_role: UserRole) -> bool:
|
def verify_role(user_to_verify: DATestUser, target_role: UserRole) -> bool:
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def slack_test_setup() -> Generator[tuple[dict[str, Any], dict[str, Any]], None, None]:
|
||||||
|
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||||
|
admin_user_id = SlackManager.build_slack_user_email_id_map(slack_client)[
|
||||||
|
"admin@onyx-test.com"
|
||||||
|
]
|
||||||
|
|
||||||
|
(
|
||||||
|
public_channel,
|
||||||
|
private_channel,
|
||||||
|
run_id,
|
||||||
|
) = SlackManager.get_and_provision_available_slack_channels(
|
||||||
|
slack_client=slack_client, admin_user_id=admin_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
yield public_channel, private_channel
|
||||||
|
|
||||||
|
# This part will always run after the test, even if it fails
|
||||||
|
SlackManager.cleanup_after_test(slack_client=slack_client, test_id=run_id)
|
@ -0,0 +1,311 @@
|
|||||||
|
"""
|
||||||
|
Assumptions:
|
||||||
|
- The test users have already been created
|
||||||
|
- General is empty of messages
|
||||||
|
- In addition to the normal slack oauth permissions, the following scopes are needed:
|
||||||
|
- channels:manage
|
||||||
|
- groups:write
|
||||||
|
- chat:write
|
||||||
|
- chat:write.public
|
||||||
|
"""
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from slack_sdk import WebClient
|
||||||
|
from slack_sdk.errors import SlackApiError
|
||||||
|
|
||||||
|
from danswer.connectors.slack.connector import default_msg_filter
|
||||||
|
from danswer.connectors.slack.connector import get_channel_messages
|
||||||
|
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||||
|
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
|
||||||
|
|
||||||
|
|
||||||
|
def _get_slack_channel_id(channel: dict[str, Any]) -> str:
|
||||||
|
if not (channel_id := channel.get("id")):
|
||||||
|
raise ValueError("Channel ID is missing")
|
||||||
|
return channel_id
|
||||||
|
|
||||||
|
|
||||||
|
def _get_non_general_channels(
|
||||||
|
slack_client: WebClient,
|
||||||
|
get_private: bool,
|
||||||
|
get_public: bool,
|
||||||
|
only_get_done: bool = False,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
channel_types = []
|
||||||
|
if get_private:
|
||||||
|
channel_types.append("private_channel")
|
||||||
|
if get_public:
|
||||||
|
channel_types.append("public_channel")
|
||||||
|
|
||||||
|
conversations: list[dict[str, Any]] = []
|
||||||
|
for result in make_paginated_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_list,
|
||||||
|
exclude_archived=False,
|
||||||
|
types=channel_types,
|
||||||
|
):
|
||||||
|
conversations.extend(result["channels"])
|
||||||
|
|
||||||
|
filtered_conversations = []
|
||||||
|
for conversation in conversations:
|
||||||
|
if conversation.get("is_general", False):
|
||||||
|
continue
|
||||||
|
if only_get_done and "done" not in conversation.get("name", ""):
|
||||||
|
continue
|
||||||
|
filtered_conversations.append(conversation)
|
||||||
|
return filtered_conversations
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_slack_conversation_members(
|
||||||
|
slack_client: WebClient,
|
||||||
|
admin_user_id: str,
|
||||||
|
channel: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
channel_id = _get_slack_channel_id(channel)
|
||||||
|
member_ids: list[str] = []
|
||||||
|
for result in make_paginated_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_members,
|
||||||
|
channel=channel_id,
|
||||||
|
):
|
||||||
|
member_ids.extend(result["members"])
|
||||||
|
|
||||||
|
for member_id in member_ids:
|
||||||
|
if member_id == admin_user_id:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
make_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_kick, channel=channel_id, user=member_id
|
||||||
|
)
|
||||||
|
print(f"Kicked member: {member_id}")
|
||||||
|
except Exception as e:
|
||||||
|
if "cant_kick_self" in str(e):
|
||||||
|
continue
|
||||||
|
print(f"Error kicking member: {e}")
|
||||||
|
print(member_id)
|
||||||
|
try:
|
||||||
|
make_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_unarchive, channel=channel_id
|
||||||
|
)
|
||||||
|
channel["is_archived"] = False
|
||||||
|
except Exception:
|
||||||
|
# Channel is already unarchived
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _add_slack_conversation_members(
|
||||||
|
slack_client: WebClient, channel: dict[str, Any], member_ids: list[str]
|
||||||
|
) -> None:
|
||||||
|
channel_id = _get_slack_channel_id(channel)
|
||||||
|
for user_id in member_ids:
|
||||||
|
try:
|
||||||
|
make_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_invite,
|
||||||
|
channel=channel_id,
|
||||||
|
users=user_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if "already_in_channel" in str(e):
|
||||||
|
continue
|
||||||
|
print(f"Error inviting member: {e}")
|
||||||
|
print(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _delete_slack_conversation_messages(
|
||||||
|
slack_client: WebClient,
|
||||||
|
channel: dict[str, Any],
|
||||||
|
message_to_delete: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""deletes all messages from a channel if message_to_delete is None"""
|
||||||
|
channel_id = _get_slack_channel_id(channel)
|
||||||
|
for message_batch in get_channel_messages(slack_client, channel):
|
||||||
|
for message in message_batch:
|
||||||
|
if default_msg_filter(message):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if message_to_delete and message.get("text") != message_to_delete:
|
||||||
|
continue
|
||||||
|
print(" removing message: ", message.get("text"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (ts := message.get("ts")):
|
||||||
|
raise ValueError("Message timestamp is missing")
|
||||||
|
make_slack_api_call_w_retries(
|
||||||
|
slack_client.chat_delete,
|
||||||
|
channel=channel_id,
|
||||||
|
ts=ts,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting message: {e}")
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_slack_channel_from_name(
|
||||||
|
slack_client: WebClient,
|
||||||
|
admin_user_id: str,
|
||||||
|
suffix: str,
|
||||||
|
is_private: bool,
|
||||||
|
channel: dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
base = "public_channel" if not is_private else "private_channel"
|
||||||
|
channel_name = f"{base}-{suffix}"
|
||||||
|
if channel:
|
||||||
|
# If channel is provided, we rename it
|
||||||
|
channel_id = _get_slack_channel_id(channel)
|
||||||
|
channel_response = make_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_rename,
|
||||||
|
channel=channel_id,
|
||||||
|
name=channel_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Otherwise, we create a new channel
|
||||||
|
channel_response = make_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_create,
|
||||||
|
name=channel_name,
|
||||||
|
is_private=is_private,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
channel_response = make_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_unarchive,
|
||||||
|
channel=channel_response["channel"]["id"],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Channel is already unarchived
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
channel_response = make_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_invite,
|
||||||
|
channel=channel_response["channel"]["id"],
|
||||||
|
users=[admin_user_id],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
final_channel = channel_response["channel"] if channel_response else {}
|
||||||
|
return final_channel
|
||||||
|
|
||||||
|
|
||||||
|
class SlackManager:
|
||||||
|
@staticmethod
|
||||||
|
def get_slack_client(token: str) -> WebClient:
|
||||||
|
return WebClient(token=token)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_and_provision_available_slack_channels(
|
||||||
|
slack_client: WebClient, admin_user_id: str
|
||||||
|
) -> tuple[dict[str, Any], dict[str, Any], str]:
|
||||||
|
run_id = str(uuid4())
|
||||||
|
public_channels = _get_non_general_channels(
|
||||||
|
slack_client, get_private=False, get_public=True, only_get_done=True
|
||||||
|
)
|
||||||
|
|
||||||
|
first_available_channel = (
|
||||||
|
None if len(public_channels) < 1 else public_channels[0]
|
||||||
|
)
|
||||||
|
public_channel = _build_slack_channel_from_name(
|
||||||
|
slack_client=slack_client,
|
||||||
|
admin_user_id=admin_user_id,
|
||||||
|
suffix=run_id,
|
||||||
|
is_private=False,
|
||||||
|
channel=first_available_channel,
|
||||||
|
)
|
||||||
|
_delete_slack_conversation_messages(
|
||||||
|
slack_client=slack_client, channel=public_channel
|
||||||
|
)
|
||||||
|
|
||||||
|
private_channels = _get_non_general_channels(
|
||||||
|
slack_client, get_private=True, get_public=False, only_get_done=True
|
||||||
|
)
|
||||||
|
second_available_channel = (
|
||||||
|
None if len(private_channels) < 1 else private_channels[0]
|
||||||
|
)
|
||||||
|
private_channel = _build_slack_channel_from_name(
|
||||||
|
slack_client=slack_client,
|
||||||
|
admin_user_id=admin_user_id,
|
||||||
|
suffix=run_id,
|
||||||
|
is_private=True,
|
||||||
|
channel=second_available_channel,
|
||||||
|
)
|
||||||
|
_delete_slack_conversation_messages(
|
||||||
|
slack_client=slack_client, channel=private_channel
|
||||||
|
)
|
||||||
|
|
||||||
|
return public_channel, private_channel, run_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_slack_user_email_id_map(slack_client: WebClient) -> dict[str, str]:
|
||||||
|
users_results = make_slack_api_call_w_retries(
|
||||||
|
slack_client.users_list,
|
||||||
|
)
|
||||||
|
users: list[dict[str, Any]] = users_results.get("members", [])
|
||||||
|
user_email_id_map = {}
|
||||||
|
for user in users:
|
||||||
|
if not (email := user.get("profile", {}).get("email")):
|
||||||
|
continue
|
||||||
|
if not (user_id := user.get("id")):
|
||||||
|
raise ValueError("User ID is missing")
|
||||||
|
user_email_id_map[email] = user_id
|
||||||
|
return user_email_id_map
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_channel_members(
|
||||||
|
slack_client: WebClient,
|
||||||
|
admin_user_id: str,
|
||||||
|
channel: dict[str, Any],
|
||||||
|
user_ids: list[str],
|
||||||
|
) -> None:
|
||||||
|
_clear_slack_conversation_members(
|
||||||
|
slack_client=slack_client,
|
||||||
|
channel=channel,
|
||||||
|
admin_user_id=admin_user_id,
|
||||||
|
)
|
||||||
|
_add_slack_conversation_members(
|
||||||
|
slack_client=slack_client, channel=channel, member_ids=user_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_message_to_channel(
|
||||||
|
slack_client: WebClient, channel: dict[str, Any], message: str
|
||||||
|
) -> None:
|
||||||
|
channel_id = _get_slack_channel_id(channel)
|
||||||
|
make_slack_api_call_w_retries(
|
||||||
|
slack_client.chat_postMessage,
|
||||||
|
channel=channel_id,
|
||||||
|
text=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def remove_message_from_channel(
|
||||||
|
slack_client: WebClient, channel: dict[str, Any], message: str
|
||||||
|
) -> None:
|
||||||
|
_delete_slack_conversation_messages(
|
||||||
|
slack_client=slack_client, channel=channel, message_to_delete=message
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cleanup_after_test(
|
||||||
|
slack_client: WebClient,
|
||||||
|
test_id: str,
|
||||||
|
) -> None:
|
||||||
|
channel_types = ["private_channel", "public_channel"]
|
||||||
|
channels: list[dict[str, Any]] = []
|
||||||
|
for result in make_paginated_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_list,
|
||||||
|
exclude_archived=False,
|
||||||
|
types=channel_types,
|
||||||
|
):
|
||||||
|
channels.extend(result["channels"])
|
||||||
|
|
||||||
|
for channel in channels:
|
||||||
|
if test_id not in channel.get("name", ""):
|
||||||
|
continue
|
||||||
|
# "done" in the channel name indicates that this channel is free to be used for a new test
|
||||||
|
new_name = f"done_{str(uuid4())}"
|
||||||
|
try:
|
||||||
|
make_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_rename,
|
||||||
|
channel=channel["id"],
|
||||||
|
name=new_name,
|
||||||
|
)
|
||||||
|
except SlackApiError as e:
|
||||||
|
print(f"Error renaming channel {channel['id']}: {e}")
|
@ -0,0 +1,251 @@
|
|||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from danswer.connectors.models import InputType
|
||||||
|
from danswer.db.enums import AccessType
|
||||||
|
from danswer.search.enums import LLMEvaluationType
|
||||||
|
from danswer.search.enums import SearchType
|
||||||
|
from danswer.search.models import RetrievalDetails
|
||||||
|
from danswer.server.documents.models import DocumentSource
|
||||||
|
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
||||||
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
|
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||||
|
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||||
|
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||||
|
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||||
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
|
from tests.integration.common_utils.test_models import DATestCCPair
|
||||||
|
from tests.integration.common_utils.test_models import DATestConnector
|
||||||
|
from tests.integration.common_utils.test_models import DATestCredential
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
from tests.integration.common_utils.vespa import vespa_fixture
|
||||||
|
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_permission_sync(
|
||||||
|
reset: None,
|
||||||
|
vespa_client: vespa_fixture,
|
||||||
|
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
public_channel, private_channel = slack_test_setup
|
||||||
|
|
||||||
|
# Creating an admin user (first user created is automatically an admin)
|
||||||
|
admin_user: DATestUser = UserManager.create(
|
||||||
|
email="admin@onyx-test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Creating a non-admin user
|
||||||
|
test_user_1: DATestUser = UserManager.create(
|
||||||
|
email="test_user_1@onyx-test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Creating a non-admin user
|
||||||
|
test_user_2: DATestUser = UserManager.create(
|
||||||
|
email="test_user_2@onyx-test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||||
|
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||||
|
admin_user_id = email_id_map[admin_user.email]
|
||||||
|
|
||||||
|
LLMProviderManager.create(user_performing_action=admin_user)
|
||||||
|
|
||||||
|
before = datetime.now(timezone.utc)
|
||||||
|
credential: DATestCredential = CredentialManager.create(
|
||||||
|
source=DocumentSource.SLACK,
|
||||||
|
credential_json={
|
||||||
|
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||||
|
},
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
connector: DATestConnector = ConnectorManager.create(
|
||||||
|
name="Slack",
|
||||||
|
input_type=InputType.POLL,
|
||||||
|
source=DocumentSource.SLACK,
|
||||||
|
connector_specific_config={
|
||||||
|
"workspace": "onyx-test-workspace",
|
||||||
|
"channels": [public_channel["name"], private_channel["name"]],
|
||||||
|
},
|
||||||
|
is_public=True,
|
||||||
|
groups=[],
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
cc_pair: DATestCCPair = CCPairManager.create(
|
||||||
|
credential_id=credential.id,
|
||||||
|
connector_id=connector.id,
|
||||||
|
access_type=AccessType.SYNC,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
CCPairManager.wait_for_indexing(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
after=before,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add test_user_1 and admin_user to the private channel
|
||||||
|
desired_channel_members = [admin_user, test_user_1]
|
||||||
|
SlackManager.set_channel_members(
|
||||||
|
slack_client=slack_client,
|
||||||
|
admin_user_id=admin_user_id,
|
||||||
|
channel=private_channel,
|
||||||
|
user_ids=[email_id_map[user.email] for user in desired_channel_members],
|
||||||
|
)
|
||||||
|
|
||||||
|
public_message = "Steve's favorite number is 809752"
|
||||||
|
private_message = "Sara's favorite number is 346794"
|
||||||
|
|
||||||
|
SlackManager.add_message_to_channel(
|
||||||
|
slack_client=slack_client,
|
||||||
|
channel=public_channel,
|
||||||
|
message=public_message,
|
||||||
|
)
|
||||||
|
SlackManager.add_message_to_channel(
|
||||||
|
slack_client=slack_client,
|
||||||
|
channel=private_channel,
|
||||||
|
message=private_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run indexing
|
||||||
|
before = datetime.now(timezone.utc)
|
||||||
|
CCPairManager.run_once(cc_pair, admin_user)
|
||||||
|
CCPairManager.wait_for_indexing(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
after=before,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run permission sync
|
||||||
|
before = datetime.now(timezone.utc)
|
||||||
|
CCPairManager.sync(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
CCPairManager.wait_for_sync(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
after=before,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search as admin with access to both channels
|
||||||
|
search_request = DocumentSearchRequest(
|
||||||
|
message="favorite number",
|
||||||
|
search_type=SearchType.KEYWORD,
|
||||||
|
retrieval_options=RetrievalDetails(),
|
||||||
|
evaluation_type=LLMEvaluationType.SKIP,
|
||||||
|
)
|
||||||
|
search_request_body = search_request.model_dump()
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/query/document-search",
|
||||||
|
json=search_request_body,
|
||||||
|
headers=admin_user.headers,
|
||||||
|
)
|
||||||
|
result.raise_for_status()
|
||||||
|
found_docs = result.json()["top_documents"]
|
||||||
|
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||||
|
|
||||||
|
# Ensure admin user can see messages from both channels
|
||||||
|
assert public_message in danswer_doc_message_strings
|
||||||
|
assert private_message in danswer_doc_message_strings
|
||||||
|
|
||||||
|
# Search as test_user_2 with access to only the public channel
|
||||||
|
search_request = DocumentSearchRequest(
|
||||||
|
message="favorite number",
|
||||||
|
search_type=SearchType.KEYWORD,
|
||||||
|
retrieval_options=RetrievalDetails(),
|
||||||
|
evaluation_type=LLMEvaluationType.SKIP,
|
||||||
|
)
|
||||||
|
search_request_body = search_request.model_dump()
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/query/document-search",
|
||||||
|
json=search_request_body,
|
||||||
|
headers=test_user_2.headers,
|
||||||
|
)
|
||||||
|
result.raise_for_status()
|
||||||
|
found_docs = result.json()["top_documents"]
|
||||||
|
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||||
|
print(
|
||||||
|
"\ntop_documents content before removing from private channel for test_user_2: ",
|
||||||
|
danswer_doc_message_strings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure test_user_2 can only see messages from the public channel
|
||||||
|
assert public_message in danswer_doc_message_strings
|
||||||
|
assert private_message not in danswer_doc_message_strings
|
||||||
|
|
||||||
|
# Search as test_user_1 with access to both channels
|
||||||
|
search_request = DocumentSearchRequest(
|
||||||
|
message="favorite number",
|
||||||
|
search_type=SearchType.KEYWORD,
|
||||||
|
retrieval_options=RetrievalDetails(),
|
||||||
|
evaluation_type=LLMEvaluationType.SKIP,
|
||||||
|
)
|
||||||
|
search_request_body = search_request.model_dump()
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/query/document-search",
|
||||||
|
json=search_request_body,
|
||||||
|
headers=test_user_1.headers,
|
||||||
|
)
|
||||||
|
result.raise_for_status()
|
||||||
|
found_docs = result.json()["top_documents"]
|
||||||
|
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||||
|
print(
|
||||||
|
"\ntop_documents content before removing from private channel for test_user_1: ",
|
||||||
|
danswer_doc_message_strings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure test_user_1 can see messages from both channels
|
||||||
|
assert public_message in danswer_doc_message_strings
|
||||||
|
assert private_message in danswer_doc_message_strings
|
||||||
|
|
||||||
|
# ----------------------MAKE THE CHANGES--------------------------
|
||||||
|
print("\nRemoving test_user_1 from the private channel")
|
||||||
|
# Remove test_user_1 from the private channel
|
||||||
|
desired_channel_members = [admin_user]
|
||||||
|
SlackManager.set_channel_members(
|
||||||
|
slack_client=slack_client,
|
||||||
|
admin_user_id=admin_user_id,
|
||||||
|
channel=private_channel,
|
||||||
|
user_ids=[email_id_map[user.email] for user in desired_channel_members],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run permission sync
|
||||||
|
CCPairManager.sync(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
CCPairManager.wait_for_sync(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
after=before,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ----------------------------VERIFY THE CHANGES---------------------------
|
||||||
|
# Ensure test_user_1 can no longer see messages from the private channel
|
||||||
|
# Search as test_user_1 with access to only the public channel
|
||||||
|
search_request = DocumentSearchRequest(
|
||||||
|
message="favorite number",
|
||||||
|
search_type=SearchType.KEYWORD,
|
||||||
|
retrieval_options=RetrievalDetails(),
|
||||||
|
evaluation_type=LLMEvaluationType.SKIP,
|
||||||
|
)
|
||||||
|
search_request_body = search_request.model_dump()
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/query/document-search",
|
||||||
|
json=search_request_body,
|
||||||
|
headers=test_user_1.headers,
|
||||||
|
)
|
||||||
|
result.raise_for_status()
|
||||||
|
found_docs = result.json()["top_documents"]
|
||||||
|
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||||
|
print(
|
||||||
|
"\ntop_documents content after removing from private channel for test_user_1: ",
|
||||||
|
danswer_doc_message_strings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure test_user_1 can only see messages from the public channel
|
||||||
|
assert public_message in danswer_doc_message_strings
|
||||||
|
assert private_message not in danswer_doc_message_strings
|
@ -0,0 +1,255 @@
|
|||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from danswer.connectors.models import InputType
|
||||||
|
from danswer.db.enums import AccessType
|
||||||
|
from danswer.search.enums import LLMEvaluationType
|
||||||
|
from danswer.search.enums import SearchType
|
||||||
|
from danswer.search.models import RetrievalDetails
|
||||||
|
from danswer.server.documents.models import DocumentSource
|
||||||
|
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
||||||
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
|
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||||
|
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||||
|
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||||
|
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||||
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
|
from tests.integration.common_utils.test_models import DATestCCPair
|
||||||
|
from tests.integration.common_utils.test_models import DATestConnector
|
||||||
|
from tests.integration.common_utils.test_models import DATestCredential
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
from tests.integration.common_utils.vespa import vespa_fixture
|
||||||
|
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_prune(
|
||||||
|
reset: None,
|
||||||
|
vespa_client: vespa_fixture,
|
||||||
|
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
public_channel, private_channel = slack_test_setup
|
||||||
|
|
||||||
|
# Creating an admin user (first user created is automatically an admin)
|
||||||
|
admin_user: DATestUser = UserManager.create(
|
||||||
|
email="admin@onyx-test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Creating a non-admin user
|
||||||
|
test_user_1: DATestUser = UserManager.create(
|
||||||
|
email="test_user_1@onyx-test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||||
|
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||||
|
admin_user_id = email_id_map[admin_user.email]
|
||||||
|
|
||||||
|
LLMProviderManager.create(user_performing_action=admin_user)
|
||||||
|
|
||||||
|
before = datetime.now(timezone.utc)
|
||||||
|
credential: DATestCredential = CredentialManager.create(
|
||||||
|
source=DocumentSource.SLACK,
|
||||||
|
credential_json={
|
||||||
|
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||||
|
},
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
connector: DATestConnector = ConnectorManager.create(
|
||||||
|
name="Slack",
|
||||||
|
input_type=InputType.POLL,
|
||||||
|
source=DocumentSource.SLACK,
|
||||||
|
connector_specific_config={
|
||||||
|
"workspace": "onyx-test-workspace",
|
||||||
|
"channels": [public_channel["name"], private_channel["name"]],
|
||||||
|
},
|
||||||
|
is_public=True,
|
||||||
|
groups=[],
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
cc_pair: DATestCCPair = CCPairManager.create(
|
||||||
|
credential_id=credential.id,
|
||||||
|
connector_id=connector.id,
|
||||||
|
access_type=AccessType.SYNC,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
CCPairManager.wait_for_indexing(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
after=before,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ----------------------SETUP INITIAL SLACK STATE--------------------------
|
||||||
|
# Add test_user_1 and admin_user to the private channel
|
||||||
|
desired_channel_members = [admin_user, test_user_1]
|
||||||
|
SlackManager.set_channel_members(
|
||||||
|
slack_client=slack_client,
|
||||||
|
admin_user_id=admin_user_id,
|
||||||
|
channel=private_channel,
|
||||||
|
user_ids=[email_id_map[user.email] for user in desired_channel_members],
|
||||||
|
)
|
||||||
|
|
||||||
|
public_message = "Steve's favorite number is 809752"
|
||||||
|
private_message = "Sara's favorite number is 346794"
|
||||||
|
message_to_delete = "Rebecca's favorite number is 753468"
|
||||||
|
|
||||||
|
SlackManager.add_message_to_channel(
|
||||||
|
slack_client=slack_client,
|
||||||
|
channel=public_channel,
|
||||||
|
message=public_message,
|
||||||
|
)
|
||||||
|
SlackManager.add_message_to_channel(
|
||||||
|
slack_client=slack_client,
|
||||||
|
channel=private_channel,
|
||||||
|
message=private_message,
|
||||||
|
)
|
||||||
|
SlackManager.add_message_to_channel(
|
||||||
|
slack_client=slack_client,
|
||||||
|
channel=private_channel,
|
||||||
|
message=message_to_delete,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run indexing
|
||||||
|
before = datetime.now(timezone.utc)
|
||||||
|
CCPairManager.run_once(cc_pair, admin_user)
|
||||||
|
CCPairManager.wait_for_indexing(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
after=before,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run permission sync
|
||||||
|
before = datetime.now(timezone.utc)
|
||||||
|
CCPairManager.sync(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
CCPairManager.wait_for_sync(
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
after=before,
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ----------------------TEST THE SETUP--------------------------
|
||||||
|
# Search as admin with access to both channels
|
||||||
|
search_request = DocumentSearchRequest(
|
||||||
|
message="favorite number",
|
||||||
|
search_type=SearchType.KEYWORD,
|
||||||
|
retrieval_options=RetrievalDetails(),
|
||||||
|
evaluation_type=LLMEvaluationType.SKIP,
|
||||||
|
)
|
||||||
|
search_request_body = search_request.model_dump()
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/query/document-search",
|
||||||
|
json=search_request_body,
|
||||||
|
headers=admin_user.headers,
|
||||||
|
)
|
||||||
|
result.raise_for_status()
|
||||||
|
found_docs = result.json()["top_documents"]
|
||||||
|
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||||
|
print(
|
||||||
|
"\ntop_documents content before deleting for admin: ",
|
||||||
|
danswer_doc_message_strings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure admin user can see all messages
|
||||||
|
assert public_message in danswer_doc_message_strings
|
||||||
|
assert private_message in danswer_doc_message_strings
|
||||||
|
assert message_to_delete in danswer_doc_message_strings
|
||||||
|
|
||||||
|
# Search as test_user_1 with access to both channels
|
||||||
|
search_request = DocumentSearchRequest(
|
||||||
|
message="favorite number",
|
||||||
|
search_type=SearchType.KEYWORD,
|
||||||
|
retrieval_options=RetrievalDetails(),
|
||||||
|
evaluation_type=LLMEvaluationType.SKIP,
|
||||||
|
)
|
||||||
|
search_request_body = search_request.model_dump()
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/query/document-search",
|
||||||
|
json=search_request_body,
|
||||||
|
headers=test_user_1.headers,
|
||||||
|
)
|
||||||
|
result.raise_for_status()
|
||||||
|
found_docs = result.json()["top_documents"]
|
||||||
|
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||||
|
print(
|
||||||
|
"\ntop_documents content before deleting for test_user_1: ",
|
||||||
|
danswer_doc_message_strings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure test_user_1 can see all messages
|
||||||
|
assert public_message in danswer_doc_message_strings
|
||||||
|
assert private_message in danswer_doc_message_strings
|
||||||
|
assert message_to_delete in danswer_doc_message_strings
|
||||||
|
|
||||||
|
# ----------------------MAKE THE CHANGES--------------------------
|
||||||
|
# Delete messages
|
||||||
|
print("\nDeleting message: ", message_to_delete)
|
||||||
|
SlackManager.remove_message_from_channel(
|
||||||
|
slack_client=slack_client,
|
||||||
|
channel=private_channel,
|
||||||
|
message=message_to_delete,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prune the cc_pair
|
||||||
|
before = datetime.now(timezone.utc)
|
||||||
|
CCPairManager.prune(cc_pair, user_performing_action=admin_user)
|
||||||
|
CCPairManager.wait_for_prune(cc_pair, before, user_performing_action=admin_user)
|
||||||
|
|
||||||
|
# ----------------------------VERIFY THE CHANGES---------------------------
|
||||||
|
# Ensure admin user can't see deleted messages
|
||||||
|
# Search as admin user with access to only the public channel
|
||||||
|
search_request = DocumentSearchRequest(
|
||||||
|
message="favorite number",
|
||||||
|
search_type=SearchType.KEYWORD,
|
||||||
|
retrieval_options=RetrievalDetails(),
|
||||||
|
evaluation_type=LLMEvaluationType.SKIP,
|
||||||
|
)
|
||||||
|
search_request_body = search_request.model_dump()
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/query/document-search",
|
||||||
|
json=search_request_body,
|
||||||
|
headers=admin_user.headers,
|
||||||
|
)
|
||||||
|
result.raise_for_status()
|
||||||
|
found_docs = result.json()["top_documents"]
|
||||||
|
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||||
|
print(
|
||||||
|
"\ntop_documents content after deleting for admin: ",
|
||||||
|
danswer_doc_message_strings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure admin can't see deleted messages
|
||||||
|
assert public_message in danswer_doc_message_strings
|
||||||
|
assert private_message in danswer_doc_message_strings
|
||||||
|
assert message_to_delete not in danswer_doc_message_strings
|
||||||
|
|
||||||
|
# Ensure test_user_1 can't see deleted messages
|
||||||
|
# Search as test_user_1 with access to only the public channel
|
||||||
|
search_request = DocumentSearchRequest(
|
||||||
|
message="favorite number",
|
||||||
|
search_type=SearchType.KEYWORD,
|
||||||
|
retrieval_options=RetrievalDetails(),
|
||||||
|
evaluation_type=LLMEvaluationType.SKIP,
|
||||||
|
)
|
||||||
|
search_request_body = search_request.model_dump()
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{API_SERVER_URL}/query/document-search",
|
||||||
|
json=search_request_body,
|
||||||
|
headers=test_user_1.headers,
|
||||||
|
)
|
||||||
|
result.raise_for_status()
|
||||||
|
found_docs = result.json()["top_documents"]
|
||||||
|
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
|
||||||
|
print(
|
||||||
|
"\ntop_documents content after prune for test_user_1: ",
|
||||||
|
danswer_doc_message_strings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure test_user_1 can't see deleted messages
|
||||||
|
assert public_message in danswer_doc_message_strings
|
||||||
|
assert private_message in danswer_doc_message_strings
|
||||||
|
assert message_to_delete not in danswer_doc_message_strings
|
Loading…
x
Reference in New Issue
Block a user