mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 13:05:49 +02:00
Merge branch 'main' of https://github.com/onyx-dot-app/onyx into bugfix/sharepoint_app_init
# Conflicts: # backend/onyx/connectors/sharepoint/connector.py
This commit is contained in:
@@ -39,6 +39,12 @@ env:
|
|||||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||||
|
# Sharepoint
|
||||||
|
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||||
|
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||||
|
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||||
|
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
connectors-check:
|
connectors-check:
|
||||||
# See https://runs-on.com/runners/linux/
|
# See https://runs-on.com/runners/linux/
|
||||||
|
@@ -20,9 +20,9 @@ from onyx.utils.logger import setup_logger
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||||
# these are the field types that are considered metadata rather
|
# These field types are considered metadata by default when
|
||||||
# than sections
|
# treat_all_non_attachment_fields_as_metadata is False
|
||||||
_METADATA_FIELD_TYPES = {
|
DEFAULT_METADATA_FIELD_TYPES = {
|
||||||
"singlecollaborator",
|
"singlecollaborator",
|
||||||
"collaborator",
|
"collaborator",
|
||||||
"createdby",
|
"createdby",
|
||||||
@@ -60,12 +60,16 @@ class AirtableConnector(LoadConnector):
|
|||||||
self,
|
self,
|
||||||
base_id: str,
|
base_id: str,
|
||||||
table_name_or_id: str,
|
table_name_or_id: str,
|
||||||
|
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||||
batch_size: int = INDEX_BATCH_SIZE,
|
batch_size: int = INDEX_BATCH_SIZE,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.base_id = base_id
|
self.base_id = base_id
|
||||||
self.table_name_or_id = table_name_or_id
|
self.table_name_or_id = table_name_or_id
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.airtable_client: AirtableApi | None = None
|
self.airtable_client: AirtableApi | None = None
|
||||||
|
self.treat_all_non_attachment_fields_as_metadata = (
|
||||||
|
treat_all_non_attachment_fields_as_metadata
|
||||||
|
)
|
||||||
|
|
||||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||||
@@ -166,8 +170,14 @@ class AirtableConnector(LoadConnector):
|
|||||||
return [(str(field_info), default_link)]
|
return [(str(field_info), default_link)]
|
||||||
|
|
||||||
def _should_be_metadata(self, field_type: str) -> bool:
|
def _should_be_metadata(self, field_type: str) -> bool:
|
||||||
"""Determine if a field type should be treated as metadata."""
|
"""Determine if a field type should be treated as metadata.
|
||||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
|
||||||
|
When treat_all_non_attachment_fields_as_metadata is True, all fields except
|
||||||
|
attachments are treated as metadata. Otherwise, only fields with types listed
|
||||||
|
in DEFAULT_METADATA_FIELD_TYPES are treated as metadata."""
|
||||||
|
if self.treat_all_non_attachment_fields_as_metadata:
|
||||||
|
return field_type.lower() != "multipleattachments"
|
||||||
|
return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES
|
||||||
|
|
||||||
def _process_field(
|
def _process_field(
|
||||||
self,
|
self,
|
||||||
@@ -233,7 +243,7 @@ class AirtableConnector(LoadConnector):
|
|||||||
record: RecordDict,
|
record: RecordDict,
|
||||||
table_schema: TableSchema,
|
table_schema: TableSchema,
|
||||||
primary_field_name: str | None,
|
primary_field_name: str | None,
|
||||||
) -> Document:
|
) -> Document | None:
|
||||||
"""Process a single Airtable record into a Document.
|
"""Process a single Airtable record into a Document.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -277,6 +287,10 @@ class AirtableConnector(LoadConnector):
|
|||||||
sections.extend(field_sections)
|
sections.extend(field_sections)
|
||||||
metadata.update(field_metadata)
|
metadata.update(field_metadata)
|
||||||
|
|
||||||
|
if not sections:
|
||||||
|
logger.warning(f"No sections found for record {record_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
semantic_id = (
|
semantic_id = (
|
||||||
f"{table_name}: {primary_field_value}"
|
f"{table_name}: {primary_field_value}"
|
||||||
if primary_field_value
|
if primary_field_value
|
||||||
@@ -320,7 +334,8 @@ class AirtableConnector(LoadConnector):
|
|||||||
table_schema=table_schema,
|
table_schema=table_schema,
|
||||||
primary_field_name=primary_field_name,
|
primary_field_name=primary_field_name,
|
||||||
)
|
)
|
||||||
record_documents.append(document)
|
if document:
|
||||||
|
record_documents.append(document)
|
||||||
|
|
||||||
if len(record_documents) >= self.batch_size:
|
if len(record_documents) >= self.batch_size:
|
||||||
yield record_documents
|
yield record_documents
|
||||||
|
@@ -1,17 +1,14 @@
|
|||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
|
||||||
from dataclasses import field
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
import msal # type: ignore
|
import msal # type: ignore
|
||||||
from office365.graph_client import GraphClient # type: ignore
|
from office365.graph_client import GraphClient # type: ignore
|
||||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
|
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
|
||||||
from office365.onedrive.sites.site import Site # type: ignore
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
@@ -30,16 +27,25 @@ from onyx.utils.logger import setup_logger
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class SiteDescriptor(BaseModel):
|
||||||
class SiteData:
|
"""Data class for storing SharePoint site information.
|
||||||
url: str | None
|
|
||||||
folder: Optional[str]
|
Args:
|
||||||
sites: list = field(default_factory=list)
|
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
|
||||||
driveitems: list = field(default_factory=list)
|
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
|
||||||
|
If None, all drives will be accessed.
|
||||||
|
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
|
||||||
|
If None, all folders will be accessed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
drive_name: str | None
|
||||||
|
folder_path: str | None
|
||||||
|
|
||||||
|
|
||||||
def _convert_driveitem_to_document(
|
def _convert_driveitem_to_document(
|
||||||
driveitem: DriveItem,
|
driveitem: DriveItem,
|
||||||
|
drive_name: str,
|
||||||
) -> Document:
|
) -> Document:
|
||||||
file_text = extract_file_text(
|
file_text = extract_file_text(
|
||||||
file=io.BytesIO(driveitem.get_content().execute_query().value),
|
file=io.BytesIO(driveitem.get_content().execute_query().value),
|
||||||
@@ -59,7 +65,7 @@ def _convert_driveitem_to_document(
|
|||||||
email=driveitem.last_modified_by.user.email,
|
email=driveitem.last_modified_by.user.email,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
metadata={},
|
metadata={"drive": drive_name},
|
||||||
)
|
)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
@@ -71,107 +77,172 @@ class SharepointConnector(LoadConnector, PollConnector):
|
|||||||
sites: list[str] = [],
|
sites: list[str] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.graph_client: GraphClient | None = None
|
self._graph_client: GraphClient | None = None
|
||||||
self.site_data: list[SiteData] = self._extract_site_and_folder(sites)
|
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
|
||||||
|
sites
|
||||||
|
)
|
||||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph_client(self) -> GraphClient:
|
||||||
|
if self._graph_client is None:
|
||||||
|
raise ConnectorMissingCredentialError("Sharepoint")
|
||||||
|
|
||||||
|
return self._graph_client
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]:
|
def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]:
|
||||||
site_data_list = []
|
site_data_list = []
|
||||||
for url in site_urls:
|
for url in site_urls:
|
||||||
parts = url.strip().split("/")
|
parts = url.strip().split("/")
|
||||||
if "sites" in parts:
|
if "sites" in parts:
|
||||||
sites_index = parts.index("sites")
|
sites_index = parts.index("sites")
|
||||||
site_url = "/".join(parts[: sites_index + 2])
|
site_url = "/".join(parts[: sites_index + 2])
|
||||||
folder = (
|
remaining_parts = parts[sites_index + 2 :]
|
||||||
"/".join(unquote(part) for part in parts[sites_index + 2 :])
|
|
||||||
if len(parts) > sites_index + 2
|
# Extract drive name and folder path
|
||||||
else None
|
if remaining_parts:
|
||||||
)
|
drive_name = unquote(remaining_parts[0])
|
||||||
# Handling for new URL structure
|
folder_path = (
|
||||||
if folder and folder.startswith("Shared Documents/"):
|
"/".join(unquote(part) for part in remaining_parts[1:])
|
||||||
folder = folder[len("Shared Documents/") :]
|
if len(remaining_parts) > 1
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
drive_name = None
|
||||||
|
folder_path = None
|
||||||
|
|
||||||
site_data_list.append(
|
site_data_list.append(
|
||||||
SiteData(url=site_url, folder=folder, sites=[], driveitems=[])
|
SiteDescriptor(
|
||||||
|
url=site_url,
|
||||||
|
drive_name=drive_name,
|
||||||
|
folder_path=folder_path,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return site_data_list
|
return site_data_list
|
||||||
|
|
||||||
def _populate_sitedata_driveitems(
|
def _fetch_driveitems(
|
||||||
self,
|
self,
|
||||||
|
site_descriptor: SiteDescriptor,
|
||||||
start: datetime | None = None,
|
start: datetime | None = None,
|
||||||
end: datetime | None = None,
|
end: datetime | None = None,
|
||||||
) -> None:
|
) -> list[tuple[DriveItem, str]]:
|
||||||
filter_str = ""
|
filter_str = ""
|
||||||
if start is not None and end is not None:
|
if start is not None and end is not None:
|
||||||
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
|
filter_str = (
|
||||||
|
f"last_modified_datetime ge {start.isoformat()} and "
|
||||||
|
f"last_modified_datetime le {end.isoformat()}"
|
||||||
|
)
|
||||||
|
|
||||||
for element in self.site_data:
|
final_driveitems: list[tuple[DriveItem, str]] = []
|
||||||
sites: list[Site] = []
|
try:
|
||||||
for site in element.sites:
|
site = self.graph_client.sites.get_by_url(site_descriptor.url)
|
||||||
site_sublist = site.lists.get().execute_query()
|
|
||||||
sites.extend(site_sublist)
|
|
||||||
|
|
||||||
for site in sites:
|
# Get all drives in the site
|
||||||
|
drives = site.drives.get().execute_query()
|
||||||
|
logger.debug(f"Found drives: {[drive.name for drive in drives]}")
|
||||||
|
|
||||||
|
# Filter drives based on the requested drive name
|
||||||
|
if site_descriptor.drive_name:
|
||||||
|
drives = [
|
||||||
|
drive
|
||||||
|
for drive in drives
|
||||||
|
if drive.name == site_descriptor.drive_name
|
||||||
|
or (
|
||||||
|
drive.name == "Documents"
|
||||||
|
and site_descriptor.drive_name == "Shared Documents"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if not drives:
|
||||||
|
logger.warning(f"Drive '{site_descriptor.drive_name}' not found")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Process each matching drive
|
||||||
|
for drive in drives:
|
||||||
try:
|
try:
|
||||||
query = site.drive.root.get_files(True, 1000)
|
root_folder = drive.root
|
||||||
|
if site_descriptor.folder_path:
|
||||||
|
# If a specific folder is requested, navigate to it
|
||||||
|
for folder_part in site_descriptor.folder_path.split("/"):
|
||||||
|
root_folder = root_folder.get_by_path(folder_part)
|
||||||
|
|
||||||
|
# Get all items recursively
|
||||||
|
query = root_folder.get_files(True, 1000)
|
||||||
if filter_str:
|
if filter_str:
|
||||||
query = query.filter(filter_str)
|
query = query.filter(filter_str)
|
||||||
driveitems = query.execute_query()
|
driveitems = query.execute_query()
|
||||||
if element.folder:
|
logger.debug(
|
||||||
expected_path = f"/root:/{element.folder}"
|
f"Found {len(driveitems)} items in drive '{drive.name}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use "Shared Documents" as the library name for the default "Documents" drive
|
||||||
|
drive_name = (
|
||||||
|
"Shared Documents" if drive.name == "Documents" else drive.name
|
||||||
|
)
|
||||||
|
|
||||||
|
if site_descriptor.folder_path:
|
||||||
|
# Filter items to ensure they're in the specified folder or its subfolders
|
||||||
|
# The path will be in format: /drives/{drive_id}/root:/folder/path
|
||||||
filtered_driveitems = [
|
filtered_driveitems = [
|
||||||
item
|
(item, drive_name)
|
||||||
for item in driveitems
|
for item in driveitems
|
||||||
if item.parent_reference.path.endswith(expected_path)
|
if any(
|
||||||
|
path_part == site_descriptor.folder_path
|
||||||
|
or path_part.startswith(
|
||||||
|
site_descriptor.folder_path + "/"
|
||||||
|
)
|
||||||
|
for path_part in item.parent_reference.path.split(
|
||||||
|
"root:/"
|
||||||
|
)[1].split("/")
|
||||||
|
)
|
||||||
]
|
]
|
||||||
if len(filtered_driveitems) == 0:
|
if len(filtered_driveitems) == 0:
|
||||||
all_paths = [
|
all_paths = [
|
||||||
item.parent_reference.path for item in driveitems
|
item.parent_reference.path for item in driveitems
|
||||||
]
|
]
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Nothing found for folder '{expected_path}' in any of valid paths: {all_paths}"
|
f"Nothing found for folder '{site_descriptor.folder_path}' "
|
||||||
|
f"in; any of valid paths: {all_paths}"
|
||||||
)
|
)
|
||||||
element.driveitems.extend(filtered_driveitems)
|
final_driveitems.extend(filtered_driveitems)
|
||||||
else:
|
else:
|
||||||
element.driveitems.extend(driveitems)
|
final_driveitems.extend(
|
||||||
|
[(item, drive_name) for item in driveitems]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Some drives might not be accessible
|
||||||
|
logger.warning(f"Failed to process drive: {str(e)}")
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# Sites include things that do not contain .drive.root so this fails
|
# Sites include things that do not contain drives so this fails
|
||||||
# but this is fine, as there are no actually documents in those
|
# but this is fine, as there are no actual documents in those
|
||||||
pass
|
logger.warning(f"Failed to process site: {str(e)}")
|
||||||
|
|
||||||
def _populate_sitedata_sites(self) -> None:
|
return final_driveitems
|
||||||
if self.graph_client is None:
|
|
||||||
raise ConnectorMissingCredentialError("Sharepoint")
|
|
||||||
|
|
||||||
if self.site_data:
|
def _fetch_sites(self) -> list[SiteDescriptor]:
|
||||||
for element in self.site_data:
|
sites = self.graph_client.sites.get_all().execute_query()
|
||||||
element.sites = [
|
site_descriptors = [
|
||||||
self.graph_client.sites.get_by_url(element.url)
|
SiteDescriptor(
|
||||||
.get()
|
url=sites.resource_url,
|
||||||
.execute_query()
|
drive_name=None,
|
||||||
]
|
folder_path=None,
|
||||||
else:
|
)
|
||||||
sites = self.graph_client.sites.get_all().execute_query()
|
]
|
||||||
self.site_data = [
|
return site_descriptors
|
||||||
SiteData(url=None, folder=None, sites=sites, driveitems=[])
|
|
||||||
]
|
|
||||||
|
|
||||||
def _fetch_from_sharepoint(
|
def _fetch_from_sharepoint(
|
||||||
self, start: datetime | None = None, end: datetime | None = None
|
self, start: datetime | None = None, end: datetime | None = None
|
||||||
) -> GenerateDocumentsOutput:
|
) -> GenerateDocumentsOutput:
|
||||||
if self.graph_client is None:
|
site_descriptors = self.site_descriptors or self._fetch_sites()
|
||||||
raise ConnectorMissingCredentialError("Sharepoint")
|
|
||||||
|
|
||||||
self._populate_sitedata_sites()
|
|
||||||
self._populate_sitedata_driveitems(start=start, end=end)
|
|
||||||
|
|
||||||
# goes over all urls, converts them into Document objects and then yields them in batches
|
# goes over all urls, converts them into Document objects and then yields them in batches
|
||||||
doc_batch: list[Document] = []
|
doc_batch: list[Document] = []
|
||||||
for element in self.site_data:
|
for site_descriptor in site_descriptors:
|
||||||
for driveitem in element.driveitems:
|
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
|
||||||
|
for driveitem, drive_name in driveitems:
|
||||||
logger.debug(f"Processing: {driveitem.web_url}")
|
logger.debug(f"Processing: {driveitem.web_url}")
|
||||||
doc_batch.append(_convert_driveitem_to_document(driveitem))
|
doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
|
||||||
|
|
||||||
if len(doc_batch) >= self.batch_size:
|
if len(doc_batch) >= self.batch_size:
|
||||||
yield doc_batch
|
yield doc_batch
|
||||||
@@ -202,7 +273,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
|||||||
)
|
)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
self.graph_client = GraphClient(_acquire_token_func)
|
self._graph_client = GraphClient(_acquire_token_func)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||||
@@ -211,19 +282,19 @@ class SharepointConnector(LoadConnector, PollConnector):
|
|||||||
def poll_source(
|
def poll_source(
|
||||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||||
) -> GenerateDocumentsOutput:
|
) -> GenerateDocumentsOutput:
|
||||||
start_datetime = datetime.utcfromtimestamp(start)
|
start_datetime = datetime.fromtimestamp(start, timezone.utc)
|
||||||
end_datetime = datetime.utcfromtimestamp(end)
|
end_datetime = datetime.fromtimestamp(end, timezone.utc)
|
||||||
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
|
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
connector = SharepointConnector(sites=os.environ["SITES"].split(","))
|
connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(","))
|
||||||
|
|
||||||
connector.load_credentials(
|
connector.load_credentials(
|
||||||
{
|
{
|
||||||
"sp_client_id": os.environ["SP_CLIENT_ID"],
|
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
|
||||||
"sp_client_secret": os.environ["SP_CLIENT_SECRET"],
|
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
|
||||||
"sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"],
|
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
document_batches = connector.load_from_state()
|
document_batches = connector.load_from_state()
|
||||||
|
@@ -537,30 +537,36 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
|||||||
# Let the tag flow handle this case, don't reply twice
|
# Let the tag flow handle this case, don't reply twice
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if event.get("bot_profile"):
|
# Check if this is a bot message (either via bot_profile or bot_message subtype)
|
||||||
|
is_bot_message = bool(
|
||||||
|
event.get("bot_profile") or event.get("subtype") == "bot_message"
|
||||||
|
)
|
||||||
|
if is_bot_message:
|
||||||
channel_name, _ = get_channel_name_from_id(
|
channel_name, _ = get_channel_name_from_id(
|
||||||
client=client.web_client, channel_id=channel
|
client=client.web_client, channel_id=channel
|
||||||
)
|
)
|
||||||
|
|
||||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
slack_bot_id=client.slack_bot_id,
|
slack_bot_id=client.slack_bot_id,
|
||||||
channel_name=channel_name,
|
channel_name=channel_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
|
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
|
||||||
if (not bot_tag_id or bot_tag_id not in msg) and (
|
if (not bot_tag_id or bot_tag_id not in msg) and (
|
||||||
not slack_channel_config
|
not slack_channel_config
|
||||||
or not slack_channel_config.channel_config.get("respond_to_bots")
|
or not slack_channel_config.channel_config.get("respond_to_bots")
|
||||||
):
|
):
|
||||||
channel_specific_logger.info("Ignoring message from bot")
|
channel_specific_logger.info(
|
||||||
|
"Ignoring message from bot since respond_to_bots is disabled"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Ignore things like channel_join, channel_leave, etc.
|
# Ignore things like channel_join, channel_leave, etc.
|
||||||
# NOTE: "file_share" is just a message with a file attachment, so we
|
# NOTE: "file_share" is just a message with a file attachment, so we
|
||||||
# should not ignore it
|
# should not ignore it
|
||||||
message_subtype = event.get("subtype")
|
message_subtype = event.get("subtype")
|
||||||
if message_subtype not in [None, "file_share"]:
|
if message_subtype not in [None, "file_share", "bot_message"]:
|
||||||
channel_specific_logger.info(
|
channel_specific_logger.info(
|
||||||
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
|
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
|
||||||
)
|
)
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||||
@@ -10,25 +10,24 @@ from onyx.connectors.models import Document
|
|||||||
from onyx.connectors.models import Section
|
from onyx.connectors.models import Section
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
class AirtableConfig(BaseModel):
|
||||||
params=[
|
base_id: str
|
||||||
("table_name", os.environ["AIRTABLE_TEST_TABLE_NAME"]),
|
table_identifier: str
|
||||||
("table_id", os.environ["AIRTABLE_TEST_TABLE_ID"]),
|
access_token: str
|
||||||
]
|
|
||||||
)
|
|
||||||
def airtable_connector(request: pytest.FixtureRequest) -> AirtableConnector:
|
|
||||||
param_type, table_identifier = request.param
|
|
||||||
connector = AirtableConnector(
|
|
||||||
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
|
|
||||||
table_name_or_id=table_identifier,
|
|
||||||
)
|
|
||||||
|
|
||||||
connector.load_credentials(
|
|
||||||
{
|
@pytest.fixture(params=[True, False])
|
||||||
"airtable_access_token": os.environ["AIRTABLE_ACCESS_TOKEN"],
|
def airtable_config(request: pytest.FixtureRequest) -> AirtableConfig:
|
||||||
}
|
table_identifier = (
|
||||||
|
os.environ["AIRTABLE_TEST_TABLE_NAME"]
|
||||||
|
if request.param
|
||||||
|
else os.environ["AIRTABLE_TEST_TABLE_ID"]
|
||||||
|
)
|
||||||
|
return AirtableConfig(
|
||||||
|
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
|
||||||
|
table_identifier=table_identifier,
|
||||||
|
access_token=os.environ["AIRTABLE_ACCESS_TOKEN"],
|
||||||
)
|
)
|
||||||
return connector
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_document(
|
def create_test_document(
|
||||||
@@ -46,18 +45,37 @@ def create_test_document(
|
|||||||
assignee: str,
|
assignee: str,
|
||||||
days_since_status_change: int | None,
|
days_since_status_change: int | None,
|
||||||
attachments: list[tuple[str, str]] | None = None,
|
attachments: list[tuple[str, str]] | None = None,
|
||||||
|
all_fields_as_metadata: bool = False,
|
||||||
) -> Document:
|
) -> Document:
|
||||||
link_base = f"https://airtable.com/{os.environ['AIRTABLE_TEST_BASE_ID']}/{os.environ['AIRTABLE_TEST_TABLE_ID']}"
|
base_id = os.environ.get("AIRTABLE_TEST_BASE_ID")
|
||||||
sections = [
|
table_id = os.environ.get("AIRTABLE_TEST_TABLE_ID")
|
||||||
Section(
|
missing_vars = []
|
||||||
text=f"Title:\n------------------------\n{title}\n------------------------",
|
if not base_id:
|
||||||
link=f"{link_base}/{id}",
|
missing_vars.append("AIRTABLE_TEST_BASE_ID")
|
||||||
),
|
if not table_id:
|
||||||
Section(
|
missing_vars.append("AIRTABLE_TEST_TABLE_ID")
|
||||||
text=f"Description:\n------------------------\n{description}\n------------------------",
|
|
||||||
link=f"{link_base}/{id}",
|
if missing_vars:
|
||||||
),
|
raise RuntimeError(
|
||||||
]
|
f"Required environment variables not set: {', '.join(missing_vars)}. "
|
||||||
|
"These variables are required to run Airtable connector tests."
|
||||||
|
)
|
||||||
|
link_base = f"https://airtable.com/{base_id}/{table_id}"
|
||||||
|
sections = []
|
||||||
|
|
||||||
|
if not all_fields_as_metadata:
|
||||||
|
sections.extend(
|
||||||
|
[
|
||||||
|
Section(
|
||||||
|
text=f"Title:\n------------------------\n{title}\n------------------------",
|
||||||
|
link=f"{link_base}/{id}",
|
||||||
|
),
|
||||||
|
Section(
|
||||||
|
text=f"Description:\n------------------------\n{description}\n------------------------",
|
||||||
|
link=f"{link_base}/{id}",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if attachments:
|
if attachments:
|
||||||
for attachment_text, attachment_link in attachments:
|
for attachment_text, attachment_link in attachments:
|
||||||
@@ -68,26 +86,36 @@ def create_test_document(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
metadata: dict[str, str | list[str]] = {
|
||||||
|
# "Category": category,
|
||||||
|
"Assignee": assignee,
|
||||||
|
"Submitted by": submitted_by,
|
||||||
|
"Priority": priority,
|
||||||
|
"Status": status,
|
||||||
|
"Created time": created_time,
|
||||||
|
"ID": ticket_id,
|
||||||
|
"Status last changed": status_last_changed,
|
||||||
|
**(
|
||||||
|
{"Days since status change": str(days_since_status_change)}
|
||||||
|
if days_since_status_change is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
if all_fields_as_metadata:
|
||||||
|
metadata.update(
|
||||||
|
{
|
||||||
|
"Title": title,
|
||||||
|
"Description": description,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return Document(
|
return Document(
|
||||||
id=f"airtable__{id}",
|
id=f"airtable__{id}",
|
||||||
sections=sections,
|
sections=sections,
|
||||||
source=DocumentSource.AIRTABLE,
|
source=DocumentSource.AIRTABLE,
|
||||||
semantic_identifier=f"{os.environ['AIRTABLE_TEST_TABLE_NAME']}: {title}",
|
semantic_identifier=f"{os.environ.get('AIRTABLE_TEST_TABLE_NAME', '')}: {title}",
|
||||||
metadata={
|
metadata=metadata,
|
||||||
# "Category": category,
|
|
||||||
"Assignee": assignee,
|
|
||||||
"Submitted by": submitted_by,
|
|
||||||
"Priority": priority,
|
|
||||||
"Status": status,
|
|
||||||
"Created time": created_time,
|
|
||||||
"ID": ticket_id,
|
|
||||||
"Status last changed": status_last_changed,
|
|
||||||
**(
|
|
||||||
{"Days since status change": str(days_since_status_change)}
|
|
||||||
if days_since_status_change is not None
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
},
|
|
||||||
doc_updated_at=None,
|
doc_updated_at=None,
|
||||||
primary_owners=None,
|
primary_owners=None,
|
||||||
secondary_owners=None,
|
secondary_owners=None,
|
||||||
@@ -97,15 +125,75 @@ def create_test_document(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
def compare_documents(
|
||||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
actual_docs: list[Document], expected_docs: list[Document]
|
||||||
return_value=None,
|
|
||||||
)
|
|
||||||
def test_airtable_connector_basic(
|
|
||||||
mock_get_api_key: MagicMock, airtable_connector: AirtableConnector
|
|
||||||
) -> None:
|
) -> None:
|
||||||
doc_batch_generator = airtable_connector.load_from_state()
|
"""Utility function to compare actual and expected documents, ignoring order."""
|
||||||
|
actual_docs_dict = {doc.id: doc for doc in actual_docs}
|
||||||
|
expected_docs_dict = {doc.id: doc for doc in expected_docs}
|
||||||
|
|
||||||
|
assert actual_docs_dict.keys() == expected_docs_dict.keys(), "Document ID mismatch"
|
||||||
|
|
||||||
|
for doc_id in actual_docs_dict:
|
||||||
|
actual = actual_docs_dict[doc_id]
|
||||||
|
expected = expected_docs_dict[doc_id]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
actual.source == expected.source
|
||||||
|
), f"Source mismatch for document {doc_id}"
|
||||||
|
assert (
|
||||||
|
actual.semantic_identifier == expected.semantic_identifier
|
||||||
|
), f"Semantic identifier mismatch for document {doc_id}"
|
||||||
|
assert (
|
||||||
|
actual.metadata == expected.metadata
|
||||||
|
), f"Metadata mismatch for document {doc_id}"
|
||||||
|
assert (
|
||||||
|
actual.doc_updated_at == expected.doc_updated_at
|
||||||
|
), f"Updated at mismatch for document {doc_id}"
|
||||||
|
assert (
|
||||||
|
actual.primary_owners == expected.primary_owners
|
||||||
|
), f"Primary owners mismatch for document {doc_id}"
|
||||||
|
assert (
|
||||||
|
actual.secondary_owners == expected.secondary_owners
|
||||||
|
), f"Secondary owners mismatch for document {doc_id}"
|
||||||
|
assert actual.title == expected.title, f"Title mismatch for document {doc_id}"
|
||||||
|
assert (
|
||||||
|
actual.from_ingestion_api == expected.from_ingestion_api
|
||||||
|
), f"Ingestion API flag mismatch for document {doc_id}"
|
||||||
|
assert (
|
||||||
|
actual.additional_info == expected.additional_info
|
||||||
|
), f"Additional info mismatch for document {doc_id}"
|
||||||
|
|
||||||
|
# Compare sections
|
||||||
|
assert len(actual.sections) == len(
|
||||||
|
expected.sections
|
||||||
|
), f"Number of sections mismatch for document {doc_id}"
|
||||||
|
for i, (actual_section, expected_section) in enumerate(
|
||||||
|
zip(actual.sections, expected.sections)
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
actual_section.text == expected_section.text
|
||||||
|
), f"Section {i} text mismatch for document {doc_id}"
|
||||||
|
assert (
|
||||||
|
actual_section.link == expected_section.link
|
||||||
|
), f"Section {i} link mismatch for document {doc_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_airtable_connector_basic(
|
||||||
|
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
|
||||||
|
) -> None:
|
||||||
|
"""Test behavior when all non-attachment fields are treated as metadata."""
|
||||||
|
connector = AirtableConnector(
|
||||||
|
base_id=airtable_config.base_id,
|
||||||
|
table_name_or_id=airtable_config.table_identifier,
|
||||||
|
treat_all_non_attachment_fields_as_metadata=False,
|
||||||
|
)
|
||||||
|
connector.load_credentials(
|
||||||
|
{
|
||||||
|
"airtable_access_token": airtable_config.access_token,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
doc_batch_generator = connector.load_from_state()
|
||||||
doc_batch = next(doc_batch_generator)
|
doc_batch = next(doc_batch_generator)
|
||||||
with pytest.raises(StopIteration):
|
with pytest.raises(StopIteration):
|
||||||
next(doc_batch_generator)
|
next(doc_batch_generator)
|
||||||
@@ -119,15 +207,62 @@ def test_airtable_connector_basic(
|
|||||||
description="The internet connection is very slow.",
|
description="The internet connection is very slow.",
|
||||||
priority="Medium",
|
priority="Medium",
|
||||||
status="In Progress",
|
status="In Progress",
|
||||||
# Link to another record is skipped for now
|
|
||||||
# category="Data Science",
|
|
||||||
ticket_id="2",
|
ticket_id="2",
|
||||||
created_time="2024-12-24T21:02:49.000Z",
|
created_time="2024-12-24T21:02:49.000Z",
|
||||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||||
days_since_status_change=0,
|
days_since_status_change=0,
|
||||||
assignee="Chris Weaver (chris@onyx.app)",
|
assignee="Chris Weaver (chris@onyx.app)",
|
||||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||||
|
all_fields_as_metadata=False,
|
||||||
),
|
),
|
||||||
|
create_test_document(
|
||||||
|
id="reccSlIA4pZEFxPBg",
|
||||||
|
title="Printer Issue",
|
||||||
|
description="The office printer is not working.",
|
||||||
|
priority="High",
|
||||||
|
status="Open",
|
||||||
|
ticket_id="1",
|
||||||
|
created_time="2024-12-24T21:02:49.000Z",
|
||||||
|
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||||
|
days_since_status_change=0,
|
||||||
|
assignee="Chris Weaver (chris@onyx.app)",
|
||||||
|
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||||
|
attachments=[
|
||||||
|
(
|
||||||
|
"Test.pdf:\ntesting!!!",
|
||||||
|
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
all_fields_as_metadata=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Compare documents using the utility function
|
||||||
|
compare_documents(doc_batch, expected_docs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_airtable_connector_all_metadata(
|
||||||
|
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
|
||||||
|
) -> None:
|
||||||
|
connector = AirtableConnector(
|
||||||
|
base_id=airtable_config.base_id,
|
||||||
|
table_name_or_id=airtable_config.table_identifier,
|
||||||
|
treat_all_non_attachment_fields_as_metadata=True,
|
||||||
|
)
|
||||||
|
connector.load_credentials(
|
||||||
|
{
|
||||||
|
"airtable_access_token": airtable_config.access_token,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
doc_batch_generator = connector.load_from_state()
|
||||||
|
doc_batch = next(doc_batch_generator)
|
||||||
|
with pytest.raises(StopIteration):
|
||||||
|
next(doc_batch_generator)
|
||||||
|
|
||||||
|
# NOTE: one of the rows has no attachments -> no content -> no document
|
||||||
|
assert len(doc_batch) == 1
|
||||||
|
|
||||||
|
expected_docs = [
|
||||||
create_test_document(
|
create_test_document(
|
||||||
id="reccSlIA4pZEFxPBg",
|
id="reccSlIA4pZEFxPBg",
|
||||||
title="Printer Issue",
|
title="Printer Issue",
|
||||||
@@ -149,50 +284,9 @@ def test_airtable_connector_basic(
|
|||||||
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
|
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
all_fields_as_metadata=True,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Compare each document field by field
|
# Compare documents using the utility function
|
||||||
for actual, expected in zip(doc_batch, expected_docs):
|
compare_documents(doc_batch, expected_docs)
|
||||||
assert actual.id == expected.id, f"ID mismatch for document {actual.id}"
|
|
||||||
assert (
|
|
||||||
actual.source == expected.source
|
|
||||||
), f"Source mismatch for document {actual.id}"
|
|
||||||
assert (
|
|
||||||
actual.semantic_identifier == expected.semantic_identifier
|
|
||||||
), f"Semantic identifier mismatch for document {actual.id}"
|
|
||||||
assert (
|
|
||||||
actual.metadata == expected.metadata
|
|
||||||
), f"Metadata mismatch for document {actual.id}"
|
|
||||||
assert (
|
|
||||||
actual.doc_updated_at == expected.doc_updated_at
|
|
||||||
), f"Updated at mismatch for document {actual.id}"
|
|
||||||
assert (
|
|
||||||
actual.primary_owners == expected.primary_owners
|
|
||||||
), f"Primary owners mismatch for document {actual.id}"
|
|
||||||
assert (
|
|
||||||
actual.secondary_owners == expected.secondary_owners
|
|
||||||
), f"Secondary owners mismatch for document {actual.id}"
|
|
||||||
assert (
|
|
||||||
actual.title == expected.title
|
|
||||||
), f"Title mismatch for document {actual.id}"
|
|
||||||
assert (
|
|
||||||
actual.from_ingestion_api == expected.from_ingestion_api
|
|
||||||
), f"Ingestion API flag mismatch for document {actual.id}"
|
|
||||||
assert (
|
|
||||||
actual.additional_info == expected.additional_info
|
|
||||||
), f"Additional info mismatch for document {actual.id}"
|
|
||||||
|
|
||||||
# Compare sections
|
|
||||||
assert len(actual.sections) == len(
|
|
||||||
expected.sections
|
|
||||||
), f"Number of sections mismatch for document {actual.id}"
|
|
||||||
for i, (actual_section, expected_section) in enumerate(
|
|
||||||
zip(actual.sections, expected.sections)
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
actual_section.text == expected_section.text
|
|
||||||
), f"Section {i} text mismatch for document {actual.id}"
|
|
||||||
assert (
|
|
||||||
actual_section.link == expected_section.link
|
|
||||||
), f"Section {i} link mismatch for document {actual.id}"
|
|
||||||
|
14
backend/tests/daily/connectors/conftest.py
Normal file
14
backend/tests/daily/connectors/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||||
|
with patch(
|
||||||
|
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||||
|
return_value=None,
|
||||||
|
) as mock:
|
||||||
|
yield mock
|
@@ -0,0 +1,178 @@
|
|||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timezone
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from onyx.configs.constants import DocumentSource
|
||||||
|
from onyx.connectors.models import Document
|
||||||
|
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExpectedDocument:
|
||||||
|
semantic_identifier: str
|
||||||
|
content: str
|
||||||
|
folder_path: str | None = None
|
||||||
|
library: str = "Shared Documents" # Default to main library
|
||||||
|
|
||||||
|
|
||||||
|
EXPECTED_DOCUMENTS = [
|
||||||
|
ExpectedDocument(
|
||||||
|
semantic_identifier="test1.docx",
|
||||||
|
content="test1",
|
||||||
|
folder_path="test",
|
||||||
|
),
|
||||||
|
ExpectedDocument(
|
||||||
|
semantic_identifier="test2.docx",
|
||||||
|
content="test2",
|
||||||
|
folder_path="test/nested with spaces",
|
||||||
|
),
|
||||||
|
ExpectedDocument(
|
||||||
|
semantic_identifier="should-not-index-on-specific-folder.docx",
|
||||||
|
content="should-not-index-on-specific-folder",
|
||||||
|
folder_path=None, # root folder
|
||||||
|
),
|
||||||
|
ExpectedDocument(
|
||||||
|
semantic_identifier="other.docx",
|
||||||
|
content="other",
|
||||||
|
folder_path=None,
|
||||||
|
library="Other Library",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def verify_document_metadata(doc: Document) -> None:
|
||||||
|
"""Verify common metadata that should be present on all documents."""
|
||||||
|
assert isinstance(doc.doc_updated_at, datetime)
|
||||||
|
assert doc.doc_updated_at.tzinfo == timezone.utc
|
||||||
|
assert doc.source == DocumentSource.SHAREPOINT
|
||||||
|
assert doc.primary_owners is not None
|
||||||
|
assert len(doc.primary_owners) == 1
|
||||||
|
owner = doc.primary_owners[0]
|
||||||
|
assert owner.display_name is not None
|
||||||
|
assert owner.email is not None
|
||||||
|
|
||||||
|
|
||||||
|
def verify_document_content(doc: Document, expected: ExpectedDocument) -> None:
|
||||||
|
"""Verify a document matches its expected content."""
|
||||||
|
assert doc.semantic_identifier == expected.semantic_identifier
|
||||||
|
assert len(doc.sections) == 1
|
||||||
|
assert expected.content in doc.sections[0].text
|
||||||
|
verify_document_metadata(doc)
|
||||||
|
|
||||||
|
|
||||||
|
def find_document(documents: list[Document], semantic_identifier: str) -> Document:
|
||||||
|
"""Find a document by its semantic identifier."""
|
||||||
|
matching_docs = [
|
||||||
|
d for d in documents if d.semantic_identifier == semantic_identifier
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
len(matching_docs) == 1
|
||||||
|
), f"Expected exactly one document with identifier {semantic_identifier}"
|
||||||
|
return matching_docs[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sharepoint_credentials() -> dict[str, str]:
|
||||||
|
return {
|
||||||
|
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
|
||||||
|
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
|
||||||
|
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharepoint_connector_specific_folder(
|
||||||
|
mock_get_unstructured_api_key: MagicMock,
|
||||||
|
sharepoint_credentials: dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
# Initialize connector with the test site URL and specific folder
|
||||||
|
connector = SharepointConnector(
|
||||||
|
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load credentials
|
||||||
|
connector.load_credentials(sharepoint_credentials)
|
||||||
|
|
||||||
|
# Get all documents
|
||||||
|
document_batches = list(connector.load_from_state())
|
||||||
|
found_documents: list[Document] = [
|
||||||
|
doc for batch in document_batches for doc in batch
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should only find documents in the test folder
|
||||||
|
test_folder_docs = [
|
||||||
|
doc
|
||||||
|
for doc in EXPECTED_DOCUMENTS
|
||||||
|
if doc.folder_path and doc.folder_path.startswith("test")
|
||||||
|
]
|
||||||
|
assert len(found_documents) == len(
|
||||||
|
test_folder_docs
|
||||||
|
), "Should only find documents in test folder"
|
||||||
|
|
||||||
|
# Verify each expected document
|
||||||
|
for expected in test_folder_docs:
|
||||||
|
doc = find_document(found_documents, expected.semantic_identifier)
|
||||||
|
verify_document_content(doc, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharepoint_connector_root_folder(
|
||||||
|
mock_get_unstructured_api_key: MagicMock,
|
||||||
|
sharepoint_credentials: dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
# Initialize connector with the base site URL
|
||||||
|
connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]])
|
||||||
|
|
||||||
|
# Load credentials
|
||||||
|
connector.load_credentials(sharepoint_credentials)
|
||||||
|
|
||||||
|
# Get all documents
|
||||||
|
document_batches = list(connector.load_from_state())
|
||||||
|
found_documents: list[Document] = [
|
||||||
|
doc for batch in document_batches for doc in batch
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(found_documents) == len(
|
||||||
|
EXPECTED_DOCUMENTS
|
||||||
|
), "Should find all documents in main library"
|
||||||
|
|
||||||
|
# Verify each expected document
|
||||||
|
for expected in EXPECTED_DOCUMENTS:
|
||||||
|
doc = find_document(found_documents, expected.semantic_identifier)
|
||||||
|
verify_document_content(doc, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sharepoint_connector_other_library(
|
||||||
|
mock_get_unstructured_api_key: MagicMock,
|
||||||
|
sharepoint_credentials: dict[str, str],
|
||||||
|
) -> None:
|
||||||
|
# Initialize connector with the other library
|
||||||
|
connector = SharepointConnector(
|
||||||
|
sites=[
|
||||||
|
os.environ["SHAREPOINT_SITE"] + "/Other Library",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load credentials
|
||||||
|
connector.load_credentials(sharepoint_credentials)
|
||||||
|
|
||||||
|
# Get all documents
|
||||||
|
document_batches = list(connector.load_from_state())
|
||||||
|
found_documents: list[Document] = [
|
||||||
|
doc for batch in document_batches for doc in batch
|
||||||
|
]
|
||||||
|
expected_documents: list[ExpectedDocument] = [
|
||||||
|
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should find all documents in `Other Library`
|
||||||
|
assert len(found_documents) == len(
|
||||||
|
expected_documents
|
||||||
|
), "Should find all documents in `Other Library`"
|
||||||
|
|
||||||
|
# Verify each expected document
|
||||||
|
for expected in expected_documents:
|
||||||
|
doc = find_document(found_documents, expected.semantic_identifier)
|
||||||
|
verify_document_content(doc, expected)
|
@@ -76,13 +76,7 @@ import {
|
|||||||
import { buildFilters } from "@/lib/search/utils";
|
import { buildFilters } from "@/lib/search/utils";
|
||||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||||
import Dropzone from "react-dropzone";
|
import Dropzone from "react-dropzone";
|
||||||
import {
|
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
|
||||||
checkLLMSupportsImageInput,
|
|
||||||
getFinalLLM,
|
|
||||||
destructureValue,
|
|
||||||
getLLMProviderOverrideForPersona,
|
|
||||||
} from "@/lib/llm/utils";
|
|
||||||
|
|
||||||
import { ChatInputBar } from "./input/ChatInputBar";
|
import { ChatInputBar } from "./input/ChatInputBar";
|
||||||
import { useChatContext } from "@/components/context/ChatContext";
|
import { useChatContext } from "@/components/context/ChatContext";
|
||||||
import { v4 as uuidv4 } from "uuid";
|
import { v4 as uuidv4 } from "uuid";
|
||||||
@@ -203,6 +197,12 @@ export function ChatPage({
|
|||||||
|
|
||||||
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
|
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
|
||||||
|
|
||||||
|
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
|
||||||
|
|
||||||
|
const selectedChatSession = chatSessions.find(
|
||||||
|
(chatSession) => chatSession.id === existingChatSessionId
|
||||||
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (user?.is_anonymous_user) {
|
if (user?.is_anonymous_user) {
|
||||||
Cookies.set(
|
Cookies.set(
|
||||||
@@ -240,12 +240,6 @@ export function ChatPage({
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
|
|
||||||
|
|
||||||
const selectedChatSession = chatSessions.find(
|
|
||||||
(chatSession) => chatSession.id === existingChatSessionId
|
|
||||||
);
|
|
||||||
|
|
||||||
const chatSessionIdRef = useRef<string | null>(existingChatSessionId);
|
const chatSessionIdRef = useRef<string | null>(existingChatSessionId);
|
||||||
|
|
||||||
// Only updates on session load (ie. rename / switching chat session)
|
// Only updates on session load (ie. rename / switching chat session)
|
||||||
@@ -293,12 +287,6 @@ export function ChatPage({
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const llmOverrideManager = useLlmOverride(
|
|
||||||
llmProviders,
|
|
||||||
user?.preferences.default_model,
|
|
||||||
selectedChatSession
|
|
||||||
);
|
|
||||||
|
|
||||||
const [alternativeAssistant, setAlternativeAssistant] =
|
const [alternativeAssistant, setAlternativeAssistant] =
|
||||||
useState<Persona | null>(null);
|
useState<Persona | null>(null);
|
||||||
|
|
||||||
@@ -307,12 +295,27 @@ export function ChatPage({
|
|||||||
|
|
||||||
const { recentAssistants, refreshRecentAssistants } = useAssistants();
|
const { recentAssistants, refreshRecentAssistants } = useAssistants();
|
||||||
|
|
||||||
const liveAssistant: Persona | undefined =
|
const liveAssistant: Persona | undefined = useMemo(
|
||||||
alternativeAssistant ||
|
() =>
|
||||||
selectedAssistant ||
|
alternativeAssistant ||
|
||||||
recentAssistants[0] ||
|
selectedAssistant ||
|
||||||
finalAssistants[0] ||
|
recentAssistants[0] ||
|
||||||
availableAssistants[0];
|
finalAssistants[0] ||
|
||||||
|
availableAssistants[0],
|
||||||
|
[
|
||||||
|
alternativeAssistant,
|
||||||
|
selectedAssistant,
|
||||||
|
recentAssistants,
|
||||||
|
finalAssistants,
|
||||||
|
availableAssistants,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
const llmOverrideManager = useLlmOverride(
|
||||||
|
llmProviders,
|
||||||
|
selectedChatSession,
|
||||||
|
liveAssistant
|
||||||
|
);
|
||||||
|
|
||||||
const noAssistants = liveAssistant == null || liveAssistant == undefined;
|
const noAssistants = liveAssistant == null || liveAssistant == undefined;
|
||||||
|
|
||||||
@@ -320,24 +323,6 @@ export function ChatPage({
|
|||||||
const uniqueSources = Array.from(new Set(availableSources));
|
const uniqueSources = Array.from(new Set(availableSources));
|
||||||
const sources = uniqueSources.map((source) => getSourceMetadata(source));
|
const sources = uniqueSources.map((source) => getSourceMetadata(source));
|
||||||
|
|
||||||
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
|
|
||||||
useEffect(() => {
|
|
||||||
if (noAssistants) return;
|
|
||||||
const personaDefault = getLLMProviderOverrideForPersona(
|
|
||||||
liveAssistant,
|
|
||||||
llmProviders
|
|
||||||
);
|
|
||||||
|
|
||||||
if (personaDefault) {
|
|
||||||
llmOverrideManager.updateLLMOverride(personaDefault);
|
|
||||||
} else if (user?.preferences.default_model) {
|
|
||||||
llmOverrideManager.updateLLMOverride(
|
|
||||||
destructureValue(user?.preferences.default_model)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
|
||||||
}, [liveAssistant, user?.preferences.default_model]);
|
|
||||||
|
|
||||||
const stopGenerating = () => {
|
const stopGenerating = () => {
|
||||||
const currentSession = currentSessionId();
|
const currentSession = currentSessionId();
|
||||||
const controller = abortControllers.get(currentSession);
|
const controller = abortControllers.get(currentSession);
|
||||||
@@ -419,7 +404,6 @@ export function ChatPage({
|
|||||||
filterManager.setTimeRange(null);
|
filterManager.setTimeRange(null);
|
||||||
|
|
||||||
// reset LLM overrides (based on chat session!)
|
// reset LLM overrides (based on chat session!)
|
||||||
llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
|
|
||||||
llmOverrideManager.updateTemperature(null);
|
llmOverrideManager.updateTemperature(null);
|
||||||
|
|
||||||
// remove uploaded files
|
// remove uploaded files
|
||||||
@@ -1283,13 +1267,11 @@ export function ChatPage({
|
|||||||
modelProvider:
|
modelProvider:
|
||||||
modelOverRide?.name ||
|
modelOverRide?.name ||
|
||||||
llmOverrideManager.llmOverride.name ||
|
llmOverrideManager.llmOverride.name ||
|
||||||
llmOverrideManager.globalDefault.name ||
|
|
||||||
undefined,
|
undefined,
|
||||||
modelVersion:
|
modelVersion:
|
||||||
modelOverRide?.modelName ||
|
modelOverRide?.modelName ||
|
||||||
llmOverrideManager.llmOverride.modelName ||
|
llmOverrideManager.llmOverride.modelName ||
|
||||||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
|
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
|
||||||
llmOverrideManager.globalDefault.modelName ||
|
|
||||||
undefined,
|
undefined,
|
||||||
temperature: llmOverrideManager.temperature || undefined,
|
temperature: llmOverrideManager.temperature || undefined,
|
||||||
systemPromptOverride:
|
systemPromptOverride:
|
||||||
@@ -1952,6 +1934,7 @@ export function ChatPage({
|
|||||||
};
|
};
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [router]);
|
}, [router]);
|
||||||
|
|
||||||
const [sharedChatSession, setSharedChatSession] =
|
const [sharedChatSession, setSharedChatSession] =
|
||||||
useState<ChatSession | null>();
|
useState<ChatSession | null>();
|
||||||
|
|
||||||
@@ -2059,7 +2042,9 @@ export function ChatPage({
|
|||||||
{(settingsToggled || userSettingsToggled) && (
|
{(settingsToggled || userSettingsToggled) && (
|
||||||
<UserSettingsModal
|
<UserSettingsModal
|
||||||
setPopup={setPopup}
|
setPopup={setPopup}
|
||||||
setLlmOverride={llmOverrideManager.setGlobalDefault}
|
setLlmOverride={(newOverride) =>
|
||||||
|
llmOverrideManager.updateLLMOverride(newOverride)
|
||||||
|
}
|
||||||
defaultModel={user?.preferences.default_model!}
|
defaultModel={user?.preferences.default_model!}
|
||||||
llmProviders={llmProviders}
|
llmProviders={llmProviders}
|
||||||
onClose={() => {
|
onClose={() => {
|
||||||
@@ -2749,6 +2734,7 @@ export function ChatPage({
|
|||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<ChatInputBar
|
<ChatInputBar
|
||||||
toggleDocumentSidebar={toggleDocumentSidebar}
|
toggleDocumentSidebar={toggleDocumentSidebar}
|
||||||
availableSources={sources}
|
availableSources={sources}
|
||||||
|
@@ -40,8 +40,8 @@ export default function LLMPopover({
|
|||||||
currentAssistant,
|
currentAssistant,
|
||||||
}: LLMPopoverProps) {
|
}: LLMPopoverProps) {
|
||||||
const [isOpen, setIsOpen] = useState(false);
|
const [isOpen, setIsOpen] = useState(false);
|
||||||
const { llmOverride, updateLLMOverride, globalDefault } = llmOverrideManager;
|
const { llmOverride, updateLLMOverride } = llmOverrideManager;
|
||||||
const currentLlm = llmOverride.modelName || globalDefault.modelName;
|
const currentLlm = llmOverride.modelName;
|
||||||
|
|
||||||
const llmOptionsByProvider: {
|
const llmOptionsByProvider: {
|
||||||
[provider: string]: {
|
[provider: string]: {
|
||||||
|
@@ -1,13 +1,5 @@
|
|||||||
import {
|
import { useContext, useEffect, useRef } from "react";
|
||||||
Dispatch,
|
|
||||||
SetStateAction,
|
|
||||||
useContext,
|
|
||||||
useEffect,
|
|
||||||
useRef,
|
|
||||||
useState,
|
|
||||||
} from "react";
|
|
||||||
import { Modal } from "@/components/Modal";
|
import { Modal } from "@/components/Modal";
|
||||||
import Text from "@/components/ui/text";
|
|
||||||
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks";
|
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks";
|
||||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||||
|
|
||||||
@@ -33,7 +25,7 @@ export function UserSettingsModal({
|
|||||||
}: {
|
}: {
|
||||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||||
llmProviders: LLMProviderDescriptor[];
|
llmProviders: LLMProviderDescriptor[];
|
||||||
setLlmOverride?: Dispatch<SetStateAction<LlmOverride>>;
|
setLlmOverride?: (newOverride: LlmOverride) => void;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
defaultModel: string | null;
|
defaultModel: string | null;
|
||||||
}) {
|
}) {
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
export const SEARCH_TOOL_NAME = "SearchTool";
|
export const SEARCH_TOOL_NAME = "run_search";
|
||||||
export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search";
|
export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search";
|
||||||
export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation";
|
export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation";
|
||||||
|
@@ -1106,6 +1106,14 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
|
|||||||
name: "table_name_or_id",
|
name: "table_name_or_id",
|
||||||
optional: false,
|
optional: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
type: "checkbox",
|
||||||
|
label: "Treat all fields except attachments as metadata",
|
||||||
|
name: "treat_all_non_attachment_fields_as_metadata",
|
||||||
|
description:
|
||||||
|
"Choose this if the primary content to index are attachments and all other columns are metadata for these attachments.",
|
||||||
|
optional: false,
|
||||||
|
},
|
||||||
],
|
],
|
||||||
advanced_values: [],
|
advanced_values: [],
|
||||||
overrideDefaultFreq: 60 * 60 * 24,
|
overrideDefaultFreq: 60 * 60 * 24,
|
||||||
|
@@ -13,16 +13,21 @@ import { errorHandlingFetcher } from "./fetcher";
|
|||||||
import { useContext, useEffect, useState } from "react";
|
import { useContext, useEffect, useState } from "react";
|
||||||
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
|
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
|
||||||
import { Filters, SourceMetadata } from "./search/interfaces";
|
import { Filters, SourceMetadata } from "./search/interfaces";
|
||||||
import { destructureValue, structureValue } from "./llm/utils";
|
import {
|
||||||
|
destructureValue,
|
||||||
|
findProviderForModel,
|
||||||
|
structureValue,
|
||||||
|
} from "./llm/utils";
|
||||||
import { ChatSession } from "@/app/chat/interfaces";
|
import { ChatSession } from "@/app/chat/interfaces";
|
||||||
import { AllUsersResponse } from "./types";
|
import { AllUsersResponse } from "./types";
|
||||||
import { Credential } from "./connectors/credentials";
|
import { Credential } from "./connectors/credentials";
|
||||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||||
import { PersonaLabel } from "@/app/admin/assistants/interfaces";
|
import { Persona, PersonaLabel } from "@/app/admin/assistants/interfaces";
|
||||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||||
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
|
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
|
||||||
import { getSourceMetadata } from "./sources";
|
import { getSourceMetadata } from "./sources";
|
||||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
|
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
|
||||||
|
import { useUser } from "@/components/user/UserProvider";
|
||||||
|
|
||||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||||
|
|
||||||
@@ -355,82 +360,141 @@ export interface LlmOverride {
|
|||||||
export interface LlmOverrideManager {
|
export interface LlmOverrideManager {
|
||||||
llmOverride: LlmOverride;
|
llmOverride: LlmOverride;
|
||||||
updateLLMOverride: (newOverride: LlmOverride) => void;
|
updateLLMOverride: (newOverride: LlmOverride) => void;
|
||||||
globalDefault: LlmOverride;
|
|
||||||
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
|
|
||||||
temperature: number | null;
|
temperature: number | null;
|
||||||
updateTemperature: (temperature: number | null) => void;
|
updateTemperature: (temperature: number | null) => void;
|
||||||
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
||||||
imageFilesPresent: boolean;
|
imageFilesPresent: boolean;
|
||||||
updateImageFilesPresent: (present: boolean) => void;
|
updateImageFilesPresent: (present: boolean) => void;
|
||||||
|
liveAssistant: Persona | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
LLM Override is as follows (i.e. this order)
|
||||||
|
- User override (explicitly set in the chat input bar)
|
||||||
|
- User preference (defaults to system wide default if no preference set)
|
||||||
|
|
||||||
|
On switching to an existing or new chat session or a different assistant:
|
||||||
|
- If we have a live assistant after any switch with a model override, use that- otherwise use the above hierarchy
|
||||||
|
|
||||||
|
Thus, the input should be
|
||||||
|
- User preference
|
||||||
|
- LLM Providers (which contain the system wide default)
|
||||||
|
- Current assistant
|
||||||
|
|
||||||
|
Changes take place as
|
||||||
|
- liveAssistant or currentChatSession changes (and the associated model override is set)
|
||||||
|
- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
|
||||||
|
|
||||||
|
If we have a live assistant, we should use that model override
|
||||||
|
*/
|
||||||
|
|
||||||
export function useLlmOverride(
|
export function useLlmOverride(
|
||||||
llmProviders: LLMProviderDescriptor[],
|
llmProviders: LLMProviderDescriptor[],
|
||||||
globalModel?: string | null,
|
|
||||||
currentChatSession?: ChatSession,
|
currentChatSession?: ChatSession,
|
||||||
defaultTemperature?: number
|
liveAssistant?: Persona
|
||||||
): LlmOverrideManager {
|
): LlmOverrideManager {
|
||||||
|
const { user } = useUser();
|
||||||
|
|
||||||
|
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
||||||
|
|
||||||
|
const llmOverrideUpdate = () => {
|
||||||
|
if (!chatSession && currentChatSession) {
|
||||||
|
setChatSession(currentChatSession || null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (liveAssistant?.llm_model_version_override) {
|
||||||
|
setLlmOverride(
|
||||||
|
getValidLlmOverride(liveAssistant.llm_model_version_override)
|
||||||
|
);
|
||||||
|
} else if (currentChatSession?.current_alternate_model) {
|
||||||
|
setLlmOverride(
|
||||||
|
getValidLlmOverride(currentChatSession.current_alternate_model)
|
||||||
|
);
|
||||||
|
} else if (user?.preferences?.default_model) {
|
||||||
|
setLlmOverride(getValidLlmOverride(user.preferences.default_model));
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
const defaultProvider = llmProviders.find(
|
||||||
|
(provider) => provider.is_default_provider
|
||||||
|
);
|
||||||
|
|
||||||
|
if (defaultProvider) {
|
||||||
|
setLlmOverride({
|
||||||
|
name: defaultProvider.name,
|
||||||
|
provider: defaultProvider.provider,
|
||||||
|
modelName: defaultProvider.default_model_name,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
setChatSession(currentChatSession || null);
|
||||||
|
};
|
||||||
|
|
||||||
const getValidLlmOverride = (
|
const getValidLlmOverride = (
|
||||||
overrideModel: string | null | undefined
|
overrideModel: string | null | undefined
|
||||||
): LlmOverride => {
|
): LlmOverride => {
|
||||||
if (overrideModel) {
|
if (overrideModel) {
|
||||||
const model = destructureValue(overrideModel);
|
const model = destructureValue(overrideModel);
|
||||||
const provider = llmProviders.find(
|
if (!(model.modelName && model.modelName.length > 0)) {
|
||||||
(p) =>
|
const provider = llmProviders.find((p) =>
|
||||||
p.model_names.includes(model.modelName) &&
|
p.model_names.includes(overrideModel)
|
||||||
p.provider === model.provider
|
);
|
||||||
|
if (provider) {
|
||||||
|
return {
|
||||||
|
modelName: overrideModel,
|
||||||
|
name: provider.name,
|
||||||
|
provider: provider.provider,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider = llmProviders.find((p) =>
|
||||||
|
p.model_names.includes(model.modelName)
|
||||||
);
|
);
|
||||||
|
|
||||||
if (provider) {
|
if (provider) {
|
||||||
return { ...model, name: provider.name };
|
return { ...model, name: provider.name };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return { name: "", provider: "", modelName: "" };
|
return { name: "", provider: "", modelName: "" };
|
||||||
};
|
};
|
||||||
|
|
||||||
const [imageFilesPresent, setImageFilesPresent] = useState(false);
|
const [imageFilesPresent, setImageFilesPresent] = useState(false);
|
||||||
|
|
||||||
const updateImageFilesPresent = (present: boolean) => {
|
const updateImageFilesPresent = (present: boolean) => {
|
||||||
setImageFilesPresent(present);
|
setImageFilesPresent(present);
|
||||||
};
|
};
|
||||||
|
|
||||||
const [globalDefault, setGlobalDefault] = useState<LlmOverride>(
|
const [llmOverride, setLlmOverride] = useState<LlmOverride>({
|
||||||
getValidLlmOverride(globalModel)
|
name: "",
|
||||||
);
|
provider: "",
|
||||||
const updateLLMOverride = (newOverride: LlmOverride) => {
|
modelName: "",
|
||||||
setLlmOverride(
|
});
|
||||||
getValidLlmOverride(
|
|
||||||
structureValue(
|
|
||||||
newOverride.name,
|
|
||||||
newOverride.provider,
|
|
||||||
newOverride.modelName
|
|
||||||
)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
|
// Manually set the override
|
||||||
currentChatSession && currentChatSession.current_alternate_model
|
const updateLLMOverride = (newOverride: LlmOverride) => {
|
||||||
? getValidLlmOverride(currentChatSession.current_alternate_model)
|
const provider =
|
||||||
: { name: "", provider: "", modelName: "" }
|
newOverride.provider ||
|
||||||
);
|
findProviderForModel(llmProviders, newOverride.modelName);
|
||||||
|
const structuredValue = structureValue(
|
||||||
|
newOverride.name,
|
||||||
|
provider,
|
||||||
|
newOverride.modelName
|
||||||
|
);
|
||||||
|
setLlmOverride(getValidLlmOverride(structuredValue));
|
||||||
|
};
|
||||||
|
|
||||||
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
|
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
|
||||||
setLlmOverride(
|
if (chatSession && chatSession.current_alternate_model?.length > 0) {
|
||||||
chatSession && chatSession.current_alternate_model
|
setLlmOverride(getValidLlmOverride(chatSession.current_alternate_model));
|
||||||
? getValidLlmOverride(chatSession.current_alternate_model)
|
}
|
||||||
: globalDefault
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const [temperature, setTemperature] = useState<number | null>(
|
const [temperature, setTemperature] = useState<number | null>(0);
|
||||||
defaultTemperature !== undefined ? defaultTemperature : 0
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setGlobalDefault(getValidLlmOverride(globalModel));
|
llmOverrideUpdate();
|
||||||
}, [globalModel, llmProviders]);
|
}, [liveAssistant, currentChatSession]);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
|
|
||||||
}, [defaultTemperature]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||||
@@ -450,12 +514,11 @@ export function useLlmOverride(
|
|||||||
updateModelOverrideForChatSession,
|
updateModelOverrideForChatSession,
|
||||||
llmOverride,
|
llmOverride,
|
||||||
updateLLMOverride,
|
updateLLMOverride,
|
||||||
globalDefault,
|
|
||||||
setGlobalDefault,
|
|
||||||
temperature,
|
temperature,
|
||||||
updateTemperature,
|
updateTemperature,
|
||||||
imageFilesPresent,
|
imageFilesPresent,
|
||||||
updateImageFilesPresent,
|
updateImageFilesPresent,
|
||||||
|
liveAssistant: liveAssistant ?? null,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -143,3 +143,11 @@ export const destructureValue = (value: string): LlmOverride => {
|
|||||||
modelName,
|
modelName,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const findProviderForModel = (
|
||||||
|
llmProviders: LLMProviderDescriptor[],
|
||||||
|
modelName: string
|
||||||
|
): string => {
|
||||||
|
const provider = llmProviders.find((p) => p.model_names.includes(modelName));
|
||||||
|
return provider ? provider.provider : "";
|
||||||
|
};
|
||||||
|
@@ -358,7 +358,8 @@ export type ConfigurableSources = Exclude<
|
|||||||
|
|
||||||
export const oauthSupportedSources: ConfigurableSources[] = [
|
export const oauthSupportedSources: ConfigurableSources[] = [
|
||||||
ValidSources.Slack,
|
ValidSources.Slack,
|
||||||
ValidSources.GoogleDrive,
|
// NOTE: temporarily disabled until our GDrive App is approved
|
||||||
|
// ValidSources.GoogleDrive,
|
||||||
];
|
];
|
||||||
|
|
||||||
export type OAuthSupportedSource = (typeof oauthSupportedSources)[number];
|
export type OAuthSupportedSource = (typeof oauthSupportedSources)[number];
|
||||||
|
Reference in New Issue
Block a user