Merge branch 'main' of https://github.com/onyx-dot-app/onyx into bugfix/sharepoint_app_init

# Conflicts:
#	backend/onyx/connectors/sharepoint/connector.py
This commit is contained in:
Richard Kuo (Danswer)
2025-01-28 21:11:13 -08:00
15 changed files with 731 additions and 289 deletions

View File

@@ -39,6 +39,12 @@ env:
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }} AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }} AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }} AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
jobs: jobs:
connectors-check: connectors-check:
# See https://runs-on.com/runners/linux/ # See https://runs-on.com/runners/linux/

View File

@@ -20,9 +20,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
# NOTE: all are made lowercase to avoid case sensitivity issues # NOTE: all are made lowercase to avoid case sensitivity issues
# these are the field types that are considered metadata rather # These field types are considered metadata by default when
# than sections # treat_all_non_attachment_fields_as_metadata is False
_METADATA_FIELD_TYPES = { DEFAULT_METADATA_FIELD_TYPES = {
"singlecollaborator", "singlecollaborator",
"collaborator", "collaborator",
"createdby", "createdby",
@@ -60,12 +60,16 @@ class AirtableConnector(LoadConnector):
self, self,
base_id: str, base_id: str,
table_name_or_id: str, table_name_or_id: str,
treat_all_non_attachment_fields_as_metadata: bool = False,
batch_size: int = INDEX_BATCH_SIZE, batch_size: int = INDEX_BATCH_SIZE,
) -> None: ) -> None:
self.base_id = base_id self.base_id = base_id
self.table_name_or_id = table_name_or_id self.table_name_or_id = table_name_or_id
self.batch_size = batch_size self.batch_size = batch_size
self.airtable_client: AirtableApi | None = None self.airtable_client: AirtableApi | None = None
self.treat_all_non_attachment_fields_as_metadata = (
treat_all_non_attachment_fields_as_metadata
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.airtable_client = AirtableApi(credentials["airtable_access_token"]) self.airtable_client = AirtableApi(credentials["airtable_access_token"])
@@ -166,8 +170,14 @@ class AirtableConnector(LoadConnector):
return [(str(field_info), default_link)] return [(str(field_info), default_link)]
def _should_be_metadata(self, field_type: str) -> bool: def _should_be_metadata(self, field_type: str) -> bool:
"""Determine if a field type should be treated as metadata.""" """Determine if a field type should be treated as metadata.
return field_type.lower() in _METADATA_FIELD_TYPES
When treat_all_non_attachment_fields_as_metadata is True, all fields except
attachments are treated as metadata. Otherwise, only fields with types listed
in DEFAULT_METADATA_FIELD_TYPES are treated as metadata."""
if self.treat_all_non_attachment_fields_as_metadata:
return field_type.lower() != "multipleattachments"
return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES
def _process_field( def _process_field(
self, self,
@@ -233,7 +243,7 @@ class AirtableConnector(LoadConnector):
record: RecordDict, record: RecordDict,
table_schema: TableSchema, table_schema: TableSchema,
primary_field_name: str | None, primary_field_name: str | None,
) -> Document: ) -> Document | None:
"""Process a single Airtable record into a Document. """Process a single Airtable record into a Document.
Args: Args:
@@ -277,6 +287,10 @@ class AirtableConnector(LoadConnector):
sections.extend(field_sections) sections.extend(field_sections)
metadata.update(field_metadata) metadata.update(field_metadata)
if not sections:
logger.warning(f"No sections found for record {record_id}")
return None
semantic_id = ( semantic_id = (
f"{table_name}: {primary_field_value}" f"{table_name}: {primary_field_value}"
if primary_field_value if primary_field_value
@@ -320,7 +334,8 @@ class AirtableConnector(LoadConnector):
table_schema=table_schema, table_schema=table_schema,
primary_field_name=primary_field_name, primary_field_name=primary_field_name,
) )
record_documents.append(document) if document:
record_documents.append(document)
if len(record_documents) >= self.batch_size: if len(record_documents) >= self.batch_size:
yield record_documents yield record_documents

View File

@@ -1,17 +1,14 @@
import io import io
import os import os
from dataclasses import dataclass
from dataclasses import field
from datetime import datetime from datetime import datetime
from datetime import timezone from datetime import timezone
from typing import Any from typing import Any
from typing import Optional
from urllib.parse import unquote from urllib.parse import unquote
import msal # type: ignore import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore from office365.graph_client import GraphClient # type: ignore
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
from office365.onedrive.sites.site import Site # type: ignore from pydantic import BaseModel
from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource from onyx.configs.constants import DocumentSource
@@ -30,16 +27,25 @@ from onyx.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@dataclass class SiteDescriptor(BaseModel):
class SiteData: """Data class for storing SharePoint site information.
url: str | None
folder: Optional[str] Args:
sites: list = field(default_factory=list) url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
driveitems: list = field(default_factory=list) drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
If None, all drives will be accessed.
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
If None, all folders will be accessed.
"""
url: str
drive_name: str | None
folder_path: str | None
def _convert_driveitem_to_document( def _convert_driveitem_to_document(
driveitem: DriveItem, driveitem: DriveItem,
drive_name: str,
) -> Document: ) -> Document:
file_text = extract_file_text( file_text = extract_file_text(
file=io.BytesIO(driveitem.get_content().execute_query().value), file=io.BytesIO(driveitem.get_content().execute_query().value),
@@ -59,7 +65,7 @@ def _convert_driveitem_to_document(
email=driveitem.last_modified_by.user.email, email=driveitem.last_modified_by.user.email,
) )
], ],
metadata={}, metadata={"drive": drive_name},
) )
return doc return doc
@@ -71,107 +77,172 @@ class SharepointConnector(LoadConnector, PollConnector):
sites: list[str] = [], sites: list[str] = [],
) -> None: ) -> None:
self.batch_size = batch_size self.batch_size = batch_size
self.graph_client: GraphClient | None = None self._graph_client: GraphClient | None = None
self.site_data: list[SiteData] = self._extract_site_and_folder(sites) self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
sites
)
self.msal_app: msal.ConfidentialClientApplication | None = None self.msal_app: msal.ConfidentialClientApplication | None = None
@property
def graph_client(self) -> GraphClient:
if self._graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
return self._graph_client
@staticmethod @staticmethod
def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]: def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]:
site_data_list = [] site_data_list = []
for url in site_urls: for url in site_urls:
parts = url.strip().split("/") parts = url.strip().split("/")
if "sites" in parts: if "sites" in parts:
sites_index = parts.index("sites") sites_index = parts.index("sites")
site_url = "/".join(parts[: sites_index + 2]) site_url = "/".join(parts[: sites_index + 2])
folder = ( remaining_parts = parts[sites_index + 2 :]
"/".join(unquote(part) for part in parts[sites_index + 2 :])
if len(parts) > sites_index + 2 # Extract drive name and folder path
else None if remaining_parts:
) drive_name = unquote(remaining_parts[0])
# Handling for new URL structure folder_path = (
if folder and folder.startswith("Shared Documents/"): "/".join(unquote(part) for part in remaining_parts[1:])
folder = folder[len("Shared Documents/") :] if len(remaining_parts) > 1
else None
)
else:
drive_name = None
folder_path = None
site_data_list.append( site_data_list.append(
SiteData(url=site_url, folder=folder, sites=[], driveitems=[]) SiteDescriptor(
url=site_url,
drive_name=drive_name,
folder_path=folder_path,
)
) )
return site_data_list return site_data_list
def _populate_sitedata_driveitems( def _fetch_driveitems(
self, self,
site_descriptor: SiteDescriptor,
start: datetime | None = None, start: datetime | None = None,
end: datetime | None = None, end: datetime | None = None,
) -> None: ) -> list[tuple[DriveItem, str]]:
filter_str = "" filter_str = ""
if start is not None and end is not None: if start is not None and end is not None:
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}" filter_str = (
f"last_modified_datetime ge {start.isoformat()} and "
f"last_modified_datetime le {end.isoformat()}"
)
for element in self.site_data: final_driveitems: list[tuple[DriveItem, str]] = []
sites: list[Site] = [] try:
for site in element.sites: site = self.graph_client.sites.get_by_url(site_descriptor.url)
site_sublist = site.lists.get().execute_query()
sites.extend(site_sublist)
for site in sites: # Get all drives in the site
drives = site.drives.get().execute_query()
logger.debug(f"Found drives: {[drive.name for drive in drives]}")
# Filter drives based on the requested drive name
if site_descriptor.drive_name:
drives = [
drive
for drive in drives
if drive.name == site_descriptor.drive_name
or (
drive.name == "Documents"
and site_descriptor.drive_name == "Shared Documents"
)
]
if not drives:
logger.warning(f"Drive '{site_descriptor.drive_name}' not found")
return []
# Process each matching drive
for drive in drives:
try: try:
query = site.drive.root.get_files(True, 1000) root_folder = drive.root
if site_descriptor.folder_path:
# If a specific folder is requested, navigate to it
for folder_part in site_descriptor.folder_path.split("/"):
root_folder = root_folder.get_by_path(folder_part)
# Get all items recursively
query = root_folder.get_files(True, 1000)
if filter_str: if filter_str:
query = query.filter(filter_str) query = query.filter(filter_str)
driveitems = query.execute_query() driveitems = query.execute_query()
if element.folder: logger.debug(
expected_path = f"/root:/{element.folder}" f"Found {len(driveitems)} items in drive '{drive.name}'"
)
# Use "Shared Documents" as the library name for the default "Documents" drive
drive_name = (
"Shared Documents" if drive.name == "Documents" else drive.name
)
if site_descriptor.folder_path:
# Filter items to ensure they're in the specified folder or its subfolders
# The path will be in format: /drives/{drive_id}/root:/folder/path
filtered_driveitems = [ filtered_driveitems = [
item (item, drive_name)
for item in driveitems for item in driveitems
if item.parent_reference.path.endswith(expected_path) if any(
path_part == site_descriptor.folder_path
or path_part.startswith(
site_descriptor.folder_path + "/"
)
for path_part in item.parent_reference.path.split(
"root:/"
)[1].split("/")
)
] ]
if len(filtered_driveitems) == 0: if len(filtered_driveitems) == 0:
all_paths = [ all_paths = [
item.parent_reference.path for item in driveitems item.parent_reference.path for item in driveitems
] ]
logger.warning( logger.warning(
f"Nothing found for folder '{expected_path}' in any of valid paths: {all_paths}" f"Nothing found for folder '{site_descriptor.folder_path}' "
f"in; any of valid paths: {all_paths}"
) )
element.driveitems.extend(filtered_driveitems) final_driveitems.extend(filtered_driveitems)
else: else:
element.driveitems.extend(driveitems) final_driveitems.extend(
[(item, drive_name) for item in driveitems]
)
except Exception as e:
# Some drives might not be accessible
logger.warning(f"Failed to process drive: {str(e)}")
except Exception: except Exception as e:
# Sites include things that do not contain .drive.root so this fails # Sites include things that do not contain drives so this fails
# but this is fine, as there are no actually documents in those # but this is fine, as there are no actual documents in those
pass logger.warning(f"Failed to process site: {str(e)}")
def _populate_sitedata_sites(self) -> None: return final_driveitems
if self.graph_client is None:
raise ConnectorMissingCredentialError("Sharepoint")
if self.site_data: def _fetch_sites(self) -> list[SiteDescriptor]:
for element in self.site_data: sites = self.graph_client.sites.get_all().execute_query()
element.sites = [ site_descriptors = [
self.graph_client.sites.get_by_url(element.url) SiteDescriptor(
.get() url=sites.resource_url,
.execute_query() drive_name=None,
] folder_path=None,
else: )
sites = self.graph_client.sites.get_all().execute_query() ]
self.site_data = [ return site_descriptors
SiteData(url=None, folder=None, sites=sites, driveitems=[])
]
def _fetch_from_sharepoint( def _fetch_from_sharepoint(
self, start: datetime | None = None, end: datetime | None = None self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput: ) -> GenerateDocumentsOutput:
if self.graph_client is None: site_descriptors = self.site_descriptors or self._fetch_sites()
raise ConnectorMissingCredentialError("Sharepoint")
self._populate_sitedata_sites()
self._populate_sitedata_driveitems(start=start, end=end)
# goes over all urls, converts them into Document objects and then yields them in batches # goes over all urls, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = [] doc_batch: list[Document] = []
for element in self.site_data: for site_descriptor in site_descriptors:
for driveitem in element.driveitems: driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
for driveitem, drive_name in driveitems:
logger.debug(f"Processing: {driveitem.web_url}") logger.debug(f"Processing: {driveitem.web_url}")
doc_batch.append(_convert_driveitem_to_document(driveitem)) doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
if len(doc_batch) >= self.batch_size: if len(doc_batch) >= self.batch_size:
yield doc_batch yield doc_batch
@@ -202,7 +273,7 @@ class SharepointConnector(LoadConnector, PollConnector):
) )
return token return token
self.graph_client = GraphClient(_acquire_token_func) self._graph_client = GraphClient(_acquire_token_func)
return None return None
def load_from_state(self) -> GenerateDocumentsOutput: def load_from_state(self) -> GenerateDocumentsOutput:
@@ -211,19 +282,19 @@ class SharepointConnector(LoadConnector, PollConnector):
def poll_source( def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput: ) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start) start_datetime = datetime.fromtimestamp(start, timezone.utc)
end_datetime = datetime.utcfromtimestamp(end) end_datetime = datetime.fromtimestamp(end, timezone.utc)
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime) return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
if __name__ == "__main__": if __name__ == "__main__":
connector = SharepointConnector(sites=os.environ["SITES"].split(",")) connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(","))
connector.load_credentials( connector.load_credentials(
{ {
"sp_client_id": os.environ["SP_CLIENT_ID"], "sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
"sp_client_secret": os.environ["SP_CLIENT_SECRET"], "sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
"sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"], "sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
} }
) )
document_batches = connector.load_from_state() document_batches = connector.load_from_state()

