mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 13:05:49 +02:00
Sharepoint fixes (#3826)
* Sharepoint connector fixes * Refactor sharepoint to be better * Improve env variable naming * Fix * Add new secrets * Fix unstructured failure
This commit is contained in:
@@ -39,6 +39,12 @@ env:
|
|||||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||||
|
# Sharepoint
|
||||||
|
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||||
|
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||||
|
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||||
|
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
connectors-check:
|
connectors-check:
|
||||||
# See https://runs-on.com/runners/linux/
|
# See https://runs-on.com/runners/linux/
|
||||||
|
@@ -1,17 +1,14 @@
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
|
||||||
from dataclasses import field
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
import msal # type: ignore
|
import msal # type: ignore
|
||||||
from office365.graph_client import GraphClient # type: ignore
|
from office365.graph_client import GraphClient # type: ignore
|
||||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
|
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
|
||||||
from office365.onedrive.sites.site import Site # type: ignore
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
@@ -30,16 +27,25 @@ from onyx.utils.logger import setup_logger
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class SiteDescriptor(BaseModel):
|
||||||
class SiteData:
|
"""Data class for storing SharePoint site information.
|
||||||
url: str | None
|
|
||||||
folder: Optional[str]
|
Args:
|
||||||
sites: list = field(default_factory=list)
|
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
|
||||||
driveitems: list = field(default_factory=list)
|
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
|
||||||
|
If None, all drives will be accessed.
|
||||||
|
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
|
||||||
|
If None, all folders will be accessed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
drive_name: str | None
|
||||||
|
folder_path: str | None
|
||||||
|
|
||||||
|
|
||||||
def _convert_driveitem_to_document(
|
def _convert_driveitem_to_document(
|
||||||
driveitem: DriveItem,
|
driveitem: DriveItem,
|
||||||
|
drive_name: str,
|
||||||
) -> Document:
|
) -> Document:
|
||||||
file_text = extract_file_text(
|
file_text = extract_file_text(
|
||||||
file=io.BytesIO(driveitem.get_content().execute_query().value),
|
file=io.BytesIO(driveitem.get_content().execute_query().value),
|
||||||
@@ -59,7 +65,7 @@ def _convert_driveitem_to_document(
|
|||||||
email=driveitem.last_modified_by.user.email,
|
email=driveitem.last_modified_by.user.email,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
metadata={},
|
metadata={"drive": drive_name},
|
||||||
)
|
)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
@@ -71,106 +77,171 @@ class SharepointConnector(LoadConnector, PollConnector):
|
|||||||
sites: list[str] = [],
|
sites: list[str] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.graph_client: GraphClient | None = None
|
self._graph_client: GraphClient | None = None
|
||||||
self.site_data: list[SiteData] = self._extract_site_and_folder(sites)
|
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
|
||||||
|
sites
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph_client(self) -> GraphClient:
|
||||||
|
if self._graph_client is None:
|
||||||
|
raise ConnectorMissingCredentialError("Sharepoint")
|
||||||
|
|
||||||
|
return self._graph_client
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]:
|
def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]:
|
||||||
site_data_list = []
|
site_data_list = []
|
||||||
for url in site_urls:
|
for url in site_urls:
|
||||||
parts = url.strip().split("/")
|
parts = url.strip().split("/")
|
||||||
if "sites" in parts:
|
if "sites" in parts:
|
||||||
sites_index = parts.index("sites")
|
sites_index = parts.index("sites")
|
||||||
site_url = "/".join(parts[: sites_index + 2])
|
site_url = "/".join(parts[: sites_index + 2])
|
||||||
folder = (
|
remaining_parts = parts[sites_index + 2 :]
|
||||||
"/".join(unquote(part) for part in parts[sites_index + 2 :])
|
|
||||||
if len(parts) > sites_index + 2
|
# Extract drive name and folder path
|
||||||
|
if remaining_parts:
|
||||||
|
drive_name = unquote(remaining_parts[0])
|
||||||
|
folder_path = (
|
||||||
|
"/".join(unquote(part) for part in remaining_parts[1:])
|
||||||
|
if len(remaining_parts) > 1
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
# Handling for new URL structure
|
else:
|
||||||
if folder and folder.startswith("Shared Documents/"):
|
drive_name = None
|
||||||
folder = folder[len("Shared Documents/") :]
|
folder_path = None
|
||||||
|
|
||||||
site_data_list.append(
|
site_data_list.append(
|
||||||
SiteData(url=site_url, folder=folder, sites=[], driveitems=[])
|
SiteDescriptor(
|
||||||
|
url=site_url,
|
||||||
|
drive_name=drive_name,
|
||||||
|
folder_path=folder_path,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return site_data_list
|
return site_data_list
|
||||||
|
|
||||||
def _populate_sitedata_driveitems(
|
def _fetch_driveitems(
|
||||||
self,
|
self,
|
||||||
|
site_descriptor: SiteDescriptor,
|
||||||
start: datetime | None = None,
|
start: datetime | None = None,
|
||||||
end: datetime | None = None,
|
end: datetime | None = None,
|
||||||
) -> None:
|
) -> list[tuple[DriveItem, str]]:
|
||||||
filter_str = ""
|
filter_str = ""
|
||||||
if start is not None and end is not None:
|
if start is not None and end is not None:
|
||||||
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
|
filter_str = (
|
||||||
|
f"last_modified_datetime ge {start.isoformat()} and "
|
||||||
|
f"last_modified_datetime le {end.isoformat()}"
|
||||||
|
)
|
||||||
|
|
||||||
for element in self.site_data:
|
final_driveitems: list[tuple[DriveItem, str]] = []
|
||||||
sites: list[Site] = []
|
|
||||||
for site in element.sites:
|
|
||||||
site_sublist = site.lists.get().execute_query()
|
|
||||||
sites.extend(site_sublist)
|
|
||||||
|
|
||||||
for site in sites:
|
|
||||||
try:
|
try:
|
||||||
query = site.drive.root.get_files(True, 1000)
|
site = self.graph_client.sites.get_by_url(site_descriptor.url)
|
||||||
|
|
||||||
|
# Get all drives in the site
|
||||||
|
drives = site.drives.get().execute_query()
|
||||||
|
logger.debug(f"Found drives: {[drive.name for drive in drives]}")
|
||||||
|
|
||||||
|
# Filter drives based on the requested drive name
|
||||||
|
if site_descriptor.drive_name:
|
||||||
|
drives = [
|
||||||
|
drive
|
||||||
|
for drive in drives
|
||||||
|
if drive.name == site_descriptor.drive_name
|
||||||
|
or (
|
||||||
|
drive.name == "Documents"
|
||||||
|
and site_descriptor.drive_name == "Shared Documents"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if not drives:
|
||||||
|
logger.warning(f"Drive '{site_descriptor.drive_name}' not found")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Process each matching drive
|
||||||
|
for drive in drives:
|
||||||
|
try:
|
||||||
|
root_folder = drive.root
|
||||||
|
if site_descriptor.folder_path:
|
||||||
|
# If a specific folder is requested, navigate to it
|
||||||
|
for folder_part in site_descriptor.folder_path.split("/"):
|
||||||
|
root_folder = root_folder.get_by_path(folder_part)
|
||||||
|
|
||||||
|
# Get all items recursively
|
||||||
|
query = root_folder.get_files(True, 1000)
|
||||||
if filter_str:
|
if filter_str:
|
||||||
query = query.filter(filter_str)
|
query = query.filter(filter_str)
|
||||||
driveitems = query.execute_query()
|
driveitems = query.execute_query()
|
||||||
if element.folder:
|
logger.debug(
|
||||||
expected_path = f"/root:/{element.folder}"
|
f"Found {len(driveitems)} items in drive '{drive.name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use "Shared Documents" as the library name for the default "Documents" drive
|
||||||
|
drive_name = (
|
||||||
|
"Shared Documents" if drive.name == "Documents" else drive.name
|
||||||
|
)
|
||||||
|
|
||||||
|
if site_descriptor.folder_path:
|
||||||
|
# Filter items to ensure they're in the specified folder or its subfolders
|
||||||
|
# The path will be in format: /drives/{drive_id}/root:/folder/path
|
||||||
filtered_driveitems = [
|
filtered_driveitems = [
|
||||||
item
|
(item, drive_name)
|
||||||
for item in driveitems
|
for item in driveitems
|
||||||
if item.parent_reference.path.endswith(expected_path)
|
if any(
|
||||||
|
path_part == site_descriptor.folder_path
|
||||||
|
or path_part.startswith(
|
||||||
|
site_descriptor.folder_path + "/"
|
||||||
|
)
|
||||||
|
for path_part in item.parent_reference.path.split(
|
||||||
|
"root:/"
|
||||||
|
)[1].split("/")
|
||||||
|
)
|
||||||
]
|
]
|
||||||
if len(filtered_driveitems) == 0:
|
if len(filtered_driveitems) == 0:
|
||||||
all_paths = [
|
all_paths = [
|
||||||
item.parent_reference.path for item in driveitems
|
item.parent_reference.path for item in driveitems
|
||||||
]
|
]
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Nothing found for folder '{expected_path}' in any of valid paths: {all_paths}"
|
f"Nothing found for folder '{site_descriptor.folder_path}' "
|
||||||
|
f"in; any of valid paths: {all_paths}"
|
||||||
)
|
)
|
||||||
element.driveitems.extend(filtered_driveitems)
|
final_driveitems.extend(filtered_driveitems)
|
||||||
else:
|
else:
|
||||||
element.driveitems.extend(driveitems)
|
final_driveitems.extend(
|
||||||
|
[(item, drive_name) for item in driveitems]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Some drives might not be accessible
|
||||||
|
logger.warning(f"Failed to process drive: {str(e)}")
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# Sites include things that do not contain .drive.root so this fails
|
# Sites include things that do not contain drives so this fails
|
||||||
# but this is fine, as there are no actually documents in those
|
# but this is fine, as there are no actual documents in those
|
||||||
pass
|
logger.warning(f"Failed to process site: {str(e)}")
|
||||||
|
|
||||||
def _populate_sitedata_sites(self) -> None:
|
return final_driveitems
|
||||||
if self.graph_client is None:
|
|
||||||
raise ConnectorMissingCredentialError("Sharepoint")
|
|
||||||
|
|
||||||
if self.site_data:
|
def _fetch_sites(self) -> list[SiteDescriptor]:
|
||||||
for element in self.site_data:
|
|
||||||
element.sites = [
|
|
||||||
self.graph_client.sites.get_by_url(element.url)
|
|
||||||
.get()
|
|
||||||
.execute_query()
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
sites = self.graph_client.sites.get_all().execute_query()
|
sites = self.graph_client.sites.get_all().execute_query()
|
||||||
self.site_data = [
|
site_descriptors = [
|
||||||
SiteData(url=None, folder=None, sites=sites, driveitems=[])
|
SiteDescriptor(
|
||||||
|
url=sites.resource_url,
|
||||||
|
drive_name=None,
|
||||||
|
folder_path=None,
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
return site_descriptors
|
||||||
|
|
||||||
def _fetch_from_sharepoint(
|
def _fetch_from_sharepoint(
|
||||||
self, start: datetime | None = None, end: datetime | None = None
|
self, start: datetime | None = None, end: datetime | None = None
|
||||||
) -> GenerateDocumentsOutput:
|
) -> GenerateDocumentsOutput:
|
||||||
if self.graph_client is None:
|
site_descriptors = self.site_descriptors or self._fetch_sites()
|
||||||
raise ConnectorMissingCredentialError("Sharepoint")
|
|
||||||
|
|
||||||
self._populate_sitedata_sites()
|
|
||||||
self._populate_sitedata_driveitems(start=start, end=end)
|
|
||||||
|
|
||||||
# goes over all urls, converts them into Document objects and then yields them in batches
|
# goes over all urls, converts them into Document objects and then yields them in batches
|
||||||
doc_batch: list[Document] = []
|
doc_batch: list[Document] = []
|
||||||
for element in self.site_data:
|
for site_descriptor in site_descriptors:
|
||||||
for driveitem in element.driveitems:
|
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
|
||||||
|
for driveitem, drive_name in driveitems:
|
||||||
logger.debug(f"Processing: {driveitem.web_url}")
|
logger.debug(f"Processing: {driveitem.web_url}")
|
||||||
doc_batch.append(_convert_driveitem_to_document(driveitem))
|
doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
|
||||||
|
|
||||||
if len(doc_batch) >= self.batch_size:
|
if len(doc_batch) >= self.batch_size:
|
||||||
yield doc_batch
|
yield doc_batch
|
||||||
@@ -197,7 +268,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
|||||||
)
|
)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
self.graph_client = GraphClient(_acquire_token_func)
|
self._graph_client = GraphClient(_acquire_token_func)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||||
@@ -206,19 +277,19 @@ class SharepointConnector(LoadConnector, PollConnector):
|
|||||||
def poll_source(
|
def poll_source(
|
||||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||||
) -> GenerateDocumentsOutput:
|
) -> GenerateDocumentsOutput:
|
||||||
start_datetime = datetime.utcfromtimestamp(start)
|
start_datetime = datetime.fromtimestamp(start, timezone.utc)
|
||||||
end_datetime = datetime.utcfromtimestamp(end)
|
end_datetime = datetime.fromtimestamp(end, timezone.utc)
|
||||||
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
|
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
connector = SharepointConnector(sites=os.environ["SITES"].split(","))
|
connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(","))
|
||||||
|
|
||||||
connector.load_credentials(
|
connector.load_credentials(
|
||||||
{
|
{
|
||||||
"sp_client_id": os.environ["SP_CLIENT_ID"],
|
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
|
||||||
"sp_client_secret": os.environ["SP_CLIENT_SECRET"],
|
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
|
||||||
"sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"],
|
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
document_batches = connector.load_from_state()
|
document_batches = connector.load_from_state()
|
||||||
|
@@ -1,7 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from collections.abc import Generator
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -127,15 +125,6 @@ def create_test_document(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_get_api_key() -> Generator[MagicMock, None, None]:
|
|
||||||
with patch(
|
|
||||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
|
||||||
return_value=None,
|
|
||||||
) as mock:
|
|
||||||
yield mock
|
|
||||||
|
|
||||||
|
|
||||||
def compare_documents(
|
def compare_documents(
|
||||||
actual_docs: list[Document], expected_docs: list[Document]
|
actual_docs: list[Document], expected_docs: list[Document]
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -191,7 +180,7 @@ def compare_documents(
|
|||||||
|
|
||||||
|
|
||||||
def test_airtable_connector_basic(
|
def test_airtable_connector_basic(
|
||||||
mock_get_api_key: MagicMock, airtable_config: AirtableConfig
|
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test behavior when all non-attachment fields are treated as metadata."""
|
"""Test behavior when all non-attachment fields are treated as metadata."""
|
||||||
connector = AirtableConnector(
|
connector = AirtableConnector(
|
||||||
@@ -253,7 +242,7 @@ def test_airtable_connector_basic(
|
|||||||
|
|
||||||
|
|
||||||
def test_airtable_connector_all_metadata(
|
def test_airtable_connector_all_metadata(
|
||||||
mock_get_api_key: MagicMock, airtable_config: AirtableConfig
|
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
|
||||||
) -> None:
|
) -> None:
|
||||||
connector = AirtableConnector(
|
connector = AirtableConnector(
|
||||||
base_id=airtable_config.base_id,
|
base_id=airtable_config.base_id,
|
||||||
|
14
backend/tests/daily/connectors/conftest.py
Normal file
14
backend/tests/daily/connectors/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||||
|
with patch(
|
||||||
|
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||||
|
return_value=None,
|
||||||
|
) as mock:
|
||||||
|
yield mock
|
@@ -0,0 +1,178 @@
|
|||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timezone
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from onyx.configs.constants import DocumentSource
|
||||||
|
from onyx.connectors.models import Document
|
||||||
|
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExpectedDocument:
|
||||||
|
semantic_identifier: str
|
||||||
|
content: str
|
||||||
|
folder_path: str | None = None
|
||||||
|
library: str = "Shared Documents" # Default to main library
|
||||||
|
|
||||||
|
|
||||||
|
EXPECTED_DOCUMENTS = [
|
||||||
|
ExpectedDocument(
|
||||||
|
semantic_identifier="test1.docx",
|
||||||
|
content="test1",
|
||||||
|
folder_path="test",
|
||||||
|
),
|
||||||
|
ExpectedDocument(
|
||||||
|
semantic_identifier="test2.docx",
|
||||||
|
content="test2",
|
||||||
|
folder_path="test/nested with spaces",
|
||||||
|
),
|
||||||
|
ExpectedDocument(
|
||||||
|
semantic_identifier="should-not-index-on-specific-folder.docx",
|
||||||
|
content="should-not-index-on-specific-folder",
|
||||||
|
folder_path=None, # root folder
|
||||||
|
),
|
||||||
|
ExpectedDocument(
|
||||||
|
semantic_identifier="other.docx",
|
||||||
|
content="other",
|
||||||
|
folder_path=None,
|
||||||
|
library="Other Library",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def verify_document_metadata(doc: Document) -> None:
|
||||||
|
"""Verify common metadata that should be present on all documents."""
|
||||||
|
assert isinstance(doc.doc_updated_at, datetime)
|
||||||
|
assert doc.doc_updated_at.tzinfo == timezone.utc
|
||||||
|
assert doc.source == DocumentSource.SHAREPOINT
|
||||||
|
assert doc.primary_owners is not None
|
||||||
|
assert len(doc.primary_owners) == 1
|
||||||
|
owner = doc.primary_owners[0]
|
||||||
|
assert owner.display_name is not None
|
||||||
|
assert owner.email is not None
|
||||||
|
|
||||||
|
|
||||||
|
def verify_document_content(doc: Document, expected: ExpectedDocument) -> None:
|
||||||
|
"""Verify a document matches its expected content."""
|
||||||
|
assert doc.semantic_identifier == expected.semantic_identifier
|
||||||
|
assert len(doc.sections) == 1
|
||||||
|
assert expected.content in doc.sections[0].text
|
||||||
|
verify_document_metadata(doc)
|
||||||
|
|
||||||
|
|
||||||
|
def find_document(documents: list[Document], semantic_identifier: str) -> Document:
|
||||||
|
"""Find a document by its semantic identifier."""
|
||||||
|
matching_docs = [
|
||||||
|
d for d in documents if d.semantic_identifier == semantic_identifier
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(matching_docs) == 1
|
||||||
|
), f"Expected exactly one document with identifier {semantic_identifier}"
|
||||||
|
return matching_docs[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sharepoint_credentials() -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
|
||||||
|
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
|
||||||
|
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharepoint_connector_specific_folder(
|
||||||
|
mock_get_unstructured_api_key: MagicMock,
|
||||||
|
sharepoint_credentials: dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
# Initialize connector with the test site URL and specific folder
|
||||||
|
connector = SharepointConnector(
|
||||||
|
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load credentials
|
||||||
|
connector.load_credentials(sharepoint_credentials)
|
||||||
|
|
||||||
|
# Get all documents
|
||||||
|
document_batches = list(connector.load_from_state())
|
||||||
|
found_documents: list[Document] = [
|
||||||
|
doc for batch in document_batches for doc in batch
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should only find documents in the test folder
|
||||||
|
test_folder_docs = [
|
||||||
|
doc
|
||||||
|
for doc in EXPECTED_DOCUMENTS
|
||||||
|
if doc.folder_path and doc.folder_path.startswith("test")
|
||||||
|
]
|
||||||
|
assert len(found_documents) == len(
|
||||||
|
test_folder_docs
|
||||||
|
), "Should only find documents in test folder"
|
||||||
|
|
||||||
|
# Verify each expected document
|
||||||
|
for expected in test_folder_docs:
|
||||||
|
doc = find_document(found_documents, expected.semantic_identifier)
|
||||||
|
verify_document_content(doc, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharepoint_connector_root_folder(
|
||||||
|
mock_get_unstructured_api_key: MagicMock,
|
||||||
|
sharepoint_credentials: dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
# Initialize connector with the base site URL
|
||||||
|
connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]])
|
||||||
|
|
||||||
|
# Load credentials
|
||||||
|
connector.load_credentials(sharepoint_credentials)
|
||||||
|
|
||||||
|
# Get all documents
|
||||||
|
document_batches = list(connector.load_from_state())
|
||||||
|
found_documents: list[Document] = [
|
||||||
|
doc for batch in document_batches for doc in batch
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(found_documents) == len(
|
||||||
|
EXPECTED_DOCUMENTS
|
||||||
|
), "Should find all documents in main library"
|
||||||
|
|
||||||
|
# Verify each expected document
|
||||||
|
for expected in EXPECTED_DOCUMENTS:
|
||||||
|
doc = find_document(found_documents, expected.semantic_identifier)
|
||||||
|
verify_document_content(doc, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharepoint_connector_other_library(
|
||||||
|
mock_get_unstructured_api_key: MagicMock,
|
||||||
|
sharepoint_credentials: dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
# Initialize connector with the other library
|
||||||
|
connector = SharepointConnector(
|
||||||
|
sites=[
|
||||||
|
os.environ["SHAREPOINT_SITE"] + "/Other Library",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load credentials
|
||||||
|
connector.load_credentials(sharepoint_credentials)
|
||||||
|
|
||||||
|
# Get all documents
|
||||||
|
document_batches = list(connector.load_from_state())
|
||||||
|
found_documents: list[Document] = [
|
||||||
|
doc for batch in document_batches for doc in batch
|
||||||
|
]
|
||||||
|
expected_documents: list[ExpectedDocument] = [
|
||||||
|
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should find all documents in `Other Library`
|
||||||
|
assert len(found_documents) == len(
|
||||||
|
expected_documents
|
||||||
|
), "Should find all documents in `Other Library`"
|
||||||
|
|
||||||
|
# Verify each expected document
|
||||||
|
for expected in expected_documents:
|
||||||
|
doc = find_document(found_documents, expected.semantic_identifier)
|
||||||
|
verify_document_content(doc, expected)
|
Reference in New Issue
Block a user