mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-02 21:22:51 +02:00
344 lines
12 KiB
Python
344 lines
12 KiB
Python
import os
|
|
from datetime import datetime
|
|
from datetime import timezone
|
|
from typing import Any
|
|
|
|
import msal # type: ignore
|
|
from office365.graph_client import GraphClient # type: ignore
|
|
from office365.runtime.client_request_exception import ClientRequestException # type: ignore
|
|
from office365.teams.channels.channel import Channel # type: ignore
|
|
from office365.teams.chats.messages.message import ChatMessage # type: ignore
|
|
from office365.teams.team import Team # type: ignore
|
|
|
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
|
from onyx.configs.constants import DocumentSource
|
|
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
|
from onyx.connectors.exceptions import ConnectorValidationError
|
|
from onyx.connectors.exceptions import CredentialExpiredError
|
|
from onyx.connectors.exceptions import InsufficientPermissionsError
|
|
from onyx.connectors.exceptions import UnexpectedError
|
|
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
|
from onyx.connectors.interfaces import LoadConnector
|
|
from onyx.connectors.interfaces import PollConnector
|
|
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
|
from onyx.connectors.models import BasicExpertInfo
|
|
from onyx.connectors.models import ConnectorMissingCredentialError
|
|
from onyx.connectors.models import Document
|
|
from onyx.connectors.models import Section
|
|
from onyx.file_processing.html_utils import parse_html_page_basic
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def get_created_datetime(chat_message: ChatMessage) -> datetime:
|
|
# Extract the 'createdDateTime' value from the 'properties' dictionary and convert it to a datetime object
|
|
return time_str_to_utc(chat_message.properties["createdDateTime"])
|
|
|
|
|
|
def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
|
|
channel_members_list: list[BasicExpertInfo] = []
|
|
members = channel.members.get().execute_query_retry()
|
|
for member in members:
|
|
channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
|
|
return channel_members_list
|
|
|
|
|
|
def _get_threads_from_channel(
|
|
channel: Channel,
|
|
start: datetime | None = None,
|
|
end: datetime | None = None,
|
|
) -> list[list[ChatMessage]]:
|
|
# Ensure start and end are timezone-aware
|
|
if start and start.tzinfo is None:
|
|
start = start.replace(tzinfo=timezone.utc)
|
|
if end and end.tzinfo is None:
|
|
end = end.replace(tzinfo=timezone.utc)
|
|
|
|
query = channel.messages.get()
|
|
base_messages: list[ChatMessage] = query.execute_query_retry()
|
|
|
|
threads: list[list[ChatMessage]] = []
|
|
for base_message in base_messages:
|
|
message_datetime = time_str_to_utc(
|
|
base_message.properties["lastModifiedDateTime"]
|
|
)
|
|
|
|
if start and message_datetime < start:
|
|
continue
|
|
if end and message_datetime > end:
|
|
continue
|
|
|
|
reply_query = base_message.replies.get_all()
|
|
replies = reply_query.execute_query_retry()
|
|
|
|
# start a list containing the base message and its replies
|
|
thread: list[ChatMessage] = [base_message]
|
|
thread.extend(replies)
|
|
|
|
threads.append(thread)
|
|
|
|
return threads
|
|
|
|
|
|
def _get_channels_from_teams(
|
|
teams: list[Team],
|
|
) -> list[Channel]:
|
|
channels_list: list[Channel] = []
|
|
for team in teams:
|
|
query = team.channels.get()
|
|
channels = query.execute_query_retry()
|
|
channels_list.extend(channels)
|
|
|
|
return channels_list
|
|
|
|
|
|
def _construct_semantic_identifier(channel: Channel, top_message: ChatMessage) -> str:
|
|
first_poster = (
|
|
top_message.properties.get("from", {})
|
|
.get("user", {})
|
|
.get("displayName", "Unknown User")
|
|
)
|
|
channel_name = channel.properties.get("displayName", "Unknown")
|
|
thread_subject = top_message.properties.get("subject", "Unknown")
|
|
|
|
snippet = parse_html_page_basic(top_message.body.content.rstrip())
|
|
snippet = snippet[:50] + "..." if len(snippet) > 50 else snippet
|
|
|
|
return f"{first_poster} in {channel_name} about {thread_subject}: {snippet}"
|
|
|
|
|
|
def _convert_thread_to_document(
|
|
channel: Channel,
|
|
thread: list[ChatMessage],
|
|
) -> Document | None:
|
|
if len(thread) == 0:
|
|
return None
|
|
|
|
most_recent_message_datetime: datetime | None = None
|
|
top_message = thread[0]
|
|
post_members_list: list[BasicExpertInfo] = []
|
|
thread_text = ""
|
|
|
|
sorted_thread = sorted(thread, key=get_created_datetime, reverse=True)
|
|
|
|
if sorted_thread:
|
|
most_recent_message = sorted_thread[0]
|
|
most_recent_message_datetime = time_str_to_utc(
|
|
most_recent_message.properties["createdDateTime"]
|
|
)
|
|
|
|
for message in thread:
|
|
# add text and a newline
|
|
if message.body.content:
|
|
message_text = parse_html_page_basic(message.body.content)
|
|
thread_text += message_text
|
|
|
|
# if it has a subject, that means its the top level post message, so grab its id, url, and subject
|
|
if message.properties["subject"]:
|
|
top_message = message
|
|
|
|
# check to make sure there is a valid display name
|
|
if message.properties["from"]:
|
|
if message.properties["from"]["user"]:
|
|
if message.properties["from"]["user"]["displayName"]:
|
|
message_sender = message.properties["from"]["user"]["displayName"]
|
|
# if its not a duplicate, add it to the list
|
|
if message_sender not in [
|
|
member.display_name for member in post_members_list
|
|
]:
|
|
post_members_list.append(
|
|
BasicExpertInfo(display_name=message_sender)
|
|
)
|
|
|
|
# if there are no found post members, grab the members from the parent channel
|
|
if not post_members_list:
|
|
post_members_list = _extract_channel_members(channel)
|
|
|
|
if not thread_text:
|
|
return None
|
|
|
|
semantic_string = _construct_semantic_identifier(channel, top_message)
|
|
|
|
post_id = top_message.properties["id"]
|
|
web_url = top_message.web_url
|
|
|
|
doc = Document(
|
|
id=post_id,
|
|
sections=[Section(link=web_url, text=thread_text)],
|
|
source=DocumentSource.TEAMS,
|
|
semantic_identifier=semantic_string,
|
|
title="", # teams threads don't really have a "title"
|
|
doc_updated_at=most_recent_message_datetime,
|
|
primary_owners=post_members_list,
|
|
metadata={},
|
|
)
|
|
return doc
|
|
|
|
|
|
class TeamsConnector(LoadConnector, PollConnector):
|
|
def __init__(
|
|
self,
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
teams: list[str] = [],
|
|
) -> None:
|
|
self.batch_size = batch_size
|
|
self.graph_client: GraphClient | None = None
|
|
self.requested_team_list: list[str] = teams
|
|
self.msal_app: msal.ConfidentialClientApplication | None = None
|
|
|
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
|
teams_client_id = credentials["teams_client_id"]
|
|
teams_client_secret = credentials["teams_client_secret"]
|
|
teams_directory_id = credentials["teams_directory_id"]
|
|
|
|
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
|
|
self.msal_app = msal.ConfidentialClientApplication(
|
|
authority=authority_url,
|
|
client_id=teams_client_id,
|
|
client_credential=teams_client_secret,
|
|
)
|
|
|
|
def _acquire_token_func() -> dict[str, Any]:
|
|
"""
|
|
Acquire token via MSAL
|
|
"""
|
|
if self.msal_app is None:
|
|
raise RuntimeError("MSAL app is not initialized")
|
|
|
|
token = self.msal_app.acquire_token_for_client(
|
|
scopes=["https://graph.microsoft.com/.default"]
|
|
)
|
|
return token
|
|
|
|
self.graph_client = GraphClient(_acquire_token_func)
|
|
return None
|
|
|
|
def _get_all_teams(self) -> list[Team]:
|
|
if self.graph_client is None:
|
|
raise ConnectorMissingCredentialError("Teams")
|
|
|
|
teams_list: list[Team] = []
|
|
|
|
teams = self.graph_client.teams.get().execute_query_retry()
|
|
|
|
if len(self.requested_team_list) > 0:
|
|
adjusted_request_strings = [
|
|
requested_team.replace(" ", "")
|
|
for requested_team in self.requested_team_list
|
|
]
|
|
teams_list = [
|
|
team
|
|
for team in teams
|
|
if team.display_name.replace(" ", "") in adjusted_request_strings
|
|
]
|
|
else:
|
|
teams_list.extend(teams)
|
|
|
|
return teams_list
|
|
|
|
def _fetch_from_teams(
|
|
self, start: datetime | None = None, end: datetime | None = None
|
|
) -> GenerateDocumentsOutput:
|
|
if self.graph_client is None:
|
|
raise ConnectorMissingCredentialError("Teams")
|
|
|
|
teams = self._get_all_teams()
|
|
logger.debug(f"Found available teams: {[str(t) for t in teams]}")
|
|
if not teams:
|
|
msg = "No teams found."
|
|
logger.error(msg)
|
|
raise ValueError(msg)
|
|
|
|
channels = _get_channels_from_teams(
|
|
teams=teams,
|
|
)
|
|
logger.debug(f"Found available channels: {[c.id for c in channels]}")
|
|
if not channels:
|
|
msg = "No channels found."
|
|
logger.error(msg)
|
|
raise ValueError(msg)
|
|
|
|
# goes over channels, converts them into Document objects and then yields them in batches
|
|
doc_batch: list[Document] = []
|
|
for channel in channels:
|
|
logger.debug(f"Fetching threads from channel: {channel.id}")
|
|
thread_list = _get_threads_from_channel(channel, start=start, end=end)
|
|
for thread in thread_list:
|
|
converted_doc = _convert_thread_to_document(channel, thread)
|
|
if converted_doc:
|
|
doc_batch.append(converted_doc)
|
|
|
|
if len(doc_batch) >= self.batch_size:
|
|
yield doc_batch
|
|
doc_batch = []
|
|
yield doc_batch
|
|
|
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
|
return self._fetch_from_teams()
|
|
|
|
def poll_source(
|
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
) -> GenerateDocumentsOutput:
|
|
start_datetime = datetime.fromtimestamp(start, timezone.utc)
|
|
end_datetime = datetime.fromtimestamp(end, timezone.utc)
|
|
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
|
|
|
|
def validate_connector_settings(self) -> None:
|
|
if self.graph_client is None:
|
|
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
|
|
|
|
try:
|
|
# Minimal call to confirm we can retrieve Teams
|
|
found_teams = self._get_all_teams()
|
|
|
|
except ClientRequestException as e:
|
|
status_code = e.response.status_code
|
|
if status_code == 401:
|
|
raise CredentialExpiredError(
|
|
"Invalid or expired Microsoft Teams credentials (401 Unauthorized)."
|
|
)
|
|
elif status_code == 403:
|
|
raise InsufficientPermissionsError(
|
|
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
|
|
)
|
|
raise UnexpectedError(f"Unexpected error retrieving teams: {e}")
|
|
|
|
except Exception as e:
|
|
error_str = str(e).lower()
|
|
if (
|
|
"unauthorized" in error_str
|
|
or "401" in error_str
|
|
or "invalid_grant" in error_str
|
|
):
|
|
raise CredentialExpiredError(
|
|
"Invalid or expired Microsoft Teams credentials."
|
|
)
|
|
elif "forbidden" in error_str or "403" in error_str:
|
|
raise InsufficientPermissionsError(
|
|
"App lacks required permissions to read from Microsoft Teams."
|
|
)
|
|
raise ConnectorValidationError(
|
|
f"Unexpected error during Teams validation: {e}"
|
|
)
|
|
|
|
if not found_teams:
|
|
raise ConnectorValidationError(
|
|
"No Teams found for the given credentials. "
|
|
"Either there are no Teams in this tenant, or your app does not have permission to view them."
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
connector = TeamsConnector(teams=os.environ["TEAMS"].split(","))
|
|
|
|
connector.load_credentials(
|
|
{
|
|
"teams_client_id": os.environ["TEAMS_CLIENT_ID"],
|
|
"teams_client_secret": os.environ["TEAMS_CLIENT_SECRET"],
|
|
"teams_directory_id": os.environ["TEAMS_CLIENT_DIRECTORY_ID"],
|
|
}
|
|
)
|
|
document_batches = connector.load_from_state()
|
|
print(next(document_batches))
|