diff --git a/backend/danswer/connectors/teams/connector.py b/backend/danswer/connectors/teams/connector.py index 551831f2f..15869c3ca 100644 --- a/backend/danswer/connectors/teams/connector.py +++ b/backend/danswer/connectors/teams/connector.py @@ -1,20 +1,13 @@ -import io import os -import tempfile from datetime import datetime -from datetime import timezone -from typing import Any from html.parser import HTMLParser +from typing import Any -import docx # type: ignore import msal # type: ignore -import openpyxl # type: ignore -# import pptx # type: ignore from office365.graph_client import GraphClient # type: ignore -from office365.teams.team import Team # type: ignore from office365.teams.channels.channel import Channel # type: ignore from office365.teams.chats.messages.message import ChatMessage # type: ignore -from office365.outlook.mail.item_body import ItemBody # type: ignore +from office365.teams.team import Team # type: ignore from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource @@ -28,24 +21,27 @@ from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.utils.logger import setup_logger +# import pptx # type: ignore + logger = setup_logger() class HTMLFilter(HTMLParser): text = "" - def handle_data(self, data): + + def handle_data(self, data: str) -> None: self.text += data -def get_created_datetime(obj: ChatMessage): + +def get_created_datetime(obj: ChatMessage) -> datetime: # Extract the 'createdDateTime' value from the 'properties' dictionary - created_datetime_str = obj.properties['createdDateTime'] - + created_datetime_str = obj.properties["createdDateTime"] + # Convert the string to a datetime object - return datetime.strptime(created_datetime_str, '%Y-%m-%dT%H:%M:%S.%f%z') + return datetime.strptime(created_datetime_str, "%Y-%m-%dT%H:%M:%S.%f%z") class TeamsConnector(LoadConnector, PollConnector): - def __init__( self, batch_size: int = INDEX_BATCH_SIZE, @@ -77,10 +73,13 @@ class TeamsConnector(LoadConnector, PollConnector): self.graph_client = GraphClient(_acquire_token_func) return None - - def get_post_message_lists_from_channel(self, channel_object: Channel) -> list[list[ChatMessage]]: - - base_message_list: list[ChatMessage] = channel_object.messages.get().execute_query() + + def get_post_message_lists_from_channel( + self, channel_object: Channel + ) -> list[list[ChatMessage]]: + base_message_list: list[ + ChatMessage + ] = channel_object.messages.get().execute_query() post_message_lists: list[list[ChatMessage]] = [] for message in base_message_list: @@ -90,7 +89,7 @@ class TeamsConnector(LoadConnector, PollConnector): post_message_list.extend(replies) post_message_lists.append(post_message_list) - + return post_message_lists def get_channel_object_list_from_team_list( @@ -151,11 +150,14 @@ class TeamsConnector(LoadConnector, PollConnector): doc_batch: list[Document] = [] batch_count = 0 for channel_object in channel_list: - post_message_lists = self.get_post_message_lists_from_channel(channel_object) + post_message_lists = self.get_post_message_lists_from_channel( + channel_object + ) for base_message_groups in post_message_lists: doc_batch.append( - self.convert_post_message_list_to_document(channel_object, - base_message_groups) + self.convert_post_message_list_to_document( + channel_object, base_message_groups + ) ) batch_count += 1 @@ -171,28 +173,34 @@ class TeamsConnector(LoadConnector, PollConnector): post_message_list: list[ChatMessage], ) -> Document: most_recent_message_datetime: datetime | None = None - semantic_string: str = "Channel/Post: " + channel_object.properties["displayName"] + semantic_string: str = ( + "Channel/Post: " + channel_object.properties["displayName"] + ) post_id: str = channel_object.id web_url: str = channel_object.web_url messages_text = "" post_members_list: list[BasicExpertInfo] = [] - sorted_post_message_list = sorted(post_message_list, key=get_created_datetime, reverse=True) + sorted_post_message_list = sorted( + post_message_list, key=get_created_datetime, reverse=True + ) if sorted_post_message_list: most_recent_message = sorted_post_message_list[0] - most_recent_message_datetime = datetime.strptime(most_recent_message.properties["createdDateTime"], - '%Y-%m-%dT%H:%M:%S.%f%z') - + most_recent_message_datetime = datetime.strptime( + most_recent_message.properties["createdDateTime"], + "%Y-%m-%dT%H:%M:%S.%f%z", + ) + for message in post_message_list: # add text and a newline if message.body.content: html_parser = HTMLFilter() html_parser.feed(message.body.content) - messages_text += html_parser.text + '\n' + messages_text += html_parser.text + "\n" # if it has a subject, that means its the top level post message, so grab its id, url, and subject - if message.properties['subject']: + if message.properties["subject"]: semantic_string += "/" + message.properties["subject"] post_id = message.properties["id"] web_url = message.web_url @@ -201,12 +209,16 @@ class TeamsConnector(LoadConnector, PollConnector): if message.properties["from"]: if message.properties["from"]["user"]: if message.properties["from"]["user"]["displayName"]: - message_sender = message.properties["from"]["user"]["displayName"] + message_sender = message.properties["from"]["user"][ + "displayName" + ] # if its not a duplicate, add it to the list - if message_sender not in [member.display_name for member in post_members_list]: + if message_sender not in [ + member.display_name for member in post_members_list + ]: post_members_list.append( BasicExpertInfo(display_name=message_sender) - ) + ) # if there are no found post members, grab the members from the parent channel if not post_members_list: @@ -222,8 +234,8 @@ class TeamsConnector(LoadConnector, PollConnector): metadata={}, ) return doc - - def extract_channel_members(self, channel_object: Channel)->list[BasicExpertInfo]: + + def extract_channel_members(self, channel_object: Channel) -> list[BasicExpertInfo]: channel_members_list: list[BasicExpertInfo] = [] member_objects = channel_object.members.get().execute_query() for member_object in member_objects: