completed code revision suggestions

This commit is contained in:
Hagen O'Neill 2024-06-05 14:11:38 -07:00
parent 713d325f42
commit 8d74176348

@ -1,6 +1,6 @@
import os
from datetime import datetime
from html.parser import HTMLParser
from datetime import timezone
from typing import Any
import msal # type: ignore
@ -19,30 +19,20 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.utils.logger import setup_logger
# import pptx # type: ignore
logger = setup_logger()
class HTMLFilter(HTMLParser):
text = ""
def handle_data(self, data: str) -> None:
self.text += data
datetime_format_string = "%Y-%m-%dT%H:%M:%S.%f%z"
def get_created_datetime(obj: ChatMessage) -> datetime:
# Extract the 'createdDateTime' value from the 'properties' dictionary
created_datetime_str = obj.properties["createdDateTime"]
# Convert the string to a datetime object
return datetime.strptime(created_datetime_str, "%Y-%m-%dT%H:%M:%S.%f%z")
# Extract the 'createdDateTime' value from the 'properties' dictionary and convert it to a datetime object
return datetime.strptime(obj.properties["createdDateTime"], datetime_format_string)
class TeamsConnector(LoadConnector, PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@ -75,63 +65,74 @@ class TeamsConnector(LoadConnector, PollConnector):
self.graph_client = GraphClient(_acquire_token_func)
return None
def get_post_message_lists_from_channel(
self, channel_object: Channel
) -> list[list[ChatMessage]]:
base_message_list: list[
ChatMessage
] = channel_object.messages.get().execute_query()
post_message_lists: list[list[ChatMessage]] = []
for message in base_message_list:
replies = message.replies.get_all().execute_query()
post_message_list: list[ChatMessage] = [message]
post_message_list.extend(replies)
post_message_lists.append(post_message_list)
return post_message_lists
def get_channel_object_list_from_team_list(
def _get_threads_from_channel(
self,
team_object_list: list[Team],
channel: Channel,
start: datetime | None = None,
end: datetime | None = None,
) -> list[list[ChatMessage]]:
# Ensure start and end are timezone-aware
if start and start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)
query = channel.messages.get()
base_messages: list[ChatMessage] = query.execute_query()
threads: list[list[ChatMessage]] = []
for base_message in base_messages:
message_datetime = datetime.strptime(
base_message.properties["lastModifiedDateTime"], datetime_format_string
)
if start and message_datetime < start:
continue
if end and message_datetime > end:
continue
reply_query = base_message.replies.get_all()
replies = reply_query.execute_query()
# start a list containing the base message and its replies
thread: list[ChatMessage] = [base_message]
thread.extend(replies)
threads.append(thread)
return threads
def _get_channels_from_teams(
self,
teams: list[Team],
) -> list[Channel]:
filter_str = ""
if start is not None and end is not None:
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
channels: list[Channel] = []
for team in teams:
query = team.channels.get()
channels = query.execute_query()
channels.extend(channels)
channel_list: list[Channel] = []
for team_object in team_object_list:
query = team_object.channels.get()
if filter_str:
query = query.filter(filter_str)
channel_objects = query.execute_query()
channel_list.extend(channel_objects)
return channels
return channel_list
def get_all_team_objects(self) -> list[Team]:
def _get_all_teams(self) -> list[Team]:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams")
team_object_list: list[Team] = []
teams: list[Team] = []
teams_object = self.graph_client.teams.get().execute_query()
teams = self.graph_client.teams.get().execute_query()
if len(self.requested_team_list) > 0:
for requested_team in self.requested_team_list:
adjusted_request_string = requested_team.replace(" ", "")
for team_object in teams_object:
adjusted_team_string = team_object.display_name.replace(" ", "")
for team in teams:
adjusted_team_string = team.display_name.replace(" ", "")
if adjusted_team_string == adjusted_request_string:
team_object_list.append(team_object)
teams.append(team)
else:
team_object_list.extend(teams_object)
teams.extend(teams)
return team_object_list
return teams
def _fetch_from_teams(
self, start: datetime | None = None, end: datetime | None = None
@ -139,72 +140,57 @@ class TeamsConnector(LoadConnector, PollConnector):
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams")
team_object_list = self.get_all_team_objects()
teams = self._get_all_teams()
channel_list = self.get_channel_object_list_from_team_list(
team_object_list=team_object_list,
start=start,
end=end,
channels = self._get_channels_from_teams(
teams=teams,
)
# goes over channels, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
batch_count = 0
for channel_object in channel_list:
post_message_lists = self.get_post_message_lists_from_channel(
channel_object
)
for base_message_groups in post_message_lists:
doc_batch.append(
self.convert_post_message_list_to_document(
channel_object, base_message_groups
)
)
for channel in channels:
thread_list = self._get_threads_from_channel(channel, start=start, end=end)
for thread in thread_list:
converted_doc = self._convert_thread_to_document(channel, thread)
if converted_doc:
doc_batch.append(converted_doc)
batch_count += 1
if batch_count >= self.batch_size:
if len(doc_batch) >= self.batch_size:
yield doc_batch
batch_count = 0
doc_batch = []
yield doc_batch
def convert_post_message_list_to_document(
def _convert_thread_to_document(
self,
channel_object: Channel,
post_message_list: list[ChatMessage],
) -> Document:
channel: Channel,
thread: list[ChatMessage],
) -> Document | None:
if len(thread) <= 0:
return None
most_recent_message_datetime: datetime | None = None
semantic_string: str = (
"Channel/Post: " + channel_object.properties["displayName"]
)
post_id: str = channel_object.id
web_url: str = channel_object.web_url
messages_text = ""
top_message = thread[0]
post_members_list: list[BasicExpertInfo] = []
messages_text = ""
sorted_post_message_list = sorted(
post_message_list, key=get_created_datetime, reverse=True
)
sorted_thread = sorted(thread, key=get_created_datetime, reverse=True)
if sorted_post_message_list:
most_recent_message = sorted_post_message_list[0]
if sorted_thread:
most_recent_message = sorted_thread[0]
most_recent_message_datetime = datetime.strptime(
most_recent_message.properties["createdDateTime"],
"%Y-%m-%dT%H:%M:%S.%f%z",
datetime_format_string,
)
for message in post_message_list:
for message in thread:
# add text and a newline
if message.body.content:
html_parser = HTMLFilter()
html_parser.feed(message.body.content)
messages_text += html_parser.text + "\n"
message_text = parse_html_page_basic(message.body.content)
messages_text += message_text
# if it has a subject, that means its the top level post message, so grab its id, url, and subject
if message.properties["subject"]:
semantic_string += "/" + message.properties["subject"]
post_id = message.properties["id"]
web_url = message.web_url
top_message = message
# check to make sure there is a valid display name
if message.properties["from"]:
@ -223,7 +209,24 @@ class TeamsConnector(LoadConnector, PollConnector):
# if there are no found post members, grab the members from the parent channel
if not post_members_list:
post_members_list = self.extract_channel_members(channel_object)
post_members_list = self._extract_channel_members(channel)
semantic_string: str = "Post: " + channel.properties["displayName"]
first_poster = top_message.properties["from"]["user"]["displayName"]
channel_name = channel.properties["displayName"]
thread_subject = top_message.properties["subject"]
snippet = parse_html_page_basic(
top_message.body.content[:50].rstrip() + "..."
if len(top_message.body.content) > 50
else top_message.body.content
)
if post_members_list:
semantic_string = (
f"{first_poster} in {channel_name} about {thread_subject}: {snippet}"
)
post_id = top_message.properties["id"]
web_url = top_message.web_url
doc = Document(
id=post_id,
@ -236,12 +239,12 @@ class TeamsConnector(LoadConnector, PollConnector):
)
return doc
def extract_channel_members(self, channel_object: Channel) -> list[BasicExpertInfo]:
def _extract_channel_members(self, channel: Channel) -> list[BasicExpertInfo]:
channel_members_list: list[BasicExpertInfo] = []
member_objects = channel_object.members.get().execute_query()
for member_object in member_objects:
members = channel.members.get().execute_query()
for member in members:
channel_members_list.append(
BasicExpertInfo(display_name=member_object.display_name)
BasicExpertInfo(display_name=member.display_name)
)
return channel_members_list