View File

@@ -537,30 +537,36 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
# Let the tag flow handle this case, don't reply twice # Let the tag flow handle this case, don't reply twice
return False return False
if event.get("bot_profile"): # Check if this is a bot message (either via bot_profile or bot_message subtype)
is_bot_message = bool(
event.get("bot_profile") or event.get("subtype") == "bot_message"
)
if is_bot_message:
channel_name, _ = get_channel_name_from_id( channel_name, _ = get_channel_name_from_id(
client=client.web_client, channel_id=channel client=client.web_client, channel_id=channel
) )
with get_session_with_tenant(client.tenant_id) as db_session: with get_session_with_tenant(client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel( slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session, db_session=db_session,
slack_bot_id=client.slack_bot_id, slack_bot_id=client.slack_bot_id,
channel_name=channel_name, channel_name=channel_name,
) )
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message # If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
if (not bot_tag_id or bot_tag_id not in msg) and ( if (not bot_tag_id or bot_tag_id not in msg) and (
not slack_channel_config not slack_channel_config
or not slack_channel_config.channel_config.get("respond_to_bots") or not slack_channel_config.channel_config.get("respond_to_bots")
): ):
channel_specific_logger.info("Ignoring message from bot") channel_specific_logger.info(
"Ignoring message from bot since respond_to_bots is disabled"
)
return False return False
# Ignore things like channel_join, channel_leave, etc. # Ignore things like channel_join, channel_leave, etc.
# NOTE: "file_share" is just a message with a file attachment, so we # NOTE: "file_share" is just a message with a file attachment, so we
# should not ignore it # should not ignore it
message_subtype = event.get("subtype") message_subtype = event.get("subtype")
if message_subtype not in [None, "file_share"]: if message_subtype not in [None, "file_share", "bot_message"]:
channel_specific_logger.info( channel_specific_logger.info(
f"Ignoring message with subtype '{message_subtype}' since it is a special message type" f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
) )

