perf: Implement checkpointing for Teams Connector. (#4601)

* Add basic foundation for teams checkpointing classes

* Fix slack connector main entrypoint

* Saving changes

* Finish teams checkpointing impl

* Remove commented out code

* Remove more unused code

* Move code around

* Add threadpool to process requests in parallel

* Fix mypy errors / warnings

* Move test import to main function only

* Address nits on PR

* Remove unnecessary check prior to entering while-loop

* Remove print statement

* Change exception message

* Address more nits

* Use indexing instead of destructuring

* Add back invocation of `run_with_timeout` instead of a direct call

* Revert slack testing code

* Move early return to before second API call

* Pull fetch to team outside of loop

* Address nits on PR

* Add back client-side filtering

* Updated connector to return after a team's indexing is finished

* Add type ignore

* Implement proper datetime range fetching

* Address comment on PR

* Rename function

* Change exception type when no team with the given id was found

* Address nit on PR

* Add comment on why `page_loaded` is needed to be specified explicitly

* Remove duplicated calls to fetching channels

* Use helper function for thread-based yielding instead of manual logic

* Move datetime filtering to message-level instead

* Address more comments on PR

* Add new utility function for yielding sections

* Add additional utility function

* Add teams tests

* Edit error message

* Address nits on PR

* Promote url-prefix to be a class level constant

* Fix mypy error

* Remove start/end parameters from function that doesn't use them anymore; move around comments

* Address more nits on PR

* Add comment
This commit is contained in:
Raunak Bhagat
2025-05-13 21:30:57 -07:00
committed by GitHub
parent 0cc0964231
commit 312e3b92bc
4 changed files with 496 additions and 259 deletions

View File

@@ -1,99 +1,239 @@
import copy
import os
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.runtime.client_request_exception import ClientRequestException # type: ignore
from office365.runtime.http.request_options import RequestOptions # type: ignore[import-untyped]
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.chats.messages.message import ChatMessage # type: ignore
from office365.teams.team import Team # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import parallel_yield
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def get_created_datetime(chat_message: ChatMessage) -> datetime:
class TeamsCheckpoint(ConnectorCheckpoint):
todo_team_ids: list[str] | None = None
class TeamsConnector(
CheckpointedConnector[TeamsCheckpoint],
):
MAX_WORKERS = 10
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
def __init__(
self,
# TODO: (chris) move from "Display Names" to IDs, since display names
# are NOT guaranteed to be unique
teams: list[str] = [],
max_workers: int = MAX_WORKERS,
) -> None:
self.graph_client: GraphClient | None = None
self.msal_app: msal.ConfidentialClientApplication | None = None
self.max_workers = max_workers
self.requested_team_list: list[str] = teams
# impls for BaseConnector
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
teams_client_id = credentials["teams_client_id"]
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
client_credential=teams_client_secret,
)
def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
if self.msal_app is None:
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
)
if not isinstance(token, dict):
raise RuntimeError("`token` instance must be of type dict")
return token
self.graph_client = GraphClient(_acquire_token_func)
return None
def validate_connector_settings(self) -> None:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
try:
# Minimal call to confirm we can retrieve Teams
# make sure it doesn't take forever, since this is a syncronous call
found_teams = run_with_timeout(
timeout=10,
func=_collect_all_team_ids,
graph_client=self.graph_client,
requested=self.requested_team_list,
)
except ClientRequestException as e:
if not e.response:
raise RuntimeError(f"No response provided in error; {e=}")
status_code = e.response.status_code
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials (401 Unauthorized)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
)
raise UnexpectedValidationError(f"Unexpected error retrieving teams: {e}")
except Exception as e:
error_str = str(e).lower()
if (
"unauthorized" in error_str
or "401" in error_str
or "invalid_grant" in error_str
):
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials."
)
elif "forbidden" in error_str or "403" in error_str:
raise InsufficientPermissionsError(
"App lacks required permissions to read from Microsoft Teams."
)
raise ConnectorValidationError(
f"Unexpected error during Teams validation: {e}"
)
if not found_teams:
raise ConnectorValidationError(
"No Teams found for the given credentials. "
"Either there are no Teams in this tenant, or your app does not have permission to view them."
)
# impls for CheckpointedConnector
def build_dummy_checkpoint(self) -> TeamsCheckpoint:
return TeamsCheckpoint(
has_more=True,
)
def validate_checkpoint_json(self, checkpoint_json: str) -> TeamsCheckpoint:
return TeamsCheckpoint.model_validate_json(checkpoint_json)
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: TeamsCheckpoint,
) -> CheckpointOutput[TeamsCheckpoint]:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams")
checkpoint = cast(TeamsCheckpoint, copy.deepcopy(checkpoint))
todos = checkpoint.todo_team_ids
if todos is None:
root_todos = _collect_all_team_ids(
graph_client=self.graph_client,
requested=self.requested_team_list,
)
return TeamsCheckpoint(
todo_team_ids=root_todos,
has_more=bool(root_todos),
)
# `todos.pop()` should always return an element. This is because if
# `todos` was the empty list, then we would have set `has_more=False`
# during the previous invocation of `TeamsConnector.load_from_checkpoint`,
# meaning that this function wouldn't have been called in the first place.
todo_team_id = todos.pop()
team = _get_team_by_id(
graph_client=self.graph_client,
team_id=todo_team_id,
)
channels = _collect_all_channels_from_team(
graph_client=self.graph_client,
team=team,
)
docs = [
_collect_document_for_channel_id(
graph_client=self.graph_client,
team=team,
channel=channel,
start=start,
end=end,
)
for channel in channels
]
for doc in parallel_yield(
gens=docs,
max_workers=self.max_workers,
):
if doc:
yield doc
logger.info(
f"Processed team with id {todo_team_id}; {len(todos)} team(s) left to process"
)
return TeamsCheckpoint(
todo_team_ids=todos,
has_more=bool(todos),
)
def _get_created_datetime(chat_message: ChatMessage) -> datetime:
# Extract the 'createdDateTime' value from the 'properties' dictionary and convert it to a datetime object
return time_str_to_utc(chat_message.properties["createdDateTime"])
def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
channel_members_list: list[BasicExpertInfo] = []
members = channel.members.get_all().execute_query_retry()
members = channel.members.get_all(
# explicitly needed because of incorrect type definitions provided by the `office365` library
page_loaded=lambda _: None
).execute_query_retry()
for member in members:
channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
return channel_members_list
def _get_threads_from_channel(
channel: Channel,
start: datetime | None = None,
end: datetime | None = None,
) -> list[list[ChatMessage]]:
# Ensure start and end are timezone-aware
if start and start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)
query = channel.messages.get_all()
base_messages: list[ChatMessage] = query.execute_query_retry()
threads: list[list[ChatMessage]] = []
for base_message in base_messages:
message_datetime = time_str_to_utc(
base_message.properties["lastModifiedDateTime"]
)
if start and message_datetime < start:
continue
if end and message_datetime > end:
continue
reply_query = base_message.replies.get_all()
replies = reply_query.execute_query_retry()
# start a list containing the base message and its replies
thread: list[ChatMessage] = [base_message]
thread.extend(replies)
threads.append(thread)
return threads
def _get_channels_from_teams(
teams: list[Team],
) -> list[Channel]:
channels_list: list[Channel] = []
for team in teams:
query = team.channels.get_all()
channels = query.execute_query_retry()
channels_list.extend(channels)
return channels_list
def _construct_semantic_identifier(channel: Channel, top_message: ChatMessage) -> str:
# NOTE: needs to be done this weird way because sometime we get back `None` for
# the fields which causes things to explode
@@ -133,7 +273,7 @@ def _convert_thread_to_document(
post_members_list: list[BasicExpertInfo] = []
thread_text = ""
sorted_thread = sorted(thread, key=get_created_datetime, reverse=True)
sorted_thread = sorted(thread, key=_get_created_datetime, reverse=True)
if sorted_thread:
most_recent_message = sorted_thread[0]
@@ -164,13 +304,13 @@ def _convert_thread_to_document(
BasicExpertInfo(display_name=message_sender)
)
if not thread_text:
return None
# if there are no found post members, grab the members from the parent channel
if not post_members_list:
post_members_list = _extract_channel_members(channel)
if not thread_text:
return None
semantic_string = _construct_semantic_identifier(channel, top_message)
post_id = top_message.properties["id"]
@@ -189,219 +329,231 @@ def _convert_thread_to_document(
return doc
class TeamsConnector(LoadConnector, PollConnector):
MAX_CHANNELS_TO_LOG = 50
def _update_request_url(request: RequestOptions, next_url: str) -> None:
request.url = next_url
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
# TODO: (chris) move from "Display Names" to IDs, since display names
# are NOT guaranteed to be unique
teams: list[str] = [],
) -> None:
self.batch_size = batch_size
self.graph_client: GraphClient | None = None
self.requested_team_list: list[str] = teams
self.msal_app: msal.ConfidentialClientApplication | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
teams_client_id = credentials["teams_client_id"]
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
def _collect_all_team_ids(
graph_client: GraphClient,
requested: list[str] | None = None,
) -> list[str]:
team_ids: list[str] = []
next_url: str | None = None
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
client_credential=teams_client_secret,
)
filter = None
if requested:
filter = " or ".join(f"displayName eq '{team_name}'" for team_name in requested)
def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
if self.msal_app is None:
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
while True:
if filter:
query = graph_client.teams.get().filter(filter)
else:
query = graph_client.teams.get_all(
# explicitly needed because of incorrect type definitions provided by the `office365` library
page_loaded=lambda _: None
)
return token
self.graph_client = GraphClient(_acquire_token_func)
if next_url:
url = next_url
query.before_execute(
lambda req: _update_request_url(request=req, next_url=url)
)
team_collection = query.execute_query()
filtered_team_ids = [
team_id
for team_id in [
_filter_team_id(team=team, requested=requested)
for team in team_collection
]
if team_id
]
team_ids.extend(filtered_team_ids)
if team_collection.has_next:
if not isinstance(team_collection._next_request_url, str):
raise ValueError(
f"The next request url field should be a string, instead got {type(team_collection._next_request_url)}"
)
next_url = team_collection._next_request_url
else:
break
return team_ids
def _filter_team_id(
team: Team,
requested: list[str] | None = None,
) -> str | None:
"""
Returns the Team ID if:
- Team is not expired / deleted
- Team has a display-name and ID
- Team display-name is in the requested teams list
Otherwise, returns `None`.
"""
if not team.id or not team.display_name:
return None
def _get_all_teams(self) -> list[Team]:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams")
if requested and team.display_name not in requested:
return None
teams: list[Team] = []
try:
# Use get_all() to handle pagination automatically
if not self.requested_team_list:
teams = self.graph_client.teams.get_all().execute_query()
else:
# Construct filter using proper Microsoft Graph API syntax
filter_conditions = " or ".join(
[
f"displayName eq '{team_name}'"
for team_name in self.requested_team_list
]
)
props = team.properties
# Initialize pagination variables
page_size = 100 # Maximum allowed by Microsoft Graph API
skip = 0
if props.get("expirationDateTime") or props.get("deletedDateTime"):
return None
while True:
# Get a page of teams with the filter
teams_page = (
self.graph_client.teams.get()
.filter(filter_conditions)
.top(page_size)
.skip(skip)
.execute_query()
)
return team.id
if not teams_page:
break
teams.extend(teams_page)
skip += page_size
def _get_team_by_id(
graph_client: GraphClient,
team_id: str,
) -> Team:
team_collection = (
graph_client.teams.get().filter(f"id eq '{team_id}'").top(1).execute_query()
)
# If we got fewer results than the page size, we've reached the end
if len(teams_page) < page_size:
break
if not team_collection:
raise ValueError(f"No team with {team_id=} was found")
elif team_collection.has_next:
# shouldn't happen, but catching it regardless
raise RuntimeError(f"Multiple teams with {team_id=} were found")
# Validate that we found all requested teams
if len(teams) != len(self.requested_team_list):
found_team_names = {
team.properties["displayName"] for team in teams
}
missing_teams = set(self.requested_team_list) - found_team_names
raise ConnectorValidationError(
f"Requested teams not found: {list(missing_teams)}"
)
except ClientRequestException as e:
if e.response.status_code == 403:
raise InsufficientPermissionsError(
"App lacks required permissions to read Teams. "
"Please ensure the app has the following permissions: "
"Team.ReadBasic.All, TeamMember.Read.All, "
"Channel.ReadBasic.All, ChannelMessage.Read.All, "
"Group.Read.All, TeamSettings.ReadWrite.All, "
"ChannelMember.Read.All, ChannelSettings.ReadWrite.All"
)
raise
return team_collection[0]
return teams
def _fetch_from_teams(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams")
def _collect_all_channels_from_team(
graph_client: GraphClient,
team: Team,
) -> list[Channel]:
if not team.id:
raise RuntimeError(f"The {team=} has an empty `id` field")
teams = self._get_all_teams()
logger.debug(f"Found available teams: {[str(t) for t in teams]}")
if not teams:
msg = "No teams found."
logger.error(msg)
raise ValueError(msg)
channels: list[Channel] = []
next_url = None
channels = _get_channels_from_teams(
teams=teams,
while True:
query = team.channels.get_all(
# explicitly needed because of incorrect type definitions provided by the `office365` library
page_loaded=lambda _: None
)
logger.debug(
f"Found available channels (max {TeamsConnector.MAX_CHANNELS_TO_LOG} shown): "
f"{[c.id for c in channels[:TeamsConnector.MAX_CHANNELS_TO_LOG]]}"
)
if not channels:
msg = "No channels found."
logger.error(msg)
raise ValueError(msg)
# goes over channels, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
for channel in channels:
logger.info(f"Fetching threads from channel: {channel.id}")
thread_list = _get_threads_from_channel(channel, start=start, end=end)
for thread in thread_list:
converted_doc = _convert_thread_to_document(channel, thread)
if converted_doc:
doc_batch.append(converted_doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_teams()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(start, timezone.utc)
end_datetime = datetime.fromtimestamp(end, timezone.utc)
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
def validate_connector_settings(self) -> None:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
try:
# Minimal call to confirm we can retrieve Teams
# make sure it doesn't take forever, since this is a syncronous call
found_teams = run_with_timeout(10, self._get_all_teams)
except ClientRequestException as e:
status_code = e.response.status_code
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials (401 Unauthorized)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
)
raise UnexpectedValidationError(f"Unexpected error retrieving teams: {e}")
except Exception as e:
error_str = str(e).lower()
if (
"unauthorized" in error_str
or "401" in error_str
or "invalid_grant" in error_str
):
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials."
)
elif "forbidden" in error_str or "403" in error_str:
raise InsufficientPermissionsError(
"App lacks required permissions to read from Microsoft Teams."
)
raise ConnectorValidationError(
f"Unexpected error during Teams validation: {e}"
if next_url:
url = next_url
query = query.before_execute(
lambda req: _update_request_url(request=req, next_url=url)
)
if not found_teams:
raise ConnectorValidationError(
"No Teams found for the given credentials. "
"Either there are no Teams in this tenant, or your app does not have permission to view them."
)
channel_collection = query.execute_query()
channels.extend(channel for channel in channel_collection if channel.id)
if not channel_collection.has_next:
break
return channels
def _collect_document_for_channel_id(
graph_client: GraphClient,
team: Team,
channel: Channel,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> Iterator[Document | None]:
"""
This function yields just one singular `Document`.
The reason why this function returns an instance of `Iterator` is because
that is what `parallel_yield` expects. We want this to be lazily evaluated.
"""
# Server-side filter conditions are not supported on the chat-messages API.
# Therefore, we have to do this client-side, which is quite a bit more inefficient.
#
# Not allowed:
# message_collection = channel.messages.get().filter(f"createdDateTime gt {start}").execute_query()
message_collection = channel.messages.get_all(
# explicitly needed because of incorrect type definitions provided by the `office365` library
page_loaded=lambda _: None
).execute_query()
thread = [
message
for message in message_collection
if _filter_message(message=message, start=start, end=end)
]
yield _convert_thread_to_document(
channel=channel,
thread=thread,
)
def _filter_message(
message: ChatMessage,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> bool:
"""
Returns `True` if the given message was created / modified within the start-to-end datetime range.
Returns `False` otherwise.
"""
props = message.properties
if props.get("deletedDateTime"):
return False
def compare(dt_str: str) -> bool:
dt_ts = datetime.fromisoformat(dt_str).replace(tzinfo=timezone.utc).timestamp()
return start <= dt_ts and dt_ts < end
if modified_at := props.get("lastModifiedDateTime"):
if isinstance(modified_at, str):
return compare(modified_at)
if created_at := props.get("createdDateTime"):
if isinstance(created_at, str):
return compare(created_at)
logger.warn(
"No `lastModifiedDateTime` or `createdDateTime` fields found in `message.properties`"
)
return False
if __name__ == "__main__":
connector = TeamsConnector(teams=os.environ["TEAMS"].split(","))
from tests.daily.connectors.utils import load_everything_from_checkpoint_connector
app_id = os.environ["TEAMS_APPLICATION_ID"]
dir_id = os.environ["TEAMS_DIRECTORY_ID"]
secret = os.environ["TEAMS_SECRET"]
teams_env_var = os.environ.get("TEAMS", None)
teams = teams_env_var.split(",") if teams_env_var else []
connector = TeamsConnector(teams=teams)
connector.load_credentials(
{
"teams_client_id": os.environ["TEAMS_CLIENT_ID"],
"teams_client_secret": os.environ["TEAMS_CLIENT_SECRET"],
"teams_directory_id": os.environ["TEAMS_CLIENT_DIRECTORY_ID"],
"teams_client_id": app_id,
"teams_directory_id": dir_id,
"teams_client_secret": secret,
}
)
document_batches = connector.load_from_state()
print(next(document_batches))
connector.validate_connector_settings()
print(
load_everything_from_checkpoint_connector(
connector=connector,
start=0.0,
end=datetime.now(tz=timezone.utc).timestamp(),
)
)

View File

@@ -8,11 +8,11 @@ from pytest import FixtureRequest
from slack_sdk import WebClient
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.slack.connector import SlackConnector
from shared_configs.contextvars import get_current_tenant_id
from tests.daily.connectors.utils import load_everything_from_checkpoint_connector
from tests.daily.connectors.utils import to_sections
from tests.daily.connectors.utils import to_text_sections
@pytest.fixture
@@ -106,18 +106,7 @@ def test_indexing_channels_with_message_count(
end=time.time(),
)
messages: list[str] = []
for doc_or_error in docs:
if not isinstance(doc_or_error, Document):
raise RuntimeError(doc_or_error)
messages.extend(
section.text
for section in doc_or_error.sections
if isinstance(section, TextSection)
)
actual_messages = set(messages)
actual_messages = set(to_text_sections(to_sections(iter(docs))))
assert expected_messages == actual_messages

View File

@@ -0,0 +1,75 @@
import os
import time
import pytest
from onyx.connectors.teams.connector import TeamsConnector
from tests.daily.connectors.utils import load_everything_from_checkpoint_connector
from tests.daily.connectors.utils import to_sections
from tests.daily.connectors.utils import to_text_sections
@pytest.fixture
def teams_credentials() -> dict[str, str]:
app_id = os.environ["TEAMS_APPLICATION_ID"]
dir_id = os.environ["TEAMS_DIRECTORY_ID"]
secret = os.environ["TEAMS_SECRET"]
return {
"teams_client_id": app_id,
"teams_directory_id": dir_id,
"teams_client_secret": secret,
}
@pytest.fixture
def teams_connector(
request: pytest.FixtureRequest,
teams_credentials: dict[str, str],
) -> TeamsConnector:
teams: list[str] | None = None
if hasattr(request, "param"):
teams = request.param
if teams is None:
...
elif isinstance(teams, list):
for name in teams:
assert isinstance(name, str)
else:
raise ValueError(
f"`request.param` must either be `None` or of type `list[str]`; instead got {type(teams)}"
)
teams_connector = TeamsConnector(teams=teams or [])
teams_connector.load_credentials(teams_credentials)
return teams_connector
@pytest.mark.parametrize(
"teams_connector,expected_messages",
[
[["Onyx-Testing"], set(["This is the first message in Onyx-Testing ..."])],
[
["Onyx"],
set(
[
"Hello, world!",
"My favorite color is red.\n\xa0\nPablos favorite color is blue",
"but not leastyeah!",
]
),
],
],
indirect=["teams_connector"],
)
def test_teams_connector(
teams_connector: TeamsConnector,
expected_messages: set[str],
) -> None:
docs = load_everything_from_checkpoint_connector(
connector=teams_connector,
start=0.0,
end=time.time(),
)
actual_messages = set(to_text_sections(to_sections(iter(docs))))
assert actual_messages == expected_messages

View File

@@ -1,3 +1,4 @@
from collections.abc import Iterator
from typing import cast
from typing import TypeVar
@@ -7,6 +8,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
_ITERATION_LIMIT = 100_000
@@ -68,3 +71,21 @@ def load_everything_from_checkpoint_connector(
raise RuntimeError("Too many iterations. Infinite loop?")
return outputs
def to_sections(
iterator: Iterator[Document | ConnectorFailure],
) -> Iterator[TextSection | ImageSection]:
for doc in iterator:
if not isinstance(doc, Document):
failure = doc
raise RuntimeError(failure)
for section in doc.sections:
yield section
def to_text_sections(iterator: Iterator[TextSection | ImageSection]) -> Iterator[str]:
for section in iterator:
if isinstance(section, TextSection):
yield section.text