mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Merge branch 'main' of https://github.com/onyx-dot-app/onyx into bugfix/sharepoint_app_init
# Conflicts: # backend/onyx/connectors/sharepoint/connector.py
This commit is contained in:
commit
58e5deba01
@ -39,6 +39,12 @@ env:
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
@ -20,9 +20,9 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||
# these are the field types that are considered metadata rather
|
||||
# than sections
|
||||
_METADATA_FIELD_TYPES = {
|
||||
# These field types are considered metadata by default when
|
||||
# treat_all_non_attachment_fields_as_metadata is False
|
||||
DEFAULT_METADATA_FIELD_TYPES = {
|
||||
"singlecollaborator",
|
||||
"collaborator",
|
||||
"createdby",
|
||||
@ -60,12 +60,16 @@ class AirtableConnector(LoadConnector):
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
self.treat_all_non_attachment_fields_as_metadata = (
|
||||
treat_all_non_attachment_fields_as_metadata
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
@ -166,8 +170,14 @@ class AirtableConnector(LoadConnector):
|
||||
return [(str(field_info), default_link)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
||||
"""Determine if a field type should be treated as metadata.
|
||||
|
||||
When treat_all_non_attachment_fields_as_metadata is True, all fields except
|
||||
attachments are treated as metadata. Otherwise, only fields with types listed
|
||||
in DEFAULT_METADATA_FIELD_TYPES are treated as metadata."""
|
||||
if self.treat_all_non_attachment_fields_as_metadata:
|
||||
return field_type.lower() != "multipleattachments"
|
||||
return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
@ -233,7 +243,7 @@ class AirtableConnector(LoadConnector):
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
) -> Document:
|
||||
) -> Document | None:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
@ -277,6 +287,10 @@ class AirtableConnector(LoadConnector):
|
||||
sections.extend(field_sections)
|
||||
metadata.update(field_metadata)
|
||||
|
||||
if not sections:
|
||||
logger.warning(f"No sections found for record {record_id}")
|
||||
return None
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
@ -320,7 +334,8 @@ class AirtableConnector(LoadConnector):
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
record_documents.append(document)
|
||||
if document:
|
||||
record_documents.append(document)
|
||||
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
|
@ -1,17 +1,14 @@
|
||||
import io
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
import msal # type: ignore
|
||||
from office365.graph_client import GraphClient # type: ignore
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
|
||||
from office365.onedrive.sites.site import Site # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@ -30,16 +27,25 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SiteData:
|
||||
url: str | None
|
||||
folder: Optional[str]
|
||||
sites: list = field(default_factory=list)
|
||||
driveitems: list = field(default_factory=list)
|
||||
class SiteDescriptor(BaseModel):
|
||||
"""Data class for storing SharePoint site information.
|
||||
|
||||
Args:
|
||||
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
|
||||
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
|
||||
If None, all drives will be accessed.
|
||||
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
|
||||
If None, all folders will be accessed.
|
||||
"""
|
||||
|
||||
url: str
|
||||
drive_name: str | None
|
||||
folder_path: str | None
|
||||
|
||||
|
||||
def _convert_driveitem_to_document(
|
||||
driveitem: DriveItem,
|
||||
drive_name: str,
|
||||
) -> Document:
|
||||
file_text = extract_file_text(
|
||||
file=io.BytesIO(driveitem.get_content().execute_query().value),
|
||||
@ -59,7 +65,7 @@ def _convert_driveitem_to_document(
|
||||
email=driveitem.last_modified_by.user.email,
|
||||
)
|
||||
],
|
||||
metadata={},
|
||||
metadata={"drive": drive_name},
|
||||
)
|
||||
return doc
|
||||
|
||||
@ -71,107 +77,172 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
sites: list[str] = [],
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.graph_client: GraphClient | None = None
|
||||
self.site_data: list[SiteData] = self._extract_site_and_folder(sites)
|
||||
self._graph_client: GraphClient | None = None
|
||||
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
|
||||
sites
|
||||
)
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
|
||||
@property
|
||||
def graph_client(self) -> GraphClient:
|
||||
if self._graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Sharepoint")
|
||||
|
||||
return self._graph_client
|
||||
|
||||
@staticmethod
|
||||
def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]:
|
||||
def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]:
|
||||
site_data_list = []
|
||||
for url in site_urls:
|
||||
parts = url.strip().split("/")
|
||||
if "sites" in parts:
|
||||
sites_index = parts.index("sites")
|
||||
site_url = "/".join(parts[: sites_index + 2])
|
||||
folder = (
|
||||
"/".join(unquote(part) for part in parts[sites_index + 2 :])
|
||||
if len(parts) > sites_index + 2
|
||||
else None
|
||||
)
|
||||
# Handling for new URL structure
|
||||
if folder and folder.startswith("Shared Documents/"):
|
||||
folder = folder[len("Shared Documents/") :]
|
||||
remaining_parts = parts[sites_index + 2 :]
|
||||
|
||||
# Extract drive name and folder path
|
||||
if remaining_parts:
|
||||
drive_name = unquote(remaining_parts[0])
|
||||
folder_path = (
|
||||
"/".join(unquote(part) for part in remaining_parts[1:])
|
||||
if len(remaining_parts) > 1
|
||||
else None
|
||||
)
|
||||
else:
|
||||
drive_name = None
|
||||
folder_path = None
|
||||
|
||||
site_data_list.append(
|
||||
SiteData(url=site_url, folder=folder, sites=[], driveitems=[])
|
||||
SiteDescriptor(
|
||||
url=site_url,
|
||||
drive_name=drive_name,
|
||||
folder_path=folder_path,
|
||||
)
|
||||
)
|
||||
return site_data_list
|
||||
|
||||
def _populate_sitedata_driveitems(
|
||||
def _fetch_driveitems(
|
||||
self,
|
||||
site_descriptor: SiteDescriptor,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> None:
|
||||
) -> list[tuple[DriveItem, str]]:
|
||||
filter_str = ""
|
||||
if start is not None and end is not None:
|
||||
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
|
||||
filter_str = (
|
||||
f"last_modified_datetime ge {start.isoformat()} and "
|
||||
f"last_modified_datetime le {end.isoformat()}"
|
||||
)
|
||||
|
||||
for element in self.site_data:
|
||||
sites: list[Site] = []
|
||||
for site in element.sites:
|
||||
site_sublist = site.lists.get().execute_query()
|
||||
sites.extend(site_sublist)
|
||||
final_driveitems: list[tuple[DriveItem, str]] = []
|
||||
try:
|
||||
site = self.graph_client.sites.get_by_url(site_descriptor.url)
|
||||
|
||||
for site in sites:
|
||||
# Get all drives in the site
|
||||
drives = site.drives.get().execute_query()
|
||||
logger.debug(f"Found drives: {[drive.name for drive in drives]}")
|
||||
|
||||
# Filter drives based on the requested drive name
|
||||
if site_descriptor.drive_name:
|
||||
drives = [
|
||||
drive
|
||||
for drive in drives
|
||||
if drive.name == site_descriptor.drive_name
|
||||
or (
|
||||
drive.name == "Documents"
|
||||
and site_descriptor.drive_name == "Shared Documents"
|
||||
)
|
||||
]
|
||||
if not drives:
|
||||
logger.warning(f"Drive '{site_descriptor.drive_name}' not found")
|
||||
return []
|
||||
|
||||
# Process each matching drive
|
||||
for drive in drives:
|
||||
try:
|
||||
query = site.drive.root.get_files(True, 1000)
|
||||
root_folder = drive.root
|
||||
if site_descriptor.folder_path:
|
||||
# If a specific folder is requested, navigate to it
|
||||
for folder_part in site_descriptor.folder_path.split("/"):
|
||||
root_folder = root_folder.get_by_path(folder_part)
|
||||
|
||||
# Get all items recursively
|
||||
query = root_folder.get_files(True, 1000)
|
||||
if filter_str:
|
||||
query = query.filter(filter_str)
|
||||
driveitems = query.execute_query()
|
||||
if element.folder:
|
||||
expected_path = f"/root:/{element.folder}"
|
||||
logger.debug(
|
||||
f"Found {len(driveitems)} items in drive '{drive.name}'"
|
||||
)
|
||||
|
||||
# Use "Shared Documents" as the library name for the default "Documents" drive
|
||||
drive_name = (
|
||||
"Shared Documents" if drive.name == "Documents" else drive.name
|
||||
)
|
||||
|
||||
if site_descriptor.folder_path:
|
||||
# Filter items to ensure they're in the specified folder or its subfolders
|
||||
# The path will be in format: /drives/{drive_id}/root:/folder/path
|
||||
filtered_driveitems = [
|
||||
item
|
||||
(item, drive_name)
|
||||
for item in driveitems
|
||||
if item.parent_reference.path.endswith(expected_path)
|
||||
if any(
|
||||
path_part == site_descriptor.folder_path
|
||||
or path_part.startswith(
|
||||
site_descriptor.folder_path + "/"
|
||||
)
|
||||
for path_part in item.parent_reference.path.split(
|
||||
"root:/"
|
||||
)[1].split("/")
|
||||
)
|
||||
]
|
||||
if len(filtered_driveitems) == 0:
|
||||
all_paths = [
|
||||
item.parent_reference.path for item in driveitems
|
||||
]
|
||||
logger.warning(
|
||||
f"Nothing found for folder '{expected_path}' in any of valid paths: {all_paths}"
|
||||
f"Nothing found for folder '{site_descriptor.folder_path}' "
|
||||
f"in; any of valid paths: {all_paths}"
|
||||
)
|
||||
element.driveitems.extend(filtered_driveitems)
|
||||
final_driveitems.extend(filtered_driveitems)
|
||||
else:
|
||||
element.driveitems.extend(driveitems)
|
||||
final_driveitems.extend(
|
||||
[(item, drive_name) for item in driveitems]
|
||||
)
|
||||
except Exception as e:
|
||||
# Some drives might not be accessible
|
||||
logger.warning(f"Failed to process drive: {str(e)}")
|
||||
|
||||
except Exception:
|
||||
# Sites include things that do not contain .drive.root so this fails
|
||||
# but this is fine, as there are no actually documents in those
|
||||
pass
|
||||
except Exception as e:
|
||||
# Sites include things that do not contain drives so this fails
|
||||
# but this is fine, as there are no actual documents in those
|
||||
logger.warning(f"Failed to process site: {str(e)}")
|
||||
|
||||
def _populate_sitedata_sites(self) -> None:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Sharepoint")
|
||||
return final_driveitems
|
||||
|
||||
if self.site_data:
|
||||
for element in self.site_data:
|
||||
element.sites = [
|
||||
self.graph_client.sites.get_by_url(element.url)
|
||||
.get()
|
||||
.execute_query()
|
||||
]
|
||||
else:
|
||||
sites = self.graph_client.sites.get_all().execute_query()
|
||||
self.site_data = [
|
||||
SiteData(url=None, folder=None, sites=sites, driveitems=[])
|
||||
]
|
||||
def _fetch_sites(self) -> list[SiteDescriptor]:
|
||||
sites = self.graph_client.sites.get_all().execute_query()
|
||||
site_descriptors = [
|
||||
SiteDescriptor(
|
||||
url=sites.resource_url,
|
||||
drive_name=None,
|
||||
folder_path=None,
|
||||
)
|
||||
]
|
||||
return site_descriptors
|
||||
|
||||
def _fetch_from_sharepoint(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Sharepoint")
|
||||
|
||||
self._populate_sitedata_sites()
|
||||
self._populate_sitedata_driveitems(start=start, end=end)
|
||||
site_descriptors = self.site_descriptors or self._fetch_sites()
|
||||
|
||||
# goes over all urls, converts them into Document objects and then yields them in batches
|
||||
doc_batch: list[Document] = []
|
||||
for element in self.site_data:
|
||||
for driveitem in element.driveitems:
|
||||
for site_descriptor in site_descriptors:
|
||||
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
|
||||
for driveitem, drive_name in driveitems:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
doc_batch.append(_convert_driveitem_to_document(driveitem))
|
||||
doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
@ -202,7 +273,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
return token
|
||||
|
||||
self.graph_client = GraphClient(_acquire_token_func)
|
||||
self._graph_client = GraphClient(_acquire_token_func)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
@ -211,19 +282,19 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
start_datetime = datetime.fromtimestamp(start, timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, timezone.utc)
|
||||
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SharepointConnector(sites=os.environ["SITES"].split(","))
|
||||
connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(","))
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"sp_client_id": os.environ["SP_CLIENT_ID"],
|
||||
"sp_client_secret": os.environ["SP_CLIENT_SECRET"],
|
||||
"sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"],
|
||||
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
|
||||
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
|
||||
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
|
@ -537,30 +537,36 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
# Let the tag flow handle this case, don't reply twice
|
||||
return False
|
||||
|
||||
if event.get("bot_profile"):
|
||||
# Check if this is a bot message (either via bot_profile or bot_message subtype)
|
||||
is_bot_message = bool(
|
||||
event.get("bot_profile") or event.get("subtype") == "bot_message"
|
||||
)
|
||||
if is_bot_message:
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
|
||||
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
|
||||
if (not bot_tag_id or bot_tag_id not in msg) and (
|
||||
not slack_channel_config
|
||||
or not slack_channel_config.channel_config.get("respond_to_bots")
|
||||
):
|
||||
channel_specific_logger.info("Ignoring message from bot")
|
||||
channel_specific_logger.info(
|
||||
"Ignoring message from bot since respond_to_bots is disabled"
|
||||
)
|
||||
return False
|
||||
|
||||
# Ignore things like channel_join, channel_leave, etc.
|
||||
# NOTE: "file_share" is just a message with a file attachment, so we
|
||||
# should not ignore it
|
||||
message_subtype = event.get("subtype")
|
||||
if message_subtype not in [None, "file_share"]:
|
||||
if message_subtype not in [None, "file_share", "bot_message"]:
|
||||
channel_specific_logger.info(
|
||||
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
|
||||
)
|
||||
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
@ -10,25 +10,24 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("table_name", os.environ["AIRTABLE_TEST_TABLE_NAME"]),
|
||||
("table_id", os.environ["AIRTABLE_TEST_TABLE_ID"]),
|
||||
]
|
||||
)
|
||||
def airtable_connector(request: pytest.FixtureRequest) -> AirtableConnector:
|
||||
param_type, table_identifier = request.param
|
||||
connector = AirtableConnector(
|
||||
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
|
||||
table_name_or_id=table_identifier,
|
||||
)
|
||||
class AirtableConfig(BaseModel):
|
||||
base_id: str
|
||||
table_identifier: str
|
||||
access_token: str
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"airtable_access_token": os.environ["AIRTABLE_ACCESS_TOKEN"],
|
||||
}
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def airtable_config(request: pytest.FixtureRequest) -> AirtableConfig:
|
||||
table_identifier = (
|
||||
os.environ["AIRTABLE_TEST_TABLE_NAME"]
|
||||
if request.param
|
||||
else os.environ["AIRTABLE_TEST_TABLE_ID"]
|
||||
)
|
||||
return AirtableConfig(
|
||||
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
|
||||
table_identifier=table_identifier,
|
||||
access_token=os.environ["AIRTABLE_ACCESS_TOKEN"],
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
def create_test_document(
|
||||
@ -46,18 +45,37 @@ def create_test_document(
|
||||
assignee: str,
|
||||
days_since_status_change: int | None,
|
||||
attachments: list[tuple[str, str]] | None = None,
|
||||
all_fields_as_metadata: bool = False,
|
||||
) -> Document:
|
||||
link_base = f"https://airtable.com/{os.environ['AIRTABLE_TEST_BASE_ID']}/{os.environ['AIRTABLE_TEST_TABLE_ID']}"
|
||||
sections = [
|
||||
Section(
|
||||
text=f"Title:\n------------------------\n{title}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
Section(
|
||||
text=f"Description:\n------------------------\n{description}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
]
|
||||
base_id = os.environ.get("AIRTABLE_TEST_BASE_ID")
|
||||
table_id = os.environ.get("AIRTABLE_TEST_TABLE_ID")
|
||||
missing_vars = []
|
||||
if not base_id:
|
||||
missing_vars.append("AIRTABLE_TEST_BASE_ID")
|
||||
if not table_id:
|
||||
missing_vars.append("AIRTABLE_TEST_TABLE_ID")
|
||||
|
||||
if missing_vars:
|
||||
raise RuntimeError(
|
||||
f"Required environment variables not set: {', '.join(missing_vars)}. "
|
||||
"These variables are required to run Airtable connector tests."
|
||||
)
|
||||
link_base = f"https://airtable.com/{base_id}/{table_id}"
|
||||
sections = []
|
||||
|
||||
if not all_fields_as_metadata:
|
||||
sections.extend(
|
||||
[
|
||||
Section(
|
||||
text=f"Title:\n------------------------\n{title}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
Section(
|
||||
text=f"Description:\n------------------------\n{description}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if attachments:
|
||||
for attachment_text, attachment_link in attachments:
|
||||
@ -68,26 +86,36 @@ def create_test_document(
|
||||
),
|
||||
)
|
||||
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
# "Category": category,
|
||||
"Assignee": assignee,
|
||||
"Submitted by": submitted_by,
|
||||
"Priority": priority,
|
||||
"Status": status,
|
||||
"Created time": created_time,
|
||||
"ID": ticket_id,
|
||||
"Status last changed": status_last_changed,
|
||||
**(
|
||||
{"Days since status change": str(days_since_status_change)}
|
||||
if days_since_status_change is not None
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
if all_fields_as_metadata:
|
||||
metadata.update(
|
||||
{
|
||||
"Title": title,
|
||||
"Description": description,
|
||||
}
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{id}",
|
||||
sections=sections,
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=f"{os.environ['AIRTABLE_TEST_TABLE_NAME']}: {title}",
|
||||
metadata={
|
||||
# "Category": category,
|
||||
"Assignee": assignee,
|
||||
"Submitted by": submitted_by,
|
||||
"Priority": priority,
|
||||
"Status": status,
|
||||
"Created time": created_time,
|
||||
"ID": ticket_id,
|
||||
"Status last changed": status_last_changed,
|
||||
**(
|
||||
{"Days since status change": str(days_since_status_change)}
|
||||
if days_since_status_change is not None
|
||||
else {}
|
||||
),
|
||||
},
|
||||
semantic_identifier=f"{os.environ.get('AIRTABLE_TEST_TABLE_NAME', '')}: {title}",
|
||||
metadata=metadata,
|
||||
doc_updated_at=None,
|
||||
primary_owners=None,
|
||||
secondary_owners=None,
|
||||
@ -97,15 +125,75 @@ def create_test_document(
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_airtable_connector_basic(
|
||||
mock_get_api_key: MagicMock, airtable_connector: AirtableConnector
|
||||
def compare_documents(
|
||||
actual_docs: list[Document], expected_docs: list[Document]
|
||||
) -> None:
|
||||
doc_batch_generator = airtable_connector.load_from_state()
|
||||
"""Utility function to compare actual and expected documents, ignoring order."""
|
||||
actual_docs_dict = {doc.id: doc for doc in actual_docs}
|
||||
expected_docs_dict = {doc.id: doc for doc in expected_docs}
|
||||
|
||||
assert actual_docs_dict.keys() == expected_docs_dict.keys(), "Document ID mismatch"
|
||||
|
||||
for doc_id in actual_docs_dict:
|
||||
actual = actual_docs_dict[doc_id]
|
||||
expected = expected_docs_dict[doc_id]
|
||||
|
||||
assert (
|
||||
actual.source == expected.source
|
||||
), f"Source mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.semantic_identifier == expected.semantic_identifier
|
||||
), f"Semantic identifier mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.metadata == expected.metadata
|
||||
), f"Metadata mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.doc_updated_at == expected.doc_updated_at
|
||||
), f"Updated at mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.primary_owners == expected.primary_owners
|
||||
), f"Primary owners mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.secondary_owners == expected.secondary_owners
|
||||
), f"Secondary owners mismatch for document {doc_id}"
|
||||
assert actual.title == expected.title, f"Title mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.from_ingestion_api == expected.from_ingestion_api
|
||||
), f"Ingestion API flag mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.additional_info == expected.additional_info
|
||||
), f"Additional info mismatch for document {doc_id}"
|
||||
|
||||
# Compare sections
|
||||
assert len(actual.sections) == len(
|
||||
expected.sections
|
||||
), f"Number of sections mismatch for document {doc_id}"
|
||||
for i, (actual_section, expected_section) in enumerate(
|
||||
zip(actual.sections, expected.sections)
|
||||
):
|
||||
assert (
|
||||
actual_section.text == expected_section.text
|
||||
), f"Section {i} text mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual_section.link == expected_section.link
|
||||
), f"Section {i} link mismatch for document {doc_id}"
|
||||
|
||||
|
||||
def test_airtable_connector_basic(
|
||||
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
|
||||
) -> None:
|
||||
"""Test behavior when all non-attachment fields are treated as metadata."""
|
||||
connector = AirtableConnector(
|
||||
base_id=airtable_config.base_id,
|
||||
table_name_or_id=airtable_config.table_identifier,
|
||||
treat_all_non_attachment_fields_as_metadata=False,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"airtable_access_token": airtable_config.access_token,
|
||||
}
|
||||
)
|
||||
doc_batch_generator = connector.load_from_state()
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
@ -119,15 +207,62 @@ def test_airtable_connector_basic(
|
||||
description="The internet connection is very slow.",
|
||||
priority="Medium",
|
||||
status="In Progress",
|
||||
# Link to another record is skipped for now
|
||||
# category="Data Science",
|
||||
ticket_id="2",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
all_fields_as_metadata=False,
|
||||
),
|
||||
create_test_document(
|
||||
id="reccSlIA4pZEFxPBg",
|
||||
title="Printer Issue",
|
||||
description="The office printer is not working.",
|
||||
priority="High",
|
||||
status="Open",
|
||||
ticket_id="1",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
attachments=[
|
||||
(
|
||||
"Test.pdf:\ntesting!!!",
|
||||
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=False,
|
||||
),
|
||||
]
|
||||
|
||||
# Compare documents using the utility function
|
||||
compare_documents(doc_batch, expected_docs)
|
||||
|
||||
|
||||
def test_airtable_connector_all_metadata(
|
||||
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
|
||||
) -> None:
|
||||
connector = AirtableConnector(
|
||||
base_id=airtable_config.base_id,
|
||||
table_name_or_id=airtable_config.table_identifier,
|
||||
treat_all_non_attachment_fields_as_metadata=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"airtable_access_token": airtable_config.access_token,
|
||||
}
|
||||
)
|
||||
doc_batch_generator = connector.load_from_state()
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
# NOTE: one of the rows has no attachments -> no content -> no document
|
||||
assert len(doc_batch) == 1
|
||||
|
||||
expected_docs = [
|
||||
create_test_document(
|
||||
id="reccSlIA4pZEFxPBg",
|
||||
title="Printer Issue",
|
||||
@ -149,50 +284,9 @@ def test_airtable_connector_basic(
|
||||
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=True,
|
||||
),
|
||||
]
|
||||
|
||||
# Compare each document field by field
|
||||
for actual, expected in zip(doc_batch, expected_docs):
|
||||
assert actual.id == expected.id, f"ID mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.source == expected.source
|
||||
), f"Source mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.semantic_identifier == expected.semantic_identifier
|
||||
), f"Semantic identifier mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.metadata == expected.metadata
|
||||
), f"Metadata mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.doc_updated_at == expected.doc_updated_at
|
||||
), f"Updated at mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.primary_owners == expected.primary_owners
|
||||
), f"Primary owners mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.secondary_owners == expected.secondary_owners
|
||||
), f"Secondary owners mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.title == expected.title
|
||||
), f"Title mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.from_ingestion_api == expected.from_ingestion_api
|
||||
), f"Ingestion API flag mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.additional_info == expected.additional_info
|
||||
), f"Additional info mismatch for document {actual.id}"
|
||||
|
||||
# Compare sections
|
||||
assert len(actual.sections) == len(
|
||||
expected.sections
|
||||
), f"Number of sections mismatch for document {actual.id}"
|
||||
for i, (actual_section, expected_section) in enumerate(
|
||||
zip(actual.sections, expected.sections)
|
||||
):
|
||||
assert (
|
||||
actual_section.text == expected_section.text
|
||||
), f"Section {i} text mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual_section.link == expected_section.link
|
||||
), f"Section {i} link mismatch for document {actual.id}"
|
||||
# Compare documents using the utility function
|
||||
compare_documents(doc_batch, expected_docs)
|
||||
|
14
backend/tests/daily/connectors/conftest.py
Normal file
14
backend/tests/daily/connectors/conftest.py
Normal file
@ -0,0 +1,14 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
with patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
) as mock:
|
||||
yield mock
|
@ -0,0 +1,178 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpectedDocument:
|
||||
semantic_identifier: str
|
||||
content: str
|
||||
folder_path: str | None = None
|
||||
library: str = "Shared Documents" # Default to main library
|
||||
|
||||
|
||||
EXPECTED_DOCUMENTS = [
|
||||
ExpectedDocument(
|
||||
semantic_identifier="test1.docx",
|
||||
content="test1",
|
||||
folder_path="test",
|
||||
),
|
||||
ExpectedDocument(
|
||||
semantic_identifier="test2.docx",
|
||||
content="test2",
|
||||
folder_path="test/nested with spaces",
|
||||
),
|
||||
ExpectedDocument(
|
||||
semantic_identifier="should-not-index-on-specific-folder.docx",
|
||||
content="should-not-index-on-specific-folder",
|
||||
folder_path=None, # root folder
|
||||
),
|
||||
ExpectedDocument(
|
||||
semantic_identifier="other.docx",
|
||||
content="other",
|
||||
folder_path=None,
|
||||
library="Other Library",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def verify_document_metadata(doc: Document) -> None:
|
||||
"""Verify common metadata that should be present on all documents."""
|
||||
assert isinstance(doc.doc_updated_at, datetime)
|
||||
assert doc.doc_updated_at.tzinfo == timezone.utc
|
||||
assert doc.source == DocumentSource.SHAREPOINT
|
||||
assert doc.primary_owners is not None
|
||||
assert len(doc.primary_owners) == 1
|
||||
owner = doc.primary_owners[0]
|
||||
assert owner.display_name is not None
|
||||
assert owner.email is not None
|
||||
|
||||
|
||||
def verify_document_content(doc: Document, expected: ExpectedDocument) -> None:
|
||||
"""Verify a document matches its expected content."""
|
||||
assert doc.semantic_identifier == expected.semantic_identifier
|
||||
assert len(doc.sections) == 1
|
||||
assert expected.content in doc.sections[0].text
|
||||
verify_document_metadata(doc)
|
||||
|
||||
|
||||
def find_document(documents: list[Document], semantic_identifier: str) -> Document:
|
||||
"""Find a document by its semantic identifier."""
|
||||
matching_docs = [
|
||||
d for d in documents if d.semantic_identifier == semantic_identifier
|
||||
]
|
||||
assert (
|
||||
len(matching_docs) == 1
|
||||
), f"Expected exactly one document with identifier {semantic_identifier}"
|
||||
return matching_docs[0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sharepoint_credentials() -> dict[str, str]:
|
||||
return {
|
||||
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
|
||||
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
|
||||
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
|
||||
}
|
||||
|
||||
|
||||
def test_sharepoint_connector_specific_folder(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the test site URL and specific folder
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
|
||||
# Should only find documents in the test folder
|
||||
test_folder_docs = [
|
||||
doc
|
||||
for doc in EXPECTED_DOCUMENTS
|
||||
if doc.folder_path and doc.folder_path.startswith("test")
|
||||
]
|
||||
assert len(found_documents) == len(
|
||||
test_folder_docs
|
||||
), "Should only find documents in test folder"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in test_folder_docs:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_root_folder(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]])
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
|
||||
assert len(found_documents) == len(
|
||||
EXPECTED_DOCUMENTS
|
||||
), "Should find all documents in main library"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in EXPECTED_DOCUMENTS:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_other_library(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the other library
|
||||
connector = SharepointConnector(
|
||||
sites=[
|
||||
os.environ["SHAREPOINT_SITE"] + "/Other Library",
|
||||
]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
expected_documents: list[ExpectedDocument] = [
|
||||
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
|
||||
]
|
||||
|
||||
# Should find all documents in `Other Library`
|
||||
assert len(found_documents) == len(
|
||||
expected_documents
|
||||
), "Should find all documents in `Other Library`"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in expected_documents:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
@ -76,13 +76,7 @@ import {
|
||||
import { buildFilters } from "@/lib/search/utils";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import Dropzone from "react-dropzone";
|
||||
import {
|
||||
checkLLMSupportsImageInput,
|
||||
getFinalLLM,
|
||||
destructureValue,
|
||||
getLLMProviderOverrideForPersona,
|
||||
} from "@/lib/llm/utils";
|
||||
|
||||
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
|
||||
import { ChatInputBar } from "./input/ChatInputBar";
|
||||
import { useChatContext } from "@/components/context/ChatContext";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
@ -203,6 +197,12 @@ export function ChatPage({
|
||||
|
||||
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
|
||||
|
||||
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
|
||||
|
||||
const selectedChatSession = chatSessions.find(
|
||||
(chatSession) => chatSession.id === existingChatSessionId
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (user?.is_anonymous_user) {
|
||||
Cookies.set(
|
||||
@ -240,12 +240,6 @@ export function ChatPage({
|
||||
}
|
||||
};
|
||||
|
||||
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
|
||||
|
||||
const selectedChatSession = chatSessions.find(
|
||||
(chatSession) => chatSession.id === existingChatSessionId
|
||||
);
|
||||
|
||||
const chatSessionIdRef = useRef<string | null>(existingChatSessionId);
|
||||
|
||||
// Only updates on session load (ie. rename / switching chat session)
|
||||
@ -293,12 +287,6 @@ export function ChatPage({
|
||||
);
|
||||
};
|
||||
|
||||
const llmOverrideManager = useLlmOverride(
|
||||
llmProviders,
|
||||
user?.preferences.default_model,
|
||||
selectedChatSession
|
||||
);
|
||||
|
||||
const [alternativeAssistant, setAlternativeAssistant] =
|
||||
useState<Persona | null>(null);
|
||||
|
||||
@ -307,12 +295,27 @@ export function ChatPage({
|
||||
|
||||
const { recentAssistants, refreshRecentAssistants } = useAssistants();
|
||||
|
||||
const liveAssistant: Persona | undefined =
|
||||
alternativeAssistant ||
|
||||
selectedAssistant ||
|
||||
recentAssistants[0] ||
|
||||
finalAssistants[0] ||
|
||||
availableAssistants[0];
|
||||
const liveAssistant: Persona | undefined = useMemo(
|
||||
() =>
|
||||
alternativeAssistant ||
|
||||
selectedAssistant ||
|
||||
recentAssistants[0] ||
|
||||
finalAssistants[0] ||
|
||||
availableAssistants[0],
|
||||
[
|
||||
alternativeAssistant,
|
||||
selectedAssistant,
|
||||
recentAssistants,
|
||||
finalAssistants,
|
||||
availableAssistants,
|
||||
]
|
||||
);
|
||||
|
||||
const llmOverrideManager = useLlmOverride(
|
||||
llmProviders,
|
||||
selectedChatSession,
|
||||
liveAssistant
|
||||
);
|
||||
|
||||
const noAssistants = liveAssistant == null || liveAssistant == undefined;
|
||||
|
||||
@ -320,24 +323,6 @@ export function ChatPage({
|
||||
const uniqueSources = Array.from(new Set(availableSources));
|
||||
const sources = uniqueSources.map((source) => getSourceMetadata(source));
|
||||
|
||||
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
|
||||
useEffect(() => {
|
||||
if (noAssistants) return;
|
||||
const personaDefault = getLLMProviderOverrideForPersona(
|
||||
liveAssistant,
|
||||
llmProviders
|
||||
);
|
||||
|
||||
if (personaDefault) {
|
||||
llmOverrideManager.updateLLMOverride(personaDefault);
|
||||
} else if (user?.preferences.default_model) {
|
||||
llmOverrideManager.updateLLMOverride(
|
||||
destructureValue(user?.preferences.default_model)
|
||||
);
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [liveAssistant, user?.preferences.default_model]);
|
||||
|
||||
const stopGenerating = () => {
|
||||
const currentSession = currentSessionId();
|
||||
const controller = abortControllers.get(currentSession);
|
||||
@ -419,7 +404,6 @@ export function ChatPage({
|
||||
filterManager.setTimeRange(null);
|
||||
|
||||
// reset LLM overrides (based on chat session!)
|
||||
llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
|
||||
llmOverrideManager.updateTemperature(null);
|
||||
|
||||
// remove uploaded files
|
||||
@ -1283,13 +1267,11 @@ export function ChatPage({
|
||||
modelProvider:
|
||||
modelOverRide?.name ||
|
||||
llmOverrideManager.llmOverride.name ||
|
||||
llmOverrideManager.globalDefault.name ||
|
||||
undefined,
|
||||
modelVersion:
|
||||
modelOverRide?.modelName ||
|
||||
llmOverrideManager.llmOverride.modelName ||
|
||||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
|
||||
llmOverrideManager.globalDefault.modelName ||
|
||||
undefined,
|
||||
temperature: llmOverrideManager.temperature || undefined,
|
||||
systemPromptOverride:
|
||||
@ -1952,6 +1934,7 @@ export function ChatPage({
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [router]);
|
||||
|
||||
const [sharedChatSession, setSharedChatSession] =
|
||||
useState<ChatSession | null>();
|
||||
|
||||
@ -2059,7 +2042,9 @@ export function ChatPage({
|
||||
{(settingsToggled || userSettingsToggled) && (
|
||||
<UserSettingsModal
|
||||
setPopup={setPopup}
|
||||
setLlmOverride={llmOverrideManager.setGlobalDefault}
|
||||
setLlmOverride={(newOverride) =>
|
||||
llmOverrideManager.updateLLMOverride(newOverride)
|
||||
}
|
||||
defaultModel={user?.preferences.default_model!}
|
||||
llmProviders={llmProviders}
|
||||
onClose={() => {
|
||||
@ -2749,6 +2734,7 @@ export function ChatPage({
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<ChatInputBar
|
||||
toggleDocumentSidebar={toggleDocumentSidebar}
|
||||
availableSources={sources}
|
||||
|
@ -40,8 +40,8 @@ export default function LLMPopover({
|
||||
currentAssistant,
|
||||
}: LLMPopoverProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { llmOverride, updateLLMOverride, globalDefault } = llmOverrideManager;
|
||||
const currentLlm = llmOverride.modelName || globalDefault.modelName;
|
||||
const { llmOverride, updateLLMOverride } = llmOverrideManager;
|
||||
const currentLlm = llmOverride.modelName;
|
||||
|
||||
const llmOptionsByProvider: {
|
||||
[provider: string]: {
|
||||
|
@ -1,13 +1,5 @@
|
||||
import {
|
||||
Dispatch,
|
||||
SetStateAction,
|
||||
useContext,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
import { useContext, useEffect, useRef } from "react";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import Text from "@/components/ui/text";
|
||||
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
|
||||
@ -33,7 +25,7 @@ export function UserSettingsModal({
|
||||
}: {
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
setLlmOverride?: Dispatch<SetStateAction<LlmOverride>>;
|
||||
setLlmOverride?: (newOverride: LlmOverride) => void;
|
||||
onClose: () => void;
|
||||
defaultModel: string | null;
|
||||
}) {
|
||||
|
@ -1,3 +1,3 @@
|
||||
export const SEARCH_TOOL_NAME = "SearchTool";
|
||||
export const SEARCH_TOOL_NAME = "run_search";
|
||||
export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search";
|
||||
export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation";
|
||||
|
@ -1106,6 +1106,14 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
|
||||
name: "table_name_or_id",
|
||||
optional: false,
|
||||
},
|
||||
{
|
||||
type: "checkbox",
|
||||
label: "Treat all fields except attachments as metadata",
|
||||
name: "treat_all_non_attachment_fields_as_metadata",
|
||||
description:
|
||||
"Choose this if the primary content to index are attachments and all other columns are metadata for these attachments.",
|
||||
optional: false,
|
||||
},
|
||||
],
|
||||
advanced_values: [],
|
||||
overrideDefaultFreq: 60 * 60 * 24,
|
||||
|
@ -13,16 +13,21 @@ import { errorHandlingFetcher } from "./fetcher";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
|
||||
import { Filters, SourceMetadata } from "./search/interfaces";
|
||||
import { destructureValue, structureValue } from "./llm/utils";
|
||||
import {
|
||||
destructureValue,
|
||||
findProviderForModel,
|
||||
structureValue,
|
||||
} from "./llm/utils";
|
||||
import { ChatSession } from "@/app/chat/interfaces";
|
||||
import { AllUsersResponse } from "./types";
|
||||
import { Credential } from "./connectors/credentials";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { PersonaLabel } from "@/app/admin/assistants/interfaces";
|
||||
import { Persona, PersonaLabel } from "@/app/admin/assistants/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { getSourceMetadata } from "./sources";
|
||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
|
||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||
|
||||
@ -355,82 +360,141 @@ export interface LlmOverride {
|
||||
export interface LlmOverrideManager {
|
||||
llmOverride: LlmOverride;
|
||||
updateLLMOverride: (newOverride: LlmOverride) => void;
|
||||
globalDefault: LlmOverride;
|
||||
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
|
||||
temperature: number | null;
|
||||
updateTemperature: (temperature: number | null) => void;
|
||||
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
||||
imageFilesPresent: boolean;
|
||||
updateImageFilesPresent: (present: boolean) => void;
|
||||
liveAssistant: Persona | null;
|
||||
}
|
||||
|
||||
/*
|
||||
LLM Override is as follows (i.e. this order)
|
||||
- User override (explicitly set in the chat input bar)
|
||||
- User preference (defaults to system wide default if no preference set)
|
||||
|
||||
On switching to an existing or new chat session or a different assistant:
|
||||
- If we have a live assistant after any switch with a model override, use that- otherwise use the above hierarchy
|
||||
|
||||
Thus, the input should be
|
||||
- User preference
|
||||
- LLM Providers (which contain the system wide default)
|
||||
- Current assistant
|
||||
|
||||
Changes take place as
|
||||
- liveAssistant or currentChatSession changes (and the associated model override is set)
|
||||
- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
|
||||
|
||||
If we have a live assistant, we should use that model override
|
||||
*/
|
||||
|
||||
export function useLlmOverride(
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
globalModel?: string | null,
|
||||
currentChatSession?: ChatSession,
|
||||
defaultTemperature?: number
|
||||
liveAssistant?: Persona
|
||||
): LlmOverrideManager {
|
||||
const { user } = useUser();
|
||||
|
||||
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
||||
|
||||
const llmOverrideUpdate = () => {
|
||||
if (!chatSession && currentChatSession) {
|
||||
setChatSession(currentChatSession || null);
|
||||
return;
|
||||
}
|
||||
|
||||
if (liveAssistant?.llm_model_version_override) {
|
||||
setLlmOverride(
|
||||
getValidLlmOverride(liveAssistant.llm_model_version_override)
|
||||
);
|
||||
} else if (currentChatSession?.current_alternate_model) {
|
||||
setLlmOverride(
|
||||
getValidLlmOverride(currentChatSession.current_alternate_model)
|
||||
);
|
||||
} else if (user?.preferences?.default_model) {
|
||||
setLlmOverride(getValidLlmOverride(user.preferences.default_model));
|
||||
return;
|
||||
} else {
|
||||
const defaultProvider = llmProviders.find(
|
||||
(provider) => provider.is_default_provider
|
||||
);
|
||||
|
||||
if (defaultProvider) {
|
||||
setLlmOverride({
|
||||
name: defaultProvider.name,
|
||||
provider: defaultProvider.provider,
|
||||
modelName: defaultProvider.default_model_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
setChatSession(currentChatSession || null);
|
||||
};
|
||||
|
||||
const getValidLlmOverride = (
|
||||
overrideModel: string | null | undefined
|
||||
): LlmOverride => {
|
||||
if (overrideModel) {
|
||||
const model = destructureValue(overrideModel);
|
||||
const provider = llmProviders.find(
|
||||
(p) =>
|
||||
p.model_names.includes(model.modelName) &&
|
||||
p.provider === model.provider
|
||||
if (!(model.modelName && model.modelName.length > 0)) {
|
||||
const provider = llmProviders.find((p) =>
|
||||
p.model_names.includes(overrideModel)
|
||||
);
|
||||
if (provider) {
|
||||
return {
|
||||
modelName: overrideModel,
|
||||
name: provider.name,
|
||||
provider: provider.provider,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const provider = llmProviders.find((p) =>
|
||||
p.model_names.includes(model.modelName)
|
||||
);
|
||||
|
||||
if (provider) {
|
||||
return { ...model, name: provider.name };
|
||||
}
|
||||
}
|
||||
return { name: "", provider: "", modelName: "" };
|
||||
};
|
||||
|
||||
const [imageFilesPresent, setImageFilesPresent] = useState(false);
|
||||
|
||||
const updateImageFilesPresent = (present: boolean) => {
|
||||
setImageFilesPresent(present);
|
||||
};
|
||||
|
||||
const [globalDefault, setGlobalDefault] = useState<LlmOverride>(
|
||||
getValidLlmOverride(globalModel)
|
||||
);
|
||||
const updateLLMOverride = (newOverride: LlmOverride) => {
|
||||
setLlmOverride(
|
||||
getValidLlmOverride(
|
||||
structureValue(
|
||||
newOverride.name,
|
||||
newOverride.provider,
|
||||
newOverride.modelName
|
||||
)
|
||||
)
|
||||
);
|
||||
};
|
||||
const [llmOverride, setLlmOverride] = useState<LlmOverride>({
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
});
|
||||
|
||||
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
|
||||
currentChatSession && currentChatSession.current_alternate_model
|
||||
? getValidLlmOverride(currentChatSession.current_alternate_model)
|
||||
: { name: "", provider: "", modelName: "" }
|
||||
);
|
||||
// Manually set the override
|
||||
const updateLLMOverride = (newOverride: LlmOverride) => {
|
||||
const provider =
|
||||
newOverride.provider ||
|
||||
findProviderForModel(llmProviders, newOverride.modelName);
|
||||
const structuredValue = structureValue(
|
||||
newOverride.name,
|
||||
provider,
|
||||
newOverride.modelName
|
||||
);
|
||||
setLlmOverride(getValidLlmOverride(structuredValue));
|
||||
};
|
||||
|
||||
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
|
||||
setLlmOverride(
|
||||
chatSession && chatSession.current_alternate_model
|
||||
? getValidLlmOverride(chatSession.current_alternate_model)
|
||||
: globalDefault
|
||||
);
|
||||
if (chatSession && chatSession.current_alternate_model?.length > 0) {
|
||||
setLlmOverride(getValidLlmOverride(chatSession.current_alternate_model));
|
||||
}
|
||||
};
|
||||
|
||||
const [temperature, setTemperature] = useState<number | null>(
|
||||
defaultTemperature !== undefined ? defaultTemperature : 0
|
||||
);
|
||||
const [temperature, setTemperature] = useState<number | null>(0);
|
||||
|
||||
useEffect(() => {
|
||||
setGlobalDefault(getValidLlmOverride(globalModel));
|
||||
}, [globalModel, llmProviders]);
|
||||
|
||||
useEffect(() => {
|
||||
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
|
||||
}, [defaultTemperature]);
|
||||
llmOverrideUpdate();
|
||||
}, [liveAssistant, currentChatSession]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||
@ -450,12 +514,11 @@ export function useLlmOverride(
|
||||
updateModelOverrideForChatSession,
|
||||
llmOverride,
|
||||
updateLLMOverride,
|
||||
globalDefault,
|
||||
setGlobalDefault,
|
||||
temperature,
|
||||
updateTemperature,
|
||||
imageFilesPresent,
|
||||
updateImageFilesPresent,
|
||||
liveAssistant: liveAssistant ?? null,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -143,3 +143,11 @@ export const destructureValue = (value: string): LlmOverride => {
|
||||
modelName,
|
||||
};
|
||||
};
|
||||
|
||||
export const findProviderForModel = (
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
modelName: string
|
||||
): string => {
|
||||
const provider = llmProviders.find((p) => p.model_names.includes(modelName));
|
||||
return provider ? provider.provider : "";
|
||||
};
|
||||
|
@ -358,7 +358,8 @@ export type ConfigurableSources = Exclude<
|
||||
|
||||
export const oauthSupportedSources: ConfigurableSources[] = [
|
||||
ValidSources.Slack,
|
||||
ValidSources.GoogleDrive,
|
||||
// NOTE: temporarily disabled until our GDrive App is approved
|
||||
// ValidSources.GoogleDrive,
|
||||
];
|
||||
|
||||
export type OAuthSupportedSource = (typeof oauthSupportedSources)[number];
|
||||
|
Loading…
x
Reference in New Issue
Block a user