View File

@@ -1,8 +1,8 @@
import os import os
from unittest.mock import MagicMock from unittest.mock import MagicMock
from unittest.mock import patch
import pytest import pytest
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource from onyx.configs.constants import DocumentSource
from onyx.connectors.airtable.airtable_connector import AirtableConnector from onyx.connectors.airtable.airtable_connector import AirtableConnector
@@ -10,25 +10,24 @@ from onyx.connectors.models import Document
from onyx.connectors.models import Section from onyx.connectors.models import Section
@pytest.fixture( class AirtableConfig(BaseModel):
params=[ base_id: str
("table_name", os.environ["AIRTABLE_TEST_TABLE_NAME"]), table_identifier: str
("table_id", os.environ["AIRTABLE_TEST_TABLE_ID"]), access_token: str
]
)
def airtable_connector(request: pytest.FixtureRequest) -> AirtableConnector:
param_type, table_identifier = request.param
connector = AirtableConnector(
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
table_name_or_id=table_identifier,
)
connector.load_credentials(
{ @pytest.fixture(params=[True, False])
"airtable_access_token": os.environ["AIRTABLE_ACCESS_TOKEN"], def airtable_config(request: pytest.FixtureRequest) -> AirtableConfig:
} table_identifier = (
os.environ["AIRTABLE_TEST_TABLE_NAME"]
if request.param
else os.environ["AIRTABLE_TEST_TABLE_ID"]
)
return AirtableConfig(
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
table_identifier=table_identifier,
access_token=os.environ["AIRTABLE_ACCESS_TOKEN"],
) )
return connector
def create_test_document( def create_test_document(
@@ -46,18 +45,37 @@ def create_test_document(
assignee: str, assignee: str,
days_since_status_change: int | None, days_since_status_change: int | None,
attachments: list[tuple[str, str]] | None = None, attachments: list[tuple[str, str]] | None = None,
all_fields_as_metadata: bool = False,
) -> Document: ) -> Document:
link_base = f"https://airtable.com/{os.environ['AIRTABLE_TEST_BASE_ID']}/{os.environ['AIRTABLE_TEST_TABLE_ID']}" base_id = os.environ.get("AIRTABLE_TEST_BASE_ID")
sections = [ table_id = os.environ.get("AIRTABLE_TEST_TABLE_ID")
Section( missing_vars = []
text=f"Title:\n------------------------\n{title}\n------------------------", if not base_id:
link=f"{link_base}/{id}", missing_vars.append("AIRTABLE_TEST_BASE_ID")
), if not table_id:
Section( missing_vars.append("AIRTABLE_TEST_TABLE_ID")
text=f"Description:\n------------------------\n{description}\n------------------------",
link=f"{link_base}/{id}", if missing_vars:
), raise RuntimeError(
] f"Required environment variables not set: {', '.join(missing_vars)}. "
"These variables are required to run Airtable connector tests."
)
link_base = f"https://airtable.com/{base_id}/{table_id}"
sections = []
if not all_fields_as_metadata:
sections.extend(
[
Section(
text=f"Title:\n------------------------\n{title}\n------------------------",
link=f"{link_base}/{id}",
),
Section(
text=f"Description:\n------------------------\n{description}\n------------------------",
link=f"{link_base}/{id}",
),
]
)
if attachments: if attachments:
for attachment_text, attachment_link in attachments: for attachment_text, attachment_link in attachments:
@@ -68,26 +86,36 @@ def create_test_document(
), ),
) )
metadata: dict[str, str | list[str]] = {
# "Category": category,
"Assignee": assignee,
"Submitted by": submitted_by,
"Priority": priority,
"Status": status,
"Created time": created_time,
"ID": ticket_id,
"Status last changed": status_last_changed,
**(
{"Days since status change": str(days_since_status_change)}
if days_since_status_change is not None
else {}
),
}
if all_fields_as_metadata:
metadata.update(
{
"Title": title,
"Description": description,
}
)
return Document( return Document(
id=f"airtable__{id}", id=f"airtable__{id}",
sections=sections, sections=sections,
source=DocumentSource.AIRTABLE, source=DocumentSource.AIRTABLE,
semantic_identifier=f"{os.environ['AIRTABLE_TEST_TABLE_NAME']}: {title}", semantic_identifier=f"{os.environ.get('AIRTABLE_TEST_TABLE_NAME', '')}: {title}",
metadata={ metadata=metadata,
# "Category": category,
"Assignee": assignee,
"Submitted by": submitted_by,
"Priority": priority,
"Status": status,
"Created time": created_time,
"ID": ticket_id,
"Status last changed": status_last_changed,
**(
{"Days since status change": str(days_since_status_change)}
if days_since_status_change is not None
else {}
),
},
doc_updated_at=None, doc_updated_at=None,
primary_owners=None, primary_owners=None,
secondary_owners=None, secondary_owners=None,
@@ -97,15 +125,75 @@ def create_test_document(
) )
@patch( def compare_documents(
"onyx.file_processing.extract_file_text.get_unstructured_api_key", actual_docs: list[Document], expected_docs: list[Document]
return_value=None,
)
def test_airtable_connector_basic(
mock_get_api_key: MagicMock, airtable_connector: AirtableConnector
) -> None: ) -> None:
doc_batch_generator = airtable_connector.load_from_state() """Utility function to compare actual and expected documents, ignoring order."""
actual_docs_dict = {doc.id: doc for doc in actual_docs}
expected_docs_dict = {doc.id: doc for doc in expected_docs}
assert actual_docs_dict.keys() == expected_docs_dict.keys(), "Document ID mismatch"
for doc_id in actual_docs_dict:
actual = actual_docs_dict[doc_id]
expected = expected_docs_dict[doc_id]
assert (
actual.source == expected.source
), f"Source mismatch for document {doc_id}"
assert (
actual.semantic_identifier == expected.semantic_identifier
), f"Semantic identifier mismatch for document {doc_id}"
assert (
actual.metadata == expected.metadata
), f"Metadata mismatch for document {doc_id}"
assert (
actual.doc_updated_at == expected.doc_updated_at
), f"Updated at mismatch for document {doc_id}"
assert (
actual.primary_owners == expected.primary_owners
), f"Primary owners mismatch for document {doc_id}"
assert (
actual.secondary_owners == expected.secondary_owners
), f"Secondary owners mismatch for document {doc_id}"
assert actual.title == expected.title, f"Title mismatch for document {doc_id}"
assert (
actual.from_ingestion_api == expected.from_ingestion_api
), f"Ingestion API flag mismatch for document {doc_id}"
assert (
actual.additional_info == expected.additional_info
), f"Additional info mismatch for document {doc_id}"
# Compare sections
assert len(actual.sections) == len(
expected.sections
), f"Number of sections mismatch for document {doc_id}"
for i, (actual_section, expected_section) in enumerate(
zip(actual.sections, expected.sections)
):
assert (
actual_section.text == expected_section.text
), f"Section {i} text mismatch for document {doc_id}"
assert (
actual_section.link == expected_section.link
), f"Section {i} link mismatch for document {doc_id}"
def test_airtable_connector_basic(
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
) -> None:
"""Test behavior when all non-attachment fields are treated as metadata."""
connector = AirtableConnector(
base_id=airtable_config.base_id,
table_name_or_id=airtable_config.table_identifier,
treat_all_non_attachment_fields_as_metadata=False,
)
connector.load_credentials(
{
"airtable_access_token": airtable_config.access_token,
}
)
doc_batch_generator = connector.load_from_state()
doc_batch = next(doc_batch_generator) doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
next(doc_batch_generator) next(doc_batch_generator)
@@ -119,15 +207,62 @@ def test_airtable_connector_basic(
description="The internet connection is very slow.", description="The internet connection is very slow.",
priority="Medium", priority="Medium",
status="In Progress", status="In Progress",
# Link to another record is skipped for now
# category="Data Science",
ticket_id="2", ticket_id="2",
created_time="2024-12-24T21:02:49.000Z", created_time="2024-12-24T21:02:49.000Z",
status_last_changed="2024-12-24T21:02:49.000Z", status_last_changed="2024-12-24T21:02:49.000Z",
days_since_status_change=0, days_since_status_change=0,
assignee="Chris Weaver (chris@onyx.app)", assignee="Chris Weaver (chris@onyx.app)",
submitted_by="Chris Weaver (chris@onyx.app)", submitted_by="Chris Weaver (chris@onyx.app)",
all_fields_as_metadata=False,
), ),
create_test_document(
id="reccSlIA4pZEFxPBg",
title="Printer Issue",
description="The office printer is not working.",
priority="High",
status="Open",
ticket_id="1",
created_time="2024-12-24T21:02:49.000Z",
status_last_changed="2024-12-24T21:02:49.000Z",
days_since_status_change=0,
assignee="Chris Weaver (chris@onyx.app)",
submitted_by="Chris Weaver (chris@onyx.app)",
attachments=[
(
"Test.pdf:\ntesting!!!",
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
)
],
all_fields_as_metadata=False,
),
]
# Compare documents using the utility function
compare_documents(doc_batch, expected_docs)
def test_airtable_connector_all_metadata(
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
) -> None:
connector = AirtableConnector(
base_id=airtable_config.base_id,
table_name_or_id=airtable_config.table_identifier,
treat_all_non_attachment_fields_as_metadata=True,
)
connector.load_credentials(
{
"airtable_access_token": airtable_config.access_token,
}
)
doc_batch_generator = connector.load_from_state()
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
# NOTE: one of the rows has no attachments -> no content -> no document
assert len(doc_batch) == 1
expected_docs = [
create_test_document( create_test_document(
id="reccSlIA4pZEFxPBg", id="reccSlIA4pZEFxPBg",
title="Printer Issue", title="Printer Issue",
@@ -149,50 +284,9 @@ def test_airtable_connector_basic(
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide", "https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
) )
], ],
all_fields_as_metadata=True,
), ),
] ]
# Compare each document field by field # Compare documents using the utility function
for actual, expected in zip(doc_batch, expected_docs): compare_documents(doc_batch, expected_docs)
assert actual.id == expected.id, f"ID mismatch for document {actual.id}"
assert (
actual.source == expected.source
), f"Source mismatch for document {actual.id}"
assert (
actual.semantic_identifier == expected.semantic_identifier
), f"Semantic identifier mismatch for document {actual.id}"
assert (
actual.metadata == expected.metadata
), f"Metadata mismatch for document {actual.id}"
assert (
actual.doc_updated_at == expected.doc_updated_at
), f"Updated at mismatch for document {actual.id}"
assert (
actual.primary_owners == expected.primary_owners
), f"Primary owners mismatch for document {actual.id}"
assert (
actual.secondary_owners == expected.secondary_owners
), f"Secondary owners mismatch for document {actual.id}"
assert (
actual.title == expected.title
), f"Title mismatch for document {actual.id}"
assert (
actual.from_ingestion_api == expected.from_ingestion_api
), f"Ingestion API flag mismatch for document {actual.id}"
assert (
actual.additional_info == expected.additional_info
), f"Additional info mismatch for document {actual.id}"
# Compare sections
assert len(actual.sections) == len(
expected.sections
), f"Number of sections mismatch for document {actual.id}"
for i, (actual_section, expected_section) in enumerate(
zip(actual.sections, expected.sections)
):
assert (
actual_section.text == expected_section.text
), f"Section {i} text mismatch for document {actual.id}"
assert (
actual_section.link == expected_section.link
), f"Section {i} link mismatch for document {actual.id}"

