Added permissions syncing for slack (#2602)

* Added permissions syncing for slack

* add no email case handling

* mypy fixes

* frontend

* minor cleanup

* param tweak
This commit is contained in:
hagen-danswer 2024-09-30 08:14:43 -07:00 committed by GitHub
parent 728a41a35a
commit b0056907fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 457 additions and 84 deletions

View File

@ -63,6 +63,7 @@ def identify_connector_class(
DocumentSource.SLACK: {
InputType.LOAD_STATE: SlackLoadConnector,
InputType.POLL: SlackPollConnector,
InputType.PRUNE: SlackPollConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,

View File

@ -8,13 +8,12 @@ from typing import cast
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
@ -23,9 +22,8 @@ from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_logged
from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.utils.logger import setup_logger
@ -38,47 +36,18 @@ MessageType = dict[str, Any]
# list of messages in a thread
ThreadType = list[MessageType]
basic_retry_wrapper = retry_builder()
def _make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)
)(**kwargs)
def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)(**kwargs)
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
"""Get information about a channel. Needed to convert channel ID to channel name"""
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
"channel"
]
def _get_channels(
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
get_private: bool,
channel_types: list[str],
) -> list[ChannelType]:
channels: list[dict[str, Any]] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_list,
exclude_archived=exclude_archived,
# also get private channels the bot is added to
types=["public_channel", "private_channel"]
if get_private
else ["public_channel"],
types=channel_types,
):
channels.extend(result["channels"])
@ -88,19 +57,38 @@ def _get_channels(
def get_channels(
client: WebClient,
exclude_archived: bool = True,
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# try getting private channels as well at first
try:
return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=True
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
except SlackApiError as e:
logger.info(f"Unable to fetch private channels due to - {e}")
logger.info("trying again without private channels")
if get_public:
channel_types = ["public_channel"]
else:
logger.warning("No channels to fetch")
return []
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=False
)
return channels
def get_channel_messages(
@ -112,14 +100,14 @@ def get_channel_messages(
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
_make_slack_api_call(
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
@ -131,7 +119,7 @@ def get_channel_messages(
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread"""
threads: list[MessageType] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
@ -266,7 +254,7 @@ def filter_channels(
]
def get_all_docs(
def _get_all_docs(
client: WebClient,
workspace: str,
channels: list[str] | None = None,
@ -328,7 +316,44 @@ def get_all_docs(
)
class SlackPollConnector(PollConnector):
def _get_all_doc_ids(
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
) -> set[str]:
"""
Get all document ids in the workspace, channel by channel
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
This makes it an order of magnitude faster than get_all_docs
"""
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
)
all_doc_ids = set()
for channel in filtered_channels:
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
)
for message_batch in channel_message_batches:
for message in message_batch:
if msg_filter_func(message):
continue
# The document id is the channel id and the ts of the first message in the thread
# Since we already have the first message of the thread, we dont have to
# fetch the thread for id retrieval, saving time and API calls
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
return all_doc_ids
class SlackPollConnector(PollConnector, IdConnector):
def __init__(
self,
workspace: str,
@ -349,6 +374,16 @@ class SlackPollConnector(PollConnector):
self.client = WebClient(token=bot_token)
return None
def retrieve_all_source_ids(self) -> set[str]:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")
return _get_all_doc_ids(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
@ -356,7 +391,7 @@ class SlackPollConnector(PollConnector):
raise ConnectorMissingCredentialError("Slack")
documents: list[Document] = []
for document in get_all_docs(
for document in _get_all_docs(
client=self.client,
workspace=self.workspace,
channels=self.channels,

View File

@ -10,11 +10,13 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
basic_retry_wrapper = retry_builder()
# number of messages we request per page when fetching paginated slack messages
_SLACK_LIMIT = 900
@ -34,7 +36,7 @@ def get_message_link(
)
def make_slack_api_call_logged(
def _make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
@ -47,7 +49,7 @@ def make_slack_api_call_logged(
return logged_call
def make_slack_api_call_paginated(
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., Generator[dict[str, Any], None, None]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
@ -116,6 +118,24 @@ def make_slack_api_rate_limited(
return rate_limited_call
def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)(**kwargs)
def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)
)(**kwargs)
def expert_info_from_slack_id(
user_id: str | None,
client: WebClient,

View File

@ -26,9 +26,7 @@ from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_function_map import (
check_if_valid_sync_source,
)
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
logger = setup_logger()

View File

@ -4,29 +4,38 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import get_document_ids_for_connector_credential_pair
from danswer.db.models import ConnectorCredentialPair
from danswer.document_index.factory import get_current_primary_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.permission_sync_function_map import (
DOC_PERMISSIONS_FUNC_MAP,
)
from ee.danswer.external_permissions.permission_sync_function_map import (
GROUP_PERMISSIONS_FUNC_MAP,
)
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
logger = setup_logger()
# None means that the connector runs every time
_RESTRICTED_FETCH_PERIOD: dict[DocumentSource, int | None] = {
# Polling is supported
DocumentSource.GOOGLE_DRIVE: None,
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: 5 * 60,
}
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is None, it has never been run so we run the sync
if cc_pair.last_time_perm_sync is None:
return True
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
# If the last sync is greater than the full fetch period, we run the sync
if (current_time - last_sync).total_seconds() > source_sync_period:
return True
return False
def run_external_group_permission_sync(
@ -44,6 +53,9 @@ def run_external_group_permission_sync(
# Not all sync connectors support group permissions so this is fine
return
if not _is_time_to_run_sync(cc_pair):
return
try:
# This function updates:
# - the user_email <-> external_user_group_id mapping
@ -79,20 +91,8 @@ def run_external_doc_permission_sync(
f"No permission sync function found for source type: {source_type}"
)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
# If RESTRICTED_FETCH_PERIOD is not None, we only run sync if the
# last sync was more than RESTRICTED_FETCH_PERIOD seconds ago.
full_fetch_period = _RESTRICTED_FETCH_PERIOD[cc_pair.connector.source]
if full_fetch_period is not None:
last_sync = cc_pair.last_time_perm_sync
if (
last_sync
and (
datetime.now(timezone.utc) - last_sync.replace(tzinfo=timezone.utc)
).total_seconds()
< full_fetch_period
):
return
if not _is_time_to_run_sync(cc_pair):
return
try:
# This function updates:
@ -131,6 +131,9 @@ def run_external_doc_permission_sync(
# update vespa
document_index.update(update_reqs)
cc_pair.last_time_perm_sync = datetime.now(timezone.utc)
# update postgres
db_session.commit()
except Exception as e:

View File

@ -0,0 +1,192 @@
from slack_sdk import WebClient
from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.models import InputType
from danswer.connectors.slack.connector import get_channels
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
logger = setup_logger()
def _extract_channel_id_from_doc_id(doc_id: str) -> str:
"""
Extracts the channel ID from a document ID string.
The document ID is expected to be in the format: "{channel_id}__{message_ts}"
Args:
doc_id (str): The document ID string.
Returns:
str: The extracted channel ID.
Raises:
ValueError: If the doc_id doesn't contain the expected separator.
"""
try:
channel_id, _ = doc_id.split("__", 1)
return channel_id
except ValueError:
raise ValueError(f"Invalid doc_id format: {doc_id}")
def _get_slack_document_ids_and_channels(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> dict[str, list[str]]:
# Get all document ids that need their permissions updated
runnable_connector = instantiate_connector(
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.PRUNE,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
assert isinstance(runnable_connector, IdConnector)
channel_doc_map: dict[str, list[str]] = {}
for doc_id in runnable_connector.retrieve_all_source_ids():
channel_id = _extract_channel_id_from_doc_id(doc_id)
if channel_id not in channel_doc_map:
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_id)
return channel_doc_map
def _fetch_worspace_permissions(
db_session: Session,
user_id_to_email_map: dict[str, str],
) -> ExternalAccess:
user_emails = set()
for email in user_id_to_email_map.values():
user_emails.add(email)
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
return ExternalAccess(
external_user_emails=user_emails,
# No group<->document mapping for slack
external_user_group_ids=set(),
# No way to determine if slack is invite only without enterprise liscense
is_public=False,
)
def _fetch_channel_permissions(
db_session: Session,
slack_client: WebClient,
workspace_permissions: ExternalAccess,
user_id_to_email_map: dict[str, str],
) -> dict[str, ExternalAccess]:
channel_permissions = {}
public_channels = get_channels(
client=slack_client,
get_public=True,
get_private=False,
)
public_channel_ids = [
channel["id"] for channel in public_channels if "id" in channel
]
for channel_id in public_channel_ids:
channel_permissions[channel_id] = workspace_permissions
private_channels = get_channels(
client=slack_client,
get_public=False,
get_private=True,
)
private_channel_ids = [
channel["id"] for channel in private_channels if "id" in channel
]
for channel_id in private_channel_ids:
# Collect all member ids for the channel pagination calls
member_ids = []
for result in make_paginated_slack_api_call_w_retries(
slack_client.conversations_members,
channel=channel_id,
):
member_ids.extend(result.get("members", []))
# Collect all member emails for the channel
member_emails = set()
for member_id in member_ids:
member_email = user_id_to_email_map.get(member_id)
if not member_email:
# If the user is an external user, they wont get returned from the
# conversations_members call so we need to make a separate call to users_info
# and add them to the user_id_to_email_map
member_info = slack_client.users_info(user=member_id)
member_email = member_info["user"]["profile"].get("email")
if not member_email:
# If no email is found, we skip the user
continue
user_id_to_email_map[member_id] = member_email
batch_add_non_web_user_if_not_exists__no_commit(
db_session, [member_email]
)
member_emails.add(member_email)
channel_permissions[channel_id] = ExternalAccess(
external_user_emails=member_emails,
# No group<->document mapping for slack
external_user_group_ids=set(),
# No way to determine if slack is invite only without enterprise liscense
is_public=False,
)
return channel_permissions
def slack_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> None:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
it in postgres so that when it gets created later, the permissions are
already populated
"""
slack_client = WebClient(
token=cc_pair.credential.credential_json["slack_bot_token"]
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
db_session=db_session,
cc_pair=cc_pair,
)
workspace_permissions = _fetch_worspace_permissions(
db_session=db_session,
user_id_to_email_map=user_id_to_email_map,
)
channel_permissions = _fetch_channel_permissions(
db_session=db_session,
slack_client=slack_client,
workspace_permissions=workspace_permissions,
user_id_to_email_map=user_id_to_email_map,
)
for channel_id, ext_access in channel_permissions.items():
doc_ids = channel_doc_map.get(channel_id)
if not doc_ids:
# No documents found for channel the channel_id
continue
for doc_id in doc_ids:
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=doc_id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)

View File

@ -0,0 +1,92 @@
"""
THIS IS NOT USEFUL OR USED FOR PERMISSION SYNCING
WHEN USERGROUPS ARE ADDED TO A CHANNEL, IT JUST RESOLVES ALL THE USERS TO THAT CHANNEL
SO WHEN CHECKING IF A USER CAN ACCESS A DOCUMENT, WE ONLY NEED TO CHECK THEIR EMAIL
THERE IS NO USERGROUP <-> DOCUMENT PERMISSION MAPPING
"""
from slack_sdk import WebClient
from sqlalchemy.orm import Session
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
logger = setup_logger()
def _get_slack_group_ids(
slack_client: WebClient,
) -> list[str]:
group_ids = []
for result in make_paginated_slack_api_call_w_retries(slack_client.usergroups_list):
for group in result.get("usergroups", []):
group_ids.append(group.get("id"))
return group_ids
def _get_slack_group_members_email(
db_session: Session,
slack_client: WebClient,
group_name: str,
user_id_to_email_map: dict[str, str],
) -> list[str]:
group_member_emails = []
for result in make_paginated_slack_api_call_w_retries(
slack_client.usergroups_users_list, usergroup=group_name
):
for member_id in result.get("users", []):
member_email = user_id_to_email_map.get(member_id)
if not member_email:
# If the user is an external user, they wont get returned from the
# conversations_members call so we need to make a separate call to users_info
member_info = slack_client.users_info(user=member_id)
member_email = member_info["user"]["profile"].get("email")
if not member_email:
# If no email is found, we skip the user
continue
user_id_to_email_map[member_id] = member_email
batch_add_non_web_user_if_not_exists__no_commit(
db_session, [member_email]
)
group_member_emails.append(member_email)
return group_member_emails
def slack_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> None:
slack_client = WebClient(
token=cc_pair.credential.credential_json["slack_bot_token"]
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
danswer_groups: list[ExternalUserGroup] = []
for group_name in _get_slack_group_ids(slack_client):
group_member_emails = _get_slack_group_members_email(
db_session=db_session,
slack_client=slack_client,
group_name=group_name,
user_id_to_email_map=user_id_to_email_map,
)
group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=group_member_emails
)
if group_members:
danswer_groups.append(
ExternalUserGroup(
id=group_name, user_ids=[user.id for user in group_members]
)
)
replace_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=danswer_groups,
source=cc_pair.connector.source,
)

View File

@ -0,0 +1,18 @@
from slack_sdk import WebClient
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
def fetch_user_id_to_email_map(
slack_client: WebClient,
) -> dict[str, str]:
user_id_to_email_map = {}
for user_info in make_paginated_slack_api_call_w_retries(
slack_client.users_list,
):
for user in user_info.get("members", []):
if user.get("profile", {}).get("email"):
user_id_to_email_map[user.get("id")] = user.get("profile", {}).get(
"email"
)
return user_id_to_email_map

View File

@ -8,7 +8,7 @@ from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_s
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync
# Defining the input/output types for the sync functions
SyncFuncType = Callable[
@ -27,6 +27,7 @@ SyncFuncType = Callable[
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
DocumentSource.CONFLUENCE: confluence_doc_sync,
DocumentSource.SLACK: slack_doc_sync,
}
# These functions update:
@ -39,5 +40,13 @@ GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: 5 * 60,
DocumentSource.SLACK: 5 * 60,
}
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
return source_type in DOC_PERMISSIONS_FUNC_MAP

View File

@ -44,4 +44,5 @@ export const autoSyncConfigBySource: Record<
),
},
},
slack: {},
};

View File

@ -272,5 +272,9 @@ export type ConfigurableSources = Exclude<
>;
// The sources that have auto-sync support on the backend
export const validAutoSyncSources = ["confluence", "google_drive"] as const;
export const validAutoSyncSources = [
"confluence",
"google_drive",
"slack",
] as const;
export type ValidAutoSyncSources = (typeof validAutoSyncSources)[number];