diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml index c3947b233..81ab06665 100644 --- a/.github/workflows/pr-python-connector-tests.yml +++ b/.github/workflows/pr-python-connector-tests.yml @@ -39,6 +39,12 @@ env: AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }} AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }} AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }} + # Sharepoint + SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }} + SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }} + SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }} + SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }} + jobs: connectors-check: # See https://runs-on.com/runners/linux/ diff --git a/backend/onyx/connectors/sharepoint/connector.py b/backend/onyx/connectors/sharepoint/connector.py index 88874db6b..5747df03b 100644 --- a/backend/onyx/connectors/sharepoint/connector.py +++ b/backend/onyx/connectors/sharepoint/connector.py @@ -1,17 +1,14 @@ import io import os -from dataclasses import dataclass -from dataclasses import field from datetime import datetime from datetime import timezone from typing import Any -from typing import Optional from urllib.parse import unquote import msal # type: ignore from office365.graph_client import GraphClient # type: ignore from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore -from office365.onedrive.sites.site import Site # type: ignore +from pydantic import BaseModel from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource @@ -30,16 +27,25 @@ from onyx.utils.logger import setup_logger logger = setup_logger() -@dataclass -class SiteData: - url: str | None - folder: Optional[str] - sites: list = field(default_factory=list) - driveitems: list = field(default_factory=list) +class SiteDescriptor(BaseModel): + """Data class for storing SharePoint site information. + + Args: + url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests) + drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library") + If None, all drives will be accessed. + folder_path: The folder path within the drive to access (e.g. "test/nested with spaces") + If None, all folders will be accessed. + """ + + url: str + drive_name: str | None + folder_path: str | None def _convert_driveitem_to_document( driveitem: DriveItem, + drive_name: str, ) -> Document: file_text = extract_file_text( file=io.BytesIO(driveitem.get_content().execute_query().value), @@ -59,7 +65,7 @@ def _convert_driveitem_to_document( email=driveitem.last_modified_by.user.email, ) ], - metadata={}, + metadata={"drive": drive_name}, ) return doc @@ -71,106 +77,171 @@ class SharepointConnector(LoadConnector, PollConnector): sites: list[str] = [], ) -> None: self.batch_size = batch_size - self.graph_client: GraphClient | None = None - self.site_data: list[SiteData] = self._extract_site_and_folder(sites) + self._graph_client: GraphClient | None = None + self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info( + sites + ) + + @property + def graph_client(self) -> GraphClient: + if self._graph_client is None: + raise ConnectorMissingCredentialError("Sharepoint") + + return self._graph_client @staticmethod - def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]: + def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]: site_data_list = [] for url in site_urls: parts = url.strip().split("/") if "sites" in parts: sites_index = parts.index("sites") site_url = "/".join(parts[: sites_index + 2]) - folder = ( - "/".join(unquote(part) for part in parts[sites_index + 2 :]) - if len(parts) > sites_index + 2 - else None - ) - # Handling for new URL structure - if folder and folder.startswith("Shared Documents/"): - folder = folder[len("Shared Documents/") :] + remaining_parts = parts[sites_index + 2 :] + + # Extract drive name and folder path + if remaining_parts: + drive_name = unquote(remaining_parts[0]) + folder_path = ( + "/".join(unquote(part) for part in remaining_parts[1:]) + if len(remaining_parts) > 1 + else None + ) + else: + drive_name = None + folder_path = None + site_data_list.append( - SiteData(url=site_url, folder=folder, sites=[], driveitems=[]) + SiteDescriptor( + url=site_url, + drive_name=drive_name, + folder_path=folder_path, + ) ) return site_data_list - def _populate_sitedata_driveitems( + def _fetch_driveitems( self, + site_descriptor: SiteDescriptor, start: datetime | None = None, end: datetime | None = None, - ) -> None: + ) -> list[tuple[DriveItem, str]]: filter_str = "" if start is not None and end is not None: - filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}" + filter_str = ( + f"last_modified_datetime ge {start.isoformat()} and " + f"last_modified_datetime le {end.isoformat()}" + ) - for element in self.site_data: - sites: list[Site] = [] - for site in element.sites: - site_sublist = site.lists.get().execute_query() - sites.extend(site_sublist) + final_driveitems: list[tuple[DriveItem, str]] = [] + try: + site = self.graph_client.sites.get_by_url(site_descriptor.url) - for site in sites: + # Get all drives in the site + drives = site.drives.get().execute_query() + logger.debug(f"Found drives: {[drive.name for drive in drives]}") + + # Filter drives based on the requested drive name + if site_descriptor.drive_name: + drives = [ + drive + for drive in drives + if drive.name == site_descriptor.drive_name + or ( + drive.name == "Documents" + and site_descriptor.drive_name == "Shared Documents" + ) + ] + if not drives: + logger.warning(f"Drive '{site_descriptor.drive_name}' not found") + return [] + + # Process each matching drive + for drive in drives: try: - query = site.drive.root.get_files(True, 1000) + root_folder = drive.root + if site_descriptor.folder_path: + # If a specific folder is requested, navigate to it + for folder_part in site_descriptor.folder_path.split("/"): + root_folder = root_folder.get_by_path(folder_part) + + # Get all items recursively + query = root_folder.get_files(True, 1000) if filter_str: query = query.filter(filter_str) driveitems = query.execute_query() - if element.folder: - expected_path = f"/root:/{element.folder}" + logger.debug( + f"Found {len(driveitems)} items in drive '{drive.name}'" + ) + + # Use "Shared Documents" as the library name for the default "Documents" drive + drive_name = ( + "Shared Documents" if drive.name == "Documents" else drive.name + ) + + if site_descriptor.folder_path: + # Filter items to ensure they're in the specified folder or its subfolders + # The path will be in format: /drives/{drive_id}/root:/folder/path filtered_driveitems = [ - item + (item, drive_name) for item in driveitems - if item.parent_reference.path.endswith(expected_path) + if any( + path_part == site_descriptor.folder_path + or path_part.startswith( + site_descriptor.folder_path + "/" + ) + for path_part in item.parent_reference.path.split( + "root:/" + )[1].split("/") + ) ] if len(filtered_driveitems) == 0: all_paths = [ item.parent_reference.path for item in driveitems ] logger.warning( - f"Nothing found for folder '{expected_path}' in any of valid paths: {all_paths}" + f"Nothing found for folder '{site_descriptor.folder_path}' " + f"in; any of valid paths: {all_paths}" ) - element.driveitems.extend(filtered_driveitems) + final_driveitems.extend(filtered_driveitems) else: - element.driveitems.extend(driveitems) + final_driveitems.extend( + [(item, drive_name) for item in driveitems] + ) + except Exception as e: + # Some drives might not be accessible + logger.warning(f"Failed to process drive: {str(e)}") - except Exception: - # Sites include things that do not contain .drive.root so this fails - # but this is fine, as there are no actually documents in those - pass + except Exception as e: + # Sites include things that do not contain drives so this fails + # but this is fine, as there are no actual documents in those + logger.warning(f"Failed to process site: {str(e)}") - def _populate_sitedata_sites(self) -> None: - if self.graph_client is None: - raise ConnectorMissingCredentialError("Sharepoint") + return final_driveitems - if self.site_data: - for element in self.site_data: - element.sites = [ - self.graph_client.sites.get_by_url(element.url) - .get() - .execute_query() - ] - else: - sites = self.graph_client.sites.get_all().execute_query() - self.site_data = [ - SiteData(url=None, folder=None, sites=sites, driveitems=[]) - ] + def _fetch_sites(self) -> list[SiteDescriptor]: + sites = self.graph_client.sites.get_all().execute_query() + site_descriptors = [ + SiteDescriptor( + url=sites.resource_url, + drive_name=None, + folder_path=None, + ) + ] + return site_descriptors def _fetch_from_sharepoint( self, start: datetime | None = None, end: datetime | None = None ) -> GenerateDocumentsOutput: - if self.graph_client is None: - raise ConnectorMissingCredentialError("Sharepoint") - - self._populate_sitedata_sites() - self._populate_sitedata_driveitems(start=start, end=end) + site_descriptors = self.site_descriptors or self._fetch_sites() # goes over all urls, converts them into Document objects and then yields them in batches doc_batch: list[Document] = [] - for element in self.site_data: - for driveitem in element.driveitems: + for site_descriptor in site_descriptors: + driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end) + for driveitem, drive_name in driveitems: logger.debug(f"Processing: {driveitem.web_url}") - doc_batch.append(_convert_driveitem_to_document(driveitem)) + doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name)) if len(doc_batch) >= self.batch_size: yield doc_batch @@ -197,7 +268,7 @@ class SharepointConnector(LoadConnector, PollConnector): ) return token - self.graph_client = GraphClient(_acquire_token_func) + self._graph_client = GraphClient(_acquire_token_func) return None def load_from_state(self) -> GenerateDocumentsOutput: @@ -206,19 +277,19 @@ class SharepointConnector(LoadConnector, PollConnector): def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: - start_datetime = datetime.utcfromtimestamp(start) - end_datetime = datetime.utcfromtimestamp(end) + start_datetime = datetime.fromtimestamp(start, timezone.utc) + end_datetime = datetime.fromtimestamp(end, timezone.utc) return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime) if __name__ == "__main__": - connector = SharepointConnector(sites=os.environ["SITES"].split(",")) + connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(",")) connector.load_credentials( { - "sp_client_id": os.environ["SP_CLIENT_ID"], - "sp_client_secret": os.environ["SP_CLIENT_SECRET"], - "sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"], + "sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"], + "sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"], + "sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"], } ) document_batches = connector.load_from_state() diff --git a/backend/tests/daily/connectors/airtable/test_airtable_basic.py b/backend/tests/daily/connectors/airtable/test_airtable_basic.py index bb5e312f1..6610d91d6 100644 --- a/backend/tests/daily/connectors/airtable/test_airtable_basic.py +++ b/backend/tests/daily/connectors/airtable/test_airtable_basic.py @@ -1,7 +1,5 @@ import os -from collections.abc import Generator from unittest.mock import MagicMock -from unittest.mock import patch import pytest from pydantic import BaseModel @@ -127,15 +125,6 @@ def create_test_document( ) -@pytest.fixture -def mock_get_api_key() -> Generator[MagicMock, None, None]: - with patch( - "onyx.file_processing.extract_file_text.get_unstructured_api_key", - return_value=None, - ) as mock: - yield mock - - def compare_documents( actual_docs: list[Document], expected_docs: list[Document] ) -> None: @@ -191,7 +180,7 @@ def compare_documents( def test_airtable_connector_basic( - mock_get_api_key: MagicMock, airtable_config: AirtableConfig + mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig ) -> None: """Test behavior when all non-attachment fields are treated as metadata.""" connector = AirtableConnector( @@ -253,7 +242,7 @@ def test_airtable_connector_basic( def test_airtable_connector_all_metadata( - mock_get_api_key: MagicMock, airtable_config: AirtableConfig + mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig ) -> None: connector = AirtableConnector( base_id=airtable_config.base_id, diff --git a/backend/tests/daily/connectors/conftest.py b/backend/tests/daily/connectors/conftest.py new file mode 100644 index 000000000..88a00b57a --- /dev/null +++ b/backend/tests/daily/connectors/conftest.py @@ -0,0 +1,14 @@ +from collections.abc import Generator +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + + +@pytest.fixture +def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]: + with patch( + "onyx.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, + ) as mock: + yield mock diff --git a/backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py b/backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py new file mode 100644 index 000000000..8fc40564f --- /dev/null +++ b/backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py @@ -0,0 +1,178 @@ +import os +from dataclasses import dataclass +from datetime import datetime +from datetime import timezone +from unittest.mock import MagicMock + +import pytest + +from onyx.configs.constants import DocumentSource +from onyx.connectors.models import Document +from onyx.connectors.sharepoint.connector import SharepointConnector + + +@dataclass +class ExpectedDocument: + semantic_identifier: str + content: str + folder_path: str | None = None + library: str = "Shared Documents" # Default to main library + + +EXPECTED_DOCUMENTS = [ + ExpectedDocument( + semantic_identifier="test1.docx", + content="test1", + folder_path="test", + ), + ExpectedDocument( + semantic_identifier="test2.docx", + content="test2", + folder_path="test/nested with spaces", + ), + ExpectedDocument( + semantic_identifier="should-not-index-on-specific-folder.docx", + content="should-not-index-on-specific-folder", + folder_path=None, # root folder + ), + ExpectedDocument( + semantic_identifier="other.docx", + content="other", + folder_path=None, + library="Other Library", + ), +] + + +def verify_document_metadata(doc: Document) -> None: + """Verify common metadata that should be present on all documents.""" + assert isinstance(doc.doc_updated_at, datetime) + assert doc.doc_updated_at.tzinfo == timezone.utc + assert doc.source == DocumentSource.SHAREPOINT + assert doc.primary_owners is not None + assert len(doc.primary_owners) == 1 + owner = doc.primary_owners[0] + assert owner.display_name is not None + assert owner.email is not None + + +def verify_document_content(doc: Document, expected: ExpectedDocument) -> None: + """Verify a document matches its expected content.""" + assert doc.semantic_identifier == expected.semantic_identifier + assert len(doc.sections) == 1 + assert expected.content in doc.sections[0].text + verify_document_metadata(doc) + + +def find_document(documents: list[Document], semantic_identifier: str) -> Document: + """Find a document by its semantic identifier.""" + matching_docs = [ + d for d in documents if d.semantic_identifier == semantic_identifier + ] + assert ( + len(matching_docs) == 1 + ), f"Expected exactly one document with identifier {semantic_identifier}" + return matching_docs[0] + + +@pytest.fixture +def sharepoint_credentials() -> dict[str, str]: + return { + "sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"], + "sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"], + "sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"], + } + + +def test_sharepoint_connector_specific_folder( + mock_get_unstructured_api_key: MagicMock, + sharepoint_credentials: dict[str, str], +) -> None: + # Initialize connector with the test site URL and specific folder + connector = SharepointConnector( + sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"] + ) + + # Load credentials + connector.load_credentials(sharepoint_credentials) + + # Get all documents + document_batches = list(connector.load_from_state()) + found_documents: list[Document] = [ + doc for batch in document_batches for doc in batch + ] + + # Should only find documents in the test folder + test_folder_docs = [ + doc + for doc in EXPECTED_DOCUMENTS + if doc.folder_path and doc.folder_path.startswith("test") + ] + assert len(found_documents) == len( + test_folder_docs + ), "Should only find documents in test folder" + + # Verify each expected document + for expected in test_folder_docs: + doc = find_document(found_documents, expected.semantic_identifier) + verify_document_content(doc, expected) + + +def test_sharepoint_connector_root_folder( + mock_get_unstructured_api_key: MagicMock, + sharepoint_credentials: dict[str, str], +) -> None: + # Initialize connector with the base site URL + connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]]) + + # Load credentials + connector.load_credentials(sharepoint_credentials) + + # Get all documents + document_batches = list(connector.load_from_state()) + found_documents: list[Document] = [ + doc for batch in document_batches for doc in batch + ] + + assert len(found_documents) == len( + EXPECTED_DOCUMENTS + ), "Should find all documents in main library" + + # Verify each expected document + for expected in EXPECTED_DOCUMENTS: + doc = find_document(found_documents, expected.semantic_identifier) + verify_document_content(doc, expected) + + +def test_sharepoint_connector_other_library( + mock_get_unstructured_api_key: MagicMock, + sharepoint_credentials: dict[str, str], +) -> None: + # Initialize connector with the other library + connector = SharepointConnector( + sites=[ + os.environ["SHAREPOINT_SITE"] + "/Other Library", + ] + ) + + # Load credentials + connector.load_credentials(sharepoint_credentials) + + # Get all documents + document_batches = list(connector.load_from_state()) + found_documents: list[Document] = [ + doc for batch in document_batches for doc in batch + ] + expected_documents: list[ExpectedDocument] = [ + doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library" + ] + + # Should find all documents in `Other Library` + assert len(found_documents) == len( + expected_documents + ), "Should find all documents in `Other Library`" + + # Verify each expected document + for expected in expected_documents: + doc = find_document(found_documents, expected.semantic_identifier) + verify_document_content(doc, expected)