View File

@@ -0,0 +1,14 @@
from collections.abc import Generator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
@pytest.fixture
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
with patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
) as mock:
yield mock

View File

@@ -0,0 +1,178 @@
import os
from dataclasses import dataclass
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.sharepoint.connector import SharepointConnector
@dataclass
class ExpectedDocument:
semantic_identifier: str
content: str
folder_path: str | None = None
library: str = "Shared Documents" # Default to main library
EXPECTED_DOCUMENTS = [
ExpectedDocument(
semantic_identifier="test1.docx",
content="test1",
folder_path="test",
),
ExpectedDocument(
semantic_identifier="test2.docx",
content="test2",
folder_path="test/nested with spaces",
),
ExpectedDocument(
semantic_identifier="should-not-index-on-specific-folder.docx",
content="should-not-index-on-specific-folder",
folder_path=None, # root folder
),
ExpectedDocument(
semantic_identifier="other.docx",
content="other",
folder_path=None,
library="Other Library",
),
]
def verify_document_metadata(doc: Document) -> None:
"""Verify common metadata that should be present on all documents."""
assert isinstance(doc.doc_updated_at, datetime)
assert doc.doc_updated_at.tzinfo == timezone.utc
assert doc.source == DocumentSource.SHAREPOINT
assert doc.primary_owners is not None
assert len(doc.primary_owners) == 1
owner = doc.primary_owners[0]
assert owner.display_name is not None
assert owner.email is not None
def verify_document_content(doc: Document, expected: ExpectedDocument) -> None:
"""Verify a document matches its expected content."""
assert doc.semantic_identifier == expected.semantic_identifier
assert len(doc.sections) == 1
assert expected.content in doc.sections[0].text
verify_document_metadata(doc)
def find_document(documents: list[Document], semantic_identifier: str) -> Document:
"""Find a document by its semantic identifier."""
matching_docs = [
d for d in documents if d.semantic_identifier == semantic_identifier
]
assert (
len(matching_docs) == 1
), f"Expected exactly one document with identifier {semantic_identifier}"
return matching_docs[0]
@pytest.fixture
def sharepoint_credentials() -> dict[str, str]:
return {
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
}
def test_sharepoint_connector_specific_folder(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the test site URL and specific folder
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Should only find documents in the test folder
test_folder_docs = [
doc
for doc in EXPECTED_DOCUMENTS
if doc.folder_path and doc.folder_path.startswith("test")
]
assert len(found_documents) == len(
test_folder_docs
), "Should only find documents in test folder"
# Verify each expected document
for expected in test_folder_docs:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_root_folder(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]])
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
assert len(found_documents) == len(
EXPECTED_DOCUMENTS
), "Should find all documents in main library"
# Verify each expected document
for expected in EXPECTED_DOCUMENTS:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_other_library(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the other library
connector = SharepointConnector(
sites=[
os.environ["SHAREPOINT_SITE"] + "/Other Library",
]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
expected_documents: list[ExpectedDocument] = [
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
]
# Should find all documents in `Other Library`
assert len(found_documents) == len(
expected_documents
), "Should find all documents in `Other Library`"
# Verify each expected document
for expected in expected_documents:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)

View File

@@ -76,13 +76,7 @@ import {
import { buildFilters } from "@/lib/search/utils"; import { buildFilters } from "@/lib/search/utils";
import { SettingsContext } from "@/components/settings/SettingsProvider"; import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone"; import Dropzone from "react-dropzone";
import { import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
checkLLMSupportsImageInput,
getFinalLLM,
destructureValue,
getLLMProviderOverrideForPersona,
} from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar"; import { ChatInputBar } from "./input/ChatInputBar";
import { useChatContext } from "@/components/context/ChatContext"; import { useChatContext } from "@/components/context/ChatContext";
import { v4 as uuidv4 } from "uuid"; import { v4 as uuidv4 } from "uuid";
@@ -203,6 +197,12 @@ export function ChatPage({
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
const selectedChatSession = chatSessions.find(
(chatSession) => chatSession.id === existingChatSessionId
);
useEffect(() => { useEffect(() => {
if (user?.is_anonymous_user) { if (user?.is_anonymous_user) {
Cookies.set( Cookies.set(
@@ -240,12 +240,6 @@ export function ChatPage({
} }
}; };
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
const selectedChatSession = chatSessions.find(
(chatSession) => chatSession.id === existingChatSessionId
);
const chatSessionIdRef = useRef<string | null>(existingChatSessionId); const chatSessionIdRef = useRef<string | null>(existingChatSessionId);
// Only updates on session load (ie. rename / switching chat session) // Only updates on session load (ie. rename / switching chat session)
@@ -293,12 +287,6 @@ export function ChatPage({
); );
}; };
const llmOverrideManager = useLlmOverride(
llmProviders,
user?.preferences.default_model,
selectedChatSession
);
const [alternativeAssistant, setAlternativeAssistant] = const [alternativeAssistant, setAlternativeAssistant] =
useState<Persona | null>(null); useState<Persona | null>(null);
@@ -307,12 +295,27 @@ export function ChatPage({
const { recentAssistants, refreshRecentAssistants } = useAssistants(); const { recentAssistants, refreshRecentAssistants } = useAssistants();
const liveAssistant: Persona | undefined = const liveAssistant: Persona | undefined = useMemo(
alternativeAssistant || () =>
selectedAssistant || alternativeAssistant ||
recentAssistants[0] || selectedAssistant ||
finalAssistants[0] || recentAssistants[0] ||
availableAssistants[0]; finalAssistants[0] ||
availableAssistants[0],
[
alternativeAssistant,
selectedAssistant,
recentAssistants,
finalAssistants,
availableAssistants,
]
);
const llmOverrideManager = useLlmOverride(
llmProviders,
selectedChatSession,
liveAssistant
);
const noAssistants = liveAssistant == null || liveAssistant == undefined; const noAssistants = liveAssistant == null || liveAssistant == undefined;
@@ -320,24 +323,6 @@ export function ChatPage({
const uniqueSources = Array.from(new Set(availableSources)); const uniqueSources = Array.from(new Set(availableSources));
const sources = uniqueSources.map((source) => getSourceMetadata(source)); const sources = uniqueSources.map((source) => getSourceMetadata(source));
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
useEffect(() => {
if (noAssistants) return;
const personaDefault = getLLMProviderOverrideForPersona(
liveAssistant,
llmProviders
);
if (personaDefault) {
llmOverrideManager.updateLLMOverride(personaDefault);
} else if (user?.preferences.default_model) {
llmOverrideManager.updateLLMOverride(
destructureValue(user?.preferences.default_model)
);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [liveAssistant, user?.preferences.default_model]);
const stopGenerating = () => { const stopGenerating = () => {
const currentSession = currentSessionId(); const currentSession = currentSessionId();
const controller = abortControllers.get(currentSession); const controller = abortControllers.get(currentSession);
@@ -419,7 +404,6 @@ export function ChatPage({
filterManager.setTimeRange(null); filterManager.setTimeRange(null);
// reset LLM overrides (based on chat session!) // reset LLM overrides (based on chat session!)
llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
llmOverrideManager.updateTemperature(null); llmOverrideManager.updateTemperature(null);
// remove uploaded files // remove uploaded files
@@ -1283,13 +1267,11 @@ export function ChatPage({
modelProvider: modelProvider:
modelOverRide?.name || modelOverRide?.name ||
llmOverrideManager.llmOverride.name || llmOverrideManager.llmOverride.name ||
llmOverrideManager.globalDefault.name ||
undefined, undefined,
modelVersion: modelVersion:
modelOverRide?.modelName || modelOverRide?.modelName ||
llmOverrideManager.llmOverride.modelName || llmOverrideManager.llmOverride.modelName ||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
llmOverrideManager.globalDefault.modelName ||
undefined, undefined,
temperature: llmOverrideManager.temperature || undefined, temperature: llmOverrideManager.temperature || undefined,
systemPromptOverride: systemPromptOverride:
@@ -1952,6 +1934,7 @@ export function ChatPage({
}; };
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [router]); }, [router]);
const [sharedChatSession, setSharedChatSession] = const [sharedChatSession, setSharedChatSession] =
useState<ChatSession | null>(); useState<ChatSession | null>();
@@ -2059,7 +2042,9 @@ export function ChatPage({
{(settingsToggled || userSettingsToggled) && ( {(settingsToggled || userSettingsToggled) && (
<UserSettingsModal <UserSettingsModal
setPopup={setPopup} setPopup={setPopup}
setLlmOverride={llmOverrideManager.setGlobalDefault} setLlmOverride={(newOverride) =>
llmOverrideManager.updateLLMOverride(newOverride)
}
defaultModel={user?.preferences.default_model!} defaultModel={user?.preferences.default_model!}
llmProviders={llmProviders} llmProviders={llmProviders}
onClose={() => { onClose={() => {
@@ -2749,6 +2734,7 @@ export function ChatPage({
</button> </button>
</div> </div>
)} )}
<ChatInputBar <ChatInputBar
toggleDocumentSidebar={toggleDocumentSidebar} toggleDocumentSidebar={toggleDocumentSidebar}
availableSources={sources} availableSources={sources}

View File

@@ -40,8 +40,8 @@ export default function LLMPopover({
currentAssistant, currentAssistant,
}: LLMPopoverProps) { }: LLMPopoverProps) {
const [isOpen, setIsOpen] = useState(false); const [isOpen, setIsOpen] = useState(false);
const { llmOverride, updateLLMOverride, globalDefault } = llmOverrideManager; const { llmOverride, updateLLMOverride } = llmOverrideManager;
const currentLlm = llmOverride.modelName || globalDefault.modelName; const currentLlm = llmOverride.modelName;
const llmOptionsByProvider: { const llmOptionsByProvider: {
[provider: string]: { [provider: string]: {

View File

@@ -1,13 +1,5 @@
import { import { useContext, useEffect, useRef } from "react";
Dispatch,
SetStateAction,
useContext,
useEffect,
useRef,
useState,
} from "react";
import { Modal } from "@/components/Modal"; import { Modal } from "@/components/Modal";
import Text from "@/components/ui/text";
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks"; import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
@@ -33,7 +25,7 @@ export function UserSettingsModal({
}: { }: {
setPopup: (popupSpec: PopupSpec | null) => void; setPopup: (popupSpec: PopupSpec | null) => void;
llmProviders: LLMProviderDescriptor[]; llmProviders: LLMProviderDescriptor[];
setLlmOverride?: Dispatch<SetStateAction<LlmOverride>>; setLlmOverride?: (newOverride: LlmOverride) => void;
onClose: () => void; onClose: () => void;
defaultModel: string | null; defaultModel: string | null;
}) { }) {

View File

@@ -1,3 +1,3 @@
export const SEARCH_TOOL_NAME = "SearchTool"; export const SEARCH_TOOL_NAME = "run_search";
export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search"; export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search";
export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation"; export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation";

View File

@@ -1106,6 +1106,14 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
name: "table_name_or_id", name: "table_name_or_id",
optional: false, optional: false,
}, },
{
type: "checkbox",
label: "Treat all fields except attachments as metadata",
name: "treat_all_non_attachment_fields_as_metadata",
description:
"Choose this if the primary content to index are attachments and all other columns are metadata for these attachments.",
optional: false,
},
], ],
advanced_values: [], advanced_values: [],
overrideDefaultFreq: 60 * 60 * 24, overrideDefaultFreq: 60 * 60 * 24,

View File

@@ -13,16 +13,21 @@ import { errorHandlingFetcher } from "./fetcher";
import { useContext, useEffect, useState } from "react"; import { useContext, useEffect, useState } from "react";
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector"; import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
import { Filters, SourceMetadata } from "./search/interfaces"; import { Filters, SourceMetadata } from "./search/interfaces";
import { destructureValue, structureValue } from "./llm/utils"; import {
destructureValue,
findProviderForModel,
structureValue,
} from "./llm/utils";
import { ChatSession } from "@/app/chat/interfaces"; import { ChatSession } from "@/app/chat/interfaces";
import { AllUsersResponse } from "./types"; import { AllUsersResponse } from "./types";
import { Credential } from "./connectors/credentials"; import { Credential } from "./connectors/credentials";
import { SettingsContext } from "@/components/settings/SettingsProvider"; import { SettingsContext } from "@/components/settings/SettingsProvider";
import { PersonaLabel } from "@/app/admin/assistants/interfaces"; import { Persona, PersonaLabel } from "@/app/admin/assistants/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces"; import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
import { getSourceMetadata } from "./sources"; import { getSourceMetadata } from "./sources";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants"; import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
import { useUser } from "@/components/user/UserProvider";
const CREDENTIAL_URL = "/api/manage/admin/credential"; const CREDENTIAL_URL = "/api/manage/admin/credential";
@@ -355,82 +360,141 @@ export interface LlmOverride {
export interface LlmOverrideManager { export interface LlmOverrideManager {
llmOverride: LlmOverride; llmOverride: LlmOverride;
updateLLMOverride: (newOverride: LlmOverride) => void; updateLLMOverride: (newOverride: LlmOverride) => void;
globalDefault: LlmOverride;
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
temperature: number | null; temperature: number | null;
updateTemperature: (temperature: number | null) => void; updateTemperature: (temperature: number | null) => void;
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void; updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
imageFilesPresent: boolean; imageFilesPresent: boolean;
updateImageFilesPresent: (present: boolean) => void; updateImageFilesPresent: (present: boolean) => void;
liveAssistant: Persona | null;
} }
/*
LLM Override is as follows (i.e. this order)
- User override (explicitly set in the chat input bar)
- User preference (defaults to system wide default if no preference set)
On switching to an existing or new chat session or a different assistant:
- If we have a live assistant after any switch with a model override, use that- otherwise use the above hierarchy
Thus, the input should be
- User preference
- LLM Providers (which contain the system wide default)
- Current assistant
Changes take place as
- liveAssistant or currentChatSession changes (and the associated model override is set)
- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
If we have a live assistant, we should use that model override
*/
export function useLlmOverride( export function useLlmOverride(
llmProviders: LLMProviderDescriptor[], llmProviders: LLMProviderDescriptor[],
globalModel?: string | null,
currentChatSession?: ChatSession, currentChatSession?: ChatSession,
defaultTemperature?: number liveAssistant?: Persona
): LlmOverrideManager { ): LlmOverrideManager {
const { user } = useUser();
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
const llmOverrideUpdate = () => {
if (!chatSession && currentChatSession) {
setChatSession(currentChatSession || null);
return;
}
if (liveAssistant?.llm_model_version_override) {
setLlmOverride(
getValidLlmOverride(liveAssistant.llm_model_version_override)
);
} else if (currentChatSession?.current_alternate_model) {
setLlmOverride(
getValidLlmOverride(currentChatSession.current_alternate_model)
);
} else if (user?.preferences?.default_model) {
setLlmOverride(getValidLlmOverride(user.preferences.default_model));
return;
} else {
const defaultProvider = llmProviders.find(
(provider) => provider.is_default_provider
);
if (defaultProvider) {
setLlmOverride({
name: defaultProvider.name,
provider: defaultProvider.provider,
modelName: defaultProvider.default_model_name,
});
}
}
setChatSession(currentChatSession || null);
};
const getValidLlmOverride = ( const getValidLlmOverride = (
overrideModel: string | null | undefined overrideModel: string | null | undefined
): LlmOverride => { ): LlmOverride => {
if (overrideModel) { if (overrideModel) {
const model = destructureValue(overrideModel); const model = destructureValue(overrideModel);
const provider = llmProviders.find( if (!(model.modelName && model.modelName.length > 0)) {
(p) => const provider = llmProviders.find((p) =>
p.model_names.includes(model.modelName) && p.model_names.includes(overrideModel)
p.provider === model.provider );
if (provider) {
return {
modelName: overrideModel,
name: provider.name,
provider: provider.provider,
};
}
}
const provider = llmProviders.find((p) =>
p.model_names.includes(model.modelName)
); );
if (provider) { if (provider) {
return { ...model, name: provider.name }; return { ...model, name: provider.name };
} }
} }
return { name: "", provider: "", modelName: "" }; return { name: "", provider: "", modelName: "" };
}; };
const [imageFilesPresent, setImageFilesPresent] = useState(false); const [imageFilesPresent, setImageFilesPresent] = useState(false);
const updateImageFilesPresent = (present: boolean) => { const updateImageFilesPresent = (present: boolean) => {
setImageFilesPresent(present); setImageFilesPresent(present);
}; };
const [globalDefault, setGlobalDefault] = useState<LlmOverride>( const [llmOverride, setLlmOverride] = useState<LlmOverride>({
getValidLlmOverride(globalModel) name: "",
); provider: "",
const updateLLMOverride = (newOverride: LlmOverride) => { modelName: "",
setLlmOverride( });
getValidLlmOverride(
structureValue(
newOverride.name,
newOverride.provider,
newOverride.modelName
)
)
);
};
const [llmOverride, setLlmOverride] = useState<LlmOverride>( // Manually set the override
currentChatSession && currentChatSession.current_alternate_model const updateLLMOverride = (newOverride: LlmOverride) => {
? getValidLlmOverride(currentChatSession.current_alternate_model) const provider =
: { name: "", provider: "", modelName: "" } newOverride.provider ||
); findProviderForModel(llmProviders, newOverride.modelName);
const structuredValue = structureValue(
newOverride.name,
provider,
newOverride.modelName
);
setLlmOverride(getValidLlmOverride(structuredValue));
};
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => { const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
setLlmOverride( if (chatSession && chatSession.current_alternate_model?.length > 0) {
chatSession && chatSession.current_alternate_model setLlmOverride(getValidLlmOverride(chatSession.current_alternate_model));
? getValidLlmOverride(chatSession.current_alternate_model) }
: globalDefault
);
}; };
const [temperature, setTemperature] = useState<number | null>( const [temperature, setTemperature] = useState<number | null>(0);
defaultTemperature !== undefined ? defaultTemperature : 0
);
useEffect(() => { useEffect(() => {
setGlobalDefault(getValidLlmOverride(globalModel)); llmOverrideUpdate();
}, [globalModel, llmProviders]); }, [liveAssistant, currentChatSession]);
useEffect(() => {
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
}, [defaultTemperature]);
useEffect(() => { useEffect(() => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) { if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
@@ -450,12 +514,11 @@ export function useLlmOverride(
updateModelOverrideForChatSession, updateModelOverrideForChatSession,
llmOverride, llmOverride,
updateLLMOverride, updateLLMOverride,
globalDefault,
setGlobalDefault,
temperature, temperature,
updateTemperature, updateTemperature,
imageFilesPresent, imageFilesPresent,
updateImageFilesPresent, updateImageFilesPresent,
liveAssistant: liveAssistant ?? null,
}; };
} }

View File

@@ -143,3 +143,11 @@ export const destructureValue = (value: string): LlmOverride => {
modelName, modelName,
}; };
}; };
export const findProviderForModel = (
llmProviders: LLMProviderDescriptor[],
modelName: string
): string => {
const provider = llmProviders.find((p) => p.model_names.includes(modelName));
return provider ? provider.provider : "";
};

View File

@@ -358,7 +358,8 @@ export type ConfigurableSources = Exclude<
export const oauthSupportedSources: ConfigurableSources[] = [ export const oauthSupportedSources: ConfigurableSources[] = [
ValidSources.Slack, ValidSources.Slack,
ValidSources.GoogleDrive, // NOTE: temporarily disabled until our GDrive App is approved
// ValidSources.GoogleDrive,
]; ];
export type OAuthSupportedSource = (typeof oauthSupportedSources)[number]; export type OAuthSupportedSource = (typeof oauthSupportedSources)[number];