* Fix teams

* Use get_all

* Add comment
This commit is contained in:
Chris Weaver 2025-04-28 11:53:22 -07:00 committed by GitHub
parent eebfa5be18
commit 47b9e7aa62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,6 +27,7 @@ from onyx.connectors.models import Document
from onyx.connectors.models import TextSection from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger() logger = setup_logger()
@ -38,7 +39,7 @@ def get_created_datetime(chat_message: ChatMessage) -> datetime:
def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]: def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
channel_members_list: list[BasicExpertInfo] = [] channel_members_list: list[BasicExpertInfo] = []
members = channel.members.get().execute_query_retry() members = channel.members.get_all().execute_query_retry()
for member in members: for member in members:
channel_members_list.append(BasicExpertInfo(display_name=member.display_name)) channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
return channel_members_list return channel_members_list
@ -55,7 +56,7 @@ def _get_threads_from_channel(
if end and end.tzinfo is None: if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc) end = end.replace(tzinfo=timezone.utc)
query = channel.messages.get() query = channel.messages.get_all()
base_messages: list[ChatMessage] = query.execute_query_retry() base_messages: list[ChatMessage] = query.execute_query_retry()
threads: list[list[ChatMessage]] = [] threads: list[list[ChatMessage]] = []
@ -86,7 +87,7 @@ def _get_channels_from_teams(
) -> list[Channel]: ) -> list[Channel]:
channels_list: list[Channel] = [] channels_list: list[Channel] = []
for team in teams: for team in teams:
query = team.channels.get() query = team.channels.get_all()
channels = query.execute_query_retry() channels = query.execute_query_retry()
channels_list.extend(channels) channels_list.extend(channels)
@ -180,6 +181,8 @@ class TeamsConnector(LoadConnector, PollConnector):
def __init__( def __init__(
self, self,
batch_size: int = INDEX_BATCH_SIZE, batch_size: int = INDEX_BATCH_SIZE,
# TODO: (chris) move from "Display Names" to IDs, since display names
# are NOT guaranteed to be unique
teams: list[str] = [], teams: list[str] = [],
) -> None: ) -> None:
self.batch_size = batch_size self.batch_size = batch_size
@ -218,24 +221,66 @@ class TeamsConnector(LoadConnector, PollConnector):
if self.graph_client is None: if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams") raise ConnectorMissingCredentialError("Teams")
teams_list: list[Team] = [] teams: list[Team] = []
try:
# Use get_all() to handle pagination automatically
if not self.requested_team_list:
teams = self.graph_client.teams.get_all().execute_query()
else:
# Construct filter using proper Microsoft Graph API syntax
filter_conditions = " or ".join(
[
f"displayName eq '{team_name}'"
for team_name in self.requested_team_list
]
)
teams = self.graph_client.teams.get().execute_query_retry() # Initialize pagination variables
page_size = 100 # Maximum allowed by Microsoft Graph API
skip = 0
if len(self.requested_team_list) > 0: while True:
adjusted_request_strings = [ # Get a page of teams with the filter
requested_team.replace(" ", "") teams_page = (
for requested_team in self.requested_team_list self.graph_client.teams.get()
] .filter(filter_conditions)
teams_list = [ .top(page_size)
team .skip(skip)
for team in teams .execute_query()
if team.display_name.replace(" ", "") in adjusted_request_strings )
]
else:
teams_list.extend(teams)
return teams_list if not teams_page:
break
teams.extend(teams_page)
skip += page_size
# If we got fewer results than the page size, we've reached the end
if len(teams_page) < page_size:
break
# Validate that we found all requested teams
if len(teams) != len(self.requested_team_list):
found_team_names = {
team.properties["displayName"] for team in teams
}
missing_teams = set(self.requested_team_list) - found_team_names
raise ConnectorValidationError(
f"Requested teams not found: {list(missing_teams)}"
)
except ClientRequestException as e:
if e.response.status_code == 403:
raise InsufficientPermissionsError(
"App lacks required permissions to read Teams. "
"Please ensure the app has the following permissions: "
"Team.ReadBasic.All, TeamMember.Read.All, "
"Channel.ReadBasic.All, ChannelMessage.Read.All, "
"Group.Read.All, TeamSettings.ReadWrite.All, "
"ChannelMember.Read.All, ChannelSettings.ReadWrite.All"
)
raise
return teams
def _fetch_from_teams( def _fetch_from_teams(
self, start: datetime | None = None, end: datetime | None = None self, start: datetime | None = None, end: datetime | None = None
@ -262,7 +307,7 @@ class TeamsConnector(LoadConnector, PollConnector):
# goes over channels, converts them into Document objects and then yields them in batches # goes over channels, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = [] doc_batch: list[Document] = []
for channel in channels: for channel in channels:
logger.debug(f"Fetching threads from channel: {channel.id}") logger.info(f"Fetching threads from channel: {channel.id}")
thread_list = _get_threads_from_channel(channel, start=start, end=end) thread_list = _get_threads_from_channel(channel, start=start, end=end)
for thread in thread_list: for thread in thread_list:
converted_doc = _convert_thread_to_document(channel, thread) converted_doc = _convert_thread_to_document(channel, thread)
@ -290,7 +335,8 @@ class TeamsConnector(LoadConnector, PollConnector):
try: try:
# Minimal call to confirm we can retrieve Teams # Minimal call to confirm we can retrieve Teams
found_teams = self._get_all_teams() # make sure it doesn't take forever, since this is a syncronous call
found_teams = run_with_timeout(10, self._get_all_teams)
except ClientRequestException as e: except ClientRequestException as e:
status_code = e.response.status_code status_code = e.response.status_code