fixed mypy issues

This commit is contained in:
Hagen O'Neill 2024-06-04 13:35:30 -07:00
parent 750c1df0bb
commit f5deb37fde

View File

@ -1,20 +1,13 @@
import io
import os
import tempfile
from datetime import datetime
from datetime import timezone
from typing import Any
from html.parser import HTMLParser
from typing import Any
import docx # type: ignore
import msal # type: ignore
import openpyxl # type: ignore
# import pptx # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.teams.team import Team # type: ignore
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.chats.messages.message import ChatMessage # type: ignore
from office365.outlook.mail.item_body import ItemBody # type: ignore
from office365.teams.team import Team # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
@ -28,24 +21,27 @@ from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
# import pptx # type: ignore
logger = setup_logger()
class HTMLFilter(HTMLParser):
text = ""
def handle_data(self, data):
def handle_data(self, data: str) -> None:
self.text += data
def get_created_datetime(obj: ChatMessage):
def get_created_datetime(obj: ChatMessage) -> datetime:
# Extract the 'createdDateTime' value from the 'properties' dictionary
created_datetime_str = obj.properties['createdDateTime']
created_datetime_str = obj.properties["createdDateTime"]
# Convert the string to a datetime object
return datetime.strptime(created_datetime_str, '%Y-%m-%dT%H:%M:%S.%f%z')
return datetime.strptime(created_datetime_str, "%Y-%m-%dT%H:%M:%S.%f%z")
class TeamsConnector(LoadConnector, PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@ -77,10 +73,13 @@ class TeamsConnector(LoadConnector, PollConnector):
self.graph_client = GraphClient(_acquire_token_func)
return None
def get_post_message_lists_from_channel(self, channel_object: Channel) -> list[list[ChatMessage]]:
base_message_list: list[ChatMessage] = channel_object.messages.get().execute_query()
def get_post_message_lists_from_channel(
self, channel_object: Channel
) -> list[list[ChatMessage]]:
base_message_list: list[
ChatMessage
] = channel_object.messages.get().execute_query()
post_message_lists: list[list[ChatMessage]] = []
for message in base_message_list:
@ -90,7 +89,7 @@ class TeamsConnector(LoadConnector, PollConnector):
post_message_list.extend(replies)
post_message_lists.append(post_message_list)
return post_message_lists
def get_channel_object_list_from_team_list(
@ -151,11 +150,14 @@ class TeamsConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
batch_count = 0
for channel_object in channel_list:
post_message_lists = self.get_post_message_lists_from_channel(channel_object)
post_message_lists = self.get_post_message_lists_from_channel(
channel_object
)
for base_message_groups in post_message_lists:
doc_batch.append(
self.convert_post_message_list_to_document(channel_object,
base_message_groups)
self.convert_post_message_list_to_document(
channel_object, base_message_groups
)
)
batch_count += 1
@ -171,28 +173,34 @@ class TeamsConnector(LoadConnector, PollConnector):
post_message_list: list[ChatMessage],
) -> Document:
most_recent_message_datetime: datetime | None = None
semantic_string: str = "Channel/Post: " + channel_object.properties["displayName"]
semantic_string: str = (
"Channel/Post: " + channel_object.properties["displayName"]
)
post_id: str = channel_object.id
web_url: str = channel_object.web_url
messages_text = ""
post_members_list: list[BasicExpertInfo] = []
sorted_post_message_list = sorted(post_message_list, key=get_created_datetime, reverse=True)
sorted_post_message_list = sorted(
post_message_list, key=get_created_datetime, reverse=True
)
if sorted_post_message_list:
most_recent_message = sorted_post_message_list[0]
most_recent_message_datetime = datetime.strptime(most_recent_message.properties["createdDateTime"],
'%Y-%m-%dT%H:%M:%S.%f%z')
most_recent_message_datetime = datetime.strptime(
most_recent_message.properties["createdDateTime"],
"%Y-%m-%dT%H:%M:%S.%f%z",
)
for message in post_message_list:
# add text and a newline
if message.body.content:
html_parser = HTMLFilter()
html_parser.feed(message.body.content)
messages_text += html_parser.text + '\n'
messages_text += html_parser.text + "\n"
# if it has a subject, that means its the top level post message, so grab its id, url, and subject
if message.properties['subject']:
if message.properties["subject"]:
semantic_string += "/" + message.properties["subject"]
post_id = message.properties["id"]
web_url = message.web_url
@ -201,12 +209,16 @@ class TeamsConnector(LoadConnector, PollConnector):
if message.properties["from"]:
if message.properties["from"]["user"]:
if message.properties["from"]["user"]["displayName"]:
message_sender = message.properties["from"]["user"]["displayName"]
message_sender = message.properties["from"]["user"][
"displayName"
]
# if its not a duplicate, add it to the list
if message_sender not in [member.display_name for member in post_members_list]:
if message_sender not in [
member.display_name for member in post_members_list
]:
post_members_list.append(
BasicExpertInfo(display_name=message_sender)
)
)
# if there are no found post members, grab the members from the parent channel
if not post_members_list:
@ -222,8 +234,8 @@ class TeamsConnector(LoadConnector, PollConnector):
metadata={},
)
return doc
def extract_channel_members(self, channel_object: Channel)->list[BasicExpertInfo]:
def extract_channel_members(self, channel_object: Channel) -> list[BasicExpertInfo]:
channel_members_list: list[BasicExpertInfo] = []
member_objects = channel_object.members.get().execute_query()
for member_object in member_objects: