Implement source testing framework + Slack (#2650)

* Added permission sync tests for Slack

* moved folders

* prune test + mypy

* added wait for indexing to cc_pair creation

* commented out check

* should fix other tests

* added slack channel pool

* fixed everything and mypy

* reduced flake
This commit is contained in:
hagen-danswer 2024-10-02 16:16:07 -07:00 committed by GitHub
parent b3c367d09c
commit c2088602e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1098 additions and 63 deletions

View File

@ -12,6 +12,7 @@ on:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
integration-tests:
@ -142,6 +143,7 @@ jobs:
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test
continue-on-error: true

View File

@ -205,12 +205,17 @@ _DISALLOWED_MSG_SUBTYPES = {
"group_leave",
"group_archive",
"group_unarchive",
"channel_leave",
"channel_name",
"channel_join",
}
def _default_msg_filter(message: MessageType) -> bool:
def default_msg_filter(message: MessageType) -> bool:
# Don't keep messages from bots
if message.get("bot_id") or message.get("app_id"):
if message.get("bot_profile", {}).get("name") == "DanswerConnector":
return False
return True
# Uninformative
@ -261,7 +266,7 @@ def _get_all_docs(
channel_name_regex_enabled: bool = False,
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Generator[Document, None, None]:
"""Get all documents in the workspace, channel by channel"""
slack_cleaner = SlackTextCleaner(client=client)
@ -320,7 +325,7 @@ def _get_all_doc_ids(
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> set[str]:
"""
Get all document ids in the workspace, channel by channel

View File

@ -29,15 +29,19 @@ from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from danswer.db.models import User
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.server.documents.models import CCPairFullInfo
from danswer.server.documents.models import CCPairPruningTask
from danswer.server.documents.models import CCStatusUpdateRequest
from danswer.server.documents.models import CeleryTaskStatus
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairMetadata
from danswer.server.documents.models import PaginatedIndexAttempts
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.db.user_group import validate_user_creation_permissions
logger = setup_logger()
@ -199,7 +203,7 @@ def get_cc_pair_latest_prune(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> CCPairPruningTask:
) -> CeleryTaskStatus:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
@ -223,7 +227,7 @@ def get_cc_pair_latest_prune(
detail="No pruning task found.",
)
return CCPairPruningTask(
return CeleryTaskStatus(
id=last_pruning_task.task_id,
name=last_pruning_task.task_name,
status=last_pruning_task.status,
@ -280,6 +284,95 @@ def prune_cc_pair(
)
@router.get("/admin/cc-pair/{cc_pair_id}/sync")
def get_cc_pair_latest_sync(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> CeleryTaskStatus:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
# look up the last sync task for this connector (if it exists)
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
last_sync_task = get_latest_task(sync_task_name, db_session)
if not last_sync_task:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="No sync task found.",
)
return CeleryTaskStatus(
id=last_sync_task.task_id,
name=last_sync_task.task_name,
status=last_sync_task.status,
start_time=last_sync_task.start_time,
register_time=last_sync_task.register_time,
)
@router.post("/admin/cc-pair/{cc_pair_id}/sync")
def sync_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
# avoiding circular refs
from ee.danswer.background.celery.celery_app import (
sync_external_doc_permissions_task,
)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=False,
)
if not cc_pair:
raise HTTPException(
status_code=400,
detail="Connection not found for current user's permissions",
)
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
last_sync_task = get_latest_task(sync_task_name, db_session)
if last_sync_task and check_task_is_live_and_not_timed_out(
last_sync_task, db_session
):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Sync task already in progress.",
)
if skip_cc_pair_pruning_by_task(
last_sync_task,
db_session=db_session,
):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Sync task already in progress.",
)
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair_id),
)
return StatusResponse(
success=True,
message="Successfully created the sync task.",
)
@router.put("/connector/{connector_id}/credential/{credential_id}")
def associate_credential_to_connector(
connector_id: int,

View File

@ -781,6 +781,7 @@ def connector_run_once(
detail="Connector has no valid credentials, cannot create index attempts.",
)
# Prevents index attempts for cc pairs that already have an index attempt currently running
skipped_credentials = [
credential_id
for credential_id in credential_ids
@ -790,15 +791,15 @@ def connector_run_once(
credential_id=credential_id,
),
only_current=True,
disinclude_finished=True,
db_session=db_session,
disinclude_finished=True,
)
]
search_settings = get_current_search_settings(db_session)
connector_credential_pairs = [
get_connector_credential_pair(run_info.connector_id, credential_id, db_session)
get_connector_credential_pair(connector_id, credential_id, db_session)
for credential_id in credential_ids
if credential_id not in skipped_credentials
]

View File

@ -268,7 +268,7 @@ class CCPairFullInfo(BaseModel):
)
class CCPairPruningTask(BaseModel):
class CeleryTaskStatus(BaseModel):
id: str
name: str
status: TaskStatus

View File

@ -1,3 +1,6 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy.orm import Session
from danswer.db.enums import AccessType
@ -12,10 +15,32 @@ from ee.danswer.background.task_name_builders import (
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
logger = setup_logger()
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is None, it has never been run so we run the sync
if cc_pair.last_time_perm_sync is None:
return True
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
# If the last sync is greater than the full fetch period, we run the sync
if (current_time - last_sync).total_seconds() > source_sync_period:
return True
return False
def should_perform_chat_ttl_check(
retention_limit_days: int | None, db_session: Session
) -> bool:
@ -28,7 +53,7 @@ def should_perform_chat_ttl_check(
if not latest_task:
return True
if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session):
if check_task_is_live_and_not_timed_out(latest_task, db_session):
logger.debug(f"{task_name} is already being performed. Skipping.")
return False
return True
@ -50,6 +75,9 @@ def should_perform_external_doc_permissions_check(
logger.debug(f"{task_name} is already being performed. Skipping.")
return False
if not _is_time_to_run_sync(cc_pair):
return False
return True
@ -69,4 +97,7 @@ def should_perform_external_group_permissions_check(
logger.debug(f"{task_name} is already being performed. Skipping.")
return False
if not _is_time_to_run_sync(cc_pair):
return False
return True

View File

@ -6,38 +6,15 @@ from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import get_document_ids_for_connector_credential_pair
from danswer.db.models import ConnectorCredentialPair
from danswer.document_index.factory import get_current_primary_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
logger = setup_logger()
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is None, it has never been run so we run the sync
if cc_pair.last_time_perm_sync is None:
return True
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
# If the last sync is greater than the full fetch period, we run the sync
if (current_time - last_sync).total_seconds() > source_sync_period:
return True
return False
def run_external_group_permission_sync(
db_session: Session,
cc_pair_id: int,
@ -53,9 +30,6 @@ def run_external_group_permission_sync(
# Not all sync connectors support group permissions so this is fine
return
if not _is_time_to_run_sync(cc_pair):
return
try:
# This function updates:
# - the user_email <-> external_user_group_id mapping
@ -91,9 +65,6 @@ def run_external_doc_permission_sync(
f"No permission sync function found for source type: {source_type}"
)
if not _is_time_to_run_sync(cc_pair):
return
try:
# This function updates:
# - the user_email <-> document mapping

View File

@ -4,7 +4,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
MAX_DELAY = 30
MAX_DELAY = 45
GENERAL_HEADERS = {"Content-Type": "application/json"}

View File

@ -9,7 +9,7 @@ from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import TaskStatus
from danswer.server.documents.models import CCPairPruningTask
from danswer.server.documents.models import CeleryTaskStatus
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorIndexingStatus
from danswer.server.documents.models import DocumentSource
@ -85,7 +85,7 @@ class CCPairManager:
groups=groups,
user_performing_action=user_performing_action,
)
return _cc_pair_creator(
cc_pair = _cc_pair_creator(
connector_id=connector.id,
credential_id=credential.id,
name=name,
@ -93,6 +93,7 @@ class CCPairManager:
groups=groups,
user_performing_action=user_performing_action,
)
return cc_pair
@staticmethod
def create(
@ -103,7 +104,7 @@ class CCPairManager:
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
return _cc_pair_creator(
cc_pair = _cc_pair_creator(
connector_id=connector_id,
credential_id=credential_id,
name=name,
@ -111,6 +112,7 @@ class CCPairManager:
groups=groups,
user_performing_action=user_performing_action,
)
return cc_pair
@staticmethod
def pause_cc_pair(
@ -203,9 +205,28 @@ class CCPairManager:
if not verify_deleted:
raise ValueError(f"CC pair {cc_pair.id} not found")
@staticmethod
def run_once(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
body = {
"connector_id": cc_pair.connector_id,
"credential_ids": [cc_pair.credential_id],
"from_beginning": True,
}
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
json=body,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
result.raise_for_status()
@staticmethod
def wait_for_indexing(
cc_pair_test: DATestCCPair,
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
@ -213,14 +234,20 @@ class CCPairManager:
"""after: Wait for an indexing success time after this time"""
start = time.monotonic()
while True:
cc_pairs = CCPairManager.get_all(user_performing_action)
for cc_pair in cc_pairs:
if cc_pair.cc_pair_id != cc_pair_test.id:
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
for fetched_cc_pair in fetched_cc_pairs:
if fetched_cc_pair.cc_pair_id != cc_pair.id:
continue
if cc_pair.last_success and cc_pair.last_success > after:
print(f"cc_pair {cc_pair_test.id} indexing complete.")
if (
fetched_cc_pair.last_success
and fetched_cc_pair.last_success > after
):
print(f"cc_pair {cc_pair.id} indexing complete.")
return
else:
print("cc_pair found but not finished:")
# print(fetched_cc_pair.__dict__)
elapsed = time.monotonic() - start
if elapsed > timeout:
@ -250,7 +277,7 @@ class CCPairManager:
def get_prune_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> CCPairPruningTask:
) -> CeleryTaskStatus:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
headers=user_performing_action.headers
@ -258,11 +285,11 @@ class CCPairManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return CCPairPruningTask(**response.json())
return CeleryTaskStatus(**response.json())
@staticmethod
def wait_for_prune(
cc_pair_test: DATestCCPair,
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
@ -270,7 +297,7 @@ class CCPairManager:
"""after: The task register time must be after this time."""
start = time.monotonic()
while True:
task = CCPairManager.get_prune_task(cc_pair_test, user_performing_action)
task = CCPairManager.get_prune_task(cc_pair, user_performing_action)
if not task:
raise ValueError("Prune task not found.")
@ -292,16 +319,75 @@ class CCPairManager:
)
time.sleep(5)
@staticmethod
def sync(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
result.raise_for_status()
@staticmethod
def get_sync_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> CeleryTaskStatus:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return CeleryTaskStatus(**response.json())
@staticmethod
def wait_for_sync(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: The task register time must be after this time."""
start = time.monotonic()
while True:
task = CCPairManager.get_sync_task(cc_pair, user_performing_action)
if not task:
raise ValueError("Sync task not found.")
if not task.register_time or task.register_time < after:
raise ValueError("Sync task register time is too early.")
if task.status == TaskStatus.SUCCESS:
# Sync succeeded
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"CC pair syncing was not completed within {timeout} seconds"
)
print(
f"Waiting for CC syncing to complete. elapsed={elapsed:.2f} timeout={timeout}"
)
time.sleep(5)
@staticmethod
def wait_for_deletion_completion(
user_performing_action: DATestUser | None = None,
) -> None:
start = time.monotonic()
while True:
cc_pairs = CCPairManager.get_all(user_performing_action)
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
if all(
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
for cc_pair in cc_pairs
for cc_pair in fetched_cc_pairs
):
return

View File

@ -50,9 +50,7 @@ class LLMProviderManager:
)
llm_response.raise_for_status()
response_data = llm_response.json()
import json
print(json.dumps(response_data, indent=4))
result_llm = DATestLLMProvider(
id=response_data["id"],
name=response_data["name"],

View File

@ -17,11 +17,14 @@ class UserManager:
@staticmethod
def create(
name: str | None = None,
email: str | None = None,
) -> DATestUser:
if name is None:
name = f"test{str(uuid4())}"
email = f"{name}@test.com"
if email is None:
email = f"{name}@test.com"
password = "test"
body = {
@ -44,12 +47,10 @@ class UserManager:
)
print(f"Created user {test_user.email}")
test_user.headers["Cookie"] = UserManager.login_as_user(test_user)
return test_user
return UserManager.login_as_user(test_user)
@staticmethod
def login_as_user(test_user: DATestUser) -> str:
def login_as_user(test_user: DATestUser) -> DATestUser:
data = urlencode(
{
"username": test_user.email,
@ -71,7 +72,9 @@ class UserManager:
raise Exception("Failed to login")
print(f"Logged in as {test_user.email}")
return f"{result_cookie.name}={result_cookie.value}"
cookie = f"{result_cookie.name}={result_cookie.value}"
test_user.headers["Cookie"] = cookie
return test_user
@staticmethod
def verify_role(user_to_verify: DATestUser, target_role: UserRole) -> bool:

View File

@ -0,0 +1,28 @@
import os
from collections.abc import Generator
from typing import Any
import pytest
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
@pytest.fixture()
def slack_test_setup() -> Generator[tuple[dict[str, Any], dict[str, Any]], None, None]:
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
admin_user_id = SlackManager.build_slack_user_email_id_map(slack_client)[
"admin@onyx-test.com"
]
(
public_channel,
private_channel,
run_id,
) = SlackManager.get_and_provision_available_slack_channels(
slack_client=slack_client, admin_user_id=admin_user_id
)
yield public_channel, private_channel
# This part will always run after the test, even if it fails
SlackManager.cleanup_after_test(slack_client=slack_client, test_id=run_id)

View File

@ -0,0 +1,311 @@
"""
Assumptions:
- The test users have already been created
- General is empty of messages
- In addition to the normal slack oauth permissions, the following scopes are needed:
- channels:manage
- groups:write
- chat:write
- chat:write.public
"""
from typing import Any
from uuid import uuid4
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from danswer.connectors.slack.connector import default_msg_filter
from danswer.connectors.slack.connector import get_channel_messages
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
def _get_slack_channel_id(channel: dict[str, Any]) -> str:
if not (channel_id := channel.get("id")):
raise ValueError("Channel ID is missing")
return channel_id
def _get_non_general_channels(
slack_client: WebClient,
get_private: bool,
get_public: bool,
only_get_done: bool = False,
) -> list[dict[str, Any]]:
channel_types = []
if get_private:
channel_types.append("private_channel")
if get_public:
channel_types.append("public_channel")
conversations: list[dict[str, Any]] = []
for result in make_paginated_slack_api_call_w_retries(
slack_client.conversations_list,
exclude_archived=False,
types=channel_types,
):
conversations.extend(result["channels"])
filtered_conversations = []
for conversation in conversations:
if conversation.get("is_general", False):
continue
if only_get_done and "done" not in conversation.get("name", ""):
continue
filtered_conversations.append(conversation)
return filtered_conversations
def _clear_slack_conversation_members(
slack_client: WebClient,
admin_user_id: str,
channel: dict[str, Any],
) -> None:
channel_id = _get_slack_channel_id(channel)
member_ids: list[str] = []
for result in make_paginated_slack_api_call_w_retries(
slack_client.conversations_members,
channel=channel_id,
):
member_ids.extend(result["members"])
for member_id in member_ids:
if member_id == admin_user_id:
continue
try:
make_slack_api_call_w_retries(
slack_client.conversations_kick, channel=channel_id, user=member_id
)
print(f"Kicked member: {member_id}")
except Exception as e:
if "cant_kick_self" in str(e):
continue
print(f"Error kicking member: {e}")
print(member_id)
try:
make_slack_api_call_w_retries(
slack_client.conversations_unarchive, channel=channel_id
)
channel["is_archived"] = False
except Exception:
# Channel is already unarchived
pass
def _add_slack_conversation_members(
slack_client: WebClient, channel: dict[str, Any], member_ids: list[str]
) -> None:
channel_id = _get_slack_channel_id(channel)
for user_id in member_ids:
try:
make_slack_api_call_w_retries(
slack_client.conversations_invite,
channel=channel_id,
users=user_id,
)
except Exception as e:
if "already_in_channel" in str(e):
continue
print(f"Error inviting member: {e}")
print(user_id)
def _delete_slack_conversation_messages(
slack_client: WebClient,
channel: dict[str, Any],
message_to_delete: str | None = None,
) -> None:
"""deletes all messages from a channel if message_to_delete is None"""
channel_id = _get_slack_channel_id(channel)
for message_batch in get_channel_messages(slack_client, channel):
for message in message_batch:
if default_msg_filter(message):
continue
if message_to_delete and message.get("text") != message_to_delete:
continue
print(" removing message: ", message.get("text"))
try:
if not (ts := message.get("ts")):
raise ValueError("Message timestamp is missing")
make_slack_api_call_w_retries(
slack_client.chat_delete,
channel=channel_id,
ts=ts,
)
except Exception as e:
print(f"Error deleting message: {e}")
print(message)
def _build_slack_channel_from_name(
slack_client: WebClient,
admin_user_id: str,
suffix: str,
is_private: bool,
channel: dict[str, Any] | None,
) -> dict[str, Any]:
base = "public_channel" if not is_private else "private_channel"
channel_name = f"{base}-{suffix}"
if channel:
# If channel is provided, we rename it
channel_id = _get_slack_channel_id(channel)
channel_response = make_slack_api_call_w_retries(
slack_client.conversations_rename,
channel=channel_id,
name=channel_name,
)
else:
# Otherwise, we create a new channel
channel_response = make_slack_api_call_w_retries(
slack_client.conversations_create,
name=channel_name,
is_private=is_private,
)
try:
channel_response = make_slack_api_call_w_retries(
slack_client.conversations_unarchive,
channel=channel_response["channel"]["id"],
)
except Exception:
# Channel is already unarchived
pass
try:
channel_response = make_slack_api_call_w_retries(
slack_client.conversations_invite,
channel=channel_response["channel"]["id"],
users=[admin_user_id],
)
except Exception:
pass
final_channel = channel_response["channel"] if channel_response else {}
return final_channel
class SlackManager:
@staticmethod
def get_slack_client(token: str) -> WebClient:
return WebClient(token=token)
@staticmethod
def get_and_provision_available_slack_channels(
slack_client: WebClient, admin_user_id: str
) -> tuple[dict[str, Any], dict[str, Any], str]:
run_id = str(uuid4())
public_channels = _get_non_general_channels(
slack_client, get_private=False, get_public=True, only_get_done=True
)
first_available_channel = (
None if len(public_channels) < 1 else public_channels[0]
)
public_channel = _build_slack_channel_from_name(
slack_client=slack_client,
admin_user_id=admin_user_id,
suffix=run_id,
is_private=False,
channel=first_available_channel,
)
_delete_slack_conversation_messages(
slack_client=slack_client, channel=public_channel
)
private_channels = _get_non_general_channels(
slack_client, get_private=True, get_public=False, only_get_done=True
)
second_available_channel = (
None if len(private_channels) < 1 else private_channels[0]
)
private_channel = _build_slack_channel_from_name(
slack_client=slack_client,
admin_user_id=admin_user_id,
suffix=run_id,
is_private=True,
channel=second_available_channel,
)
_delete_slack_conversation_messages(
slack_client=slack_client, channel=private_channel
)
return public_channel, private_channel, run_id
@staticmethod
def build_slack_user_email_id_map(slack_client: WebClient) -> dict[str, str]:
users_results = make_slack_api_call_w_retries(
slack_client.users_list,
)
users: list[dict[str, Any]] = users_results.get("members", [])
user_email_id_map = {}
for user in users:
if not (email := user.get("profile", {}).get("email")):
continue
if not (user_id := user.get("id")):
raise ValueError("User ID is missing")
user_email_id_map[email] = user_id
return user_email_id_map
@staticmethod
def set_channel_members(
slack_client: WebClient,
admin_user_id: str,
channel: dict[str, Any],
user_ids: list[str],
) -> None:
_clear_slack_conversation_members(
slack_client=slack_client,
channel=channel,
admin_user_id=admin_user_id,
)
_add_slack_conversation_members(
slack_client=slack_client, channel=channel, member_ids=user_ids
)
@staticmethod
def add_message_to_channel(
slack_client: WebClient, channel: dict[str, Any], message: str
) -> None:
channel_id = _get_slack_channel_id(channel)
make_slack_api_call_w_retries(
slack_client.chat_postMessage,
channel=channel_id,
text=message,
)
@staticmethod
def remove_message_from_channel(
slack_client: WebClient, channel: dict[str, Any], message: str
) -> None:
_delete_slack_conversation_messages(
slack_client=slack_client, channel=channel, message_to_delete=message
)
@staticmethod
def cleanup_after_test(
slack_client: WebClient,
test_id: str,
) -> None:
channel_types = ["private_channel", "public_channel"]
channels: list[dict[str, Any]] = []
for result in make_paginated_slack_api_call_w_retries(
slack_client.conversations_list,
exclude_archived=False,
types=channel_types,
):
channels.extend(result["channels"])
for channel in channels:
if test_id not in channel.get("name", ""):
continue
# "done" in the channel name indicates that this channel is free to be used for a new test
new_name = f"done_{str(uuid4())}"
try:
make_slack_api_call_w_retries(
slack_client.conversations_rename,
channel=channel["id"],
name=new_name,
)
except SlackApiError as e:
print(f"Error renaming channel {channel['id']}: {e}")

View File

@ -0,0 +1,251 @@
import os
from datetime import datetime
from datetime import timezone
from typing import Any
import requests
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDetails
from danswer.server.documents.models import DocumentSource
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
def test_slack_permission_sync(
reset: None,
vespa_client: vespa_fixture,
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
) -> None:
public_channel, private_channel = slack_test_setup
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
# Creating a non-admin user
test_user_1: DATestUser = UserManager.create(
email="test_user_1@onyx-test.com",
)
# Creating a non-admin user
test_user_2: DATestUser = UserManager.create(
email="test_user_2@onyx-test.com",
)
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = email_id_map[admin_user.email]
LLMProviderManager.create(user_performing_action=admin_user)
before = datetime.now(timezone.utc)
credential: DATestCredential = CredentialManager.create(
source=DocumentSource.SLACK,
credential_json={
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
},
user_performing_action=admin_user,
)
connector: DATestConnector = ConnectorManager.create(
name="Slack",
input_type=InputType.POLL,
source=DocumentSource.SLACK,
connector_specific_config={
"workspace": "onyx-test-workspace",
"channels": [public_channel["name"], private_channel["name"]],
},
is_public=True,
groups=[],
user_performing_action=admin_user,
)
cc_pair: DATestCCPair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# Add test_user_1 and admin_user to the private channel
desired_channel_members = [admin_user, test_user_1]
SlackManager.set_channel_members(
slack_client=slack_client,
admin_user_id=admin_user_id,
channel=private_channel,
user_ids=[email_id_map[user.email] for user in desired_channel_members],
)
public_message = "Steve's favorite number is 809752"
private_message = "Sara's favorite number is 346794"
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=public_channel,
message=public_message,
)
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=private_channel,
message=private_message,
)
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# Run permission sync
before = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# Search as admin with access to both channels
search_request = DocumentSearchRequest(
message="favorite number",
search_type=SearchType.KEYWORD,
retrieval_options=RetrievalDetails(),
evaluation_type=LLMEvaluationType.SKIP,
)
search_request_body = search_request.model_dump()
result = requests.post(
url=f"{API_SERVER_URL}/query/document-search",
json=search_request_body,
headers=admin_user.headers,
)
result.raise_for_status()
found_docs = result.json()["top_documents"]
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
# Ensure admin user can see messages from both channels
assert public_message in danswer_doc_message_strings
assert private_message in danswer_doc_message_strings
# Search as test_user_2 with access to only the public channel
search_request = DocumentSearchRequest(
message="favorite number",
search_type=SearchType.KEYWORD,
retrieval_options=RetrievalDetails(),
evaluation_type=LLMEvaluationType.SKIP,
)
search_request_body = search_request.model_dump()
result = requests.post(
url=f"{API_SERVER_URL}/query/document-search",
json=search_request_body,
headers=test_user_2.headers,
)
result.raise_for_status()
found_docs = result.json()["top_documents"]
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
print(
"\ntop_documents content before removing from private channel for test_user_2: ",
danswer_doc_message_strings,
)
# Ensure test_user_2 can only see messages from the public channel
assert public_message in danswer_doc_message_strings
assert private_message not in danswer_doc_message_strings
# Search as test_user_1 with access to both channels
search_request = DocumentSearchRequest(
message="favorite number",
search_type=SearchType.KEYWORD,
retrieval_options=RetrievalDetails(),
evaluation_type=LLMEvaluationType.SKIP,
)
search_request_body = search_request.model_dump()
result = requests.post(
url=f"{API_SERVER_URL}/query/document-search",
json=search_request_body,
headers=test_user_1.headers,
)
result.raise_for_status()
found_docs = result.json()["top_documents"]
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
print(
"\ntop_documents content before removing from private channel for test_user_1: ",
danswer_doc_message_strings,
)
# Ensure test_user_1 can see messages from both channels
assert public_message in danswer_doc_message_strings
assert private_message in danswer_doc_message_strings
# ----------------------MAKE THE CHANGES--------------------------
print("\nRemoving test_user_1 from the private channel")
# Remove test_user_1 from the private channel
desired_channel_members = [admin_user]
SlackManager.set_channel_members(
slack_client=slack_client,
admin_user_id=admin_user_id,
channel=private_channel,
user_ids=[email_id_map[user.email] for user in desired_channel_members],
)
# Run permission sync
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# ----------------------------VERIFY THE CHANGES---------------------------
# Ensure test_user_1 can no longer see messages from the private channel
# Search as test_user_1 with access to only the public channel
search_request = DocumentSearchRequest(
message="favorite number",
search_type=SearchType.KEYWORD,
retrieval_options=RetrievalDetails(),
evaluation_type=LLMEvaluationType.SKIP,
)
search_request_body = search_request.model_dump()
result = requests.post(
url=f"{API_SERVER_URL}/query/document-search",
json=search_request_body,
headers=test_user_1.headers,
)
result.raise_for_status()
found_docs = result.json()["top_documents"]
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
print(
"\ntop_documents content after removing from private channel for test_user_1: ",
danswer_doc_message_strings,
)
# Ensure test_user_1 can only see messages from the public channel
assert public_message in danswer_doc_message_strings
assert private_message not in danswer_doc_message_strings

View File

@ -0,0 +1,255 @@
import os
from datetime import datetime
from datetime import timezone
from typing import Any
import requests
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDetails
from danswer.server.documents.models import DocumentSource
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
def test_slack_prune(
reset: None,
vespa_client: vespa_fixture,
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
) -> None:
public_channel, private_channel = slack_test_setup
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
# Creating a non-admin user
test_user_1: DATestUser = UserManager.create(
email="test_user_1@onyx-test.com",
)
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = email_id_map[admin_user.email]
LLMProviderManager.create(user_performing_action=admin_user)
before = datetime.now(timezone.utc)
credential: DATestCredential = CredentialManager.create(
source=DocumentSource.SLACK,
credential_json={
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
},
user_performing_action=admin_user,
)
connector: DATestConnector = ConnectorManager.create(
name="Slack",
input_type=InputType.POLL,
source=DocumentSource.SLACK,
connector_specific_config={
"workspace": "onyx-test-workspace",
"channels": [public_channel["name"], private_channel["name"]],
},
is_public=True,
groups=[],
user_performing_action=admin_user,
)
cc_pair: DATestCCPair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# ----------------------SETUP INITIAL SLACK STATE--------------------------
# Add test_user_1 and admin_user to the private channel
desired_channel_members = [admin_user, test_user_1]
SlackManager.set_channel_members(
slack_client=slack_client,
admin_user_id=admin_user_id,
channel=private_channel,
user_ids=[email_id_map[user.email] for user in desired_channel_members],
)
public_message = "Steve's favorite number is 809752"
private_message = "Sara's favorite number is 346794"
message_to_delete = "Rebecca's favorite number is 753468"
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=public_channel,
message=public_message,
)
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=private_channel,
message=private_message,
)
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=private_channel,
message=message_to_delete,
)
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# Run permission sync
before = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# ----------------------TEST THE SETUP--------------------------
# Search as admin with access to both channels
search_request = DocumentSearchRequest(
message="favorite number",
search_type=SearchType.KEYWORD,
retrieval_options=RetrievalDetails(),
evaluation_type=LLMEvaluationType.SKIP,
)
search_request_body = search_request.model_dump()
result = requests.post(
url=f"{API_SERVER_URL}/query/document-search",
json=search_request_body,
headers=admin_user.headers,
)
result.raise_for_status()
found_docs = result.json()["top_documents"]
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
print(
"\ntop_documents content before deleting for admin: ",
danswer_doc_message_strings,
)
# Ensure admin user can see all messages
assert public_message in danswer_doc_message_strings
assert private_message in danswer_doc_message_strings
assert message_to_delete in danswer_doc_message_strings
# Search as test_user_1 with access to both channels
search_request = DocumentSearchRequest(
message="favorite number",
search_type=SearchType.KEYWORD,
retrieval_options=RetrievalDetails(),
evaluation_type=LLMEvaluationType.SKIP,
)
search_request_body = search_request.model_dump()
result = requests.post(
url=f"{API_SERVER_URL}/query/document-search",
json=search_request_body,
headers=test_user_1.headers,
)
result.raise_for_status()
found_docs = result.json()["top_documents"]
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
print(
"\ntop_documents content before deleting for test_user_1: ",
danswer_doc_message_strings,
)
# Ensure test_user_1 can see all messages
assert public_message in danswer_doc_message_strings
assert private_message in danswer_doc_message_strings
assert message_to_delete in danswer_doc_message_strings
# ----------------------MAKE THE CHANGES--------------------------
# Delete messages
print("\nDeleting message: ", message_to_delete)
SlackManager.remove_message_from_channel(
slack_client=slack_client,
channel=private_channel,
message=message_to_delete,
)
# Prune the cc_pair
before = datetime.now(timezone.utc)
CCPairManager.prune(cc_pair, user_performing_action=admin_user)
CCPairManager.wait_for_prune(cc_pair, before, user_performing_action=admin_user)
# ----------------------------VERIFY THE CHANGES---------------------------
# Ensure admin user can't see deleted messages
# Search as admin user with access to only the public channel
search_request = DocumentSearchRequest(
message="favorite number",
search_type=SearchType.KEYWORD,
retrieval_options=RetrievalDetails(),
evaluation_type=LLMEvaluationType.SKIP,
)
search_request_body = search_request.model_dump()
result = requests.post(
url=f"{API_SERVER_URL}/query/document-search",
json=search_request_body,
headers=admin_user.headers,
)
result.raise_for_status()
found_docs = result.json()["top_documents"]
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
print(
"\ntop_documents content after deleting for admin: ",
danswer_doc_message_strings,
)
# Ensure admin can't see deleted messages
assert public_message in danswer_doc_message_strings
assert private_message in danswer_doc_message_strings
assert message_to_delete not in danswer_doc_message_strings
# Ensure test_user_1 can't see deleted messages
# Search as test_user_1 with access to only the public channel
search_request = DocumentSearchRequest(
message="favorite number",
search_type=SearchType.KEYWORD,
retrieval_options=RetrievalDetails(),
evaluation_type=LLMEvaluationType.SKIP,
)
search_request_body = search_request.model_dump()
result = requests.post(
url=f"{API_SERVER_URL}/query/document-search",
json=search_request_body,
headers=test_user_1.headers,
)
result.raise_for_status()
found_docs = result.json()["top_documents"]
danswer_doc_message_strings = [doc["content"] for doc in found_docs]
print(
"\ntop_documents content after prune for test_user_1: ",
danswer_doc_message_strings,
)
# Ensure test_user_1 can't see deleted messages
assert public_message in danswer_doc_message_strings
assert private_message in danswer_doc_message_strings
assert message_to_delete not in danswer_doc_message_strings