From 57b4639709671d4b1bc1ba2f826c59656f368abf Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Tue, 28 Jan 2025 16:52:00 -0800 Subject: [PATCH 1/6] fix name --- web/src/app/chat/tools/constants.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/app/chat/tools/constants.ts b/web/src/app/chat/tools/constants.ts index c646ee9d9..35624493b 100644 --- a/web/src/app/chat/tools/constants.ts +++ b/web/src/app/chat/tools/constants.ts @@ -1,3 +1,3 @@ -export const SEARCH_TOOL_NAME = "SearchTool"; +export const SEARCH_TOOL_NAME = "run_search"; export const INTERNET_SEARCH_TOOL_NAME = "run_internet_search"; export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation"; From d903e5912a0069d49256cd53396ffe296737060a Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:28:32 -0800 Subject: [PATCH 2/6] feat: add option to treat all non-attachment fields as metadata in Airtable connector (#3817) * feat: add option to treat all non-attachment fields as metadata in Airtable connector - Added new UI option 'treat_all_non_attachment_fields_as_metadata' - Updated backend logic to support treating all fields except attachments as metadata - Added tests for both default and all-metadata behaviors Co-Authored-By: Chris Weaver * fix: handle missing environment variables gracefully in airtable tests Co-Authored-By: Chris Weaver * fix: clean up test file and handle environment variables properly Co-Authored-By: Chris Weaver * fix: add missing test fixture and fix formatting Co-Authored-By: Chris Weaver * chore: fix black formatting Co-Authored-By: Chris Weaver * fix: add type annotation for metadata dict in airtable tests Co-Authored-By: Chris Weaver * fix: add type annotation for mock_get_api_key fixture Co-Authored-By: Chris Weaver * fix: update Generator import to use collections.abc Co-Authored-By: Chris Weaver * refactor: make treat_all_non_attachment_fields_as_metadata a direct required parameter - Move parameter from connector_config to direct class parameter - Place parameter right under table_name_or_id argument - Make parameter required in UI with no default value - Update tests to use new parameter structure Co-Authored-By: Chris Weaver * chore: fix black formatting Co-Authored-By: Chris Weaver * chore: rename _METADATA_FIELD_TYPES to DEFAULT_METADATA_FIELD_TYPES and clarify usage Co-Authored-By: Chris Weaver * chore: fix black formatting in docstring Co-Authored-By: Chris Weaver * test: make airtable tests fail loudly on missing env vars Co-Authored-By: Chris Weaver * style: fix black formatting in test file Co-Authored-By: Chris Weaver * style: add required newline between test functions Co-Authored-By: Chris Weaver * test: update error message pattern in parameter validation test Co-Authored-By: Chris Weaver * style: fix black formatting in test file Co-Authored-By: Chris Weaver * test: fix error message pattern in parameter validation test Co-Authored-By: Chris Weaver * style: fix line length in test file Co-Authored-By: Chris Weaver * test: simplify error message pattern in parameter validation test Co-Authored-By: Chris Weaver * test: add type validation test for treat_all_non_attachment_fields_as_metadata Co-Authored-By: Chris Weaver * fix: add missing required parameter in test Co-Authored-By: Chris Weaver * fix: remove parameter from test to properly validate it is required Co-Authored-By: Chris Weaver * fix: add type validation for treat_all_non_attachment_fields_as_metadata parameter Co-Authored-By: Chris Weaver * style: fix black formatting in airtable_connector.py Co-Authored-By: Chris Weaver * fix: update type validation test to handle mypy errors Co-Authored-By: Chris Weaver * fix: specify mypy ignore type for call-arg Co-Authored-By: Chris Weaver * Also handle rows w/o sections * style: fix black formatting in test assertion Co-Authored-By: Chris Weaver * add TODO * Remove unnecessary check * Fix test * Do not break existing airtable connectors --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Chris Weaver Co-authored-by: Weves --- .../connectors/airtable/airtable_connector.py | 29 +- .../airtable/test_airtable_basic.py | 301 ++++++++++++------ web/src/lib/connectors/connectors.tsx | 8 + 3 files changed, 233 insertions(+), 105 deletions(-) diff --git a/backend/onyx/connectors/airtable/airtable_connector.py b/backend/onyx/connectors/airtable/airtable_connector.py index 898fb0f31..777f2137f 100644 --- a/backend/onyx/connectors/airtable/airtable_connector.py +++ b/backend/onyx/connectors/airtable/airtable_connector.py @@ -20,9 +20,9 @@ from onyx.utils.logger import setup_logger logger = setup_logger() # NOTE: all are made lowercase to avoid case sensitivity issues -# these are the field types that are considered metadata rather -# than sections -_METADATA_FIELD_TYPES = { +# These field types are considered metadata by default when +# treat_all_non_attachment_fields_as_metadata is False +DEFAULT_METADATA_FIELD_TYPES = { "singlecollaborator", "collaborator", "createdby", @@ -60,12 +60,16 @@ class AirtableConnector(LoadConnector): self, base_id: str, table_name_or_id: str, + treat_all_non_attachment_fields_as_metadata: bool = False, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.base_id = base_id self.table_name_or_id = table_name_or_id self.batch_size = batch_size self.airtable_client: AirtableApi | None = None + self.treat_all_non_attachment_fields_as_metadata = ( + treat_all_non_attachment_fields_as_metadata + ) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.airtable_client = AirtableApi(credentials["airtable_access_token"]) @@ -166,8 +170,14 @@ class AirtableConnector(LoadConnector): return [(str(field_info), default_link)] def _should_be_metadata(self, field_type: str) -> bool: - """Determine if a field type should be treated as metadata.""" - return field_type.lower() in _METADATA_FIELD_TYPES + """Determine if a field type should be treated as metadata. + + When treat_all_non_attachment_fields_as_metadata is True, all fields except + attachments are treated as metadata. Otherwise, only fields with types listed + in DEFAULT_METADATA_FIELD_TYPES are treated as metadata.""" + if self.treat_all_non_attachment_fields_as_metadata: + return field_type.lower() != "multipleattachments" + return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES def _process_field( self, @@ -233,7 +243,7 @@ class AirtableConnector(LoadConnector): record: RecordDict, table_schema: TableSchema, primary_field_name: str | None, - ) -> Document: + ) -> Document | None: """Process a single Airtable record into a Document. Args: @@ -277,6 +287,10 @@ class AirtableConnector(LoadConnector): sections.extend(field_sections) metadata.update(field_metadata) + if not sections: + logger.warning(f"No sections found for record {record_id}") + return None + semantic_id = ( f"{table_name}: {primary_field_value}" if primary_field_value @@ -320,7 +334,8 @@ class AirtableConnector(LoadConnector): table_schema=table_schema, primary_field_name=primary_field_name, ) - record_documents.append(document) + if document: + record_documents.append(document) if len(record_documents) >= self.batch_size: yield record_documents diff --git a/backend/tests/daily/connectors/airtable/test_airtable_basic.py b/backend/tests/daily/connectors/airtable/test_airtable_basic.py index 078c7206e..bb5e312f1 100644 --- a/backend/tests/daily/connectors/airtable/test_airtable_basic.py +++ b/backend/tests/daily/connectors/airtable/test_airtable_basic.py @@ -1,8 +1,10 @@ import os +from collections.abc import Generator from unittest.mock import MagicMock from unittest.mock import patch import pytest +from pydantic import BaseModel from onyx.configs.constants import DocumentSource from onyx.connectors.airtable.airtable_connector import AirtableConnector @@ -10,25 +12,24 @@ from onyx.connectors.models import Document from onyx.connectors.models import Section -@pytest.fixture( - params=[ - ("table_name", os.environ["AIRTABLE_TEST_TABLE_NAME"]), - ("table_id", os.environ["AIRTABLE_TEST_TABLE_ID"]), - ] -) -def airtable_connector(request: pytest.FixtureRequest) -> AirtableConnector: - param_type, table_identifier = request.param - connector = AirtableConnector( - base_id=os.environ["AIRTABLE_TEST_BASE_ID"], - table_name_or_id=table_identifier, - ) +class AirtableConfig(BaseModel): + base_id: str + table_identifier: str + access_token: str - connector.load_credentials( - { - "airtable_access_token": os.environ["AIRTABLE_ACCESS_TOKEN"], - } + +@pytest.fixture(params=[True, False]) +def airtable_config(request: pytest.FixtureRequest) -> AirtableConfig: + table_identifier = ( + os.environ["AIRTABLE_TEST_TABLE_NAME"] + if request.param + else os.environ["AIRTABLE_TEST_TABLE_ID"] + ) + return AirtableConfig( + base_id=os.environ["AIRTABLE_TEST_BASE_ID"], + table_identifier=table_identifier, + access_token=os.environ["AIRTABLE_ACCESS_TOKEN"], ) - return connector def create_test_document( @@ -46,18 +47,37 @@ def create_test_document( assignee: str, days_since_status_change: int | None, attachments: list[tuple[str, str]] | None = None, + all_fields_as_metadata: bool = False, ) -> Document: - link_base = f"https://airtable.com/{os.environ['AIRTABLE_TEST_BASE_ID']}/{os.environ['AIRTABLE_TEST_TABLE_ID']}" - sections = [ - Section( - text=f"Title:\n------------------------\n{title}\n------------------------", - link=f"{link_base}/{id}", - ), - Section( - text=f"Description:\n------------------------\n{description}\n------------------------", - link=f"{link_base}/{id}", - ), - ] + base_id = os.environ.get("AIRTABLE_TEST_BASE_ID") + table_id = os.environ.get("AIRTABLE_TEST_TABLE_ID") + missing_vars = [] + if not base_id: + missing_vars.append("AIRTABLE_TEST_BASE_ID") + if not table_id: + missing_vars.append("AIRTABLE_TEST_TABLE_ID") + + if missing_vars: + raise RuntimeError( + f"Required environment variables not set: {', '.join(missing_vars)}. " + "These variables are required to run Airtable connector tests." + ) + link_base = f"https://airtable.com/{base_id}/{table_id}" + sections = [] + + if not all_fields_as_metadata: + sections.extend( + [ + Section( + text=f"Title:\n------------------------\n{title}\n------------------------", + link=f"{link_base}/{id}", + ), + Section( + text=f"Description:\n------------------------\n{description}\n------------------------", + link=f"{link_base}/{id}", + ), + ] + ) if attachments: for attachment_text, attachment_link in attachments: @@ -68,26 +88,36 @@ def create_test_document( ), ) + metadata: dict[str, str | list[str]] = { + # "Category": category, + "Assignee": assignee, + "Submitted by": submitted_by, + "Priority": priority, + "Status": status, + "Created time": created_time, + "ID": ticket_id, + "Status last changed": status_last_changed, + **( + {"Days since status change": str(days_since_status_change)} + if days_since_status_change is not None + else {} + ), + } + + if all_fields_as_metadata: + metadata.update( + { + "Title": title, + "Description": description, + } + ) + return Document( id=f"airtable__{id}", sections=sections, source=DocumentSource.AIRTABLE, - semantic_identifier=f"{os.environ['AIRTABLE_TEST_TABLE_NAME']}: {title}", - metadata={ - # "Category": category, - "Assignee": assignee, - "Submitted by": submitted_by, - "Priority": priority, - "Status": status, - "Created time": created_time, - "ID": ticket_id, - "Status last changed": status_last_changed, - **( - {"Days since status change": str(days_since_status_change)} - if days_since_status_change is not None - else {} - ), - }, + semantic_identifier=f"{os.environ.get('AIRTABLE_TEST_TABLE_NAME', '')}: {title}", + metadata=metadata, doc_updated_at=None, primary_owners=None, secondary_owners=None, @@ -97,15 +127,84 @@ def create_test_document( ) -@patch( - "onyx.file_processing.extract_file_text.get_unstructured_api_key", - return_value=None, -) -def test_airtable_connector_basic( - mock_get_api_key: MagicMock, airtable_connector: AirtableConnector -) -> None: - doc_batch_generator = airtable_connector.load_from_state() +@pytest.fixture +def mock_get_api_key() -> Generator[MagicMock, None, None]: + with patch( + "onyx.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, + ) as mock: + yield mock + +def compare_documents( + actual_docs: list[Document], expected_docs: list[Document] +) -> None: + """Utility function to compare actual and expected documents, ignoring order.""" + actual_docs_dict = {doc.id: doc for doc in actual_docs} + expected_docs_dict = {doc.id: doc for doc in expected_docs} + + assert actual_docs_dict.keys() == expected_docs_dict.keys(), "Document ID mismatch" + + for doc_id in actual_docs_dict: + actual = actual_docs_dict[doc_id] + expected = expected_docs_dict[doc_id] + + assert ( + actual.source == expected.source + ), f"Source mismatch for document {doc_id}" + assert ( + actual.semantic_identifier == expected.semantic_identifier + ), f"Semantic identifier mismatch for document {doc_id}" + assert ( + actual.metadata == expected.metadata + ), f"Metadata mismatch for document {doc_id}" + assert ( + actual.doc_updated_at == expected.doc_updated_at + ), f"Updated at mismatch for document {doc_id}" + assert ( + actual.primary_owners == expected.primary_owners + ), f"Primary owners mismatch for document {doc_id}" + assert ( + actual.secondary_owners == expected.secondary_owners + ), f"Secondary owners mismatch for document {doc_id}" + assert actual.title == expected.title, f"Title mismatch for document {doc_id}" + assert ( + actual.from_ingestion_api == expected.from_ingestion_api + ), f"Ingestion API flag mismatch for document {doc_id}" + assert ( + actual.additional_info == expected.additional_info + ), f"Additional info mismatch for document {doc_id}" + + # Compare sections + assert len(actual.sections) == len( + expected.sections + ), f"Number of sections mismatch for document {doc_id}" + for i, (actual_section, expected_section) in enumerate( + zip(actual.sections, expected.sections) + ): + assert ( + actual_section.text == expected_section.text + ), f"Section {i} text mismatch for document {doc_id}" + assert ( + actual_section.link == expected_section.link + ), f"Section {i} link mismatch for document {doc_id}" + + +def test_airtable_connector_basic( + mock_get_api_key: MagicMock, airtable_config: AirtableConfig +) -> None: + """Test behavior when all non-attachment fields are treated as metadata.""" + connector = AirtableConnector( + base_id=airtable_config.base_id, + table_name_or_id=airtable_config.table_identifier, + treat_all_non_attachment_fields_as_metadata=False, + ) + connector.load_credentials( + { + "airtable_access_token": airtable_config.access_token, + } + ) + doc_batch_generator = connector.load_from_state() doc_batch = next(doc_batch_generator) with pytest.raises(StopIteration): next(doc_batch_generator) @@ -119,15 +218,62 @@ def test_airtable_connector_basic( description="The internet connection is very slow.", priority="Medium", status="In Progress", - # Link to another record is skipped for now - # category="Data Science", ticket_id="2", created_time="2024-12-24T21:02:49.000Z", status_last_changed="2024-12-24T21:02:49.000Z", days_since_status_change=0, assignee="Chris Weaver (chris@onyx.app)", submitted_by="Chris Weaver (chris@onyx.app)", + all_fields_as_metadata=False, ), + create_test_document( + id="reccSlIA4pZEFxPBg", + title="Printer Issue", + description="The office printer is not working.", + priority="High", + status="Open", + ticket_id="1", + created_time="2024-12-24T21:02:49.000Z", + status_last_changed="2024-12-24T21:02:49.000Z", + days_since_status_change=0, + assignee="Chris Weaver (chris@onyx.app)", + submitted_by="Chris Weaver (chris@onyx.app)", + attachments=[ + ( + "Test.pdf:\ntesting!!!", + "https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide", + ) + ], + all_fields_as_metadata=False, + ), + ] + + # Compare documents using the utility function + compare_documents(doc_batch, expected_docs) + + +def test_airtable_connector_all_metadata( + mock_get_api_key: MagicMock, airtable_config: AirtableConfig +) -> None: + connector = AirtableConnector( + base_id=airtable_config.base_id, + table_name_or_id=airtable_config.table_identifier, + treat_all_non_attachment_fields_as_metadata=True, + ) + connector.load_credentials( + { + "airtable_access_token": airtable_config.access_token, + } + ) + doc_batch_generator = connector.load_from_state() + doc_batch = next(doc_batch_generator) + with pytest.raises(StopIteration): + next(doc_batch_generator) + + # NOTE: one of the rows has no attachments -> no content -> no document + assert len(doc_batch) == 1 + + expected_docs = [ create_test_document( id="reccSlIA4pZEFxPBg", title="Printer Issue", @@ -149,50 +295,9 @@ def test_airtable_connector_basic( "https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide", ) ], + all_fields_as_metadata=True, ), ] - # Compare each document field by field - for actual, expected in zip(doc_batch, expected_docs): - assert actual.id == expected.id, f"ID mismatch for document {actual.id}" - assert ( - actual.source == expected.source - ), f"Source mismatch for document {actual.id}" - assert ( - actual.semantic_identifier == expected.semantic_identifier - ), f"Semantic identifier mismatch for document {actual.id}" - assert ( - actual.metadata == expected.metadata - ), f"Metadata mismatch for document {actual.id}" - assert ( - actual.doc_updated_at == expected.doc_updated_at - ), f"Updated at mismatch for document {actual.id}" - assert ( - actual.primary_owners == expected.primary_owners - ), f"Primary owners mismatch for document {actual.id}" - assert ( - actual.secondary_owners == expected.secondary_owners - ), f"Secondary owners mismatch for document {actual.id}" - assert ( - actual.title == expected.title - ), f"Title mismatch for document {actual.id}" - assert ( - actual.from_ingestion_api == expected.from_ingestion_api - ), f"Ingestion API flag mismatch for document {actual.id}" - assert ( - actual.additional_info == expected.additional_info - ), f"Additional info mismatch for document {actual.id}" - - # Compare sections - assert len(actual.sections) == len( - expected.sections - ), f"Number of sections mismatch for document {actual.id}" - for i, (actual_section, expected_section) in enumerate( - zip(actual.sections, expected.sections) - ): - assert ( - actual_section.text == expected_section.text - ), f"Section {i} text mismatch for document {actual.id}" - assert ( - actual_section.link == expected_section.link - ), f"Section {i} link mismatch for document {actual.id}" + # Compare documents using the utility function + compare_documents(doc_batch, expected_docs) diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index 7f6f0050b..715310468 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -1106,6 +1106,14 @@ For example, specifying .*-support.* as a "channel" will cause the connector to name: "table_name_or_id", optional: false, }, + { + type: "checkbox", + label: "Treat all fields except attachments as metadata", + name: "treat_all_non_attachment_fields_as_metadata", + description: + "Choose this if the primary content to index are attachments and all other columns are metadata for these attachments.", + optional: false, + }, ], advanced_values: [], overrideDefaultFreq: 60 * 60 * 24, From 7e9b12403a7a95a095d5a06c457b7e0d4771b545 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:29:23 -0800 Subject: [PATCH 3/6] Allow Slack workflow messages when respond_to_bots is enabled (#3819) * Allow workflow 'bot_message' subtype when respond_to_bots is enabled Co-Authored-By: Chris Weaver * refactor: consolidate bot message checks to avoid redundant code Co-Authored-By: Chris Weaver * style: fix black formatting Co-Authored-By: Chris Weaver * Remove unnecessary call --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Chris Weaver Co-authored-by: Weves --- backend/onyx/onyxbot/slack/listener.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/backend/onyx/onyxbot/slack/listener.py b/backend/onyx/onyxbot/slack/listener.py index 106891d41..721140fa1 100644 --- a/backend/onyx/onyxbot/slack/listener.py +++ b/backend/onyx/onyxbot/slack/listener.py @@ -537,30 +537,36 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) - # Let the tag flow handle this case, don't reply twice return False - if event.get("bot_profile"): + # Check if this is a bot message (either via bot_profile or bot_message subtype) + is_bot_message = bool( + event.get("bot_profile") or event.get("subtype") == "bot_message" + ) + if is_bot_message: channel_name, _ = get_channel_name_from_id( client=client.web_client, channel_id=channel ) - with get_session_with_tenant(client.tenant_id) as db_session: slack_channel_config = get_slack_channel_config_for_bot_and_channel( db_session=db_session, slack_bot_id=client.slack_bot_id, channel_name=channel_name, ) + # If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message if (not bot_tag_id or bot_tag_id not in msg) and ( not slack_channel_config or not slack_channel_config.channel_config.get("respond_to_bots") ): - channel_specific_logger.info("Ignoring message from bot") + channel_specific_logger.info( + "Ignoring message from bot since respond_to_bots is disabled" + ) return False # Ignore things like channel_join, channel_leave, etc. # NOTE: "file_share" is just a message with a file attachment, so we # should not ignore it message_subtype = event.get("subtype") - if message_subtype not in [None, "file_share"]: + if message_subtype not in [None, "file_share", "bot_message"]: channel_specific_logger.info( f"Ignoring message with subtype '{message_subtype}' since it is a special message type" ) From 601037abb5403868210558356a79a84c66058b9f Mon Sep 17 00:00:00 2001 From: pablonyx Date: Tue, 28 Jan 2025 17:42:28 -0800 Subject: [PATCH 4/6] Customer love (#3813) * additional logs * disable gdrive oauth * Revert "additional ogs" This reverts commit 1bd7f9d433bcb6708c7d29945ee695fd7ed97382. --- web/src/lib/types.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 42848ba3f..414b23b46 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -358,7 +358,8 @@ export type ConfigurableSources = Exclude< export const oauthSupportedSources: ConfigurableSources[] = [ ValidSources.Slack, - ValidSources.GoogleDrive, + // NOTE: temporarily disabled until our GDrive App is approved + // ValidSources.GoogleDrive, ]; export type OAuthSupportedSource = (typeof oauthSupportedSources)[number]; From 2701f83634b826045bc33ac382d1cdf94daee8d0 Mon Sep 17 00:00:00 2001 From: pablonyx Date: Tue, 28 Jan 2025 18:44:50 -0800 Subject: [PATCH 5/6] llm provider re-org (#3810) * nit * clean up logic * update --- web/src/app/chat/ChatPage.tsx | 80 ++++------ web/src/app/chat/input/LLMPopover.tsx | 4 +- web/src/app/chat/modal/UserSettingsModal.tsx | 12 +- web/src/lib/hooks.ts | 153 +++++++++++++------ web/src/lib/llm/utils.ts | 8 + 5 files changed, 153 insertions(+), 104 deletions(-) diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 9d0e3fea4..4ab259381 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -76,13 +76,7 @@ import { import { buildFilters } from "@/lib/search/utils"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import Dropzone from "react-dropzone"; -import { - checkLLMSupportsImageInput, - getFinalLLM, - destructureValue, - getLLMProviderOverrideForPersona, -} from "@/lib/llm/utils"; - +import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils"; import { ChatInputBar } from "./input/ChatInputBar"; import { useChatContext } from "@/components/context/ChatContext"; import { v4 as uuidv4 } from "uuid"; @@ -203,6 +197,12 @@ export function ChatPage({ const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open + const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null; + + const selectedChatSession = chatSessions.find( + (chatSession) => chatSession.id === existingChatSessionId + ); + useEffect(() => { if (user?.is_anonymous_user) { Cookies.set( @@ -240,12 +240,6 @@ export function ChatPage({ } }; - const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null; - - const selectedChatSession = chatSessions.find( - (chatSession) => chatSession.id === existingChatSessionId - ); - const chatSessionIdRef = useRef(existingChatSessionId); // Only updates on session load (ie. rename / switching chat session) @@ -293,12 +287,6 @@ export function ChatPage({ ); }; - const llmOverrideManager = useLlmOverride( - llmProviders, - user?.preferences.default_model, - selectedChatSession - ); - const [alternativeAssistant, setAlternativeAssistant] = useState(null); @@ -307,12 +295,27 @@ export function ChatPage({ const { recentAssistants, refreshRecentAssistants } = useAssistants(); - const liveAssistant: Persona | undefined = - alternativeAssistant || - selectedAssistant || - recentAssistants[0] || - finalAssistants[0] || - availableAssistants[0]; + const liveAssistant: Persona | undefined = useMemo( + () => + alternativeAssistant || + selectedAssistant || + recentAssistants[0] || + finalAssistants[0] || + availableAssistants[0], + [ + alternativeAssistant, + selectedAssistant, + recentAssistants, + finalAssistants, + availableAssistants, + ] + ); + + const llmOverrideManager = useLlmOverride( + llmProviders, + selectedChatSession, + liveAssistant + ); const noAssistants = liveAssistant == null || liveAssistant == undefined; @@ -320,24 +323,6 @@ export function ChatPage({ const uniqueSources = Array.from(new Set(availableSources)); const sources = uniqueSources.map((source) => getSourceMetadata(source)); - // always set the model override for the chat session, when an assistant, llm provider, or user preference exists - useEffect(() => { - if (noAssistants) return; - const personaDefault = getLLMProviderOverrideForPersona( - liveAssistant, - llmProviders - ); - - if (personaDefault) { - llmOverrideManager.updateLLMOverride(personaDefault); - } else if (user?.preferences.default_model) { - llmOverrideManager.updateLLMOverride( - destructureValue(user?.preferences.default_model) - ); - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [liveAssistant, user?.preferences.default_model]); - const stopGenerating = () => { const currentSession = currentSessionId(); const controller = abortControllers.get(currentSession); @@ -419,7 +404,6 @@ export function ChatPage({ filterManager.setTimeRange(null); // reset LLM overrides (based on chat session!) - llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession); llmOverrideManager.updateTemperature(null); // remove uploaded files @@ -1283,13 +1267,11 @@ export function ChatPage({ modelProvider: modelOverRide?.name || llmOverrideManager.llmOverride.name || - llmOverrideManager.globalDefault.name || undefined, modelVersion: modelOverRide?.modelName || llmOverrideManager.llmOverride.modelName || searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || - llmOverrideManager.globalDefault.modelName || undefined, temperature: llmOverrideManager.temperature || undefined, systemPromptOverride: @@ -1952,6 +1934,7 @@ export function ChatPage({ }; // eslint-disable-next-line react-hooks/exhaustive-deps }, [router]); + const [sharedChatSession, setSharedChatSession] = useState(); @@ -2059,7 +2042,9 @@ export function ChatPage({ {(settingsToggled || userSettingsToggled) && ( + llmOverrideManager.updateLLMOverride(newOverride) + } defaultModel={user?.preferences.default_model!} llmProviders={llmProviders} onClose={() => { @@ -2749,6 +2734,7 @@ export function ChatPage({ )} + void; llmProviders: LLMProviderDescriptor[]; - setLlmOverride?: Dispatch>; + setLlmOverride?: (newOverride: LlmOverride) => void; onClose: () => void; defaultModel: string | null; }) { diff --git a/web/src/lib/hooks.ts b/web/src/lib/hooks.ts index 0b67ab9d4..b4ee55747 100644 --- a/web/src/lib/hooks.ts +++ b/web/src/lib/hooks.ts @@ -13,16 +13,21 @@ import { errorHandlingFetcher } from "./fetcher"; import { useContext, useEffect, useState } from "react"; import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector"; import { Filters, SourceMetadata } from "./search/interfaces"; -import { destructureValue, structureValue } from "./llm/utils"; +import { + destructureValue, + findProviderForModel, + structureValue, +} from "./llm/utils"; import { ChatSession } from "@/app/chat/interfaces"; import { AllUsersResponse } from "./types"; import { Credential } from "./connectors/credentials"; import { SettingsContext } from "@/components/settings/SettingsProvider"; -import { PersonaLabel } from "@/app/admin/assistants/interfaces"; +import { Persona, PersonaLabel } from "@/app/admin/assistants/interfaces"; import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces"; import { isAnthropic } from "@/app/admin/configuration/llm/interfaces"; import { getSourceMetadata } from "./sources"; import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants"; +import { useUser } from "@/components/user/UserProvider"; const CREDENTIAL_URL = "/api/manage/admin/credential"; @@ -355,82 +360,141 @@ export interface LlmOverride { export interface LlmOverrideManager { llmOverride: LlmOverride; updateLLMOverride: (newOverride: LlmOverride) => void; - globalDefault: LlmOverride; - setGlobalDefault: React.Dispatch>; temperature: number | null; updateTemperature: (temperature: number | null) => void; updateModelOverrideForChatSession: (chatSession?: ChatSession) => void; imageFilesPresent: boolean; updateImageFilesPresent: (present: boolean) => void; + liveAssistant: Persona | null; } + +/* +LLM Override is as follows (i.e. this order) +- User override (explicitly set in the chat input bar) +- User preference (defaults to system wide default if no preference set) + +On switching to an existing or new chat session or a different assistant: +- If we have a live assistant after any switch with a model override, use that- otherwise use the above hierarchy + +Thus, the input should be +- User preference +- LLM Providers (which contain the system wide default) +- Current assistant + +Changes take place as +- liveAssistant or currentChatSession changes (and the associated model override is set) +- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant) + +If we have a live assistant, we should use that model override +*/ + export function useLlmOverride( llmProviders: LLMProviderDescriptor[], - globalModel?: string | null, currentChatSession?: ChatSession, - defaultTemperature?: number + liveAssistant?: Persona ): LlmOverrideManager { + const { user } = useUser(); + + const [chatSession, setChatSession] = useState(null); + + const llmOverrideUpdate = () => { + if (!chatSession && currentChatSession) { + setChatSession(currentChatSession || null); + return; + } + + if (liveAssistant?.llm_model_version_override) { + setLlmOverride( + getValidLlmOverride(liveAssistant.llm_model_version_override) + ); + } else if (currentChatSession?.current_alternate_model) { + setLlmOverride( + getValidLlmOverride(currentChatSession.current_alternate_model) + ); + } else if (user?.preferences?.default_model) { + setLlmOverride(getValidLlmOverride(user.preferences.default_model)); + return; + } else { + const defaultProvider = llmProviders.find( + (provider) => provider.is_default_provider + ); + + if (defaultProvider) { + setLlmOverride({ + name: defaultProvider.name, + provider: defaultProvider.provider, + modelName: defaultProvider.default_model_name, + }); + } + } + setChatSession(currentChatSession || null); + }; + const getValidLlmOverride = ( overrideModel: string | null | undefined ): LlmOverride => { if (overrideModel) { const model = destructureValue(overrideModel); - const provider = llmProviders.find( - (p) => - p.model_names.includes(model.modelName) && - p.provider === model.provider + if (!(model.modelName && model.modelName.length > 0)) { + const provider = llmProviders.find((p) => + p.model_names.includes(overrideModel) + ); + if (provider) { + return { + modelName: overrideModel, + name: provider.name, + provider: provider.provider, + }; + } + } + + const provider = llmProviders.find((p) => + p.model_names.includes(model.modelName) ); + if (provider) { return { ...model, name: provider.name }; } } return { name: "", provider: "", modelName: "" }; }; + const [imageFilesPresent, setImageFilesPresent] = useState(false); const updateImageFilesPresent = (present: boolean) => { setImageFilesPresent(present); }; - const [globalDefault, setGlobalDefault] = useState( - getValidLlmOverride(globalModel) - ); - const updateLLMOverride = (newOverride: LlmOverride) => { - setLlmOverride( - getValidLlmOverride( - structureValue( - newOverride.name, - newOverride.provider, - newOverride.modelName - ) - ) - ); - }; + const [llmOverride, setLlmOverride] = useState({ + name: "", + provider: "", + modelName: "", + }); - const [llmOverride, setLlmOverride] = useState( - currentChatSession && currentChatSession.current_alternate_model - ? getValidLlmOverride(currentChatSession.current_alternate_model) - : { name: "", provider: "", modelName: "" } - ); + // Manually set the override + const updateLLMOverride = (newOverride: LlmOverride) => { + const provider = + newOverride.provider || + findProviderForModel(llmProviders, newOverride.modelName); + const structuredValue = structureValue( + newOverride.name, + provider, + newOverride.modelName + ); + setLlmOverride(getValidLlmOverride(structuredValue)); + }; const updateModelOverrideForChatSession = (chatSession?: ChatSession) => { - setLlmOverride( - chatSession && chatSession.current_alternate_model - ? getValidLlmOverride(chatSession.current_alternate_model) - : globalDefault - ); + if (chatSession && chatSession.current_alternate_model?.length > 0) { + setLlmOverride(getValidLlmOverride(chatSession.current_alternate_model)); + } }; - const [temperature, setTemperature] = useState( - defaultTemperature !== undefined ? defaultTemperature : 0 - ); + const [temperature, setTemperature] = useState(0); useEffect(() => { - setGlobalDefault(getValidLlmOverride(globalModel)); - }, [globalModel, llmProviders]); - - useEffect(() => { - setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0); - }, [defaultTemperature]); + llmOverrideUpdate(); + }, [liveAssistant, currentChatSession]); useEffect(() => { if (isAnthropic(llmOverride.provider, llmOverride.modelName)) { @@ -450,12 +514,11 @@ export function useLlmOverride( updateModelOverrideForChatSession, llmOverride, updateLLMOverride, - globalDefault, - setGlobalDefault, temperature, updateTemperature, imageFilesPresent, updateImageFilesPresent, + liveAssistant: liveAssistant ?? null, }; } diff --git a/web/src/lib/llm/utils.ts b/web/src/lib/llm/utils.ts index 3eca6cacc..1880385e0 100644 --- a/web/src/lib/llm/utils.ts +++ b/web/src/lib/llm/utils.ts @@ -143,3 +143,11 @@ export const destructureValue = (value: string): LlmOverride => { modelName, }; }; + +export const findProviderForModel = ( + llmProviders: LLMProviderDescriptor[], + modelName: string +): string => { + const provider = llmProviders.find((p) => p.model_names.includes(modelName)); + return provider ? provider.provider : ""; +}; From 028e877342e2e89d5ba76c281892b6ac74c28a2a Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Tue, 28 Jan 2025 20:06:09 -0800 Subject: [PATCH 6/6] Sharepoint fixes (#3826) * Sharepoint connector fixes * Refactor sharepoint to be better * Improve env variable naming * Fix * Add new secrets * Fix unstructured failure --- .../workflows/pr-python-connector-tests.yml | 6 + .../onyx/connectors/sharepoint/connector.py | 219 ++++++++++++------ .../airtable/test_airtable_basic.py | 15 +- backend/tests/daily/connectors/conftest.py | 14 ++ .../sharepoint/test_sharepoint_connector.py | 178 ++++++++++++++ 5 files changed, 345 insertions(+), 87 deletions(-) create mode 100644 backend/tests/daily/connectors/conftest.py create mode 100644 backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml index c3947b233..81ab06665 100644 --- a/.github/workflows/pr-python-connector-tests.yml +++ b/.github/workflows/pr-python-connector-tests.yml @@ -39,6 +39,12 @@ env: AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }} AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }} AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }} + # Sharepoint + SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }} + SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }} + SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }} + SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }} + jobs: connectors-check: # See https://runs-on.com/runners/linux/ diff --git a/backend/onyx/connectors/sharepoint/connector.py b/backend/onyx/connectors/sharepoint/connector.py index 88874db6b..5747df03b 100644 --- a/backend/onyx/connectors/sharepoint/connector.py +++ b/backend/onyx/connectors/sharepoint/connector.py @@ -1,17 +1,14 @@ import io import os -from dataclasses import dataclass -from dataclasses import field from datetime import datetime from datetime import timezone from typing import Any -from typing import Optional from urllib.parse import unquote import msal # type: ignore from office365.graph_client import GraphClient # type: ignore from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore -from office365.onedrive.sites.site import Site # type: ignore +from pydantic import BaseModel from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource @@ -30,16 +27,25 @@ from onyx.utils.logger import setup_logger logger = setup_logger() -@dataclass -class SiteData: - url: str | None - folder: Optional[str] - sites: list = field(default_factory=list) - driveitems: list = field(default_factory=list) +class SiteDescriptor(BaseModel): + """Data class for storing SharePoint site information. + + Args: + url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests) + drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library") + If None, all drives will be accessed. + folder_path: The folder path within the drive to access (e.g. "test/nested with spaces") + If None, all folders will be accessed. + """ + + url: str + drive_name: str | None + folder_path: str | None def _convert_driveitem_to_document( driveitem: DriveItem, + drive_name: str, ) -> Document: file_text = extract_file_text( file=io.BytesIO(driveitem.get_content().execute_query().value), @@ -59,7 +65,7 @@ def _convert_driveitem_to_document( email=driveitem.last_modified_by.user.email, ) ], - metadata={}, + metadata={"drive": drive_name}, ) return doc @@ -71,106 +77,171 @@ class SharepointConnector(LoadConnector, PollConnector): sites: list[str] = [], ) -> None: self.batch_size = batch_size - self.graph_client: GraphClient | None = None - self.site_data: list[SiteData] = self._extract_site_and_folder(sites) + self._graph_client: GraphClient | None = None + self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info( + sites + ) + + @property + def graph_client(self) -> GraphClient: + if self._graph_client is None: + raise ConnectorMissingCredentialError("Sharepoint") + + return self._graph_client @staticmethod - def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]: + def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]: site_data_list = [] for url in site_urls: parts = url.strip().split("/") if "sites" in parts: sites_index = parts.index("sites") site_url = "/".join(parts[: sites_index + 2]) - folder = ( - "/".join(unquote(part) for part in parts[sites_index + 2 :]) - if len(parts) > sites_index + 2 - else None - ) - # Handling for new URL structure - if folder and folder.startswith("Shared Documents/"): - folder = folder[len("Shared Documents/") :] + remaining_parts = parts[sites_index + 2 :] + + # Extract drive name and folder path + if remaining_parts: + drive_name = unquote(remaining_parts[0]) + folder_path = ( + "/".join(unquote(part) for part in remaining_parts[1:]) + if len(remaining_parts) > 1 + else None + ) + else: + drive_name = None + folder_path = None + site_data_list.append( - SiteData(url=site_url, folder=folder, sites=[], driveitems=[]) + SiteDescriptor( + url=site_url, + drive_name=drive_name, + folder_path=folder_path, + ) ) return site_data_list - def _populate_sitedata_driveitems( + def _fetch_driveitems( self, + site_descriptor: SiteDescriptor, start: datetime | None = None, end: datetime | None = None, - ) -> None: + ) -> list[tuple[DriveItem, str]]: filter_str = "" if start is not None and end is not None: - filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}" + filter_str = ( + f"last_modified_datetime ge {start.isoformat()} and " + f"last_modified_datetime le {end.isoformat()}" + ) - for element in self.site_data: - sites: list[Site] = [] - for site in element.sites: - site_sublist = site.lists.get().execute_query() - sites.extend(site_sublist) + final_driveitems: list[tuple[DriveItem, str]] = [] + try: + site = self.graph_client.sites.get_by_url(site_descriptor.url) - for site in sites: + # Get all drives in the site + drives = site.drives.get().execute_query() + logger.debug(f"Found drives: {[drive.name for drive in drives]}") + + # Filter drives based on the requested drive name + if site_descriptor.drive_name: + drives = [ + drive + for drive in drives + if drive.name == site_descriptor.drive_name + or ( + drive.name == "Documents" + and site_descriptor.drive_name == "Shared Documents" + ) + ] + if not drives: + logger.warning(f"Drive '{site_descriptor.drive_name}' not found") + return [] + + # Process each matching drive + for drive in drives: try: - query = site.drive.root.get_files(True, 1000) + root_folder = drive.root + if site_descriptor.folder_path: + # If a specific folder is requested, navigate to it + for folder_part in site_descriptor.folder_path.split("/"): + root_folder = root_folder.get_by_path(folder_part) + + # Get all items recursively + query = root_folder.get_files(True, 1000) if filter_str: query = query.filter(filter_str) driveitems = query.execute_query() - if element.folder: - expected_path = f"/root:/{element.folder}" + logger.debug( + f"Found {len(driveitems)} items in drive '{drive.name}'" + ) + + # Use "Shared Documents" as the library name for the default "Documents" drive + drive_name = ( + "Shared Documents" if drive.name == "Documents" else drive.name + ) + + if site_descriptor.folder_path: + # Filter items to ensure they're in the specified folder or its subfolders + # The path will be in format: /drives/{drive_id}/root:/folder/path filtered_driveitems = [ - item + (item, drive_name) for item in driveitems - if item.parent_reference.path.endswith(expected_path) + if any( + path_part == site_descriptor.folder_path + or path_part.startswith( + site_descriptor.folder_path + "/" + ) + for path_part in item.parent_reference.path.split( + "root:/" + )[1].split("/") + ) ] if len(filtered_driveitems) == 0: all_paths = [ item.parent_reference.path for item in driveitems ] logger.warning( - f"Nothing found for folder '{expected_path}' in any of valid paths: {all_paths}" + f"Nothing found for folder '{site_descriptor.folder_path}' " + f"in; any of valid paths: {all_paths}" ) - element.driveitems.extend(filtered_driveitems) + final_driveitems.extend(filtered_driveitems) else: - element.driveitems.extend(driveitems) + final_driveitems.extend( + [(item, drive_name) for item in driveitems] + ) + except Exception as e: + # Some drives might not be accessible + logger.warning(f"Failed to process drive: {str(e)}") - except Exception: - # Sites include things that do not contain .drive.root so this fails - # but this is fine, as there are no actually documents in those - pass + except Exception as e: + # Sites include things that do not contain drives so this fails + # but this is fine, as there are no actual documents in those + logger.warning(f"Failed to process site: {str(e)}") - def _populate_sitedata_sites(self) -> None: - if self.graph_client is None: - raise ConnectorMissingCredentialError("Sharepoint") + return final_driveitems - if self.site_data: - for element in self.site_data: - element.sites = [ - self.graph_client.sites.get_by_url(element.url) - .get() - .execute_query() - ] - else: - sites = self.graph_client.sites.get_all().execute_query() - self.site_data = [ - SiteData(url=None, folder=None, sites=sites, driveitems=[]) - ] + def _fetch_sites(self) -> list[SiteDescriptor]: + sites = self.graph_client.sites.get_all().execute_query() + site_descriptors = [ + SiteDescriptor( + url=sites.resource_url, + drive_name=None, + folder_path=None, + ) + ] + return site_descriptors def _fetch_from_sharepoint( self, start: datetime | None = None, end: datetime | None = None ) -> GenerateDocumentsOutput: - if self.graph_client is None: - raise ConnectorMissingCredentialError("Sharepoint") - - self._populate_sitedata_sites() - self._populate_sitedata_driveitems(start=start, end=end) + site_descriptors = self.site_descriptors or self._fetch_sites() # goes over all urls, converts them into Document objects and then yields them in batches doc_batch: list[Document] = [] - for element in self.site_data: - for driveitem in element.driveitems: + for site_descriptor in site_descriptors: + driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end) + for driveitem, drive_name in driveitems: logger.debug(f"Processing: {driveitem.web_url}") - doc_batch.append(_convert_driveitem_to_document(driveitem)) + doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name)) if len(doc_batch) >= self.batch_size: yield doc_batch @@ -197,7 +268,7 @@ class SharepointConnector(LoadConnector, PollConnector): ) return token - self.graph_client = GraphClient(_acquire_token_func) + self._graph_client = GraphClient(_acquire_token_func) return None def load_from_state(self) -> GenerateDocumentsOutput: @@ -206,19 +277,19 @@ class SharepointConnector(LoadConnector, PollConnector): def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: - start_datetime = datetime.utcfromtimestamp(start) - end_datetime = datetime.utcfromtimestamp(end) + start_datetime = datetime.fromtimestamp(start, timezone.utc) + end_datetime = datetime.fromtimestamp(end, timezone.utc) return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime) if __name__ == "__main__": - connector = SharepointConnector(sites=os.environ["SITES"].split(",")) + connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(",")) connector.load_credentials( { - "sp_client_id": os.environ["SP_CLIENT_ID"], - "sp_client_secret": os.environ["SP_CLIENT_SECRET"], - "sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"], + "sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"], + "sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"], + "sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"], } ) document_batches = connector.load_from_state() diff --git a/backend/tests/daily/connectors/airtable/test_airtable_basic.py b/backend/tests/daily/connectors/airtable/test_airtable_basic.py index bb5e312f1..6610d91d6 100644 --- a/backend/tests/daily/connectors/airtable/test_airtable_basic.py +++ b/backend/tests/daily/connectors/airtable/test_airtable_basic.py @@ -1,7 +1,5 @@ import os -from collections.abc import Generator from unittest.mock import MagicMock -from unittest.mock import patch import pytest from pydantic import BaseModel @@ -127,15 +125,6 @@ def create_test_document( ) -@pytest.fixture -def mock_get_api_key() -> Generator[MagicMock, None, None]: - with patch( - "onyx.file_processing.extract_file_text.get_unstructured_api_key", - return_value=None, - ) as mock: - yield mock - - def compare_documents( actual_docs: list[Document], expected_docs: list[Document] ) -> None: @@ -191,7 +180,7 @@ def compare_documents( def test_airtable_connector_basic( - mock_get_api_key: MagicMock, airtable_config: AirtableConfig + mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig ) -> None: """Test behavior when all non-attachment fields are treated as metadata.""" connector = AirtableConnector( @@ -253,7 +242,7 @@ def test_airtable_connector_basic( def test_airtable_connector_all_metadata( - mock_get_api_key: MagicMock, airtable_config: AirtableConfig + mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig ) -> None: connector = AirtableConnector( base_id=airtable_config.base_id, diff --git a/backend/tests/daily/connectors/conftest.py b/backend/tests/daily/connectors/conftest.py new file mode 100644 index 000000000..88a00b57a --- /dev/null +++ b/backend/tests/daily/connectors/conftest.py @@ -0,0 +1,14 @@ +from collections.abc import Generator +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + + +@pytest.fixture +def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]: + with patch( + "onyx.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, + ) as mock: + yield mock diff --git a/backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py b/backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py new file mode 100644 index 000000000..8fc40564f --- /dev/null +++ b/backend/tests/daily/connectors/sharepoint/test_sharepoint_connector.py @@ -0,0 +1,178 @@ +import os +from dataclasses import dataclass +from datetime import datetime +from datetime import timezone +from unittest.mock import MagicMock + +import pytest + +from onyx.configs.constants import DocumentSource +from onyx.connectors.models import Document +from onyx.connectors.sharepoint.connector import SharepointConnector + + +@dataclass +class ExpectedDocument: + semantic_identifier: str + content: str + folder_path: str | None = None + library: str = "Shared Documents" # Default to main library + + +EXPECTED_DOCUMENTS = [ + ExpectedDocument( + semantic_identifier="test1.docx", + content="test1", + folder_path="test", + ), + ExpectedDocument( + semantic_identifier="test2.docx", + content="test2", + folder_path="test/nested with spaces", + ), + ExpectedDocument( + semantic_identifier="should-not-index-on-specific-folder.docx", + content="should-not-index-on-specific-folder", + folder_path=None, # root folder + ), + ExpectedDocument( + semantic_identifier="other.docx", + content="other", + folder_path=None, + library="Other Library", + ), +] + + +def verify_document_metadata(doc: Document) -> None: + """Verify common metadata that should be present on all documents.""" + assert isinstance(doc.doc_updated_at, datetime) + assert doc.doc_updated_at.tzinfo == timezone.utc + assert doc.source == DocumentSource.SHAREPOINT + assert doc.primary_owners is not None + assert len(doc.primary_owners) == 1 + owner = doc.primary_owners[0] + assert owner.display_name is not None + assert owner.email is not None + + +def verify_document_content(doc: Document, expected: ExpectedDocument) -> None: + """Verify a document matches its expected content.""" + assert doc.semantic_identifier == expected.semantic_identifier + assert len(doc.sections) == 1 + assert expected.content in doc.sections[0].text + verify_document_metadata(doc) + + +def find_document(documents: list[Document], semantic_identifier: str) -> Document: + """Find a document by its semantic identifier.""" + matching_docs = [ + d for d in documents if d.semantic_identifier == semantic_identifier + ] + assert ( + len(matching_docs) == 1 + ), f"Expected exactly one document with identifier {semantic_identifier}" + return matching_docs[0] + + +@pytest.fixture +def sharepoint_credentials() -> dict[str, str]: + return { + "sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"], + "sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"], + "sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"], + } + + +def test_sharepoint_connector_specific_folder( + mock_get_unstructured_api_key: MagicMock, + sharepoint_credentials: dict[str, str], +) -> None: + # Initialize connector with the test site URL and specific folder + connector = SharepointConnector( + sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"] + ) + + # Load credentials + connector.load_credentials(sharepoint_credentials) + + # Get all documents + document_batches = list(connector.load_from_state()) + found_documents: list[Document] = [ + doc for batch in document_batches for doc in batch + ] + + # Should only find documents in the test folder + test_folder_docs = [ + doc + for doc in EXPECTED_DOCUMENTS + if doc.folder_path and doc.folder_path.startswith("test") + ] + assert len(found_documents) == len( + test_folder_docs + ), "Should only find documents in test folder" + + # Verify each expected document + for expected in test_folder_docs: + doc = find_document(found_documents, expected.semantic_identifier) + verify_document_content(doc, expected) + + +def test_sharepoint_connector_root_folder( + mock_get_unstructured_api_key: MagicMock, + sharepoint_credentials: dict[str, str], +) -> None: + # Initialize connector with the base site URL + connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]]) + + # Load credentials + connector.load_credentials(sharepoint_credentials) + + # Get all documents + document_batches = list(connector.load_from_state()) + found_documents: list[Document] = [ + doc for batch in document_batches for doc in batch + ] + + assert len(found_documents) == len( + EXPECTED_DOCUMENTS + ), "Should find all documents in main library" + + # Verify each expected document + for expected in EXPECTED_DOCUMENTS: + doc = find_document(found_documents, expected.semantic_identifier) + verify_document_content(doc, expected) + + +def test_sharepoint_connector_other_library( + mock_get_unstructured_api_key: MagicMock, + sharepoint_credentials: dict[str, str], +) -> None: + # Initialize connector with the other library + connector = SharepointConnector( + sites=[ + os.environ["SHAREPOINT_SITE"] + "/Other Library", + ] + ) + + # Load credentials + connector.load_credentials(sharepoint_credentials) + + # Get all documents + document_batches = list(connector.load_from_state()) + found_documents: list[Document] = [ + doc for batch in document_batches for doc in batch + ] + expected_documents: list[ExpectedDocument] = [ + doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library" + ] + + # Should find all documents in `Other Library` + assert len(found_documents) == len( + expected_documents + ), "Should find all documents in `Other Library`" + + # Verify each expected document + for expected in expected_documents: + doc = find_document(found_documents, expected.semantic_identifier) + verify_document_content(doc, expected)