diff --git a/backend/danswer/connectors/teams/connector.py b/backend/danswer/connectors/teams/connector.py index c26f1b5a0..ba41bdff2 100644 --- a/backend/danswer/connectors/teams/connector.py +++ b/backend/danswer/connectors/teams/connector.py @@ -1,6 +1,6 @@ import os from datetime import datetime -from html.parser import HTMLParser +from datetime import timezone from typing import Any import msal # type: ignore @@ -19,30 +19,20 @@ from danswer.connectors.models import BasicExpertInfo from danswer.connectors.models import ConnectorMissingCredentialError from danswer.connectors.models import Document from danswer.connectors.models import Section +from danswer.file_processing.html_utils import parse_html_page_basic from danswer.utils.logger import setup_logger -# import pptx # type: ignore - logger = setup_logger() - -class HTMLFilter(HTMLParser): - text = "" - - def handle_data(self, data: str) -> None: - self.text += data +datetime_format_string = "%Y-%m-%dT%H:%M:%S.%f%z" def get_created_datetime(obj: ChatMessage) -> datetime: - # Extract the 'createdDateTime' value from the 'properties' dictionary - created_datetime_str = obj.properties["createdDateTime"] - - # Convert the string to a datetime object - return datetime.strptime(created_datetime_str, "%Y-%m-%dT%H:%M:%S.%f%z") + # Extract the 'createdDateTime' value from the 'properties' dictionary and convert it to a datetime object + return datetime.strptime(obj.properties["createdDateTime"], datetime_format_string) class TeamsConnector(LoadConnector, PollConnector): - def __init__( self, batch_size: int = INDEX_BATCH_SIZE, @@ -75,63 +65,74 @@ class TeamsConnector(LoadConnector, PollConnector): self.graph_client = GraphClient(_acquire_token_func) return None - def get_post_message_lists_from_channel( - self, channel_object: Channel - ) -> list[list[ChatMessage]]: - base_message_list: list[ - ChatMessage - ] = channel_object.messages.get().execute_query() - - post_message_lists: list[list[ChatMessage]] = [] - for message in base_message_list: - replies = message.replies.get_all().execute_query() - - post_message_list: list[ChatMessage] = [message] - post_message_list.extend(replies) - - post_message_lists.append(post_message_list) - - return post_message_lists - - def get_channel_object_list_from_team_list( + def _get_threads_from_channel( self, - team_object_list: list[Team], + channel: Channel, start: datetime | None = None, end: datetime | None = None, + ) -> list[list[ChatMessage]]: + # Ensure start and end are timezone-aware + if start and start.tzinfo is None: + start = start.replace(tzinfo=timezone.utc) + if end and end.tzinfo is None: + end = end.replace(tzinfo=timezone.utc) + + query = channel.messages.get() + base_messages: list[ChatMessage] = query.execute_query() + + threads: list[list[ChatMessage]] = [] + for base_message in base_messages: + message_datetime = datetime.strptime( + base_message.properties["lastModifiedDateTime"], datetime_format_string + ) + + if start and message_datetime < start: + continue + if end and message_datetime > end: + continue + + reply_query = base_message.replies.get_all() + replies = reply_query.execute_query() + + # start a list containing the base message and its replies + thread: list[ChatMessage] = [base_message] + thread.extend(replies) + + threads.append(thread) + + return threads + + def _get_channels_from_teams( + self, + teams: list[Team], ) -> list[Channel]: - filter_str = "" - if start is not None and end is not None: - filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}" + channels: list[Channel] = [] + for team in teams: + query = team.channels.get() + channels = query.execute_query() + channels.extend(channels) - channel_list: list[Channel] = [] - for team_object in team_object_list: - query = team_object.channels.get() - if filter_str: - query = query.filter(filter_str) - channel_objects = query.execute_query() - channel_list.extend(channel_objects) + return channels - return channel_list - - def get_all_team_objects(self) -> list[Team]: + def _get_all_teams(self) -> list[Team]: if self.graph_client is None: raise ConnectorMissingCredentialError("Teams") - team_object_list: list[Team] = [] + teams: list[Team] = [] - teams_object = self.graph_client.teams.get().execute_query() + teams = self.graph_client.teams.get().execute_query() if len(self.requested_team_list) > 0: for requested_team in self.requested_team_list: adjusted_request_string = requested_team.replace(" ", "") - for team_object in teams_object: - adjusted_team_string = team_object.display_name.replace(" ", "") + for team in teams: + adjusted_team_string = team.display_name.replace(" ", "") if adjusted_team_string == adjusted_request_string: - team_object_list.append(team_object) + teams.append(team) else: - team_object_list.extend(teams_object) + teams.extend(teams) - return team_object_list + return teams def _fetch_from_teams( self, start: datetime | None = None, end: datetime | None = None @@ -139,72 +140,57 @@ class TeamsConnector(LoadConnector, PollConnector): if self.graph_client is None: raise ConnectorMissingCredentialError("Teams") - team_object_list = self.get_all_team_objects() + teams = self._get_all_teams() - channel_list = self.get_channel_object_list_from_team_list( - team_object_list=team_object_list, - start=start, - end=end, + channels = self._get_channels_from_teams( + teams=teams, ) # goes over channels, converts them into Document objects and then yields them in batches doc_batch: list[Document] = [] - batch_count = 0 - for channel_object in channel_list: - post_message_lists = self.get_post_message_lists_from_channel( - channel_object - ) - for base_message_groups in post_message_lists: - doc_batch.append( - self.convert_post_message_list_to_document( - channel_object, base_message_groups - ) - ) + for channel in channels: + thread_list = self._get_threads_from_channel(channel, start=start, end=end) + for thread in thread_list: + converted_doc = self._convert_thread_to_document(channel, thread) + if converted_doc: + doc_batch.append(converted_doc) - batch_count += 1 - if batch_count >= self.batch_size: + if len(doc_batch) >= self.batch_size: yield doc_batch - batch_count = 0 doc_batch = [] yield doc_batch - def convert_post_message_list_to_document( + def _convert_thread_to_document( self, - channel_object: Channel, - post_message_list: list[ChatMessage], - ) -> Document: + channel: Channel, + thread: list[ChatMessage], + ) -> Document | None: + if len(thread) <= 0: + return None + most_recent_message_datetime: datetime | None = None - semantic_string: str = ( - "Channel/Post: " + channel_object.properties["displayName"] - ) - post_id: str = channel_object.id - web_url: str = channel_object.web_url - messages_text = "" + top_message = thread[0] post_members_list: list[BasicExpertInfo] = [] + messages_text = "" - sorted_post_message_list = sorted( - post_message_list, key=get_created_datetime, reverse=True - ) + sorted_thread = sorted(thread, key=get_created_datetime, reverse=True) - if sorted_post_message_list: - most_recent_message = sorted_post_message_list[0] + if sorted_thread: + most_recent_message = sorted_thread[0] most_recent_message_datetime = datetime.strptime( most_recent_message.properties["createdDateTime"], - "%Y-%m-%dT%H:%M:%S.%f%z", + datetime_format_string, ) - for message in post_message_list: + for message in thread: # add text and a newline if message.body.content: - html_parser = HTMLFilter() - html_parser.feed(message.body.content) - messages_text += html_parser.text + "\n" + message_text = parse_html_page_basic(message.body.content) + messages_text += message_text # if it has a subject, that means its the top level post message, so grab its id, url, and subject if message.properties["subject"]: - semantic_string += "/" + message.properties["subject"] - post_id = message.properties["id"] - web_url = message.web_url + top_message = message # check to make sure there is a valid display name if message.properties["from"]: @@ -223,7 +209,24 @@ class TeamsConnector(LoadConnector, PollConnector): # if there are no found post members, grab the members from the parent channel if not post_members_list: - post_members_list = self.extract_channel_members(channel_object) + post_members_list = self._extract_channel_members(channel) + + semantic_string: str = "Post: " + channel.properties["displayName"] + + first_poster = top_message.properties["from"]["user"]["displayName"] + channel_name = channel.properties["displayName"] + thread_subject = top_message.properties["subject"] + snippet = parse_html_page_basic( + top_message.body.content[:50].rstrip() + "..." + if len(top_message.body.content) > 50 + else top_message.body.content + ) + if post_members_list: + semantic_string = ( + f"{first_poster} in {channel_name} about {thread_subject}: {snippet}" + ) + post_id = top_message.properties["id"] + web_url = top_message.web_url doc = Document( id=post_id, @@ -236,12 +239,12 @@ class TeamsConnector(LoadConnector, PollConnector): ) return doc - def extract_channel_members(self, channel_object: Channel) -> list[BasicExpertInfo]: + def _extract_channel_members(self, channel: Channel) -> list[BasicExpertInfo]: channel_members_list: list[BasicExpertInfo] = [] - member_objects = channel_object.members.get().execute_query() - for member_object in member_objects: + members = channel.members.get().execute_query() + for member in members: channel_members_list.append( - BasicExpertInfo(display_name=member_object.display_name) + BasicExpertInfo(display_name=member.display_name) ) return channel_members_list