diff --git a/backend/onyx/connectors/sharepoint/connector.py b/backend/onyx/connectors/sharepoint/connector.py index 5e35bf871076..141d9babcbbf 100644 --- a/backend/onyx/connectors/sharepoint/connector.py +++ b/backend/onyx/connectors/sharepoint/connector.py @@ -1,5 +1,6 @@ import io import os +from collections.abc import Generator from datetime import datetime from datetime import timezone from typing import Any @@ -8,6 +9,8 @@ from urllib.parse import unquote import msal # type: ignore from office365.graph_client import GraphClient # type: ignore from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore +from office365.onedrive.sites.site import Site # type: ignore +from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore from pydantic import BaseModel from onyx.configs.app_configs import INDEX_BATCH_SIZE @@ -227,14 +230,29 @@ class SharepointConnector(LoadConnector, PollConnector): return final_driveitems + def _handle_paginated_sites( + self, sites: SitesWithRoot + ) -> Generator[Site, None, None]: + while sites: + if sites.current_page: + yield from sites.current_page + if not sites.has_next: + break + sites = sites._get_next().execute_query() + def _fetch_sites(self) -> list[SiteDescriptor]: - sites = self.graph_client.sites.get_all().execute_query() + sites = self.graph_client.sites.get_all_sites().execute_query() + + if not sites: + raise RuntimeError("No sites found in the tenant") + site_descriptors = [ SiteDescriptor( - url=sites.resource_url, + url=site.web_url, drive_name=None, folder_path=None, ) + for site in self._handle_paginated_sites(sites) ] return site_descriptors diff --git a/backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py b/backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py index a25c04973742..fabbab9616b5 100644 --- a/backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py +++ b/backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py @@ -85,6 +85,22 @@ def sharepoint_credentials() -> dict[str, str]: } +def test_sharepoint_connector_all_sites( + mock_get_unstructured_api_key: MagicMock, + sharepoint_credentials: dict[str, str], +) -> None: + # Initialize connector with no sites + connector = SharepointConnector() + + # Load credentials + connector.load_credentials(sharepoint_credentials) + + # Not asserting expected sites because that can change in test tenant at any time + # Finding any docs is good enough to verify that the connector is working + document_batches = list(connector.load_from_state()) + assert document_batches, "Should find documents from all sites" + + def test_sharepoint_connector_specific_folder( mock_get_unstructured_api_key: MagicMock, sharepoint_credentials: dict[str, str],