mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-30 01:30:21 +02:00
Added permissions syncing for slack (#2602)
* Added permissions syncing for slack * add no email case handling * mypy fixes * frontend * minor cleanup * param tweak
This commit is contained in:
parent
728a41a35a
commit
b0056907fb
@ -63,6 +63,7 @@ def identify_connector_class(
|
||||
DocumentSource.SLACK: {
|
||||
InputType.LOAD_STATE: SlackLoadConnector,
|
||||
InputType.POLL: SlackPollConnector,
|
||||
InputType.PRUNE: SlackPollConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
|
@ -8,13 +8,12 @@ from typing import cast
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
@ -23,9 +22,8 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.slack.utils import expert_info_from_slack_id
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_logged
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_paginated
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@ -38,47 +36,18 @@ MessageType = dict[str, Any]
|
||||
# list of messages in a thread
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
basic_retry_wrapper = retry_builder()
|
||||
|
||||
|
||||
def _make_paginated_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
||||
)
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def _make_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
||||
"""Get information about a channel. Needed to convert channel ID to channel name"""
|
||||
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
|
||||
"channel"
|
||||
]
|
||||
|
||||
|
||||
def _get_channels(
|
||||
def _collect_paginated_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool,
|
||||
get_private: bool,
|
||||
channel_types: list[str],
|
||||
) -> list[ChannelType]:
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_list,
|
||||
exclude_archived=exclude_archived,
|
||||
# also get private channels the bot is added to
|
||||
types=["public_channel", "private_channel"]
|
||||
if get_private
|
||||
else ["public_channel"],
|
||||
types=channel_types,
|
||||
):
|
||||
channels.extend(result["channels"])
|
||||
|
||||
@ -88,19 +57,38 @@ def _get_channels(
|
||||
def get_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool = True,
|
||||
get_public: bool = True,
|
||||
get_private: bool = True,
|
||||
) -> list[ChannelType]:
|
||||
"""Get all channels in the workspace"""
|
||||
channels: list[dict[str, Any]] = []
|
||||
channel_types = []
|
||||
if get_public:
|
||||
channel_types.append("public_channel")
|
||||
if get_private:
|
||||
channel_types.append("private_channel")
|
||||
# try getting private channels as well at first
|
||||
try:
|
||||
return _get_channels(
|
||||
client=client, exclude_archived=exclude_archived, get_private=True
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.info(f"Unable to fetch private channels due to - {e}")
|
||||
logger.info("trying again without private channels")
|
||||
if get_public:
|
||||
channel_types = ["public_channel"]
|
||||
else:
|
||||
logger.warning("No channels to fetch")
|
||||
return []
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
|
||||
return _get_channels(
|
||||
client=client, exclude_archived=exclude_archived, get_private=False
|
||||
)
|
||||
return channels
|
||||
|
||||
|
||||
def get_channel_messages(
|
||||
@ -112,14 +100,14 @@ def get_channel_messages(
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
if not channel["is_member"]:
|
||||
_make_slack_api_call(
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
@ -131,7 +119,7 @@ def get_channel_messages(
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
"""Get all messages in a thread"""
|
||||
threads: list[MessageType] = []
|
||||
for result in _make_paginated_slack_api_call(
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||
):
|
||||
threads.extend(result["messages"])
|
||||
@ -266,7 +254,7 @@ def filter_channels(
|
||||
]
|
||||
|
||||
|
||||
def get_all_docs(
|
||||
def _get_all_docs(
|
||||
client: WebClient,
|
||||
workspace: str,
|
||||
channels: list[str] | None = None,
|
||||
@ -328,7 +316,44 @@ def get_all_docs(
|
||||
)
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector):
|
||||
def _get_all_doc_ids(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
||||
) -> set[str]:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
|
||||
This makes it an order of magnitude faster than get_all_docs
|
||||
"""
|
||||
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
all_doc_ids = set()
|
||||
for channel in filtered_channels:
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
if msg_filter_func(message):
|
||||
continue
|
||||
|
||||
# The document id is the channel id and the ts of the first message in the thread
|
||||
# Since we already have the first message of the thread, we dont have to
|
||||
# fetch the thread for id retrieval, saving time and API calls
|
||||
all_doc_ids.add(f"{channel['id']}__{message['ts']}")
|
||||
|
||||
return all_doc_ids
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector, IdConnector):
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
@ -349,6 +374,16 @@ class SlackPollConnector(PollConnector):
|
||||
self.client = WebClient(token=bot_token)
|
||||
return None
|
||||
|
||||
def retrieve_all_source_ids(self) -> set[str]:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
return _get_all_doc_ids(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
@ -356,7 +391,7 @@ class SlackPollConnector(PollConnector):
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
documents: list[Document] = []
|
||||
for document in get_all_docs(
|
||||
for document in _get_all_docs(
|
||||
client=self.client,
|
||||
workspace=self.workspace,
|
||||
channels=self.channels,
|
||||
|
@ -10,11 +10,13 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_retry_wrapper = retry_builder()
|
||||
# number of messages we request per page when fetching paginated slack messages
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
@ -34,7 +36,7 @@ def get_message_link(
|
||||
)
|
||||
|
||||
|
||||
def make_slack_api_call_logged(
|
||||
def _make_slack_api_call_logged(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., SlackResponse]:
|
||||
@wraps(call)
|
||||
@ -47,7 +49,7 @@ def make_slack_api_call_logged(
|
||||
return logged_call
|
||||
|
||||
|
||||
def make_slack_api_call_paginated(
|
||||
def _make_slack_api_call_paginated(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
||||
"""Wraps calls to slack API so that they automatically handle pagination"""
|
||||
@ -116,6 +118,24 @@ def make_slack_api_rate_limited(
|
||||
return rate_limited_call
|
||||
|
||||
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def expert_info_from_slack_id(
|
||||
user_id: str | None,
|
||||
client: WebClient,
|
||||
|
@ -26,9 +26,7 @@ from danswer.db.models import UserRole
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
check_if_valid_sync_source,
|
||||
)
|
||||
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
@ -4,29 +4,38 @@ from datetime import timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.document_index.factory import get_current_primary_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
DOC_PERMISSIONS_FUNC_MAP,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
GROUP_PERMISSIONS_FUNC_MAP,
|
||||
)
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# None means that the connector runs every time
|
||||
_RESTRICTED_FETCH_PERIOD: dict[DocumentSource, int | None] = {
|
||||
# Polling is supported
|
||||
DocumentSource.GOOGLE_DRIVE: None,
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
}
|
||||
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
if cc_pair.last_time_perm_sync is None:
|
||||
return True
|
||||
|
||||
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
if (current_time - last_sync).total_seconds() > source_sync_period:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def run_external_group_permission_sync(
|
||||
@ -44,6 +53,9 @@ def run_external_group_permission_sync(
|
||||
# Not all sync connectors support group permissions so this is fine
|
||||
return
|
||||
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
@ -79,20 +91,8 @@ def run_external_doc_permission_sync(
|
||||
f"No permission sync function found for source type: {source_type}"
|
||||
)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
# If RESTRICTED_FETCH_PERIOD is not None, we only run sync if the
|
||||
# last sync was more than RESTRICTED_FETCH_PERIOD seconds ago.
|
||||
full_fetch_period = _RESTRICTED_FETCH_PERIOD[cc_pair.connector.source]
|
||||
if full_fetch_period is not None:
|
||||
last_sync = cc_pair.last_time_perm_sync
|
||||
if (
|
||||
last_sync
|
||||
and (
|
||||
datetime.now(timezone.utc) - last_sync.replace(tzinfo=timezone.utc)
|
||||
).total_seconds()
|
||||
< full_fetch_period
|
||||
):
|
||||
return
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
@ -131,6 +131,9 @@ def run_external_doc_permission_sync(
|
||||
|
||||
# update vespa
|
||||
document_index.update(update_reqs)
|
||||
|
||||
cc_pair.last_time_perm_sync = datetime.now(timezone.utc)
|
||||
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
|
192
backend/ee/danswer/external_permissions/slack/doc_sync.py
Normal file
192
backend/ee/danswer/external_permissions/slack/doc_sync.py
Normal file
@ -0,0 +1,192 @@
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.slack.connector import get_channels
|
||||
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_channel_id_from_doc_id(doc_id: str) -> str:
|
||||
"""
|
||||
Extracts the channel ID from a document ID string.
|
||||
|
||||
The document ID is expected to be in the format: "{channel_id}__{message_ts}"
|
||||
|
||||
Args:
|
||||
doc_id (str): The document ID string.
|
||||
|
||||
Returns:
|
||||
str: The extracted channel ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If the doc_id doesn't contain the expected separator.
|
||||
"""
|
||||
try:
|
||||
channel_id, _ = doc_id.split("__", 1)
|
||||
return channel_id
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid doc_id format: {doc_id}")
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, list[str]]:
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.PRUNE,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
assert isinstance(runnable_connector, IdConnector)
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_id in runnable_connector.retrieve_all_source_ids():
|
||||
channel_id = _extract_channel_id_from_doc_id(doc_id)
|
||||
if channel_id not in channel_doc_map:
|
||||
channel_doc_map[channel_id] = []
|
||||
channel_doc_map[channel_id].append(doc_id)
|
||||
|
||||
return channel_doc_map
|
||||
|
||||
|
||||
def _fetch_worspace_permissions(
|
||||
db_session: Session,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> ExternalAccess:
|
||||
user_emails = set()
|
||||
for email in user_id_to_email_map.values():
|
||||
user_emails.add(email)
|
||||
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
# No group<->document mapping for slack
|
||||
external_user_group_ids=set(),
|
||||
# No way to determine if slack is invite only without enterprise liscense
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_channel_permissions(
|
||||
db_session: Session,
|
||||
slack_client: WebClient,
|
||||
workspace_permissions: ExternalAccess,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> dict[str, ExternalAccess]:
|
||||
channel_permissions = {}
|
||||
public_channels = get_channels(
|
||||
client=slack_client,
|
||||
get_public=True,
|
||||
get_private=False,
|
||||
)
|
||||
public_channel_ids = [
|
||||
channel["id"] for channel in public_channels if "id" in channel
|
||||
]
|
||||
for channel_id in public_channel_ids:
|
||||
channel_permissions[channel_id] = workspace_permissions
|
||||
|
||||
private_channels = get_channels(
|
||||
client=slack_client,
|
||||
get_public=False,
|
||||
get_private=True,
|
||||
)
|
||||
private_channel_ids = [
|
||||
channel["id"] for channel in private_channels if "id" in channel
|
||||
]
|
||||
|
||||
for channel_id in private_channel_ids:
|
||||
# Collect all member ids for the channel pagination calls
|
||||
member_ids = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
slack_client.conversations_members,
|
||||
channel=channel_id,
|
||||
):
|
||||
member_ids.extend(result.get("members", []))
|
||||
|
||||
# Collect all member emails for the channel
|
||||
member_emails = set()
|
||||
for member_id in member_ids:
|
||||
member_email = user_id_to_email_map.get(member_id)
|
||||
|
||||
if not member_email:
|
||||
# If the user is an external user, they wont get returned from the
|
||||
# conversations_members call so we need to make a separate call to users_info
|
||||
# and add them to the user_id_to_email_map
|
||||
member_info = slack_client.users_info(user=member_id)
|
||||
member_email = member_info["user"]["profile"].get("email")
|
||||
if not member_email:
|
||||
# If no email is found, we skip the user
|
||||
continue
|
||||
user_id_to_email_map[member_id] = member_email
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session, [member_email]
|
||||
)
|
||||
|
||||
member_emails.add(member_email)
|
||||
|
||||
channel_permissions[channel_id] = ExternalAccess(
|
||||
external_user_emails=member_emails,
|
||||
# No group<->document mapping for slack
|
||||
external_user_group_ids=set(),
|
||||
# No way to determine if slack is invite only without enterprise liscense
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
return channel_permissions
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
slack_client = WebClient(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
workspace_permissions = _fetch_worspace_permissions(
|
||||
db_session=db_session,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
channel_permissions = _fetch_channel_permissions(
|
||||
db_session=db_session,
|
||||
slack_client=slack_client,
|
||||
workspace_permissions=workspace_permissions,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
for channel_id, ext_access in channel_permissions.items():
|
||||
doc_ids = channel_doc_map.get(channel_id)
|
||||
if not doc_ids:
|
||||
# No documents found for channel the channel_id
|
||||
continue
|
||||
|
||||
for doc_id in doc_ids:
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
92
backend/ee/danswer/external_permissions/slack/group_sync.py
Normal file
92
backend/ee/danswer/external_permissions/slack/group_sync.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""
|
||||
THIS IS NOT USEFUL OR USED FOR PERMISSION SYNCING
|
||||
WHEN USERGROUPS ARE ADDED TO A CHANNEL, IT JUST RESOLVES ALL THE USERS TO THAT CHANNEL
|
||||
SO WHEN CHECKING IF A USER CAN ACCESS A DOCUMENT, WE ONLY NEED TO CHECK THEIR EMAIL
|
||||
THERE IS NO USERGROUP <-> DOCUMENT PERMISSION MAPPING
|
||||
"""
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_group_ids(
|
||||
slack_client: WebClient,
|
||||
) -> list[str]:
|
||||
group_ids = []
|
||||
for result in make_paginated_slack_api_call_w_retries(slack_client.usergroups_list):
|
||||
for group in result.get("usergroups", []):
|
||||
group_ids.append(group.get("id"))
|
||||
return group_ids
|
||||
|
||||
|
||||
def _get_slack_group_members_email(
|
||||
db_session: Session,
|
||||
slack_client: WebClient,
|
||||
group_name: str,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> list[str]:
|
||||
group_member_emails = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
slack_client.usergroups_users_list, usergroup=group_name
|
||||
):
|
||||
for member_id in result.get("users", []):
|
||||
member_email = user_id_to_email_map.get(member_id)
|
||||
if not member_email:
|
||||
# If the user is an external user, they wont get returned from the
|
||||
# conversations_members call so we need to make a separate call to users_info
|
||||
member_info = slack_client.users_info(user=member_id)
|
||||
member_email = member_info["user"]["profile"].get("email")
|
||||
if not member_email:
|
||||
# If no email is found, we skip the user
|
||||
continue
|
||||
user_id_to_email_map[member_id] = member_email
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session, [member_email]
|
||||
)
|
||||
group_member_emails.append(member_email)
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def slack_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
slack_client = WebClient(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
for group_name in _get_slack_group_ids(slack_client):
|
||||
group_member_emails = _get_slack_group_members_email(
|
||||
db_session=db_session,
|
||||
slack_client=slack_client,
|
||||
group_name=group_name,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=group_member_emails
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name, user_ids=[user.id for user in group_members]
|
||||
)
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
18
backend/ee/danswer/external_permissions/slack/utils.py
Normal file
18
backend/ee/danswer/external_permissions/slack/utils.py
Normal file
@ -0,0 +1,18 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
|
||||
|
||||
def fetch_user_id_to_email_map(
|
||||
slack_client: WebClient,
|
||||
) -> dict[str, str]:
|
||||
user_id_to_email_map = {}
|
||||
for user_info in make_paginated_slack_api_call_w_retries(
|
||||
slack_client.users_list,
|
||||
):
|
||||
for user in user_info.get("members", []):
|
||||
if user.get("profile", {}).get("email"):
|
||||
user_id_to_email_map[user.get("id")] = user.get("profile", {}).get(
|
||||
"email"
|
||||
)
|
||||
return user_id_to_email_map
|
@ -8,7 +8,7 @@ from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_s
|
||||
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
|
||||
from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
SyncFuncType = Callable[
|
||||
@ -27,6 +27,7 @@ SyncFuncType = Callable[
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
||||
DocumentSource.SLACK: slack_doc_sync,
|
||||
}
|
||||
|
||||
# These functions update:
|
||||
@ -39,5 +40,13 @@ GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
}
|
||||
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
DocumentSource.SLACK: 5 * 60,
|
||||
}
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
@ -44,4 +44,5 @@ export const autoSyncConfigBySource: Record<
|
||||
),
|
||||
},
|
||||
},
|
||||
slack: {},
|
||||
};
|
||||
|
@ -272,5 +272,9 @@ export type ConfigurableSources = Exclude<
|
||||
>;
|
||||
|
||||
// The sources that have auto-sync support on the backend
|
||||
export const validAutoSyncSources = ["confluence", "google_drive"] as const;
|
||||
export const validAutoSyncSources = [
|
||||
"confluence",
|
||||
"google_drive",
|
||||
"slack",
|
||||
] as const;
|
||||
export type ValidAutoSyncSources = (typeof validAutoSyncSources)[number];
|
||||
|
Loading…
x
Reference in New Issue
Block a user