mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-04 12:58:42 +02:00
perf: Implement checkpointing for Teams Connector. (#4601)
* Add basic foundation for teams checkpointing classes * Fix slack connector main entrypoint * Saving changes * Finish teams checkpointing impl * Remove commented out code * Remove more unused code * Move code around * Add threadpool to process requests in parallel * Fix mypy errors / warnings * Move test import to main function only * Address nits on PR * Remove unnecessary check prior to entering while-loop * Remove print statement * Change exception message * Address more nits * Use indexing instead of destructuring * Add back invocation of `run_with_timeout` instead of a direct call * Revert slack testing code * Move early return to before second API call * Pull fetch to team outside of loop * Address nits on PR * Add back client-side filtering * Updated connector to return after a team's indexing is finished * Add type ignore * Implement proper datetime range fetching * Address comment on PR * Rename function * Change exception type when no team with the given id was found * Address nit on PR * Add comment on why `page_loaded` is needed to be specified explicitly * Remove duplicated calls to fetching channels * Use helper function for thread-based yielding instead of manual logic * Move datetime filtering to message-level instead * Address more comments on PR * Add new utility function for yielding sections * Add additional utility function * Add teams tests * Edit error message * Address nits on PR * Promote url-prefix to be a class level constant * Fix mypy error * Remove start/end parameters from function that doesn't use them anymore; move around comments * Address more nits on PR * Add comment
This commit is contained in:
@@ -1,99 +1,239 @@
|
||||
import copy
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import msal # type: ignore
|
||||
from office365.graph_client import GraphClient # type: ignore
|
||||
from office365.runtime.client_request_exception import ClientRequestException # type: ignore
|
||||
from office365.runtime.http.request_options import RequestOptions # type: ignore[import-untyped]
|
||||
from office365.teams.channels.channel import Channel # type: ignore
|
||||
from office365.teams.chats.messages.message import ChatMessage # type: ignore
|
||||
from office365.teams.team import Team # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_created_datetime(chat_message: ChatMessage) -> datetime:
|
||||
class TeamsCheckpoint(ConnectorCheckpoint):
|
||||
todo_team_ids: list[str] | None = None
|
||||
|
||||
|
||||
class TeamsConnector(
|
||||
CheckpointedConnector[TeamsCheckpoint],
|
||||
):
|
||||
MAX_WORKERS = 10
|
||||
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# TODO: (chris) move from "Display Names" to IDs, since display names
|
||||
# are NOT guaranteed to be unique
|
||||
teams: list[str] = [],
|
||||
max_workers: int = MAX_WORKERS,
|
||||
) -> None:
|
||||
self.graph_client: GraphClient | None = None
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
self.max_workers = max_workers
|
||||
self.requested_team_list: list[str] = teams
|
||||
|
||||
# impls for BaseConnector
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
teams_client_id = credentials["teams_client_id"]
|
||||
teams_client_secret = credentials["teams_client_secret"]
|
||||
teams_directory_id = credentials["teams_directory_id"]
|
||||
|
||||
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=teams_client_id,
|
||||
client_credential=teams_client_secret,
|
||||
)
|
||||
|
||||
def _acquire_token_func() -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
"""
|
||||
if self.msal_app is None:
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
)
|
||||
|
||||
if not isinstance(token, dict):
|
||||
raise RuntimeError("`token` instance must be of type dict")
|
||||
|
||||
return token
|
||||
|
||||
self.graph_client = GraphClient(_acquire_token_func)
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
|
||||
|
||||
try:
|
||||
# Minimal call to confirm we can retrieve Teams
|
||||
# make sure it doesn't take forever, since this is a syncronous call
|
||||
found_teams = run_with_timeout(
|
||||
timeout=10,
|
||||
func=_collect_all_team_ids,
|
||||
graph_client=self.graph_client,
|
||||
requested=self.requested_team_list,
|
||||
)
|
||||
|
||||
except ClientRequestException as e:
|
||||
if not e.response:
|
||||
raise RuntimeError(f"No response provided in error; {e=}")
|
||||
status_code = e.response.status_code
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Invalid or expired Microsoft Teams credentials (401 Unauthorized)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
|
||||
)
|
||||
raise UnexpectedValidationError(f"Unexpected error retrieving teams: {e}")
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"unauthorized" in error_str
|
||||
or "401" in error_str
|
||||
or "invalid_grant" in error_str
|
||||
):
|
||||
raise CredentialExpiredError(
|
||||
"Invalid or expired Microsoft Teams credentials."
|
||||
)
|
||||
elif "forbidden" in error_str or "403" in error_str:
|
||||
raise InsufficientPermissionsError(
|
||||
"App lacks required permissions to read from Microsoft Teams."
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected error during Teams validation: {e}"
|
||||
)
|
||||
|
||||
if not found_teams:
|
||||
raise ConnectorValidationError(
|
||||
"No Teams found for the given credentials. "
|
||||
"Either there are no Teams in this tenant, or your app does not have permission to view them."
|
||||
)
|
||||
|
||||
# impls for CheckpointedConnector
|
||||
|
||||
def build_dummy_checkpoint(self) -> TeamsCheckpoint:
|
||||
return TeamsCheckpoint(
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> TeamsCheckpoint:
|
||||
return TeamsCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: TeamsCheckpoint,
|
||||
) -> CheckpointOutput[TeamsCheckpoint]:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Teams")
|
||||
|
||||
checkpoint = cast(TeamsCheckpoint, copy.deepcopy(checkpoint))
|
||||
|
||||
todos = checkpoint.todo_team_ids
|
||||
|
||||
if todos is None:
|
||||
root_todos = _collect_all_team_ids(
|
||||
graph_client=self.graph_client,
|
||||
requested=self.requested_team_list,
|
||||
)
|
||||
return TeamsCheckpoint(
|
||||
todo_team_ids=root_todos,
|
||||
has_more=bool(root_todos),
|
||||
)
|
||||
|
||||
# `todos.pop()` should always return an element. This is because if
|
||||
# `todos` was the empty list, then we would have set `has_more=False`
|
||||
# during the previous invocation of `TeamsConnector.load_from_checkpoint`,
|
||||
# meaning that this function wouldn't have been called in the first place.
|
||||
todo_team_id = todos.pop()
|
||||
team = _get_team_by_id(
|
||||
graph_client=self.graph_client,
|
||||
team_id=todo_team_id,
|
||||
)
|
||||
channels = _collect_all_channels_from_team(
|
||||
graph_client=self.graph_client,
|
||||
team=team,
|
||||
)
|
||||
|
||||
docs = [
|
||||
_collect_document_for_channel_id(
|
||||
graph_client=self.graph_client,
|
||||
team=team,
|
||||
channel=channel,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
for channel in channels
|
||||
]
|
||||
|
||||
for doc in parallel_yield(
|
||||
gens=docs,
|
||||
max_workers=self.max_workers,
|
||||
):
|
||||
if doc:
|
||||
yield doc
|
||||
|
||||
logger.info(
|
||||
f"Processed team with id {todo_team_id}; {len(todos)} team(s) left to process"
|
||||
)
|
||||
|
||||
return TeamsCheckpoint(
|
||||
todo_team_ids=todos,
|
||||
has_more=bool(todos),
|
||||
)
|
||||
|
||||
|
||||
def _get_created_datetime(chat_message: ChatMessage) -> datetime:
|
||||
# Extract the 'createdDateTime' value from the 'properties' dictionary and convert it to a datetime object
|
||||
return time_str_to_utc(chat_message.properties["createdDateTime"])
|
||||
|
||||
|
||||
def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
|
||||
channel_members_list: list[BasicExpertInfo] = []
|
||||
members = channel.members.get_all().execute_query_retry()
|
||||
members = channel.members.get_all(
|
||||
# explicitly needed because of incorrect type definitions provided by the `office365` library
|
||||
page_loaded=lambda _: None
|
||||
).execute_query_retry()
|
||||
for member in members:
|
||||
channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
|
||||
return channel_members_list
|
||||
|
||||
|
||||
def _get_threads_from_channel(
|
||||
channel: Channel,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> list[list[ChatMessage]]:
|
||||
# Ensure start and end are timezone-aware
|
||||
if start and start.tzinfo is None:
|
||||
start = start.replace(tzinfo=timezone.utc)
|
||||
if end and end.tzinfo is None:
|
||||
end = end.replace(tzinfo=timezone.utc)
|
||||
|
||||
query = channel.messages.get_all()
|
||||
base_messages: list[ChatMessage] = query.execute_query_retry()
|
||||
|
||||
threads: list[list[ChatMessage]] = []
|
||||
for base_message in base_messages:
|
||||
message_datetime = time_str_to_utc(
|
||||
base_message.properties["lastModifiedDateTime"]
|
||||
)
|
||||
|
||||
if start and message_datetime < start:
|
||||
continue
|
||||
if end and message_datetime > end:
|
||||
continue
|
||||
|
||||
reply_query = base_message.replies.get_all()
|
||||
replies = reply_query.execute_query_retry()
|
||||
|
||||
# start a list containing the base message and its replies
|
||||
thread: list[ChatMessage] = [base_message]
|
||||
thread.extend(replies)
|
||||
|
||||
threads.append(thread)
|
||||
|
||||
return threads
|
||||
|
||||
|
||||
def _get_channels_from_teams(
|
||||
teams: list[Team],
|
||||
) -> list[Channel]:
|
||||
channels_list: list[Channel] = []
|
||||
for team in teams:
|
||||
query = team.channels.get_all()
|
||||
channels = query.execute_query_retry()
|
||||
channels_list.extend(channels)
|
||||
|
||||
return channels_list
|
||||
|
||||
|
||||
def _construct_semantic_identifier(channel: Channel, top_message: ChatMessage) -> str:
|
||||
# NOTE: needs to be done this weird way because sometime we get back `None` for
|
||||
# the fields which causes things to explode
|
||||
@@ -133,7 +273,7 @@ def _convert_thread_to_document(
|
||||
post_members_list: list[BasicExpertInfo] = []
|
||||
thread_text = ""
|
||||
|
||||
sorted_thread = sorted(thread, key=get_created_datetime, reverse=True)
|
||||
sorted_thread = sorted(thread, key=_get_created_datetime, reverse=True)
|
||||
|
||||
if sorted_thread:
|
||||
most_recent_message = sorted_thread[0]
|
||||
@@ -164,13 +304,13 @@ def _convert_thread_to_document(
|
||||
BasicExpertInfo(display_name=message_sender)
|
||||
)
|
||||
|
||||
if not thread_text:
|
||||
return None
|
||||
|
||||
# if there are no found post members, grab the members from the parent channel
|
||||
if not post_members_list:
|
||||
post_members_list = _extract_channel_members(channel)
|
||||
|
||||
if not thread_text:
|
||||
return None
|
||||
|
||||
semantic_string = _construct_semantic_identifier(channel, top_message)
|
||||
|
||||
post_id = top_message.properties["id"]
|
||||
@@ -189,219 +329,231 @@ def _convert_thread_to_document(
|
||||
return doc
|
||||
|
||||
|
||||
class TeamsConnector(LoadConnector, PollConnector):
|
||||
MAX_CHANNELS_TO_LOG = 50
|
||||
def _update_request_url(request: RequestOptions, next_url: str) -> None:
|
||||
request.url = next_url
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
# TODO: (chris) move from "Display Names" to IDs, since display names
|
||||
# are NOT guaranteed to be unique
|
||||
teams: list[str] = [],
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.graph_client: GraphClient | None = None
|
||||
self.requested_team_list: list[str] = teams
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
teams_client_id = credentials["teams_client_id"]
|
||||
teams_client_secret = credentials["teams_client_secret"]
|
||||
teams_directory_id = credentials["teams_directory_id"]
|
||||
def _collect_all_team_ids(
|
||||
graph_client: GraphClient,
|
||||
requested: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
team_ids: list[str] = []
|
||||
next_url: str | None = None
|
||||
|
||||
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=teams_client_id,
|
||||
client_credential=teams_client_secret,
|
||||
)
|
||||
filter = None
|
||||
if requested:
|
||||
filter = " or ".join(f"displayName eq '{team_name}'" for team_name in requested)
|
||||
|
||||
def _acquire_token_func() -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
"""
|
||||
if self.msal_app is None:
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
while True:
|
||||
if filter:
|
||||
query = graph_client.teams.get().filter(filter)
|
||||
else:
|
||||
query = graph_client.teams.get_all(
|
||||
# explicitly needed because of incorrect type definitions provided by the `office365` library
|
||||
page_loaded=lambda _: None
|
||||
)
|
||||
return token
|
||||
|
||||
self.graph_client = GraphClient(_acquire_token_func)
|
||||
if next_url:
|
||||
url = next_url
|
||||
query.before_execute(
|
||||
lambda req: _update_request_url(request=req, next_url=url)
|
||||
)
|
||||
|
||||
team_collection = query.execute_query()
|
||||
|
||||
filtered_team_ids = [
|
||||
team_id
|
||||
for team_id in [
|
||||
_filter_team_id(team=team, requested=requested)
|
||||
for team in team_collection
|
||||
]
|
||||
if team_id
|
||||
]
|
||||
|
||||
team_ids.extend(filtered_team_ids)
|
||||
|
||||
if team_collection.has_next:
|
||||
if not isinstance(team_collection._next_request_url, str):
|
||||
raise ValueError(
|
||||
f"The next request url field should be a string, instead got {type(team_collection._next_request_url)}"
|
||||
)
|
||||
next_url = team_collection._next_request_url
|
||||
else:
|
||||
break
|
||||
|
||||
return team_ids
|
||||
|
||||
|
||||
def _filter_team_id(
|
||||
team: Team,
|
||||
requested: list[str] | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Returns the Team ID if:
|
||||
- Team is not expired / deleted
|
||||
- Team has a display-name and ID
|
||||
- Team display-name is in the requested teams list
|
||||
|
||||
Otherwise, returns `None`.
|
||||
"""
|
||||
|
||||
if not team.id or not team.display_name:
|
||||
return None
|
||||
|
||||
def _get_all_teams(self) -> list[Team]:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Teams")
|
||||
if requested and team.display_name not in requested:
|
||||
return None
|
||||
|
||||
teams: list[Team] = []
|
||||
try:
|
||||
# Use get_all() to handle pagination automatically
|
||||
if not self.requested_team_list:
|
||||
teams = self.graph_client.teams.get_all().execute_query()
|
||||
else:
|
||||
# Construct filter using proper Microsoft Graph API syntax
|
||||
filter_conditions = " or ".join(
|
||||
[
|
||||
f"displayName eq '{team_name}'"
|
||||
for team_name in self.requested_team_list
|
||||
]
|
||||
)
|
||||
props = team.properties
|
||||
|
||||
# Initialize pagination variables
|
||||
page_size = 100 # Maximum allowed by Microsoft Graph API
|
||||
skip = 0
|
||||
if props.get("expirationDateTime") or props.get("deletedDateTime"):
|
||||
return None
|
||||
|
||||
while True:
|
||||
# Get a page of teams with the filter
|
||||
teams_page = (
|
||||
self.graph_client.teams.get()
|
||||
.filter(filter_conditions)
|
||||
.top(page_size)
|
||||
.skip(skip)
|
||||
.execute_query()
|
||||
)
|
||||
return team.id
|
||||
|
||||
if not teams_page:
|
||||
break
|
||||
|
||||
teams.extend(teams_page)
|
||||
skip += page_size
|
||||
def _get_team_by_id(
|
||||
graph_client: GraphClient,
|
||||
team_id: str,
|
||||
) -> Team:
|
||||
team_collection = (
|
||||
graph_client.teams.get().filter(f"id eq '{team_id}'").top(1).execute_query()
|
||||
)
|
||||
|
||||
# If we got fewer results than the page size, we've reached the end
|
||||
if len(teams_page) < page_size:
|
||||
break
|
||||
if not team_collection:
|
||||
raise ValueError(f"No team with {team_id=} was found")
|
||||
elif team_collection.has_next:
|
||||
# shouldn't happen, but catching it regardless
|
||||
raise RuntimeError(f"Multiple teams with {team_id=} were found")
|
||||
|
||||
# Validate that we found all requested teams
|
||||
if len(teams) != len(self.requested_team_list):
|
||||
found_team_names = {
|
||||
team.properties["displayName"] for team in teams
|
||||
}
|
||||
missing_teams = set(self.requested_team_list) - found_team_names
|
||||
raise ConnectorValidationError(
|
||||
f"Requested teams not found: {list(missing_teams)}"
|
||||
)
|
||||
except ClientRequestException as e:
|
||||
if e.response.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"App lacks required permissions to read Teams. "
|
||||
"Please ensure the app has the following permissions: "
|
||||
"Team.ReadBasic.All, TeamMember.Read.All, "
|
||||
"Channel.ReadBasic.All, ChannelMessage.Read.All, "
|
||||
"Group.Read.All, TeamSettings.ReadWrite.All, "
|
||||
"ChannelMember.Read.All, ChannelSettings.ReadWrite.All"
|
||||
)
|
||||
raise
|
||||
return team_collection[0]
|
||||
|
||||
return teams
|
||||
|
||||
def _fetch_from_teams(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Teams")
|
||||
def _collect_all_channels_from_team(
|
||||
graph_client: GraphClient,
|
||||
team: Team,
|
||||
) -> list[Channel]:
|
||||
if not team.id:
|
||||
raise RuntimeError(f"The {team=} has an empty `id` field")
|
||||
|
||||
teams = self._get_all_teams()
|
||||
logger.debug(f"Found available teams: {[str(t) for t in teams]}")
|
||||
if not teams:
|
||||
msg = "No teams found."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
channels: list[Channel] = []
|
||||
next_url = None
|
||||
|
||||
channels = _get_channels_from_teams(
|
||||
teams=teams,
|
||||
while True:
|
||||
query = team.channels.get_all(
|
||||
# explicitly needed because of incorrect type definitions provided by the `office365` library
|
||||
page_loaded=lambda _: None
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Found available channels (max {TeamsConnector.MAX_CHANNELS_TO_LOG} shown): "
|
||||
f"{[c.id for c in channels[:TeamsConnector.MAX_CHANNELS_TO_LOG]]}"
|
||||
)
|
||||
if not channels:
|
||||
msg = "No channels found."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
# goes over channels, converts them into Document objects and then yields them in batches
|
||||
doc_batch: list[Document] = []
|
||||
for channel in channels:
|
||||
logger.info(f"Fetching threads from channel: {channel.id}")
|
||||
thread_list = _get_threads_from_channel(channel, start=start, end=end)
|
||||
for thread in thread_list:
|
||||
converted_doc = _convert_thread_to_document(channel, thread)
|
||||
if converted_doc:
|
||||
doc_batch.append(converted_doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_teams()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(start, timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, timezone.utc)
|
||||
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
|
||||
|
||||
try:
|
||||
# Minimal call to confirm we can retrieve Teams
|
||||
# make sure it doesn't take forever, since this is a syncronous call
|
||||
found_teams = run_with_timeout(10, self._get_all_teams)
|
||||
|
||||
except ClientRequestException as e:
|
||||
status_code = e.response.status_code
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Invalid or expired Microsoft Teams credentials (401 Unauthorized)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
|
||||
)
|
||||
raise UnexpectedValidationError(f"Unexpected error retrieving teams: {e}")
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"unauthorized" in error_str
|
||||
or "401" in error_str
|
||||
or "invalid_grant" in error_str
|
||||
):
|
||||
raise CredentialExpiredError(
|
||||
"Invalid or expired Microsoft Teams credentials."
|
||||
)
|
||||
elif "forbidden" in error_str or "403" in error_str:
|
||||
raise InsufficientPermissionsError(
|
||||
"App lacks required permissions to read from Microsoft Teams."
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected error during Teams validation: {e}"
|
||||
if next_url:
|
||||
url = next_url
|
||||
query = query.before_execute(
|
||||
lambda req: _update_request_url(request=req, next_url=url)
|
||||
)
|
||||
|
||||
if not found_teams:
|
||||
raise ConnectorValidationError(
|
||||
"No Teams found for the given credentials. "
|
||||
"Either there are no Teams in this tenant, or your app does not have permission to view them."
|
||||
)
|
||||
channel_collection = query.execute_query()
|
||||
channels.extend(channel for channel in channel_collection if channel.id)
|
||||
|
||||
if not channel_collection.has_next:
|
||||
break
|
||||
|
||||
return channels
|
||||
|
||||
|
||||
def _collect_document_for_channel_id(
|
||||
graph_client: GraphClient,
|
||||
team: Team,
|
||||
channel: Channel,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> Iterator[Document | None]:
|
||||
"""
|
||||
This function yields just one singular `Document`.
|
||||
|
||||
The reason why this function returns an instance of `Iterator` is because
|
||||
that is what `parallel_yield` expects. We want this to be lazily evaluated.
|
||||
"""
|
||||
|
||||
# Server-side filter conditions are not supported on the chat-messages API.
|
||||
# Therefore, we have to do this client-side, which is quite a bit more inefficient.
|
||||
#
|
||||
# Not allowed:
|
||||
# message_collection = channel.messages.get().filter(f"createdDateTime gt {start}").execute_query()
|
||||
|
||||
message_collection = channel.messages.get_all(
|
||||
# explicitly needed because of incorrect type definitions provided by the `office365` library
|
||||
page_loaded=lambda _: None
|
||||
).execute_query()
|
||||
|
||||
thread = [
|
||||
message
|
||||
for message in message_collection
|
||||
if _filter_message(message=message, start=start, end=end)
|
||||
]
|
||||
|
||||
yield _convert_thread_to_document(
|
||||
channel=channel,
|
||||
thread=thread,
|
||||
)
|
||||
|
||||
|
||||
def _filter_message(
|
||||
message: ChatMessage,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns `True` if the given message was created / modified within the start-to-end datetime range.
|
||||
Returns `False` otherwise.
|
||||
"""
|
||||
|
||||
props = message.properties
|
||||
|
||||
if props.get("deletedDateTime"):
|
||||
return False
|
||||
|
||||
def compare(dt_str: str) -> bool:
|
||||
dt_ts = datetime.fromisoformat(dt_str).replace(tzinfo=timezone.utc).timestamp()
|
||||
return start <= dt_ts and dt_ts < end
|
||||
|
||||
if modified_at := props.get("lastModifiedDateTime"):
|
||||
if isinstance(modified_at, str):
|
||||
return compare(modified_at)
|
||||
|
||||
if created_at := props.get("createdDateTime"):
|
||||
if isinstance(created_at, str):
|
||||
return compare(created_at)
|
||||
|
||||
logger.warn(
|
||||
"No `lastModifiedDateTime` or `createdDateTime` fields found in `message.properties`"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = TeamsConnector(teams=os.environ["TEAMS"].split(","))
|
||||
from tests.daily.connectors.utils import load_everything_from_checkpoint_connector
|
||||
|
||||
app_id = os.environ["TEAMS_APPLICATION_ID"]
|
||||
dir_id = os.environ["TEAMS_DIRECTORY_ID"]
|
||||
secret = os.environ["TEAMS_SECRET"]
|
||||
|
||||
teams_env_var = os.environ.get("TEAMS", None)
|
||||
teams = teams_env_var.split(",") if teams_env_var else []
|
||||
connector = TeamsConnector(teams=teams)
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"teams_client_id": os.environ["TEAMS_CLIENT_ID"],
|
||||
"teams_client_secret": os.environ["TEAMS_CLIENT_SECRET"],
|
||||
"teams_directory_id": os.environ["TEAMS_CLIENT_DIRECTORY_ID"],
|
||||
"teams_client_id": app_id,
|
||||
"teams_directory_id": dir_id,
|
||||
"teams_client_secret": secret,
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
||||
connector.validate_connector_settings()
|
||||
|
||||
print(
|
||||
load_everything_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0.0,
|
||||
end=datetime.now(tz=timezone.utc).timestamp(),
|
||||
)
|
||||
)
|
||||
|
@@ -8,11 +8,11 @@ from pytest import FixtureRequest
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from tests.daily.connectors.utils import load_everything_from_checkpoint_connector
|
||||
from tests.daily.connectors.utils import to_sections
|
||||
from tests.daily.connectors.utils import to_text_sections
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -106,18 +106,7 @@ def test_indexing_channels_with_message_count(
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
messages: list[str] = []
|
||||
|
||||
for doc_or_error in docs:
|
||||
if not isinstance(doc_or_error, Document):
|
||||
raise RuntimeError(doc_or_error)
|
||||
messages.extend(
|
||||
section.text
|
||||
for section in doc_or_error.sections
|
||||
if isinstance(section, TextSection)
|
||||
)
|
||||
|
||||
actual_messages = set(messages)
|
||||
actual_messages = set(to_text_sections(to_sections(iter(docs))))
|
||||
assert expected_messages == actual_messages
|
||||
|
||||
|
||||
|
75
backend/tests/daily/connectors/teams/test_teams_connector.py
Normal file
75
backend/tests/daily/connectors/teams/test_teams_connector.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from tests.daily.connectors.utils import load_everything_from_checkpoint_connector
|
||||
from tests.daily.connectors.utils import to_sections
|
||||
from tests.daily.connectors.utils import to_text_sections
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def teams_credentials() -> dict[str, str]:
|
||||
app_id = os.environ["TEAMS_APPLICATION_ID"]
|
||||
dir_id = os.environ["TEAMS_DIRECTORY_ID"]
|
||||
secret = os.environ["TEAMS_SECRET"]
|
||||
|
||||
return {
|
||||
"teams_client_id": app_id,
|
||||
"teams_directory_id": dir_id,
|
||||
"teams_client_secret": secret,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def teams_connector(
|
||||
request: pytest.FixtureRequest,
|
||||
teams_credentials: dict[str, str],
|
||||
) -> TeamsConnector:
|
||||
teams: list[str] | None = None
|
||||
if hasattr(request, "param"):
|
||||
teams = request.param
|
||||
if teams is None:
|
||||
...
|
||||
elif isinstance(teams, list):
|
||||
for name in teams:
|
||||
assert isinstance(name, str)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`request.param` must either be `None` or of type `list[str]`; instead got {type(teams)}"
|
||||
)
|
||||
|
||||
teams_connector = TeamsConnector(teams=teams or [])
|
||||
teams_connector.load_credentials(teams_credentials)
|
||||
return teams_connector
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"teams_connector,expected_messages",
|
||||
[
|
||||
[["Onyx-Testing"], set(["This is the first message in Onyx-Testing ..."])],
|
||||
[
|
||||
["Onyx"],
|
||||
set(
|
||||
[
|
||||
"Hello, world!",
|
||||
"My favorite color is red.\n\xa0\nPablos favorite color is blue",
|
||||
"but not leastyeah!",
|
||||
]
|
||||
),
|
||||
],
|
||||
],
|
||||
indirect=["teams_connector"],
|
||||
)
|
||||
def test_teams_connector(
|
||||
teams_connector: TeamsConnector,
|
||||
expected_messages: set[str],
|
||||
) -> None:
|
||||
docs = load_everything_from_checkpoint_connector(
|
||||
connector=teams_connector,
|
||||
start=0.0,
|
||||
end=time.time(),
|
||||
)
|
||||
actual_messages = set(to_text_sections(to_sections(iter(docs))))
|
||||
assert actual_messages == expected_messages
|
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -7,6 +8,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
|
||||
_ITERATION_LIMIT = 100_000
|
||||
|
||||
@@ -68,3 +71,21 @@ def load_everything_from_checkpoint_connector(
|
||||
raise RuntimeError("Too many iterations. Infinite loop?")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def to_sections(
|
||||
iterator: Iterator[Document | ConnectorFailure],
|
||||
) -> Iterator[TextSection | ImageSection]:
|
||||
for doc in iterator:
|
||||
if not isinstance(doc, Document):
|
||||
failure = doc
|
||||
raise RuntimeError(failure)
|
||||
|
||||
for section in doc.sections:
|
||||
yield section
|
||||
|
||||
|
||||
def to_text_sections(iterator: Iterator[TextSection | ImageSection]) -> Iterator[str]:
|
||||
for section in iterator:
|
||||
if isinstance(section, TextSection):
|
||||
yield section.text
|
||||
|
Reference in New Issue
Block a user