mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-01 18:49:27 +02:00
Added permissions syncing for slack (#2602)
* Added permissions syncing for slack * add no email case handling * mypy fixes * frontend * minor cleanup * param tweak
This commit is contained in:
parent
728a41a35a
commit
b0056907fb
@ -63,6 +63,7 @@ def identify_connector_class(
|
|||||||
DocumentSource.SLACK: {
|
DocumentSource.SLACK: {
|
||||||
InputType.LOAD_STATE: SlackLoadConnector,
|
InputType.LOAD_STATE: SlackLoadConnector,
|
||||||
InputType.POLL: SlackPollConnector,
|
InputType.POLL: SlackPollConnector,
|
||||||
|
InputType.PRUNE: SlackPollConnector,
|
||||||
},
|
},
|
||||||
DocumentSource.GITHUB: GithubConnector,
|
DocumentSource.GITHUB: GithubConnector,
|
||||||
DocumentSource.GMAIL: GmailConnector,
|
DocumentSource.GMAIL: GmailConnector,
|
||||||
|
@ -8,13 +8,12 @@ from typing import cast
|
|||||||
|
|
||||||
from slack_sdk import WebClient
|
from slack_sdk import WebClient
|
||||||
from slack_sdk.errors import SlackApiError
|
from slack_sdk.errors import SlackApiError
|
||||||
from slack_sdk.web import SlackResponse
|
|
||||||
|
|
||||||
from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
|
||||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||||
|
from danswer.connectors.interfaces import IdConnector
|
||||||
from danswer.connectors.interfaces import PollConnector
|
from danswer.connectors.interfaces import PollConnector
|
||||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||||
from danswer.connectors.models import BasicExpertInfo
|
from danswer.connectors.models import BasicExpertInfo
|
||||||
@ -23,9 +22,8 @@ from danswer.connectors.models import Document
|
|||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||||
from danswer.connectors.slack.utils import get_message_link
|
from danswer.connectors.slack.utils import get_message_link
|
||||||
from danswer.connectors.slack.utils import make_slack_api_call_logged
|
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||||
from danswer.connectors.slack.utils import make_slack_api_call_paginated
|
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
|
||||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
|
||||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
@ -38,47 +36,18 @@ MessageType = dict[str, Any]
|
|||||||
# list of messages in a thread
|
# list of messages in a thread
|
||||||
ThreadType = list[MessageType]
|
ThreadType = list[MessageType]
|
||||||
|
|
||||||
basic_retry_wrapper = retry_builder()
|
|
||||||
|
|
||||||
|
def _collect_paginated_channels(
|
||||||
def _make_paginated_slack_api_call(
|
|
||||||
call: Callable[..., SlackResponse], **kwargs: Any
|
|
||||||
) -> Generator[dict[str, Any], None, None]:
|
|
||||||
return make_slack_api_call_paginated(
|
|
||||||
basic_retry_wrapper(
|
|
||||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
|
||||||
)
|
|
||||||
)(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_slack_api_call(
|
|
||||||
call: Callable[..., SlackResponse], **kwargs: Any
|
|
||||||
) -> SlackResponse:
|
|
||||||
return basic_retry_wrapper(
|
|
||||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
|
||||||
)(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
|
||||||
"""Get information about a channel. Needed to convert channel ID to channel name"""
|
|
||||||
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
|
|
||||||
"channel"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_channels(
|
|
||||||
client: WebClient,
|
client: WebClient,
|
||||||
exclude_archived: bool,
|
exclude_archived: bool,
|
||||||
get_private: bool,
|
channel_types: list[str],
|
||||||
) -> list[ChannelType]:
|
) -> list[ChannelType]:
|
||||||
channels: list[dict[str, Any]] = []
|
channels: list[dict[str, Any]] = []
|
||||||
for result in _make_paginated_slack_api_call(
|
for result in make_paginated_slack_api_call_w_retries(
|
||||||
client.conversations_list,
|
client.conversations_list,
|
||||||
exclude_archived=exclude_archived,
|
exclude_archived=exclude_archived,
|
||||||
# also get private channels the bot is added to
|
# also get private channels the bot is added to
|
||||||
types=["public_channel", "private_channel"]
|
types=channel_types,
|
||||||
if get_private
|
|
||||||
else ["public_channel"],
|
|
||||||
):
|
):
|
||||||
channels.extend(result["channels"])
|
channels.extend(result["channels"])
|
||||||
|
|
||||||
@ -88,19 +57,38 @@ def _get_channels(
|
|||||||
def get_channels(
|
def get_channels(
|
||||||
client: WebClient,
|
client: WebClient,
|
||||||
exclude_archived: bool = True,
|
exclude_archived: bool = True,
|
||||||
|
get_public: bool = True,
|
||||||
|
get_private: bool = True,
|
||||||
) -> list[ChannelType]:
|
) -> list[ChannelType]:
|
||||||
"""Get all channels in the workspace"""
|
"""Get all channels in the workspace"""
|
||||||
|
channels: list[dict[str, Any]] = []
|
||||||
|
channel_types = []
|
||||||
|
if get_public:
|
||||||
|
channel_types.append("public_channel")
|
||||||
|
if get_private:
|
||||||
|
channel_types.append("private_channel")
|
||||||
# try getting private channels as well at first
|
# try getting private channels as well at first
|
||||||
try:
|
try:
|
||||||
return _get_channels(
|
channels = _collect_paginated_channels(
|
||||||
client=client, exclude_archived=exclude_archived, get_private=True
|
client=client,
|
||||||
|
exclude_archived=exclude_archived,
|
||||||
|
channel_types=channel_types,
|
||||||
)
|
)
|
||||||
except SlackApiError as e:
|
except SlackApiError as e:
|
||||||
logger.info(f"Unable to fetch private channels due to - {e}")
|
logger.info(f"Unable to fetch private channels due to - {e}")
|
||||||
|
logger.info("trying again without private channels")
|
||||||
|
if get_public:
|
||||||
|
channel_types = ["public_channel"]
|
||||||
|
else:
|
||||||
|
logger.warning("No channels to fetch")
|
||||||
|
return []
|
||||||
|
channels = _collect_paginated_channels(
|
||||||
|
client=client,
|
||||||
|
exclude_archived=exclude_archived,
|
||||||
|
channel_types=channel_types,
|
||||||
|
)
|
||||||
|
|
||||||
return _get_channels(
|
return channels
|
||||||
client=client, exclude_archived=exclude_archived, get_private=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_channel_messages(
|
def get_channel_messages(
|
||||||
@ -112,14 +100,14 @@ def get_channel_messages(
|
|||||||
"""Get all messages in a channel"""
|
"""Get all messages in a channel"""
|
||||||
# join so that the bot can access messages
|
# join so that the bot can access messages
|
||||||
if not channel["is_member"]:
|
if not channel["is_member"]:
|
||||||
_make_slack_api_call(
|
make_slack_api_call_w_retries(
|
||||||
client.conversations_join,
|
client.conversations_join,
|
||||||
channel=channel["id"],
|
channel=channel["id"],
|
||||||
is_private=channel["is_private"],
|
is_private=channel["is_private"],
|
||||||
)
|
)
|
||||||
logger.info(f"Successfully joined '{channel['name']}'")
|
logger.info(f"Successfully joined '{channel['name']}'")
|
||||||
|
|
||||||
for result in _make_paginated_slack_api_call(
|
for result in make_paginated_slack_api_call_w_retries(
|
||||||
client.conversations_history,
|
client.conversations_history,
|
||||||
channel=channel["id"],
|
channel=channel["id"],
|
||||||
oldest=oldest,
|
oldest=oldest,
|
||||||
@ -131,7 +119,7 @@ def get_channel_messages(
|
|||||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||||
"""Get all messages in a thread"""
|
"""Get all messages in a thread"""
|
||||||
threads: list[MessageType] = []
|
threads: list[MessageType] = []
|
||||||
for result in _make_paginated_slack_api_call(
|
for result in make_paginated_slack_api_call_w_retries(
|
||||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||||
):
|
):
|
||||||
threads.extend(result["messages"])
|
threads.extend(result["messages"])
|
||||||
@ -266,7 +254,7 @@ def filter_channels(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_all_docs(
|
def _get_all_docs(
|
||||||
client: WebClient,
|
client: WebClient,
|
||||||
workspace: str,
|
workspace: str,
|
||||||
channels: list[str] | None = None,
|
channels: list[str] | None = None,
|
||||||
@ -328,7 +316,44 @@ def get_all_docs(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SlackPollConnector(PollConnector):
|
def _get_all_doc_ids(
|
||||||
|
client: WebClient,
|
||||||
|
channels: list[str] | None = None,
|
||||||
|
channel_name_regex_enabled: bool = False,
|
||||||
|
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
||||||
|
) -> set[str]:
|
||||||
|
"""
|
||||||
|
Get all document ids in the workspace, channel by channel
|
||||||
|
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
|
||||||
|
This makes it an order of magnitude faster than get_all_docs
|
||||||
|
"""
|
||||||
|
|
||||||
|
all_channels = get_channels(client)
|
||||||
|
filtered_channels = filter_channels(
|
||||||
|
all_channels, channels, channel_name_regex_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
all_doc_ids = set()
|
||||||
|
for channel in filtered_channels:
|
||||||
|
channel_message_batches = get_channel_messages(
|
||||||
|
client=client,
|
||||||
|
channel=channel,
|
||||||
|
)
|
||||||
|
|
||||||
|
for message_batch in channel_message_batches:
|
||||||
|
for message in message_batch:
|
||||||
|
if msg_filter_func(message):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The document id is the channel id and the ts of the first message in the thread
|
||||||
|
# Since we already have the first message of the thread, we dont have to
|
||||||
|
# fetch the thread for id retrieval, saving time and API calls
|
||||||
|
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
|
||||||
|
|
||||||
|
return all_doc_ids
|
||||||
|
|
||||||
|
|
||||||
|
class SlackPollConnector(PollConnector, IdConnector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
workspace: str,
|
workspace: str,
|
||||||
@ -349,6 +374,16 @@ class SlackPollConnector(PollConnector):
|
|||||||
self.client = WebClient(token=bot_token)
|
self.client = WebClient(token=bot_token)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def retrieve_all_source_ids(self) -> set[str]:
|
||||||
|
if self.client is None:
|
||||||
|
raise ConnectorMissingCredentialError("Slack")
|
||||||
|
|
||||||
|
return _get_all_doc_ids(
|
||||||
|
client=self.client,
|
||||||
|
channels=self.channels,
|
||||||
|
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
def poll_source(
|
def poll_source(
|
||||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||||
) -> GenerateDocumentsOutput:
|
) -> GenerateDocumentsOutput:
|
||||||
@ -356,7 +391,7 @@ class SlackPollConnector(PollConnector):
|
|||||||
raise ConnectorMissingCredentialError("Slack")
|
raise ConnectorMissingCredentialError("Slack")
|
||||||
|
|
||||||
documents: list[Document] = []
|
documents: list[Document] = []
|
||||||
for document in get_all_docs(
|
for document in _get_all_docs(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
workspace=self.workspace,
|
workspace=self.workspace,
|
||||||
channels=self.channels,
|
channels=self.channels,
|
||||||
|
@ -10,11 +10,13 @@ from slack_sdk import WebClient
|
|||||||
from slack_sdk.errors import SlackApiError
|
from slack_sdk.errors import SlackApiError
|
||||||
from slack_sdk.web import SlackResponse
|
from slack_sdk.web import SlackResponse
|
||||||
|
|
||||||
|
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||||
from danswer.connectors.models import BasicExpertInfo
|
from danswer.connectors.models import BasicExpertInfo
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
basic_retry_wrapper = retry_builder()
|
||||||
# number of messages we request per page when fetching paginated slack messages
|
# number of messages we request per page when fetching paginated slack messages
|
||||||
_SLACK_LIMIT = 900
|
_SLACK_LIMIT = 900
|
||||||
|
|
||||||
@ -34,7 +36,7 @@ def get_message_link(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_slack_api_call_logged(
|
def _make_slack_api_call_logged(
|
||||||
call: Callable[..., SlackResponse],
|
call: Callable[..., SlackResponse],
|
||||||
) -> Callable[..., SlackResponse]:
|
) -> Callable[..., SlackResponse]:
|
||||||
@wraps(call)
|
@wraps(call)
|
||||||
@ -47,7 +49,7 @@ def make_slack_api_call_logged(
|
|||||||
return logged_call
|
return logged_call
|
||||||
|
|
||||||
|
|
||||||
def make_slack_api_call_paginated(
|
def _make_slack_api_call_paginated(
|
||||||
call: Callable[..., SlackResponse],
|
call: Callable[..., SlackResponse],
|
||||||
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
||||||
"""Wraps calls to slack API so that they automatically handle pagination"""
|
"""Wraps calls to slack API so that they automatically handle pagination"""
|
||||||
@ -116,6 +118,24 @@ def make_slack_api_rate_limited(
|
|||||||
return rate_limited_call
|
return rate_limited_call
|
||||||
|
|
||||||
|
|
||||||
|
def make_slack_api_call_w_retries(
|
||||||
|
call: Callable[..., SlackResponse], **kwargs: Any
|
||||||
|
) -> SlackResponse:
|
||||||
|
return basic_retry_wrapper(
|
||||||
|
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||||
|
)(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_paginated_slack_api_call_w_retries(
|
||||||
|
call: Callable[..., SlackResponse], **kwargs: Any
|
||||||
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
|
return _make_slack_api_call_paginated(
|
||||||
|
basic_retry_wrapper(
|
||||||
|
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||||
|
)
|
||||||
|
)(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def expert_info_from_slack_id(
|
def expert_info_from_slack_id(
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
client: WebClient,
|
client: WebClient,
|
||||||
|
@ -26,9 +26,7 @@ from danswer.db.models import UserRole
|
|||||||
from danswer.server.models import StatusResponse
|
from danswer.server.models import StatusResponse
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
|
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
|
||||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
|
||||||
check_if_valid_sync_source,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
@ -4,29 +4,38 @@ from datetime import timezone
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.access.access import get_access_for_documents
|
from danswer.access.access import get_access_for_documents
|
||||||
from danswer.configs.constants import DocumentSource
|
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||||
|
from danswer.db.models import ConnectorCredentialPair
|
||||||
from danswer.document_index.factory import get_current_primary_default_document_index
|
from danswer.document_index.factory import get_current_primary_default_document_index
|
||||||
from danswer.document_index.interfaces import UpdateRequest
|
from danswer.document_index.interfaces import UpdateRequest
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||||
DOC_PERMISSIONS_FUNC_MAP,
|
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||||
)
|
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
|
||||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
|
||||||
GROUP_PERMISSIONS_FUNC_MAP,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
# None means that the connector runs every time
|
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
|
||||||
_RESTRICTED_FETCH_PERIOD: dict[DocumentSource, int | None] = {
|
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||||
# Polling is supported
|
|
||||||
DocumentSource.GOOGLE_DRIVE: None,
|
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
if not source_sync_period:
|
||||||
DocumentSource.CONFLUENCE: 5 * 60,
|
return True
|
||||||
}
|
|
||||||
|
# If the last sync is None, it has never been run so we run the sync
|
||||||
|
if cc_pair.last_time_perm_sync is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
|
||||||
|
current_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# If the last sync is greater than the full fetch period, we run the sync
|
||||||
|
if (current_time - last_sync).total_seconds() > source_sync_period:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def run_external_group_permission_sync(
|
def run_external_group_permission_sync(
|
||||||
@ -44,6 +53,9 @@ def run_external_group_permission_sync(
|
|||||||
# Not all sync connectors support group permissions so this is fine
|
# Not all sync connectors support group permissions so this is fine
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if not _is_time_to_run_sync(cc_pair):
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# This function updates:
|
# This function updates:
|
||||||
# - the user_email <-> external_user_group_id mapping
|
# - the user_email <-> external_user_group_id mapping
|
||||||
@ -79,20 +91,8 @@ def run_external_doc_permission_sync(
|
|||||||
f"No permission sync function found for source type: {source_type}"
|
f"No permission sync function found for source type: {source_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
if not _is_time_to_run_sync(cc_pair):
|
||||||
# If RESTRICTED_FETCH_PERIOD is not None, we only run sync if the
|
return
|
||||||
# last sync was more than RESTRICTED_FETCH_PERIOD seconds ago.
|
|
||||||
full_fetch_period = _RESTRICTED_FETCH_PERIOD[cc_pair.connector.source]
|
|
||||||
if full_fetch_period is not None:
|
|
||||||
last_sync = cc_pair.last_time_perm_sync
|
|
||||||
if (
|
|
||||||
last_sync
|
|
||||||
and (
|
|
||||||
datetime.now(timezone.utc) - last_sync.replace(tzinfo=timezone.utc)
|
|
||||||
).total_seconds()
|
|
||||||
< full_fetch_period
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# This function updates:
|
# This function updates:
|
||||||
@ -131,6 +131,9 @@ def run_external_doc_permission_sync(
|
|||||||
|
|
||||||
# update vespa
|
# update vespa
|
||||||
document_index.update(update_reqs)
|
document_index.update(update_reqs)
|
||||||
|
|
||||||
|
cc_pair.last_time_perm_sync = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# update postgres
|
# update postgres
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
192
backend/ee/danswer/external_permissions/slack/doc_sync.py
Normal file
192
backend/ee/danswer/external_permissions/slack/doc_sync.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
from slack_sdk import WebClient
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.access.models import ExternalAccess
|
||||||
|
from danswer.connectors.factory import instantiate_connector
|
||||||
|
from danswer.connectors.interfaces import IdConnector
|
||||||
|
from danswer.connectors.models import InputType
|
||||||
|
from danswer.connectors.slack.connector import get_channels
|
||||||
|
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||||
|
from danswer.db.models import ConnectorCredentialPair
|
||||||
|
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||||
|
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_channel_id_from_doc_id(doc_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Extracts the channel ID from a document ID string.
|
||||||
|
|
||||||
|
The document ID is expected to be in the format: "{channel_id}__{message_ts}"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id (str): The document ID string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The extracted channel ID.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the doc_id doesn't contain the expected separator.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
channel_id, _ = doc_id.split("__", 1)
|
||||||
|
return channel_id
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Invalid doc_id format: {doc_id}")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_slack_document_ids_and_channels(
|
||||||
|
db_session: Session,
|
||||||
|
cc_pair: ConnectorCredentialPair,
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
|
# Get all document ids that need their permissions updated
|
||||||
|
runnable_connector = instantiate_connector(
|
||||||
|
db_session=db_session,
|
||||||
|
source=cc_pair.connector.source,
|
||||||
|
input_type=InputType.PRUNE,
|
||||||
|
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||||
|
credential=cc_pair.credential,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(runnable_connector, IdConnector)
|
||||||
|
|
||||||
|
channel_doc_map: dict[str, list[str]] = {}
|
||||||
|
for doc_id in runnable_connector.retrieve_all_source_ids():
|
||||||
|
channel_id = _extract_channel_id_from_doc_id(doc_id)
|
||||||
|
if channel_id not in channel_doc_map:
|
||||||
|
channel_doc_map[channel_id] = []
|
||||||
|
channel_doc_map[channel_id].append(doc_id)
|
||||||
|
|
||||||
|
return channel_doc_map
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_worspace_permissions(
|
||||||
|
db_session: Session,
|
||||||
|
user_id_to_email_map: dict[str, str],
|
||||||
|
) -> ExternalAccess:
|
||||||
|
user_emails = set()
|
||||||
|
for email in user_id_to_email_map.values():
|
||||||
|
user_emails.add(email)
|
||||||
|
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
|
||||||
|
return ExternalAccess(
|
||||||
|
external_user_emails=user_emails,
|
||||||
|
# No group<->document mapping for slack
|
||||||
|
external_user_group_ids=set(),
|
||||||
|
# No way to determine if slack is invite only without enterprise liscense
|
||||||
|
is_public=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_channel_permissions(
|
||||||
|
db_session: Session,
|
||||||
|
slack_client: WebClient,
|
||||||
|
workspace_permissions: ExternalAccess,
|
||||||
|
user_id_to_email_map: dict[str, str],
|
||||||
|
) -> dict[str, ExternalAccess]:
|
||||||
|
channel_permissions = {}
|
||||||
|
public_channels = get_channels(
|
||||||
|
client=slack_client,
|
||||||
|
get_public=True,
|
||||||
|
get_private=False,
|
||||||
|
)
|
||||||
|
public_channel_ids = [
|
||||||
|
channel["id"] for channel in public_channels if "id" in channel
|
||||||
|
]
|
||||||
|
for channel_id in public_channel_ids:
|
||||||
|
channel_permissions[channel_id] = workspace_permissions
|
||||||
|
|
||||||
|
private_channels = get_channels(
|
||||||
|
client=slack_client,
|
||||||
|
get_public=False,
|
||||||
|
get_private=True,
|
||||||
|
)
|
||||||
|
private_channel_ids = [
|
||||||
|
channel["id"] for channel in private_channels if "id" in channel
|
||||||
|
]
|
||||||
|
|
||||||
|
for channel_id in private_channel_ids:
|
||||||
|
# Collect all member ids for the channel pagination calls
|
||||||
|
member_ids = []
|
||||||
|
for result in make_paginated_slack_api_call_w_retries(
|
||||||
|
slack_client.conversations_members,
|
||||||
|
channel=channel_id,
|
||||||
|
):
|
||||||
|
member_ids.extend(result.get("members", []))
|
||||||
|
|
||||||
|
# Collect all member emails for the channel
|
||||||
|
member_emails = set()
|
||||||
|
for member_id in member_ids:
|
||||||
|
member_email = user_id_to_email_map.get(member_id)
|
||||||
|
|
||||||
|
if not member_email:
|
||||||
|
# If the user is an external user, they wont get returned from the
|
||||||
|
# conversations_members call so we need to make a separate call to users_info
|
||||||
|
# and add them to the user_id_to_email_map
|
||||||
|
member_info = slack_client.users_info(user=member_id)
|
||||||
|
member_email = member_info["user"]["profile"].get("email")
|
||||||
|
if not member_email:
|
||||||
|
# If no email is found, we skip the user
|
||||||
|
continue
|
||||||
|
user_id_to_email_map[member_id] = member_email
|
||||||
|
batch_add_non_web_user_if_not_exists__no_commit(
|
||||||
|
db_session, [member_email]
|
||||||
|
)
|
||||||
|
|
||||||
|
member_emails.add(member_email)
|
||||||
|
|
||||||
|
channel_permissions[channel_id] = ExternalAccess(
|
||||||
|
external_user_emails=member_emails,
|
||||||
|
# No group<->document mapping for slack
|
||||||
|
external_user_group_ids=set(),
|
||||||
|
# No way to determine if slack is invite only without enterprise liscense
|
||||||
|
is_public=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return channel_permissions
|
||||||
|
|
||||||
|
|
||||||
|
def slack_doc_sync(
|
||||||
|
db_session: Session,
|
||||||
|
cc_pair: ConnectorCredentialPair,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Adds the external permissions to the documents in postgres
|
||||||
|
if the document doesn't already exists in postgres, we create
|
||||||
|
it in postgres so that when it gets created later, the permissions are
|
||||||
|
already populated
|
||||||
|
"""
|
||||||
|
slack_client = WebClient(
|
||||||
|
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||||
|
)
|
||||||
|
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||||
|
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||||
|
db_session=db_session,
|
||||||
|
cc_pair=cc_pair,
|
||||||
|
)
|
||||||
|
workspace_permissions = _fetch_worspace_permissions(
|
||||||
|
db_session=db_session,
|
||||||
|
user_id_to_email_map=user_id_to_email_map,
|
||||||
|
)
|
||||||
|
channel_permissions = _fetch_channel_permissions(
|
||||||
|
db_session=db_session,
|
||||||
|
slack_client=slack_client,
|
||||||
|
workspace_permissions=workspace_permissions,
|
||||||
|
user_id_to_email_map=user_id_to_email_map,
|
||||||
|
)
|
||||||
|
for channel_id, ext_access in channel_permissions.items():
|
||||||
|
doc_ids = channel_doc_map.get(channel_id)
|
||||||
|
if not doc_ids:
|
||||||
|
# No documents found for channel the channel_id
|
||||||
|
continue
|
||||||
|
|
||||||
|
for doc_id in doc_ids:
|
||||||
|
upsert_document_external_perms__no_commit(
|
||||||
|
db_session=db_session,
|
||||||
|
doc_id=doc_id,
|
||||||
|
external_access=ext_access,
|
||||||
|
source_type=cc_pair.connector.source,
|
||||||
|
)
|
92
backend/ee/danswer/external_permissions/slack/group_sync.py
Normal file
92
backend/ee/danswer/external_permissions/slack/group_sync.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
"""
|
||||||
|
THIS IS NOT USEFUL OR USED FOR PERMISSION SYNCING
|
||||||
|
WHEN USERGROUPS ARE ADDED TO A CHANNEL, IT JUST RESOLVES ALL THE USERS TO THAT CHANNEL
|
||||||
|
SO WHEN CHECKING IF A USER CAN ACCESS A DOCUMENT, WE ONLY NEED TO CHECK THEIR EMAIL
|
||||||
|
THERE IS NO USERGROUP <-> DOCUMENT PERMISSION MAPPING
|
||||||
|
"""
|
||||||
|
from slack_sdk import WebClient
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||||
|
from danswer.db.models import ConnectorCredentialPair
|
||||||
|
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||||
|
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||||
|
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_slack_group_ids(
|
||||||
|
slack_client: WebClient,
|
||||||
|
) -> list[str]:
|
||||||
|
group_ids = []
|
||||||
|
for result in make_paginated_slack_api_call_w_retries(slack_client.usergroups_list):
|
||||||
|
for group in result.get("usergroups", []):
|
||||||
|
group_ids.append(group.get("id"))
|
||||||
|
return group_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _get_slack_group_members_email(
|
||||||
|
db_session: Session,
|
||||||
|
slack_client: WebClient,
|
||||||
|
group_name: str,
|
||||||
|
user_id_to_email_map: dict[str, str],
|
||||||
|
) -> list[str]:
|
||||||
|
group_member_emails = []
|
||||||
|
for result in make_paginated_slack_api_call_w_retries(
|
||||||
|
slack_client.usergroups_users_list, usergroup=group_name
|
||||||
|
):
|
||||||
|
for member_id in result.get("users", []):
|
||||||
|
member_email = user_id_to_email_map.get(member_id)
|
||||||
|
if not member_email:
|
||||||
|
# If the user is an external user, they wont get returned from the
|
||||||
|
# conversations_members call so we need to make a separate call to users_info
|
||||||
|
member_info = slack_client.users_info(user=member_id)
|
||||||
|
member_email = member_info["user"]["profile"].get("email")
|
||||||
|
if not member_email:
|
||||||
|
# If no email is found, we skip the user
|
||||||
|
continue
|
||||||
|
user_id_to_email_map[member_id] = member_email
|
||||||
|
batch_add_non_web_user_if_not_exists__no_commit(
|
||||||
|
db_session, [member_email]
|
||||||
|
)
|
||||||
|
group_member_emails.append(member_email)
|
||||||
|
|
||||||
|
return group_member_emails
|
||||||
|
|
||||||
|
|
||||||
|
def slack_group_sync(
|
||||||
|
db_session: Session,
|
||||||
|
cc_pair: ConnectorCredentialPair,
|
||||||
|
) -> None:
|
||||||
|
slack_client = WebClient(
|
||||||
|
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||||
|
)
|
||||||
|
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||||
|
|
||||||
|
danswer_groups: list[ExternalUserGroup] = []
|
||||||
|
for group_name in _get_slack_group_ids(slack_client):
|
||||||
|
group_member_emails = _get_slack_group_members_email(
|
||||||
|
db_session=db_session,
|
||||||
|
slack_client=slack_client,
|
||||||
|
group_name=group_name,
|
||||||
|
user_id_to_email_map=user_id_to_email_map,
|
||||||
|
)
|
||||||
|
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||||
|
db_session=db_session, emails=group_member_emails
|
||||||
|
)
|
||||||
|
if group_members:
|
||||||
|
danswer_groups.append(
|
||||||
|
ExternalUserGroup(
|
||||||
|
id=group_name, user_ids=[user.id for user in group_members]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
replace_user__ext_group_for_cc_pair__no_commit(
|
||||||
|
db_session=db_session,
|
||||||
|
cc_pair_id=cc_pair.id,
|
||||||
|
group_defs=danswer_groups,
|
||||||
|
source=cc_pair.connector.source,
|
||||||
|
)
|
18
backend/ee/danswer/external_permissions/slack/utils.py
Normal file
18
backend/ee/danswer/external_permissions/slack/utils.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from slack_sdk import WebClient
|
||||||
|
|
||||||
|
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_user_id_to_email_map(
|
||||||
|
slack_client: WebClient,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
user_id_to_email_map = {}
|
||||||
|
for user_info in make_paginated_slack_api_call_w_retries(
|
||||||
|
slack_client.users_list,
|
||||||
|
):
|
||||||
|
for user in user_info.get("members", []):
|
||||||
|
if user.get("profile", {}).get("email"):
|
||||||
|
user_id_to_email_map[user.get("id")] = user.get("profile", {}).get(
|
||||||
|
"email"
|
||||||
|
)
|
||||||
|
return user_id_to_email_map
|
@ -8,7 +8,7 @@ from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_s
|
|||||||
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
|
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
|
||||||
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||||
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
|
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||||
|
from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync
|
||||||
|
|
||||||
# Defining the input/output types for the sync functions
|
# Defining the input/output types for the sync functions
|
||||||
SyncFuncType = Callable[
|
SyncFuncType = Callable[
|
||||||
@ -27,6 +27,7 @@ SyncFuncType = Callable[
|
|||||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||||
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
||||||
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
||||||
|
DocumentSource.SLACK: slack_doc_sync,
|
||||||
}
|
}
|
||||||
|
|
||||||
# These functions update:
|
# These functions update:
|
||||||
@ -39,5 +40,13 @@ GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||||
|
PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||||
|
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||||
|
DocumentSource.CONFLUENCE: 5 * 60,
|
||||||
|
DocumentSource.SLACK: 5 * 60,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
@ -44,4 +44,5 @@ export const autoSyncConfigBySource: Record<
|
|||||||
),
|
),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
slack: {},
|
||||||
};
|
};
|
||||||
|
@ -272,5 +272,9 @@ export type ConfigurableSources = Exclude<
|
|||||||
>;
|
>;
|
||||||
|
|
||||||
// The sources that have auto-sync support on the backend
|
// The sources that have auto-sync support on the backend
|
||||||
export const validAutoSyncSources = ["confluence", "google_drive"] as const;
|
export const validAutoSyncSources = [
|
||||||
|
"confluence",
|
||||||
|
"google_drive",
|
||||||
|
"slack",
|
||||||
|
] as const;
|
||||||
export type ValidAutoSyncSources = (typeof validAutoSyncSources)[number];
|
export type ValidAutoSyncSources = (typeof validAutoSyncSources)[number];
|
||||||
|
Loading…
x
Reference in New Issue
Block a user