mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-25 23:40:58 +02:00
moved methods to top and fixed logic errors
This commit is contained in:
parent
ff59858327
commit
cbc53fd500
@ -34,6 +34,149 @@ def get_created_datetime(chat_message: ChatMessage) -> datetime:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
|
||||||
|
channel_members_list: list[BasicExpertInfo] = []
|
||||||
|
members = channel.members.get().execute_query()
|
||||||
|
for member in members:
|
||||||
|
channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
|
||||||
|
return channel_members_list
|
||||||
|
|
||||||
|
|
||||||
|
def _get_threads_from_channel(
|
||||||
|
channel: Channel,
|
||||||
|
start: datetime | None = None,
|
||||||
|
end: datetime | None = None,
|
||||||
|
) -> list[list[ChatMessage]]:
|
||||||
|
# Ensure start and end are timezone-aware
|
||||||
|
if start and start.tzinfo is None:
|
||||||
|
start = start.replace(tzinfo=timezone.utc)
|
||||||
|
if end and end.tzinfo is None:
|
||||||
|
end = end.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
query = channel.messages.get()
|
||||||
|
base_messages: list[ChatMessage] = query.execute_query()
|
||||||
|
|
||||||
|
threads: list[list[ChatMessage]] = []
|
||||||
|
for base_message in base_messages:
|
||||||
|
message_datetime = datetime.strptime(
|
||||||
|
base_message.properties["lastModifiedDateTime"], datetime_format_string
|
||||||
|
)
|
||||||
|
|
||||||
|
if start and message_datetime < start:
|
||||||
|
continue
|
||||||
|
if end and message_datetime > end:
|
||||||
|
continue
|
||||||
|
|
||||||
|
reply_query = base_message.replies.get_all()
|
||||||
|
replies = reply_query.execute_query()
|
||||||
|
|
||||||
|
# start a list containing the base message and its replies
|
||||||
|
thread: list[ChatMessage] = [base_message]
|
||||||
|
thread.extend(replies)
|
||||||
|
|
||||||
|
threads.append(thread)
|
||||||
|
|
||||||
|
return threads
|
||||||
|
|
||||||
|
|
||||||
|
def _get_channels_from_teams(
|
||||||
|
teams: list[Team],
|
||||||
|
) -> list[Channel]:
|
||||||
|
channels_list: list[Channel] = []
|
||||||
|
for team in teams:
|
||||||
|
query = team.channels.get()
|
||||||
|
channels = query.execute_query()
|
||||||
|
channels_list.extend(channels)
|
||||||
|
|
||||||
|
return channels_list
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_semantic_identifier(channel: Channel, top_message: ChatMessage) -> str:
|
||||||
|
first_poster = (
|
||||||
|
top_message.properties.get("from", {})
|
||||||
|
.get("user", {})
|
||||||
|
.get("displayName", "Unknown User")
|
||||||
|
)
|
||||||
|
channel_name = channel.properties.get("displayName", "Unknown")
|
||||||
|
thread_subject = top_message.properties.get("subject", "Unknown")
|
||||||
|
|
||||||
|
snippet = parse_html_page_basic(top_message.body.content.rstrip())
|
||||||
|
snippet = snippet[:50] + "..." if len(snippet) > 50 else snippet
|
||||||
|
|
||||||
|
return f"{first_poster} in {channel_name} about {thread_subject}: {snippet}"
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_thread_to_document(
|
||||||
|
channel: Channel,
|
||||||
|
thread: list[ChatMessage],
|
||||||
|
) -> Document | None:
|
||||||
|
if len(thread) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
most_recent_message_datetime: datetime | None = None
|
||||||
|
top_message = thread[0]
|
||||||
|
post_members_list: list[BasicExpertInfo] = []
|
||||||
|
thread_text = ""
|
||||||
|
|
||||||
|
sorted_thread = sorted(thread, key=get_created_datetime, reverse=True)
|
||||||
|
|
||||||
|
if sorted_thread:
|
||||||
|
most_recent_message = sorted_thread[0]
|
||||||
|
most_recent_message_datetime = datetime.strptime(
|
||||||
|
most_recent_message.properties["createdDateTime"],
|
||||||
|
datetime_format_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
for message in thread:
|
||||||
|
# add text and a newline
|
||||||
|
if message.body.content:
|
||||||
|
message_text = parse_html_page_basic(message.body.content)
|
||||||
|
thread_text += message_text
|
||||||
|
|
||||||
|
# if it has a subject, that means its the top level post message, so grab its id, url, and subject
|
||||||
|
if message.properties["subject"]:
|
||||||
|
top_message = message
|
||||||
|
|
||||||
|
# check to make sure there is a valid display name
|
||||||
|
if message.properties["from"]:
|
||||||
|
if message.properties["from"]["user"]:
|
||||||
|
if message.properties["from"]["user"]["displayName"]:
|
||||||
|
message_sender = message.properties["from"]["user"]["displayName"]
|
||||||
|
# if its not a duplicate, add it to the list
|
||||||
|
if message_sender not in [
|
||||||
|
member.display_name for member in post_members_list
|
||||||
|
]:
|
||||||
|
post_members_list.append(
|
||||||
|
BasicExpertInfo(display_name=message_sender)
|
||||||
|
)
|
||||||
|
|
||||||
|
# if there are no found post members, grab the members from the parent channel
|
||||||
|
if not post_members_list:
|
||||||
|
post_members_list = _extract_channel_members(channel)
|
||||||
|
|
||||||
|
if not thread_text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
semantic_string = _construct_semantic_identifier(channel, top_message)
|
||||||
|
if not semantic_string:
|
||||||
|
return None
|
||||||
|
|
||||||
|
post_id = top_message.properties["id"]
|
||||||
|
web_url = top_message.web_url
|
||||||
|
|
||||||
|
doc = Document(
|
||||||
|
id=post_id,
|
||||||
|
sections=[Section(link=web_url, text=thread_text)],
|
||||||
|
source=DocumentSource.TEAMS,
|
||||||
|
semantic_identifier=semantic_string,
|
||||||
|
title="", # teams threads don't really have a "title"
|
||||||
|
doc_updated_at=most_recent_message_datetime,
|
||||||
|
primary_owners=post_members_list,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
class TeamsConnector(LoadConnector, PollConnector):
|
class TeamsConnector(LoadConnector, PollConnector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -67,55 +210,6 @@ class TeamsConnector(LoadConnector, PollConnector):
|
|||||||
self.graph_client = GraphClient(_acquire_token_func)
|
self.graph_client = GraphClient(_acquire_token_func)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_threads_from_channel(
|
|
||||||
self,
|
|
||||||
channel: Channel,
|
|
||||||
start: datetime | None = None,
|
|
||||||
end: datetime | None = None,
|
|
||||||
) -> list[list[ChatMessage]]:
|
|
||||||
# Ensure start and end are timezone-aware
|
|
||||||
if start and start.tzinfo is None:
|
|
||||||
start = start.replace(tzinfo=timezone.utc)
|
|
||||||
if end and end.tzinfo is None:
|
|
||||||
end = end.replace(tzinfo=timezone.utc)
|
|
||||||
|
|
||||||
query = channel.messages.get()
|
|
||||||
base_messages: list[ChatMessage] = query.execute_query()
|
|
||||||
|
|
||||||
threads: list[list[ChatMessage]] = []
|
|
||||||
for base_message in base_messages:
|
|
||||||
message_datetime = datetime.strptime(
|
|
||||||
base_message.properties["lastModifiedDateTime"], datetime_format_string
|
|
||||||
)
|
|
||||||
|
|
||||||
if start and message_datetime < start:
|
|
||||||
continue
|
|
||||||
if end and message_datetime > end:
|
|
||||||
continue
|
|
||||||
|
|
||||||
reply_query = base_message.replies.get_all()
|
|
||||||
replies = reply_query.execute_query()
|
|
||||||
|
|
||||||
# start a list containing the base message and its replies
|
|
||||||
thread: list[ChatMessage] = [base_message]
|
|
||||||
thread.extend(replies)
|
|
||||||
|
|
||||||
threads.append(thread)
|
|
||||||
|
|
||||||
return threads
|
|
||||||
|
|
||||||
def _get_channels_from_teams(
|
|
||||||
self,
|
|
||||||
teams: list[Team],
|
|
||||||
) -> list[Channel]:
|
|
||||||
channels_list: list[Channel] = []
|
|
||||||
for team in teams:
|
|
||||||
query = team.channels.get()
|
|
||||||
channels = query.execute_query()
|
|
||||||
channels_list.extend(channels)
|
|
||||||
|
|
||||||
return channels_list
|
|
||||||
|
|
||||||
def _get_all_teams(self) -> list[Team]:
|
def _get_all_teams(self) -> list[Team]:
|
||||||
if self.graph_client is None:
|
if self.graph_client is None:
|
||||||
raise ConnectorMissingCredentialError("Teams")
|
raise ConnectorMissingCredentialError("Teams")
|
||||||
@ -144,16 +238,16 @@ class TeamsConnector(LoadConnector, PollConnector):
|
|||||||
|
|
||||||
teams = self._get_all_teams()
|
teams = self._get_all_teams()
|
||||||
|
|
||||||
channels = self._get_channels_from_teams(
|
channels = _get_channels_from_teams(
|
||||||
teams=teams,
|
teams=teams,
|
||||||
)
|
)
|
||||||
|
|
||||||
# goes over channels, converts them into Document objects and then yields them in batches
|
# goes over channels, converts them into Document objects and then yields them in batches
|
||||||
doc_batch: list[Document] = []
|
doc_batch: list[Document] = []
|
||||||
for channel in channels:
|
for channel in channels:
|
||||||
thread_list = self._get_threads_from_channel(channel, start=start, end=end)
|
thread_list = _get_threads_from_channel(channel, start=start, end=end)
|
||||||
for thread in thread_list:
|
for thread in thread_list:
|
||||||
converted_doc = self._convert_thread_to_document(channel, thread)
|
converted_doc = _convert_thread_to_document(channel, thread)
|
||||||
if converted_doc:
|
if converted_doc:
|
||||||
doc_batch.append(converted_doc)
|
doc_batch.append(converted_doc)
|
||||||
|
|
||||||
@ -162,107 +256,6 @@ class TeamsConnector(LoadConnector, PollConnector):
|
|||||||
doc_batch = []
|
doc_batch = []
|
||||||
yield doc_batch
|
yield doc_batch
|
||||||
|
|
||||||
def _construct_semantic_identifier(
|
|
||||||
self, channel: Channel, top_message: ChatMessage
|
|
||||||
) -> str | None:
|
|
||||||
first_poster = (
|
|
||||||
top_message.properties.get("from", {})
|
|
||||||
.get("user", {})
|
|
||||||
.get("displayName", {})
|
|
||||||
)
|
|
||||||
channel_name = channel.properties["displayName"]
|
|
||||||
thread_subject = top_message.properties["subject"]
|
|
||||||
snippet = parse_html_page_basic(
|
|
||||||
top_message.body.content[:50].rstrip() + "..."
|
|
||||||
if len(top_message.body.content) > 50
|
|
||||||
else top_message.body.content
|
|
||||||
)
|
|
||||||
if not first_poster or not channel_name or not thread_subject or not snippet:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return f"{first_poster} in {channel_name} about {thread_subject}: {snippet}"
|
|
||||||
|
|
||||||
def _convert_thread_to_document(
|
|
||||||
self,
|
|
||||||
channel: Channel,
|
|
||||||
thread: list[ChatMessage],
|
|
||||||
) -> Document | None:
|
|
||||||
if len(thread) <= 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
most_recent_message_datetime: datetime | None = None
|
|
||||||
top_message = thread[0]
|
|
||||||
post_members_list: list[BasicExpertInfo] = []
|
|
||||||
thread_text = ""
|
|
||||||
|
|
||||||
sorted_thread = sorted(thread, key=get_created_datetime, reverse=True)
|
|
||||||
|
|
||||||
if sorted_thread:
|
|
||||||
most_recent_message = sorted_thread[0]
|
|
||||||
most_recent_message_datetime = datetime.strptime(
|
|
||||||
most_recent_message.properties["createdDateTime"],
|
|
||||||
datetime_format_string,
|
|
||||||
)
|
|
||||||
|
|
||||||
for message in thread:
|
|
||||||
# add text and a newline
|
|
||||||
if message.body.content:
|
|
||||||
message_text = parse_html_page_basic(message.body.content)
|
|
||||||
thread_text += message_text
|
|
||||||
|
|
||||||
# if it has a subject, that means its the top level post message, so grab its id, url, and subject
|
|
||||||
if message.properties["subject"]:
|
|
||||||
top_message = message
|
|
||||||
|
|
||||||
# check to make sure there is a valid display name
|
|
||||||
if message.properties["from"]:
|
|
||||||
if message.properties["from"]["user"]:
|
|
||||||
if message.properties["from"]["user"]["displayName"]:
|
|
||||||
message_sender = message.properties["from"]["user"][
|
|
||||||
"displayName"
|
|
||||||
]
|
|
||||||
# if its not a duplicate, add it to the list
|
|
||||||
if message_sender not in [
|
|
||||||
member.display_name for member in post_members_list
|
|
||||||
]:
|
|
||||||
post_members_list.append(
|
|
||||||
BasicExpertInfo(display_name=message_sender)
|
|
||||||
)
|
|
||||||
|
|
||||||
# if there are no found post members, grab the members from the parent channel
|
|
||||||
if not post_members_list:
|
|
||||||
post_members_list = self._extract_channel_members(channel)
|
|
||||||
|
|
||||||
if not thread_text:
|
|
||||||
return None
|
|
||||||
|
|
||||||
semantic_string = self._construct_semantic_identifier(channel, top_message)
|
|
||||||
if not semantic_string:
|
|
||||||
return None
|
|
||||||
|
|
||||||
post_id = top_message.properties["id"]
|
|
||||||
web_url = top_message.web_url
|
|
||||||
|
|
||||||
doc = Document(
|
|
||||||
id=post_id,
|
|
||||||
sections=[Section(link=web_url, text=thread_text)],
|
|
||||||
source=DocumentSource.TEAMS,
|
|
||||||
semantic_identifier=semantic_string,
|
|
||||||
doc_updated_at=most_recent_message_datetime,
|
|
||||||
primary_owners=post_members_list,
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
return doc
|
|
||||||
|
|
||||||
def _extract_channel_members(self, channel: Channel) -> list[BasicExpertInfo]:
|
|
||||||
channel_members_list: list[BasicExpertInfo] = []
|
|
||||||
members = channel.members.get().execute_query()
|
|
||||||
for member in members:
|
|
||||||
channel_members_list.append(
|
|
||||||
BasicExpertInfo(display_name=member.display_name)
|
|
||||||
)
|
|
||||||
return channel_members_list
|
|
||||||
|
|
||||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||||
return self._fetch_from_teams()
|
return self._fetch_from_teams()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user