Fixed indexing when no sites are specified (#4822)

* Fixed indexing when no sites are specificed

* Added test for Sharepoint all sites index

* Accounted for paginated results.

* Typing

* Typing

---------

Co-authored-by: Wenxi Onyx <wenxi-onyx@Wenxis-MacBook-Pro.local>
This commit is contained in:
Wenxi
2025-06-05 16:25:20 -07:00
committed by GitHub
parent affb9e6941
commit dc4b9bc003
2 changed files with 36 additions and 2 deletions

View File

@@ -1,5 +1,6 @@
import io import io
import os import os
from collections.abc import Generator
from datetime import datetime from datetime import datetime
from datetime import timezone from datetime import timezone
from typing import Any from typing import Any
@@ -8,6 +9,8 @@ from urllib.parse import unquote
import msal # type: ignore import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore from office365.graph_client import GraphClient # type: ignore
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
from office365.onedrive.sites.site import Site # type: ignore
from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import INDEX_BATCH_SIZE
@@ -227,14 +230,29 @@ class SharepointConnector(LoadConnector, PollConnector):
return final_driveitems return final_driveitems
def _handle_paginated_sites(
self, sites: SitesWithRoot
) -> Generator[Site, None, None]:
while sites:
if sites.current_page:
yield from sites.current_page
if not sites.has_next:
break
sites = sites._get_next().execute_query()
def _fetch_sites(self) -> list[SiteDescriptor]: def _fetch_sites(self) -> list[SiteDescriptor]:
sites = self.graph_client.sites.get_all().execute_query() sites = self.graph_client.sites.get_all_sites().execute_query()
if not sites:
raise RuntimeError("No sites found in the tenant")
site_descriptors = [ site_descriptors = [
SiteDescriptor( SiteDescriptor(
url=sites.resource_url, url=site.web_url,
drive_name=None, drive_name=None,
folder_path=None, folder_path=None,
) )
for site in self._handle_paginated_sites(sites)
] ]
return site_descriptors return site_descriptors

View File

@@ -85,6 +85,22 @@ def sharepoint_credentials() -> dict[str, str]:
} }
def test_sharepoint_connector_all_sites(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with no sites
connector = SharepointConnector()
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Not asserting expected sites because that can change in test tenant at any time
# Finding any docs is good enough to verify that the connector is working
document_batches = list(connector.load_from_state())
assert document_batches, "Should find documents from all sites"
def test_sharepoint_connector_specific_folder( def test_sharepoint_connector_specific_folder(
mock_get_unstructured_api_key: MagicMock, mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str], sharepoint_credentials: dict[str, str],