Fix SlackBot still tagging groups (#564)

This commit is contained in:
Yuhong Sun
2023-10-12 00:32:43 -07:00
committed by GitHub
parent 51490b5cd9
commit a7578c9707
4 changed files with 49 additions and 24 deletions

View File

@@ -108,10 +108,13 @@ def build_documents_blocks(
included_docs += 1 included_docs += 1
if d.link:
block_text = f"<{d.link}|{doc_sem_id}>:\n>{remove_slack_text_interactions(match_str)}"
else:
block_text = f"{doc_sem_id}:\n>{remove_slack_text_interactions(match_str)}"
section_blocks.append( section_blocks.append(
SectionBlock( SectionBlock(text=block_text),
text=f"<{d.link}|{doc_sem_id}>:\n>{remove_slack_text_interactions(match_str)}"
),
) )
if include_feedback: if include_feedback:

View File

@@ -17,7 +17,7 @@ from danswer.bots.slack.tokens import fetch_tokens
from danswer.configs.constants import ID_SEPARATOR from danswer.configs.constants import ID_SEPARATOR
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import UserIdReplacer from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import replace_whitespaces_w_space from danswer.utils.text_processing import replace_whitespaces_w_space
@@ -157,11 +157,12 @@ def translate_vespa_highlight_to_slack(match_strs: list[str], used_chars: int) -
def remove_slack_text_interactions(slack_str: str) -> str: def remove_slack_text_interactions(slack_str: str) -> str:
slack_str = UserIdReplacer.replace_tags_basic(slack_str) slack_str = SlackTextCleaner.replace_tags_basic(slack_str)
slack_str = UserIdReplacer.replace_channels_basic(slack_str) slack_str = SlackTextCleaner.replace_channels_basic(slack_str)
slack_str = UserIdReplacer.replace_special_mentions(slack_str) slack_str = SlackTextCleaner.replace_special_mentions(slack_str)
slack_str = UserIdReplacer.replace_links(slack_str) slack_str = SlackTextCleaner.replace_links(slack_str)
slack_str = UserIdReplacer.add_zero_width_whitespace_after_tag(slack_str) slack_str = SlackTextCleaner.replace_special_catchall(slack_str)
slack_str = SlackTextCleaner.add_zero_width_whitespace_after_tag(slack_str)
return slack_str return slack_str

View File

@@ -23,7 +23,7 @@ from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_logged from danswer.connectors.slack.utils import make_slack_api_call_logged
from danswer.connectors.slack.utils import make_slack_api_call_paginated from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import UserIdReplacer from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@@ -132,7 +132,7 @@ def thread_to_doc(
workspace: str, workspace: str,
channel: ChannelType, channel: ChannelType,
thread: ThreadType, thread: ThreadType,
user_id_replacer: UserIdReplacer, slack_cleaner: SlackTextCleaner,
) -> Document: ) -> Document:
channel_id = channel["id"] channel_id = channel["id"]
return Document( return Document(
@@ -142,7 +142,7 @@ def thread_to_doc(
link=get_message_link( link=get_message_link(
event=m, workspace=workspace, channel_id=channel_id event=m, workspace=workspace, channel_id=channel_id
), ),
text=user_id_replacer.replace_user_ids_with_names(cast(str, m["text"])), text=slack_cleaner.index_clean(cast(str, m["text"])),
) )
for m in thread for m in thread
], ],
@@ -204,7 +204,7 @@ def get_all_docs(
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter, msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
) -> Generator[Document, None, None]: ) -> Generator[Document, None, None]:
"""Get all documents in the workspace, channel by channel""" """Get all documents in the workspace, channel by channel"""
user_id_replacer = UserIdReplacer(client=client) slack_cleaner = SlackTextCleaner(client=client)
all_channels = get_channels(client) all_channels = get_channels(client)
filtered_channels = _filter_channels(all_channels, channels) filtered_channels = _filter_channels(all_channels, channels)
@@ -241,7 +241,7 @@ def get_all_docs(
workspace=workspace, workspace=workspace,
channel=channel, channel=channel,
thread=filtered_thread, thread=filtered_thread,
user_id_replacer=user_id_replacer, slack_cleaner=slack_cleaner,
) )
logger.info( logger.info(

View File

@@ -101,23 +101,23 @@ def make_slack_api_rate_limited(
return rate_limited_call return rate_limited_call
class UserIdReplacer: class SlackTextCleaner:
"""Utility class to replace user IDs with usernames in a message. """Utility class to replace user IDs with usernames in a message.
Handles caching, so the same request is not made multiple times Handles caching, so the same request is not made multiple times
for the same user ID""" for the same user ID"""
def __init__(self, client: WebClient) -> None: def __init__(self, client: WebClient) -> None:
self._client = client self._client = client
self._user_id_to_name_map: dict[str, str] = {} self._id_to_name_map: dict[str, str] = {}
def _get_slack_user_name(self, user_id: str) -> str: def _get_slack_name(self, user_id: str) -> str:
if user_id not in self._user_id_to_name_map: if user_id not in self._id_to_name_map:
try: try:
response = make_slack_api_rate_limited(self._client.users_info)( response = make_slack_api_rate_limited(self._client.users_info)(
user=user_id user=user_id
) )
# prefer display name if set, since that is what is shown in Slack # prefer display name if set, since that is what is shown in Slack
self._user_id_to_name_map[user_id] = ( self._id_to_name_map[user_id] = (
response["user"]["profile"]["display_name"] response["user"]["profile"]["display_name"]
or response["user"]["profile"]["real_name"] or response["user"]["profile"]["real_name"]
) )
@@ -127,19 +127,19 @@ class UserIdReplacer:
) )
raise raise
return self._user_id_to_name_map[user_id] return self._id_to_name_map[user_id]
def replace_user_ids_with_names(self, message: str) -> str: def _replace_user_ids_with_names(self, message: str) -> str:
# Find user IDs in the message # Find user IDs in the message
user_ids = re.findall("<@(.*?)>", message) user_ids = re.findall("<@(.*?)>", message)
# Iterate over each user ID found # Iterate over each user ID found
for user_id in user_ids: for user_id in user_ids:
try: try:
if user_id in self._user_id_to_name_map: if user_id in self._id_to_name_map:
user_name = self._user_id_to_name_map[user_id] user_name = self._id_to_name_map[user_id]
else: else:
user_name = self._get_slack_user_name(user_id) user_name = self._get_slack_name(user_id)
# Replace the user ID with the username in the message # Replace the user ID with the username in the message
message = message.replace(f"<@{user_id}>", f"@{user_name}") message = message.replace(f"<@{user_id}>", f"@{user_name}")
@@ -150,6 +150,18 @@ class UserIdReplacer:
return message return message
def index_clean(self, message: str) -> str:
"""During indexing, replace pattern sets that may cause confusion to the model
Some special patterns are left in as they can provide information
ie. links that contain format text|link, both the text and the link may be informative
"""
message = self._replace_user_ids_with_names(message)
message = self.replace_tags_basic(message)
message = self.replace_channels_basic(message)
message = self.replace_special_mentions(message)
message = self.replace_special_catchall(message)
return message
@staticmethod @staticmethod
def replace_tags_basic(message: str) -> str: def replace_tags_basic(message: str) -> str:
"""Simply replaces all tags with `@<USER_ID>` in order to prevent us from """Simply replaces all tags with `@<USER_ID>` in order to prevent us from
@@ -197,6 +209,15 @@ class UserIdReplacer:
message = message.replace(f"<{possible_link}>", link_display) message = message.replace(f"<{possible_link}>", link_display)
return message return message
@staticmethod
def replace_special_catchall(message: str) -> str:
"""Replaces pattern of <!something|another-thing> with another-thing
This is added for <!subteam^TEAM-ID|@team-name> but may match other cases as well
"""
pattern = r"<!([^|]+)\|([^>]+)>"
return re.sub(pattern, r"\2", message)
@staticmethod @staticmethod
def add_zero_width_whitespace_after_tag(message: str) -> str: def add_zero_width_whitespace_after_tag(message: str) -> str:
"""Add a 0 width whitespace after every @""" """Add a 0 width whitespace after every @"""