Merge branch 'main' of https://github.com/onyx-dot-app/onyx into bugfix/sharepoint_app_init

# Conflicts:
#	backend/onyx/connectors/sharepoint/connector.py
This commit is contained in:
Richard Kuo (Danswer) 2025-01-28 21:11:13 -08:00
commit 58e5deba01
15 changed files with 731 additions and 289 deletions

View File

@ -39,6 +39,12 @@ env:
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

View File

@ -20,9 +20,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# NOTE: all are made lowercase to avoid case sensitivity issues
# these are the field types that are considered metadata rather
# than sections
_METADATA_FIELD_TYPES = {
# These field types are considered metadata by default when
# treat_all_non_attachment_fields_as_metadata is False
DEFAULT_METADATA_FIELD_TYPES = {
"singlecollaborator",
"collaborator",
"createdby",
@ -60,12 +60,16 @@ class AirtableConnector(LoadConnector):
self,
base_id: str,
table_name_or_id: str,
treat_all_non_attachment_fields_as_metadata: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.batch_size = batch_size
self.airtable_client: AirtableApi | None = None
self.treat_all_non_attachment_fields_as_metadata = (
treat_all_non_attachment_fields_as_metadata
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
@ -166,8 +170,14 @@ class AirtableConnector(LoadConnector):
return [(str(field_info), default_link)]
def _should_be_metadata(self, field_type: str) -> bool:
"""Determine if a field type should be treated as metadata."""
return field_type.lower() in _METADATA_FIELD_TYPES
"""Determine if a field type should be treated as metadata.
When treat_all_non_attachment_fields_as_metadata is True, all fields except
attachments are treated as metadata. Otherwise, only fields with types listed
in DEFAULT_METADATA_FIELD_TYPES are treated as metadata."""
if self.treat_all_non_attachment_fields_as_metadata:
return field_type.lower() != "multipleattachments"
return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES
def _process_field(
self,
@ -233,7 +243,7 @@ class AirtableConnector(LoadConnector):
record: RecordDict,
table_schema: TableSchema,
primary_field_name: str | None,
) -> Document:
) -> Document | None:
"""Process a single Airtable record into a Document.
Args:
@ -277,6 +287,10 @@ class AirtableConnector(LoadConnector):
sections.extend(field_sections)
metadata.update(field_metadata)
if not sections:
logger.warning(f"No sections found for record {record_id}")
return None
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
@ -320,7 +334,8 @@ class AirtableConnector(LoadConnector):
table_schema=table_schema,
primary_field_name=primary_field_name,
)
record_documents.append(document)
if document:
record_documents.append(document)
if len(record_documents) >= self.batch_size:
yield record_documents

View File

@ -1,17 +1,14 @@
import io
import os
from dataclasses import dataclass
from dataclasses import field
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import Optional
from urllib.parse import unquote
import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
from office365.onedrive.sites.site import Site # type: ignore
from pydantic import BaseModel
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@ -30,16 +27,25 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
@dataclass
class SiteData:
url: str | None
folder: Optional[str]
sites: list = field(default_factory=list)
driveitems: list = field(default_factory=list)
class SiteDescriptor(BaseModel):
"""Data class for storing SharePoint site information.
Args:
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
If None, all drives will be accessed.
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
If None, all folders will be accessed.
"""
url: str
drive_name: str | None
folder_path: str | None
def _convert_driveitem_to_document(
driveitem: DriveItem,
drive_name: str,
) -> Document:
file_text = extract_file_text(
file=io.BytesIO(driveitem.get_content().execute_query().value),
@ -59,7 +65,7 @@ def _convert_driveitem_to_document(
email=driveitem.last_modified_by.user.email,
)
],
metadata={},
metadata={"drive": drive_name},
)
return doc
@ -71,107 +77,172 @@ class SharepointConnector(LoadConnector, PollConnector):
sites: list[str] = [],
) -> None:
self.batch_size = batch_size
self.graph_client: GraphClient | None = None
self.site_data: list[SiteData] = self._extract_site_and_folder(sites)
self._graph_client: GraphClient | None = None
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
sites
)
self.msal_app: msal.ConfidentialClientApplication | None = None
@property
def graph_client(self) -> GraphClient:
if self._graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
return self._graph_client
@staticmethod
def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]:
def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]:
site_data_list = []
for url in site_urls:
parts = url.strip().split("/")
if "sites" in parts:
sites_index = parts.index("sites")
site_url = "/".join(parts[: sites_index + 2])
folder = (
"/".join(unquote(part) for part in parts[sites_index + 2 :])
if len(parts) > sites_index + 2
else None
)
# Handling for new URL structure
if folder and folder.startswith("Shared Documents/"):
folder = folder[len("Shared Documents/") :]
remaining_parts = parts[sites_index + 2 :]
# Extract drive name and folder path
if remaining_parts:
drive_name = unquote(remaining_parts[0])
folder_path = (
"/".join(unquote(part) for part in remaining_parts[1:])
if len(remaining_parts) > 1
else None
)
else:
drive_name = None
folder_path = None
site_data_list.append(
SiteData(url=site_url, folder=folder, sites=[], driveitems=[])
SiteDescriptor(
url=site_url,
drive_name=drive_name,
folder_path=folder_path,
)
)
return site_data_list
def _populate_sitedata_driveitems(
def _fetch_driveitems(
self,
site_descriptor: SiteDescriptor,
start: datetime | None = None,
end: datetime | None = None,
) -> None:
) -> list[tuple[DriveItem, str]]:
filter_str = ""
if start is not None and end is not None:
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
filter_str = (
f"last_modified_datetime ge {start.isoformat()} and "
f"last_modified_datetime le {end.isoformat()}"
)
for element in self.site_data:
sites: list[Site] = []
for site in element.sites:
site_sublist = site.lists.get().execute_query()
sites.extend(site_sublist)
final_driveitems: list[tuple[DriveItem, str]] = []
try:
site = self.graph_client.sites.get_by_url(site_descriptor.url)
for site in sites:
# Get all drives in the site
drives = site.drives.get().execute_query()
logger.debug(f"Found drives: {[drive.name for drive in drives]}")
# Filter drives based on the requested drive name
if site_descriptor.drive_name:
drives = [
drive
for drive in drives
if drive.name == site_descriptor.drive_name
or (
drive.name == "Documents"
and site_descriptor.drive_name == "Shared Documents"
)
]
if not drives:
logger.warning(f"Drive '{site_descriptor.drive_name}' not found")
return []
# Process each matching drive
for drive in drives:
try:
query = site.drive.root.get_files(True, 1000)
root_folder = drive.root
if site_descriptor.folder_path:
# If a specific folder is requested, navigate to it
for folder_part in site_descriptor.folder_path.split("/"):
root_folder = root_folder.get_by_path(folder_part)
# Get all items recursively
query = root_folder.get_files(True, 1000)
if filter_str:
query = query.filter(filter_str)
driveitems = query.execute_query()
if element.folder:
expected_path = f"/root:/{element.folder}"
logger.debug(
f"Found {len(driveitems)} items in drive '{drive.name}'"
)
# Use "Shared Documents" as the library name for the default "Documents" drive
drive_name = (
"Shared Documents" if drive.name == "Documents" else drive.name
)
if site_descriptor.folder_path:
# Filter items to ensure they're in the specified folder or its subfolders
# The path will be in format: /drives/{drive_id}/root:/folder/path
filtered_driveitems = [
item
(item, drive_name)
for item in driveitems
if item.parent_reference.path.endswith(expected_path)
if any(
path_part == site_descriptor.folder_path
or path_part.startswith(
site_descriptor.folder_path + "/"
)
for path_part in item.parent_reference.path.split(
"root:/"
)[1].split("/")
)
]
if len(filtered_driveitems) == 0:
all_paths = [
item.parent_reference.path for item in driveitems
]
logger.warning(
f"Nothing found for folder '{expected_path}' in any of valid paths: {all_paths}"
f"Nothing found for folder '{site_descriptor.folder_path}' "
f"in; any of valid paths: {all_paths}"
)
element.driveitems.extend(filtered_driveitems)
final_driveitems.extend(filtered_driveitems)
else:
element.driveitems.extend(driveitems)
final_driveitems.extend(
[(item, drive_name) for item in driveitems]
)
except Exception as e:
# Some drives might not be accessible
logger.warning(f"Failed to process drive: {str(e)}")
except Exception:
# Sites include things that do not contain .drive.root so this fails
# but this is fine, as there are no actually documents in those
pass
except Exception as e:
# Sites include things that do not contain drives so this fails
# but this is fine, as there are no actual documents in those
logger.warning(f"Failed to process site: {str(e)}")
def _populate_sitedata_sites(self) -> None:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
return final_driveitems
if self.site_data:
for element in self.site_data:
element.sites = [
self.graph_client.sites.get_by_url(element.url)
.get()
.execute_query()
]
else:
sites = self.graph_client.sites.get_all().execute_query()
self.site_data = [
SiteData(url=None, folder=None, sites=sites, driveitems=[])
]
def _fetch_sites(self) -> list[SiteDescriptor]:
sites = self.graph_client.sites.get_all().execute_query()
site_descriptors = [
SiteDescriptor(
url=sites.resource_url,
drive_name=None,
folder_path=None,
)
]
return site_descriptors
def _fetch_from_sharepoint(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
self._populate_sitedata_sites()
self._populate_sitedata_driveitems(start=start, end=end)
site_descriptors = self.site_descriptors or self._fetch_sites()
# goes over all urls, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
for element in self.site_data:
for driveitem in element.driveitems:
for site_descriptor in site_descriptors:
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
for driveitem, drive_name in driveitems:
logger.debug(f"Processing: {driveitem.web_url}")
doc_batch.append(_convert_driveitem_to_document(driveitem))
doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
if len(doc_batch) >= self.batch_size:
yield doc_batch
@ -202,7 +273,7 @@ class SharepointConnector(LoadConnector, PollConnector):
)
return token
self.graph_client = GraphClient(_acquire_token_func)
self._graph_client = GraphClient(_acquire_token_func)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
@ -211,19 +282,19 @@ class SharepointConnector(LoadConnector, PollConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
start_datetime = datetime.fromtimestamp(start, timezone.utc)
end_datetime = datetime.fromtimestamp(end, timezone.utc)
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
if __name__ == "__main__":
connector = SharepointConnector(sites=os.environ["SITES"].split(","))
connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(","))
connector.load_credentials(
{
"sp_client_id": os.environ["SP_CLIENT_ID"],
"sp_client_secret": os.environ["SP_CLIENT_SECRET"],
"sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"],
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
}
)
document_batches = connector.load_from_state()

View File

@ -537,30 +537,36 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
# Let the tag flow handle this case, don't reply twice
return False
if event.get("bot_profile"):
# Check if this is a bot message (either via bot_profile or bot_message subtype)
is_bot_message = bool(
event.get("bot_profile") or event.get("subtype") == "bot_message"
)
if is_bot_message:
channel_name, _ = get_channel_name_from_id(
client=client.web_client, channel_id=channel
)
with get_session_with_tenant(client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
if (not bot_tag_id or bot_tag_id not in msg) and (
not slack_channel_config
or not slack_channel_config.channel_config.get("respond_to_bots")
):
channel_specific_logger.info("Ignoring message from bot")
channel_specific_logger.info(
"Ignoring message from bot since respond_to_bots is disabled"
)
return False
# Ignore things like channel_join, channel_leave, etc.
# NOTE: "file_share" is just a message with a file attachment, so we
# should not ignore it
message_subtype = event.get("subtype")
if message_subtype not in [None, "file_share"]:
if message_subtype not in [None, "file_share", "bot_message"]:
channel_specific_logger.info(
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
)

View File

@ -1,8 +1,8 @@
import os
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.airtable.airtable_connector import AirtableConnector
@ -10,25 +10,24 @@ from onyx.connectors.models import Document
from onyx.connectors.models import Section
@pytest.fixture(
params=[
("table_name", os.environ["AIRTABLE_TEST_TABLE_NAME"]),
("table_id", os.environ["AIRTABLE_TEST_TABLE_ID"]),
]
)
def airtable_connector(request: pytest.FixtureRequest) -> AirtableConnector:
param_type, table_identifier = request.param
connector = AirtableConnector(
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
table_name_or_id=table_identifier,
)
class AirtableConfig(BaseModel):
base_id: str
table_identifier: str
access_token: str
connector.load_credentials(
{
"airtable_access_token": os.environ["AIRTABLE_ACCESS_TOKEN"],
}
@pytest.fixture(params=[True, False])
def airtable_config(request: pytest.FixtureRequest) -> AirtableConfig:
table_identifier = (
os.environ["AIRTABLE_TEST_TABLE_NAME"]
if request.param
else os.environ["AIRTABLE_TEST_TABLE_ID"]
)
return AirtableConfig(
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
table_identifier=table_identifier,
access_token=os.environ["AIRTABLE_ACCESS_TOKEN"],
)
return connector
def create_test_document(
@ -46,18 +45,37 @@ def create_test_document(
assignee: str,
days_since_status_change: int | None,
attachments: list[tuple[str, str]] | None = None,
all_fields_as_metadata: bool = False,
) -> Document:
link_base = f"https://airtable.com/{os.environ['AIRTABLE_TEST_BASE_ID']}/{os.environ['AIRTABLE_TEST_TABLE_ID']}"
sections = [
Section(
text=f"Title:\n------------------------\n{title}\n------------------------",
link=f"{link_base}/{id}",
),
Section(
text=f"Description:\n------------------------\n{description}\n------------------------",
link=f"{link_base}/{id}",
),
]
base_id = os.environ.get("AIRTABLE_TEST_BASE_ID")
table_id = os.environ.get("AIRTABLE_TEST_TABLE_ID")
missing_vars = []
if not base_id:
missing_vars.append("AIRTABLE_TEST_BASE_ID")
if not table_id:
missing_vars.append("AIRTABLE_TEST_TABLE_ID")
if missing_vars:
raise RuntimeError(
f"Required environment variables not set: {', '.join(missing_vars)}. "
"These variables are required to run Airtable connector tests."
)
link_base = f"https://airtable.com/{base_id}/{table_id}"
sections = []
if not all_fields_as_metadata:
sections.extend(
[
Section(
text=f"Title:\n------------------------\n{title}\n------------------------",
link=f"{link_base}/{id}",
),
Section(
text=f"Description:\n------------------------\n{description}\n------------------------",
link=f"{link_base}/{id}",
),
]
)
if attachments:
for attachment_text, attachment_link in attachments:
@ -68,26 +86,36 @@ def create_test_document(
),
)
metadata: dict[str, str | list[str]] = {
# "Category": category,
"Assignee": assignee,
"Submitted by": submitted_by,
"Priority": priority,
"Status": status,
"Created time": created_time,
"ID": ticket_id,
"Status last changed": status_last_changed,
**(
{"Days since status change": str(days_since_status_change)}
if days_since_status_change is not None
else {}
),
}
if all_fields_as_metadata:
metadata.update(
{
"Title": title,
"Description": description,
}
)
return Document(
id=f"airtable__{id}",
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=f"{os.environ['AIRTABLE_TEST_TABLE_NAME']}: {title}",
metadata={
# "Category": category,
"Assignee": assignee,
"Submitted by": submitted_by,
"Priority": priority,
"Status": status,
"Created time": created_time,
"ID": ticket_id,
"Status last changed": status_last_changed,
**(
{"Days since status change": str(days_since_status_change)}
if days_since_status_change is not None
else {}
),
},
semantic_identifier=f"{os.environ.get('AIRTABLE_TEST_TABLE_NAME', '')}: {title}",
metadata=metadata,
doc_updated_at=None,
primary_owners=None,
secondary_owners=None,
@ -97,15 +125,75 @@ def create_test_document(
)
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_airtable_connector_basic(
mock_get_api_key: MagicMock, airtable_connector: AirtableConnector
def compare_documents(
actual_docs: list[Document], expected_docs: list[Document]
) -> None:
doc_batch_generator = airtable_connector.load_from_state()
"""Utility function to compare actual and expected documents, ignoring order."""
actual_docs_dict = {doc.id: doc for doc in actual_docs}
expected_docs_dict = {doc.id: doc for doc in expected_docs}
assert actual_docs_dict.keys() == expected_docs_dict.keys(), "Document ID mismatch"
for doc_id in actual_docs_dict:
actual = actual_docs_dict[doc_id]
expected = expected_docs_dict[doc_id]
assert (
actual.source == expected.source
), f"Source mismatch for document {doc_id}"
assert (
actual.semantic_identifier == expected.semantic_identifier
), f"Semantic identifier mismatch for document {doc_id}"
assert (
actual.metadata == expected.metadata
), f"Metadata mismatch for document {doc_id}"
assert (
actual.doc_updated_at == expected.doc_updated_at
), f"Updated at mismatch for document {doc_id}"
assert (
actual.primary_owners == expected.primary_owners
), f"Primary owners mismatch for document {doc_id}"
assert (
actual.secondary_owners == expected.secondary_owners
), f"Secondary owners mismatch for document {doc_id}"
assert actual.title == expected.title, f"Title mismatch for document {doc_id}"
assert (
actual.from_ingestion_api == expected.from_ingestion_api
), f"Ingestion API flag mismatch for document {doc_id}"
assert (
actual.additional_info == expected.additional_info
), f"Additional info mismatch for document {doc_id}"
# Compare sections
assert len(actual.sections) == len(
expected.sections
), f"Number of sections mismatch for document {doc_id}"
for i, (actual_section, expected_section) in enumerate(
zip(actual.sections, expected.sections)
):
assert (
actual_section.text == expected_section.text
), f"Section {i} text mismatch for document {doc_id}"
assert (
actual_section.link == expected_section.link
), f"Section {i} link mismatch for document {doc_id}"
def test_airtable_connector_basic(
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
) -> None:
"""Test behavior when all non-attachment fields are treated as metadata."""
connector = AirtableConnector(
base_id=airtable_config.base_id,
table_name_or_id=airtable_config.table_identifier,
treat_all_non_attachment_fields_as_metadata=False,
)
connector.load_credentials(
{
"airtable_access_token": airtable_config.access_token,
}
)
doc_batch_generator = connector.load_from_state()
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
@ -119,15 +207,62 @@ def test_airtable_connector_basic(
description="The internet connection is very slow.",
priority="Medium",
status="In Progress",
# Link to another record is skipped for now
# category="Data Science",
ticket_id="2",
created_time="2024-12-24T21:02:49.000Z",
status_last_changed="2024-12-24T21:02:49.000Z",
days_since_status_change=0,
assignee="Chris Weaver (chris@onyx.app)",
submitted_by="Chris Weaver (chris@onyx.app)",
all_fields_as_metadata=False,
),
create_test_document(
id="reccSlIA4pZEFxPBg",
title="Printer Issue",
description="The office printer is not working.",
priority="High",
status="Open",
ticket_id="1",
created_time="2024-12-24T21:02:49.000Z",
status_last_changed="2024-12-24T21:02:49.000Z",
days_since_status_change=0,
assignee="Chris Weaver (chris@onyx.app)",
submitted_by="Chris Weaver (chris@onyx.app)",
attachments=[
(
"Test.pdf:\ntesting!!!",
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
)
],
all_fields_as_metadata=False,
),
]
# Compare documents using the utility function
compare_documents(doc_batch, expected_docs)
def test_airtable_connector_all_metadata(
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
) -> None:
connector = AirtableConnector(
base_id=airtable_config.base_id,
table_name_or_id=airtable_config.table_identifier,
treat_all_non_attachment_fields_as_metadata=True,
)
connector.load_credentials(
{
"airtable_access_token": airtable_config.access_token,
}
)
doc_batch_generator = connector.load_from_state()
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
# NOTE: one of the rows has no attachments -> no content -> no document
assert len(doc_batch) == 1
expected_docs = [
create_test_document(
id="reccSlIA4pZEFxPBg",
title="Printer Issue",
@ -149,50 +284,9 @@ def test_airtable_connector_basic(
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
)
],
all_fields_as_metadata=True,
),
]
# Compare each document field by field
for actual, expected in zip(doc_batch, expected_docs):
assert actual.id == expected.id, f"ID mismatch for document {actual.id}"
assert (
actual.source == expected.source
), f"Source mismatch for document {actual.id}"
assert (
actual.semantic_identifier == expected.semantic_identifier
), f"Semantic identifier mismatch for document {actual.id}"
assert (
actual.metadata == expected.metadata
), f"Metadata mismatch for document {actual.id}"
assert (
actual.doc_updated_at == expected.doc_updated_at
), f"Updated at mismatch for document {actual.id}"
assert (
actual.primary_owners == expected.primary_owners
), f"Primary owners mismatch for document {actual.id}"
assert (
actual.secondary_owners == expected.secondary_owners
), f"Secondary owners mismatch for document {actual.id}"
assert (
actual.title == expected.title
), f"Title mismatch for document {actual.id}"
assert (
actual.from_ingestion_api == expected.from_ingestion_api
), f"Ingestion API flag mismatch for document {actual.id}"
assert (
actual.additional_info == expected.additional_info
), f"Additional info mismatch for document {actual.id}"
# Compare sections
assert len(actual.sections) == len(
expected.sections
), f"Number of sections mismatch for document {actual.id}"
for i, (actual_section, expected_section) in enumerate(
zip(actual.sections, expected.sections)
):
assert (
actual_section.text == expected_section.text
), f"Section {i} text mismatch for document {actual.id}"
assert (
actual_section.link == expected_section.link
), f"Section {i} link mismatch for document {actual.id}"
# Compare documents using the utility function
compare_documents(doc_batch, expected_docs)

View File

@ -0,0 +1,14 @@
from collections.abc import Generator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
@pytest.fixture
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
with patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
) as mock:
yield mock

View File

@ -0,0 +1,178 @@
import os
from dataclasses import dataclass
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.sharepoint.connector import SharepointConnector
@dataclass
class ExpectedDocument:
semantic_identifier: str
content: str
folder_path: str | None = None
library: str = "Shared Documents" # Default to main library
EXPECTED_DOCUMENTS = [
ExpectedDocument(
semantic_identifier="test1.docx",
content="test1",
folder_path="test",
),
ExpectedDocument(
semantic_identifier="test2.docx",
content="test2",
folder_path="test/nested with spaces",
),
ExpectedDocument(
semantic_identifier="should-not-index-on-specific-folder.docx",
content="should-not-index-on-specific-folder",
folder_path=None, # root folder
),
ExpectedDocument(
semantic_identifier="other.docx",
content="other",
folder_path=None,
library="Other Library",
),
]
def verify_document_metadata(doc: Document) -> None:
"""Verify common metadata that should be present on all documents."""
assert isinstance(doc.doc_updated_at, datetime)
assert doc.doc_updated_at.tzinfo == timezone.utc
assert doc.source == DocumentSource.SHAREPOINT
assert doc.primary_owners is not None
assert len(doc.primary_owners) == 1
owner = doc.primary_owners[0]
assert owner.display_name is not None
assert owner.email is not None
def verify_document_content(doc: Document, expected: ExpectedDocument) -> None:
"""Verify a document matches its expected content."""
assert doc.semantic_identifier == expected.semantic_identifier
assert len(doc.sections) == 1
assert expected.content in doc.sections[0].text
verify_document_metadata(doc)
def find_document(documents: list[Document], semantic_identifier: str) -> Document:
"""Find a document by its semantic identifier."""
matching_docs = [
d for d in documents if d.semantic_identifier == semantic_identifier
]
assert (
len(matching_docs) == 1
), f"Expected exactly one document with identifier {semantic_identifier}"
return matching_docs[0]
@pytest.fixture
def sharepoint_credentials() -> dict[str, str]:
return {
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
}
def test_sharepoint_connector_specific_folder(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the test site URL and specific folder
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Should only find documents in the test folder
test_folder_docs = [
doc
for doc in EXPECTED_DOCUMENTS
if doc.folder_path and doc.folder_path.startswith("test")
]
assert len(found_documents) == len(
test_folder_docs
), "Should only find documents in test folder"
# Verify each expected document
for expected in test_folder_docs:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_root_folder(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]])
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
assert len(found_documents) == len(
EXPECTED_DOCUMENTS
), "Should find all documents in main library"
# Verify each expected document
for expected in EXPECTED_DOCUMENTS:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_other_library(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the other library
connector = SharepointConnector(
sites=[
os.environ["SHAREPOINT_SITE"] + "/Other Library",
]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
expected_documents: list[ExpectedDocument] = [
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
]
# Should find all documents in `Other Library`
assert len(found_documents) == len(
expected_documents
), "Should find all documents in `Other Library`"
# Verify each expected document
for expected in expected_documents:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)

View File

@ -76,13 +76,7 @@ import {
import { buildFilters } from "@/lib/search/utils";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
import {
checkLLMSupportsImageInput,
getFinalLLM,
destructureValue,
getLLMProviderOverrideForPersona,
} from "@/lib/llm/utils";
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar";
import { useChatContext } from "@/components/context/ChatContext";
import { v4 as uuidv4 } from "uuid";
@ -203,6 +197,12 @@ export function ChatPage({
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
const selectedChatSession = chatSessions.find(
(chatSession) => chatSession.id === existingChatSessionId
);
useEffect(() => {
if (user?.is_anonymous_user) {
Cookies.set(
@ -240,12 +240,6 @@ export function ChatPage({
}
};
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
const selectedChatSession = chatSessions.find(
(chatSession) => chatSession.id === existingChatSessionId
);
const chatSessionIdRef = useRef<string | null>(existingChatSessionId);
// Only updates on session load (ie. rename / switching chat session)
@ -293,12 +287,6 @@ export function ChatPage({
);
};
const llmOverrideManager = useLlmOverride(
llmProviders,
user?.preferences.default_model,
selectedChatSession
);
const [alternativeAssistant, setAlternativeAssistant] =
useState<Persona | null>(null);
@ -307,12 +295,27 @@ export function ChatPage({
const { recentAssistants, refreshRecentAssistants } = useAssistants();
const liveAssistant: Persona | undefined =
alternativeAssistant ||
selectedAssistant ||
recentAssistants[0] ||
finalAssistants[0] ||
availableAssistants[0];
const liveAssistant: Persona | undefined = useMemo(
() =>
alternativeAssistant ||
selectedAssistant ||
recentAssistants[0] ||
finalAssistants[0] ||
availableAssistants[0],
[
alternativeAssistant,
selectedAssistant,
recentAssistants,
finalAssistants,
availableAssistants,
]
);
const llmOverrideManager = useLlmOverride(
llmProviders,
selectedChatSession,
liveAssistant
);
const noAssistants = liveAssistant == null || liveAssistant == undefined;
@ -320,24 +323,6 @@ export function ChatPage({
const uniqueSources = Array.from(new Set(availableSources));
const sources = uniqueSources.map((source) => getSourceMetadata(source));
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
useEffect(() => {
if (noAssistants) return;
const personaDefault = getLLMProviderOverrideForPersona(
liveAssistant,
llmProviders
);
if (personaDefault) {
llmOverrideManager.updateLLMOverride(personaDefault);
} else if (user?.preferences.default_model) {
llmOverrideManager.updateLLMOverride(
destructureValue(user?.preferences.default_model)
);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [liveAssistant, user?.preferences.default_model]);
const stopGenerating = () => {
const currentSession = currentSessionId();
const controller = abortControllers.get(currentSession);
@ -419,7 +404,6 @@ export function ChatPage({
filterManager.setTimeRange(null);
// reset LLM overrides (based on chat session!)
llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
llmOverrideManager.updateTemperature(null);
// remove uploaded files
@ -1283,13 +1267,11 @@ export function ChatPage({
modelProvider:
modelOverRide?.name ||
llmOverrideManager.llmOverride.name ||
llmOverrideManager.globalDefault.name ||
undefined,
modelVersion:
modelOverRide?.modelName ||
llmOverrideManager.llmOverride.modelName ||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
llmOverrideManager.globalDefault.modelName ||
undefined,
temperature: llmOverrideManager.temperature || undefined,
systemPromptOverride:
@ -1952,6 +1934,7 @@ export function ChatPage({
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [router]);
const [sharedChatSession, setSharedChatSession] =
useState<ChatSession | null>();
@ -2059,7 +2042,9 @@ export function ChatPage({
{(settingsToggled || userSettingsToggled) && (
<UserSettingsModal
setPopup={setPopup}
setLlmOverride={llmOverrideManager.setGlobalDefault}
setLlmOverride={(newOverride) =>
llmOverrideManager.updateLLMOverride(newOverride)
}
defaultModel={user?.preferences.default_model!}
llmProviders={llmProviders}
onClose={() => {
@ -2749,6 +2734,7 @@ export function ChatPage({
</button>
</div>
)}
<ChatInputBar
toggleDocumentSidebar={toggleDocumentSidebar}
availableSources={sources}

View File

@ -40,8 +40,8 @@ export default function LLMPopover({
currentAssistant,
}: LLMPopoverProps) {
const [isOpen, setIsOpen] = useState(false);
const { llmOverride, updateLLMOverride, globalDefault } = llmOverrideManager;
const currentLlm = llmOverride.modelName || globalDefault.modelName;
const { llmOverride, updateLLMOverride } = llmOverrideManager;
const currentLlm = llmOverride.modelName;
const llmOptionsByProvider: {
[provider: string]: {

View File

@ -1,13 +1,5 @@
import {
Dispatch,
SetStateAction,
useContext,
useEffect,
useRef,
useState,
} from "react";
import { useContext, useEffect, useRef } from "react";
import { Modal } from "@/components/Modal";
import Text from "@/components/ui/text";
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
@ -33,7 +25,7 @@ export function UserSettingsModal({
}: {
setPopup: (popupSpec: PopupSpec | null) => void;
llmProviders: LLMProviderDescriptor[];
setLlmOverride?: Dispatch<SetStateAction<LlmOverride>>;
setLlmOverride?: (newOverride: LlmOverride) => void;
onClose: () => void;
defaultModel: string | null;
}) {

View File

@ -1,3 +1,3 @@
export const SEARCH_TOOL_NAME = "SearchTool";
export const SEARCH_TOOL_NAME = "run_search";
export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search";
export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation";

View File

@ -1106,6 +1106,14 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
name: "table_name_or_id",
optional: false,
},
{
type: "checkbox",
label: "Treat all fields except attachments as metadata",
name: "treat_all_non_attachment_fields_as_metadata",
description:
"Choose this if the primary content to index are attachments and all other columns are metadata for these attachments.",
optional: false,
},
],
advanced_values: [],
overrideDefaultFreq: 60 * 60 * 24,

View File

@ -13,16 +13,21 @@ import { errorHandlingFetcher } from "./fetcher";
import { useContext, useEffect, useState } from "react";
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
import { Filters, SourceMetadata } from "./search/interfaces";
import { destructureValue, structureValue } from "./llm/utils";
import {
destructureValue,
findProviderForModel,
structureValue,
} from "./llm/utils";
import { ChatSession } from "@/app/chat/interfaces";
import { AllUsersResponse } from "./types";
import { Credential } from "./connectors/credentials";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { PersonaLabel } from "@/app/admin/assistants/interfaces";
import { Persona, PersonaLabel } from "@/app/admin/assistants/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
import { getSourceMetadata } from "./sources";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
import { useUser } from "@/components/user/UserProvider";
const CREDENTIAL_URL = "/api/manage/admin/credential";
@ -355,82 +360,141 @@ export interface LlmOverride {
export interface LlmOverrideManager {
llmOverride: LlmOverride;
updateLLMOverride: (newOverride: LlmOverride) => void;
globalDefault: LlmOverride;
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
temperature: number | null;
updateTemperature: (temperature: number | null) => void;
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
imageFilesPresent: boolean;
updateImageFilesPresent: (present: boolean) => void;
liveAssistant: Persona | null;
}
/*
LLM Override is as follows (i.e. this order)
- User override (explicitly set in the chat input bar)
- User preference (defaults to system wide default if no preference set)
On switching to an existing or new chat session or a different assistant:
- If we have a live assistant after any switch with a model override, use that- otherwise use the above hierarchy
Thus, the input should be
- User preference
- LLM Providers (which contain the system wide default)
- Current assistant
Changes take place as
- liveAssistant or currentChatSession changes (and the associated model override is set)
- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
If we have a live assistant, we should use that model override
*/
export function useLlmOverride(
llmProviders: LLMProviderDescriptor[],
globalModel?: string | null,
currentChatSession?: ChatSession,
defaultTemperature?: number
liveAssistant?: Persona
): LlmOverrideManager {
const { user } = useUser();
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
const llmOverrideUpdate = () => {
if (!chatSession && currentChatSession) {
setChatSession(currentChatSession || null);
return;
}
if (liveAssistant?.llm_model_version_override) {
setLlmOverride(
getValidLlmOverride(liveAssistant.llm_model_version_override)
);
} else if (currentChatSession?.current_alternate_model) {
setLlmOverride(
getValidLlmOverride(currentChatSession.current_alternate_model)
);
} else if (user?.preferences?.default_model) {
setLlmOverride(getValidLlmOverride(user.preferences.default_model));
return;
} else {
const defaultProvider = llmProviders.find(
(provider) => provider.is_default_provider
);
if (defaultProvider) {
setLlmOverride({
name: defaultProvider.name,
provider: defaultProvider.provider,
modelName: defaultProvider.default_model_name,
});
}
}
setChatSession(currentChatSession || null);
};
const getValidLlmOverride = (
overrideModel: string | null | undefined
): LlmOverride => {
if (overrideModel) {
const model = destructureValue(overrideModel);
const provider = llmProviders.find(
(p) =>
p.model_names.includes(model.modelName) &&
p.provider === model.provider
if (!(model.modelName && model.modelName.length > 0)) {
const provider = llmProviders.find((p) =>
p.model_names.includes(overrideModel)
);
if (provider) {
return {
modelName: overrideModel,
name: provider.name,
provider: provider.provider,
};
}
}
const provider = llmProviders.find((p) =>
p.model_names.includes(model.modelName)
);
if (provider) {
return { ...model, name: provider.name };
}
}
return { name: "", provider: "", modelName: "" };
};
const [imageFilesPresent, setImageFilesPresent] = useState(false);
const updateImageFilesPresent = (present: boolean) => {
setImageFilesPresent(present);
};
const [globalDefault, setGlobalDefault] = useState<LlmOverride>(
getValidLlmOverride(globalModel)
);
const updateLLMOverride = (newOverride: LlmOverride) => {
setLlmOverride(
getValidLlmOverride(
structureValue(
newOverride.name,
newOverride.provider,
newOverride.modelName
)
)
);
};
const [llmOverride, setLlmOverride] = useState<LlmOverride>({
name: "",
provider: "",
modelName: "",
});
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
currentChatSession && currentChatSession.current_alternate_model
? getValidLlmOverride(currentChatSession.current_alternate_model)
: { name: "", provider: "", modelName: "" }
);
// Manually set the override
const updateLLMOverride = (newOverride: LlmOverride) => {
const provider =
newOverride.provider ||
findProviderForModel(llmProviders, newOverride.modelName);
const structuredValue = structureValue(
newOverride.name,
provider,
newOverride.modelName
);
setLlmOverride(getValidLlmOverride(structuredValue));
};
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
setLlmOverride(
chatSession && chatSession.current_alternate_model
? getValidLlmOverride(chatSession.current_alternate_model)
: globalDefault
);
if (chatSession && chatSession.current_alternate_model?.length > 0) {
setLlmOverride(getValidLlmOverride(chatSession.current_alternate_model));
}
};
const [temperature, setTemperature] = useState<number | null>(
defaultTemperature !== undefined ? defaultTemperature : 0
);
const [temperature, setTemperature] = useState<number | null>(0);
useEffect(() => {
setGlobalDefault(getValidLlmOverride(globalModel));
}, [globalModel, llmProviders]);
useEffect(() => {
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
}, [defaultTemperature]);
llmOverrideUpdate();
}, [liveAssistant, currentChatSession]);
useEffect(() => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
@ -450,12 +514,11 @@ export function useLlmOverride(
updateModelOverrideForChatSession,
llmOverride,
updateLLMOverride,
globalDefault,
setGlobalDefault,
temperature,
updateTemperature,
imageFilesPresent,
updateImageFilesPresent,
liveAssistant: liveAssistant ?? null,
};
}

View File

@ -143,3 +143,11 @@ export const destructureValue = (value: string): LlmOverride => {
modelName,
};
};
export const findProviderForModel = (
llmProviders: LLMProviderDescriptor[],
modelName: string
): string => {
const provider = llmProviders.find((p) => p.model_names.includes(modelName));
return provider ? provider.provider : "";
};

View File

@ -358,7 +358,8 @@ export type ConfigurableSources = Exclude<
export const oauthSupportedSources: ConfigurableSources[] = [
ValidSources.Slack,
ValidSources.GoogleDrive,
// NOTE: temporarily disabled until our GDrive App is approved
// ValidSources.GoogleDrive,
];
export type OAuthSupportedSource = (typeof oauthSupportedSources)[number];