From 4dafc3aa6d2ad1753c8e3a14fc4269f2d3aa10bb Mon Sep 17 00:00:00 2001 From: evan-danswer Date: Tue, 18 Mar 2025 21:14:05 -0700 Subject: [PATCH 01/18] Update README.md --- backend/tests/integration/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/backend/tests/integration/README.md b/backend/tests/integration/README.md index 414c8114d..5093d87e6 100644 --- a/backend/tests/integration/README.md +++ b/backend/tests/integration/README.md @@ -28,6 +28,14 @@ The idea is that each test can use the manager class to create (.create()) a "te pytest -s tests/integration/tests/path_to/test_file.py::test_function_name ``` +Running some single tests require the `mock_connector_server` container to be running. If the above doesn't work, +navigate to `backend/tests/integration/mock_services` and run +```sh +docker compose -f docker-compose.mock-it-services.yml -p mock-it-services-stack up -d +``` +You will have to modify the networks section of the docker-compose file to `_default` if you brought up the standard +onyx services with a name different from the default `onyx-stack`. + ## Guidelines for Writing Integration Tests - As authentication is currently required for all tests, each test should start by creating a user. From ae774105e31ff36b9ab833763d358bc5dbe621b2 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Wed, 19 Mar 2025 11:26:49 -0700 Subject: [PATCH 02/18] Fix slack connector creation (#4303) * Make it fail fast + succeed validation if rate limiting is happening * Add logging + reduce spam --- backend/onyx/connectors/slack/connector.py | 26 ++++++++++++++----- backend/onyx/server/manage/slack_bot.py | 7 ++++- .../channels/SlackChannelConfigFormFields.tsx | 4 +++ 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index 138c3b422..66da040d9 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -480,6 +480,7 @@ def _process_message( class SlackConnector(SlimConnector, CheckpointConnector): MAX_WORKERS = 2 + FAST_TIMEOUT = 1 def __init__( self, @@ -493,7 +494,7 @@ class SlackConnector(SlimConnector, CheckpointConnector): self.channel_regex_enabled = channel_regex_enabled self.batch_size = batch_size self.client: WebClient | None = None - + self.fast_client: WebClient | None = None # just used for efficiency self.text_cleaner: SlackTextCleaner | None = None self.user_cache: dict[str, BasicExpertInfo | None] = {} @@ -501,6 +502,10 @@ class SlackConnector(SlimConnector, CheckpointConnector): def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: bot_token = credentials["slack_bot_token"] self.client = WebClient(token=bot_token) + # use for requests that must return quickly (e.g. realtime flows where user is waiting) + self.fast_client = WebClient( + token=bot_token, timeout=SlackConnector.FAST_TIMEOUT + ) self.text_cleaner = SlackTextCleaner(client=self.client) return None @@ -676,12 +681,12 @@ class SlackConnector(SlimConnector, CheckpointConnector): 2. Ensure the bot has enough scope to list channels. 3. Check that every channel specified in self.channels exists (only when regex is not enabled). """ - if self.client is None: + if self.fast_client is None: raise ConnectorMissingCredentialError("Slack credentials not loaded.") try: # 1) Validate connection to workspace - auth_response = self.client.auth_test() + auth_response = self.fast_client.auth_test() if not auth_response.get("ok", False): error_msg = auth_response.get( "error", "Unknown error from Slack auth_test" @@ -689,7 +694,7 @@ class SlackConnector(SlimConnector, CheckpointConnector): raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}") # 2) Minimal test to confirm listing channels works - test_resp = self.client.conversations_list( + test_resp = self.fast_client.conversations_list( limit=1, types=["public_channel"] ) if not test_resp.get("ok", False): @@ -709,7 +714,7 @@ class SlackConnector(SlimConnector, CheckpointConnector): # 3) If channels are specified and regex is not enabled, verify each is accessible if self.channels and not self.channel_regex_enabled: accessible_channels = get_channels( - client=self.client, + client=self.fast_client, exclude_archived=True, get_public=True, get_private=True, @@ -729,7 +734,16 @@ class SlackConnector(SlimConnector, CheckpointConnector): except SlackApiError as e: slack_error = e.response.get("error", "") - if slack_error == "missing_scope": + if slack_error == "ratelimited": + # Handle rate limiting specifically + retry_after = int(e.response.headers.get("Retry-After", 1)) + logger.warning( + f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. " + "Proceeding with validation, but be aware that connector operations might be throttled." + ) + # Continue validation without failing - the connector is likely valid but just rate limited + return + elif slack_error == "missing_scope": raise InsufficientPermissionsError( "Slack bot token lacks the necessary scope to list/access channels. " "Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)." diff --git a/backend/onyx/server/manage/slack_bot.py b/backend/onyx/server/manage/slack_bot.py index d47f8d828..0da66a5f6 100644 --- a/backend/onyx/server/manage/slack_bot.py +++ b/backend/onyx/server/manage/slack_bot.py @@ -32,10 +32,14 @@ from onyx.server.manage.models import SlackChannelConfig from onyx.server.manage.models import SlackChannelConfigCreationRequest from onyx.server.manage.validate_tokens import validate_app_token from onyx.server.manage.validate_tokens import validate_bot_token +from onyx.utils.logger import setup_logger from onyx.utils.telemetry import create_milestone_and_report from shared_configs.contextvars import get_current_tenant_id +logger = setup_logger() + + router = APIRouter(prefix="/manage") @@ -376,7 +380,7 @@ def get_all_channels_from_slack_api( status_code=404, detail="Bot token not found for the given bot ID" ) - client = WebClient(token=tokens["bot_token"]) + client = WebClient(token=tokens["bot_token"], timeout=1) all_channels = [] next_cursor = None current_page = 0 @@ -431,6 +435,7 @@ def get_all_channels_from_slack_api( except SlackApiError as e: # Handle rate limiting or other API errors + logger.exception("Error fetching channels from Slack API") raise HTTPException( status_code=500, detail=f"Error fetching channels from Slack API: {str(e)}", diff --git a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx index 1e82f04b9..9222205b8 100644 --- a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx +++ b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx @@ -184,6 +184,10 @@ export function SlackChannelConfigFormFields({ name: channel.name, value: channel.id, })); + }, + { + shouldRetryOnError: false, // don't spam the Slack API + dedupingInterval: 60000, // Limit re-fetching to once per minute } ); From 06624a988d96b0515fc84fa54bc1e1062d664c82 Mon Sep 17 00:00:00 2001 From: evan-danswer Date: Wed, 19 Mar 2025 11:49:35 -0700 Subject: [PATCH 03/18] Gdrive checkpointed connector (#4262) * WIP rebased * style * WIP, testing theory * fix type issue * fixed filtering bug * fix silliness * correct serialization and validation of threadsafedict * concurrent drive access * nits * nit * oauth bug fix * testing fix * fix slim retrieval * fix integration tests * fix testing change * CW comments * nit * guarantee completion stage existence * fix default values --- .../indexing/checkpointing_utils.py | 15 +- .../onyx/background/indexing/run_indexing.py | 4 +- backend/onyx/configs/app_configs.py | 2 + backend/onyx/connectors/connector_runner.py | 2 +- .../onyx/connectors/google_drive/connector.py | 858 ++++++++++++++---- .../connectors/google_drive/doc_conversion.py | 35 +- .../connectors/google_drive/file_retrieval.py | 121 ++- .../onyx/connectors/google_drive/models.py | 134 +++ .../connectors/google_utils/google_utils.py | 21 +- backend/onyx/connectors/interfaces.py | 33 +- .../connectors/mock_connector/connector.py | 27 +- backend/onyx/connectors/models.py | 10 +- backend/onyx/connectors/slack/connector.py | 121 +-- backend/onyx/tools/message.py | 1 - backend/onyx/utils/lazy.py | 13 + backend/onyx/utils/threadpool_concurrency.py | 154 ++++ .../google_drive/consts_and_utils.py | 44 +- .../google_drive/test_admin_oauth.py | 31 +- .../connectors/google_drive/test_sections.py | 7 +- .../google_drive/test_service_acct.py | 35 +- .../connectors/google_drive/test_slim_docs.py | 6 +- .../google_drive/test_user_1_oauth.py | 27 +- .../tests/indexing/test_checkpointing.py | 60 +- .../onyx/utils/test_threadpool_concurrency.py | 240 +++++ 24 files changed, 1560 insertions(+), 441 deletions(-) create mode 100644 backend/onyx/utils/lazy.py diff --git a/backend/onyx/background/indexing/checkpointing_utils.py b/backend/onyx/background/indexing/checkpointing_utils.py index 254481e14..23fdf64e3 100644 --- a/backend/onyx/background/indexing/checkpointing_utils.py +++ b/backend/onyx/background/indexing/checkpointing_utils.py @@ -6,6 +6,8 @@ from sqlalchemy import and_ from sqlalchemy.orm import Session from onyx.configs.constants import FileOrigin +from onyx.connectors.interfaces import BaseConnector +from onyx.connectors.interfaces import CheckpointConnector from onyx.connectors.models import ConnectorCheckpoint from onyx.db.engine import get_db_current_time from onyx.db.index_attempt import get_index_attempt @@ -16,7 +18,6 @@ from onyx.file_store.file_store import get_default_file_store from onyx.utils.logger import setup_logger from onyx.utils.object_size_check import deep_getsizeof - logger = setup_logger() _NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20 @@ -52,7 +53,7 @@ def save_checkpoint( def load_checkpoint( - db_session: Session, index_attempt_id: int + db_session: Session, index_attempt_id: int, connector: BaseConnector ) -> ConnectorCheckpoint | None: """Load a checkpoint for a given index attempt from the file store""" checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id) @@ -60,6 +61,8 @@ def load_checkpoint( try: checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb") checkpoint_data = checkpoint_io.read().decode("utf-8") + if isinstance(connector, CheckpointConnector): + return connector.validate_checkpoint_json(checkpoint_data) return ConnectorCheckpoint.model_validate_json(checkpoint_data) except RuntimeError: return None @@ -71,6 +74,7 @@ def get_latest_valid_checkpoint( search_settings_id: int, window_start: datetime, window_end: datetime, + connector: BaseConnector, ) -> ConnectorCheckpoint: """Get the latest valid checkpoint for a given connector credential pair""" checkpoint_candidates = get_recent_completed_attempts_for_cc_pair( @@ -105,7 +109,7 @@ def get_latest_valid_checkpoint( f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start " "from scratch." ) - return ConnectorCheckpoint.build_dummy_checkpoint() + return connector.build_dummy_checkpoint() # assumes latest checkpoint is the furthest along. This only isn't true # if something else has gone wrong. @@ -113,12 +117,13 @@ def get_latest_valid_checkpoint( checkpoint_candidates[0] if checkpoint_candidates else None ) - checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() + checkpoint = connector.build_dummy_checkpoint() if latest_valid_checkpoint_candidate: try: previous_checkpoint = load_checkpoint( db_session=db_session, index_attempt_id=latest_valid_checkpoint_candidate.id, + connector=connector, ) except Exception: logger.exception( @@ -193,7 +198,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None: def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None: """Check if the checkpoint content size exceeds the limit (200MB)""" - content_size = deep_getsizeof(checkpoint.checkpoint_content) + content_size = deep_getsizeof(checkpoint.model_dump()) if content_size > 200_000_000: # 200MB in bytes raise ValueError( f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit" diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index 12b485b87..54345bf09 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -24,7 +24,6 @@ from onyx.connectors.connector_runner import ConnectorRunner from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.factory import instantiate_connector -from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document from onyx.connectors.models import IndexAttemptMetadata @@ -405,7 +404,7 @@ def _run_indexing( # the beginning in order to avoid weird interactions between # checkpointing / failure handling. if index_attempt.from_beginning: - checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() + checkpoint = connector_runner.connector.build_dummy_checkpoint() else: checkpoint = get_latest_valid_checkpoint( db_session=db_session_temp, @@ -413,6 +412,7 @@ def _run_indexing( search_settings_id=index_attempt.search_settings_id, window_start=window_start, window_end=window_end, + connector=connector_runner.connector, ) unresolved_errors = get_index_attempt_errors_for_cc_pair( diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 74be29c6d..211ed3fc1 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -158,6 +158,8 @@ try: except ValueError: INDEX_BATCH_SIZE = 16 +MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4)) + # Below are intended to match the env variables names used by the official postgres docker image # https://hub.docker.com/_/postgres POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres" diff --git a/backend/onyx/connectors/connector_runner.py b/backend/onyx/connectors/connector_runner.py index 8cb48a3d6..6cb3272b1 100644 --- a/backend/onyx/connectors/connector_runner.py +++ b/backend/onyx/connectors/connector_runner.py @@ -132,7 +132,7 @@ class ConnectorRunner: ) else: - finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() + finished_checkpoint = self.connector.build_dummy_checkpoint() finished_checkpoint.has_more = False if isinstance(self.connector, PollConnector): diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index f7a27a826..dcc14df06 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -1,16 +1,23 @@ +import copy +import threading from collections.abc import Callable +from collections.abc import Generator from collections.abc import Iterator from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor +from enum import Enum from functools import partial from typing import Any +from typing import Protocol from urllib.parse import urlparse from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from googleapiclient.errors import HttpError # type: ignore +from typing_extensions import override from onyx.configs.app_configs import INDEX_BATCH_SIZE +from onyx.configs.app_configs import MAX_DRIVE_WORKERS from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError @@ -24,12 +31,18 @@ from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive from onyx.connectors.google_drive.file_retrieval import get_root_folder_id +from onyx.connectors.google_drive.models import DriveRetrievalStage +from onyx.connectors.google_drive.models import GoogleDriveCheckpoint from onyx.connectors.google_drive.models import GoogleDriveFileType +from onyx.connectors.google_drive.models import RetrievedDriveFile +from onyx.connectors.google_drive.models import StageCompletion from onyx.connectors.google_utils.google_auth import get_google_creds from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval +from onyx.connectors.google_utils.google_utils import GoogleFields from onyx.connectors.google_utils.resources import get_admin_service from onyx.connectors.google_utils.resources import get_drive_service from onyx.connectors.google_utils.resources import get_google_docs_service +from onyx.connectors.google_utils.resources import GoogleDriveService from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) @@ -37,21 +50,28 @@ from onyx.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_S from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE from onyx.connectors.google_utils.shared_constants import USER_FIELDS -from onyx.connectors.interfaces import GenerateDocumentsOutput +from onyx.connectors.interfaces import CheckpointConnector from onyx.connectors.interfaces import GenerateSlimDocumentOutput -from onyx.connectors.interfaces import LoadConnector -from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector +from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorMissingCredentialError +from onyx.connectors.models import Document +from onyx.connectors.models import DocumentFailure +from onyx.connectors.models import EntityFailure from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface +from onyx.utils.lazy import lazy_eval from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder +from onyx.utils.threadpool_concurrency import parallel_yield +from onyx.utils.threadpool_concurrency import ThreadSafeDict logger = setup_logger() # TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html # All file retrievals could be batched and made at once +BATCHES_PER_CHECKPOINT = 10 + def _extract_str_list_from_comma_str(string: str | None) -> list[str]: if not string: @@ -67,10 +87,16 @@ def _convert_single_file( creds: Any, primary_admin_email: str, file: dict[str, Any], -) -> Any: +) -> Document | ConnectorFailure | None: user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email - user_drive_service = get_drive_service(creds, user_email=user_email) - docs_service = get_google_docs_service(creds, user_email=user_email) + + # Only construct these services when needed + user_drive_service = lazy_eval( + lambda: get_drive_service(creds, user_email=user_email) + ) + docs_service = lazy_eval( + lambda: get_google_docs_service(creds, user_email=user_email) + ) return convert_drive_item_to_document( file=file, drive_service=user_drive_service, @@ -78,23 +104,6 @@ def _convert_single_file( ) -def _process_files_batch( - files: list[GoogleDriveFileType], - convert_func: Callable[[GoogleDriveFileType], Any], - batch_size: int, -) -> GenerateDocumentsOutput: - doc_batch = [] - with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor: - for doc in executor.map(convert_func, files): - if doc: - doc_batch.append(doc) - if len(doc_batch) >= batch_size: - yield doc_batch - doc_batch = [] - if doc_batch: - yield doc_batch - - def _clean_requested_drive_ids( requested_drive_ids: set[str], requested_folder_ids: set[str], @@ -113,7 +122,39 @@ def _clean_requested_drive_ids( return valid_requested_drive_ids, filtered_folder_ids -class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): +class CredentialedRetrievalMethod(Protocol): + def __call__( + self, + is_slim: bool, + checkpoint: GoogleDriveCheckpoint, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> Iterator[RetrievedDriveFile]: + ... + + +def add_retrieval_info( + drive_files: Iterator[GoogleDriveFileType], + user_email: str, + completion_stage: DriveRetrievalStage, + parent_id: str | None = None, +) -> Iterator[RetrievedDriveFile]: + for file in drive_files: + yield RetrievedDriveFile( + drive_file=file, + user_email=user_email, + parent_id=parent_id, + completion_stage=completion_stage, + ) + + +class DriveIdStatus(Enum): + AVAILABLE = "available" + IN_PROGRESS = "in_progress" + FINISHED = "finished" + + +class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpoint]): def __init__( self, include_shared_drives: bool = False, @@ -146,13 +187,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): if continue_on_failure is not None: logger.warning("The 'continue_on_failure' parameter is deprecated.") - if ( - not include_shared_drives - and not include_my_drives - and not include_files_shared_with_me - and not shared_folder_urls - and not my_drive_emails - and not shared_drive_urls + if not any( + ( + include_shared_drives, + include_my_drives, + include_files_shared_with_me, + shared_folder_urls, + my_drive_emails, + shared_drive_urls, + ) ): raise ConnectorValidationError( "Nothing to index. Please specify at least one of the following: " @@ -222,15 +265,12 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): ) return self._creds + # TODO: ensure returned new_creds_dict is actually persisted when this is called? def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: try: self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] except KeyError: - raise ValueError( - "Primary admin email missing, " - "should not call this property " - "before calling load_credentials" - ) + raise ValueError("Credentials json missing primary admin key") self._creds, new_creds_dict = get_google_creds( credentials=credentials, @@ -287,22 +327,91 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): if not all_drive_ids: logger.warning( - "No drives found even though we are indexing shared drives was requested." + "No drives found even though indexing shared drives was requested." ) return all_drive_ids + def make_drive_id_iterator( + self, drive_ids: set[str], checkpoint: GoogleDriveCheckpoint + ) -> Callable[[str], Iterator[str]]: + cv = threading.Condition() + drive_id_status = { + drive_id: DriveIdStatus.FINISHED + if drive_id in self._retrieved_ids + else DriveIdStatus.AVAILABLE + for drive_id in drive_ids + } + + def _get_available_drive_id( + processed_ids: set[str], thread_id: str + ) -> tuple[str | None, bool]: + found_future_work = False + for drive_id, status in drive_id_status.items(): + if drive_id in self._retrieved_ids: + drive_id_status[drive_id] = DriveIdStatus.FINISHED + continue + if drive_id in processed_ids: + continue + + if status == DriveIdStatus.AVAILABLE: + return drive_id, True + elif status == DriveIdStatus.IN_PROGRESS: + found_future_work = True + return None, found_future_work + + def drive_id_iterator(thread_id: str) -> Iterator[str]: + completion = checkpoint.completion_map[thread_id] + # continue iterating until this thread has no more work to do + while True: + # this locks operations on _retrieved_ids and drive_id_status + with cv: + available_drive_id, found_future_work = _get_available_drive_id( + completion.processed_drive_ids, thread_id + ) + + # wait while there is no work currently available but still drives that may need processing + while available_drive_id is None and found_future_work: + cv.wait() + available_drive_id, found_future_work = _get_available_drive_id( + completion.processed_drive_ids, thread_id + ) + + # if there is no work available and no future work, we are done + if available_drive_id is None: + return + + drive_id_status[available_drive_id] = DriveIdStatus.IN_PROGRESS + + yield available_drive_id + with cv: + completion.processed_drive_ids.add(available_drive_id) + drive_id_status[available_drive_id] = ( + DriveIdStatus.FINISHED + if available_drive_id in self._retrieved_ids + else DriveIdStatus.AVAILABLE + ) + # wake up other threads waiting for work + cv.notify_all() + + return drive_id_iterator + def _impersonate_user_for_retrieval( self, user_email: str, is_slim: bool, - filtered_drive_ids: set[str], + checkpoint: GoogleDriveCheckpoint, + concurrent_drive_itr: Callable[[str], Iterator[str]], filtered_folder_ids: set[str], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, - ) -> Iterator[GoogleDriveFileType]: + ) -> Iterator[RetrievedDriveFile]: logger.info(f"Impersonating user {user_email}") - + curr_stage = checkpoint.completion_map[user_email] + resuming = True + if curr_stage.stage == DriveRetrievalStage.START: + curr_stage.stage = DriveRetrievalStage.MY_DRIVE_FILES + resuming = False drive_service = get_drive_service(self.creds, user_email) # validate that the user has access to the drive APIs by performing a simple @@ -325,62 +434,111 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): # drive if any of the following are true: # - include_my_drives is true # - the current user's email is in the requested emails - if self.include_my_drives or user_email in self._requested_my_drive_emails: - logger.info(f"Getting all files in my drive as '{user_email}'") - yield from get_all_files_in_my_drive( - service=drive_service, - update_traversed_ids_func=self._update_traversed_parent_ids, - is_slim=is_slim, - start=start, - end=end, - ) + if curr_stage.stage == DriveRetrievalStage.MY_DRIVE_FILES: + if self.include_my_drives or user_email in self._requested_my_drive_emails: + logger.info(f"Getting all files in my drive as '{user_email}'") - remaining_drive_ids = filtered_drive_ids - self._retrieved_ids - for drive_id in remaining_drive_ids: - logger.info(f"Getting files in shared drive '{drive_id}' as '{user_email}'") - yield from get_files_in_shared_drive( - service=drive_service, - drive_id=drive_id, - is_slim=is_slim, - update_traversed_ids_func=self._update_traversed_parent_ids, - start=start, - end=end, - ) + yield from add_retrieval_info( + get_all_files_in_my_drive( + service=drive_service, + update_traversed_ids_func=self._update_traversed_parent_ids, + is_slim=is_slim, + start=curr_stage.completed_until if resuming else start, + end=end, + ), + user_email, + DriveRetrievalStage.MY_DRIVE_FILES, + ) + curr_stage.stage = DriveRetrievalStage.SHARED_DRIVE_FILES - remaining_folders = filtered_folder_ids - self._retrieved_ids - for folder_id in remaining_folders: - logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'") - yield from crawl_folders_for_files( - service=drive_service, - parent_id=folder_id, - traversed_parent_ids=self._retrieved_ids, - update_traversed_ids_func=self._update_traversed_parent_ids, - start=start, - end=end, - ) + if curr_stage.stage == DriveRetrievalStage.SHARED_DRIVE_FILES: + + def _yield_from_drive( + drive_id: str, drive_start: SecondsSinceUnixEpoch | None + ) -> Iterator[RetrievedDriveFile]: + yield from add_retrieval_info( + get_files_in_shared_drive( + service=drive_service, + drive_id=drive_id, + is_slim=is_slim, + update_traversed_ids_func=self._update_traversed_parent_ids, + start=drive_start, + end=end, + ), + user_email, + DriveRetrievalStage.SHARED_DRIVE_FILES, + parent_id=drive_id, + ) + + # resume from a checkpoint + if resuming: + drive_id = curr_stage.completed_until_parent_id + assert drive_id is not None, "drive id not set in checkpoint" + resume_start = curr_stage.completed_until + yield from _yield_from_drive(drive_id, resume_start) + # Don't enter resuming case for folder retrieval + resuming = False + + for drive_id in concurrent_drive_itr(user_email): + logger.info( + f"Getting files in shared drive '{drive_id}' as '{user_email}'" + ) + yield from _yield_from_drive(drive_id, start) + curr_stage.stage = DriveRetrievalStage.FOLDER_FILES + + if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES: + + def _yield_from_folder_crawl( + folder_id: str, folder_start: SecondsSinceUnixEpoch | None + ) -> Iterator[RetrievedDriveFile]: + yield from crawl_folders_for_files( + service=drive_service, + parent_id=folder_id, + is_slim=is_slim, + user_email=user_email, + traversed_parent_ids=self._retrieved_ids, + update_traversed_ids_func=self._update_traversed_parent_ids, + start=folder_start, + end=end, + ) + + # resume from a checkpoint + if resuming: + folder_id = curr_stage.completed_until_parent_id + assert folder_id is not None, "folder id not set in checkpoint" + resume_start = curr_stage.completed_until + yield from _yield_from_folder_crawl(folder_id, resume_start) + + remaining_folders = filtered_folder_ids - self._retrieved_ids + for folder_id in remaining_folders: + logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'") + yield from _yield_from_folder_crawl(folder_id, start) + curr_stage.stage = DriveRetrievalStage.DONE def _manage_service_account_retrieval( self, is_slim: bool, + checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, - ) -> Iterator[GoogleDriveFileType]: - all_org_emails: list[str] = self._get_all_user_emails() + ) -> Iterator[RetrievedDriveFile]: + if checkpoint.completion_stage == DriveRetrievalStage.START: + checkpoint.completion_stage = DriveRetrievalStage.USER_EMAILS - all_drive_ids: set[str] = self.get_all_drive_ids() + if checkpoint.completion_stage == DriveRetrievalStage.USER_EMAILS: + all_org_emails: list[str] = self._get_all_user_emails() + if not is_slim: + checkpoint.user_emails = all_org_emails + checkpoint.completion_stage = DriveRetrievalStage.DRIVE_IDS + else: + assert checkpoint.user_emails is not None, "user emails not set" + all_org_emails = checkpoint.user_emails - drive_ids_to_retrieve: set[str] = set() - folder_ids_to_retrieve: set[str] = set() - if self._requested_shared_drive_ids or self._requested_folder_ids: - drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids( - requested_drive_ids=self._requested_shared_drive_ids, - requested_folder_ids=self._requested_folder_ids, - all_drive_ids_available=all_drive_ids, - ) - elif self.include_shared_drives: - drive_ids_to_retrieve = all_drive_ids + drive_ids_to_retrieve, folder_ids_to_retrieve = self._determine_retrieval_ids( + checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES + ) - # checkpoint - we've found all users and drives, now time to actually start + # we've found all users and drives, now time to actually start # fetching stuff logger.info(f"Found {len(all_org_emails)} users to impersonate") logger.debug(f"Users: {all_org_emails}") @@ -389,24 +547,217 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve") logger.debug(f"Folders: {folder_ids_to_retrieve}") - # Process users in parallel using ThreadPoolExecutor - with ThreadPoolExecutor(max_workers=10) as executor: - future_to_email = { - executor.submit( - self._impersonate_user_for_retrieval, - email, - is_slim, + drive_id_iterator = self.make_drive_id_iterator( + drive_ids_to_retrieve, checkpoint + ) + + for email in all_org_emails: + checkpoint.completion_map[email] = StageCompletion( + stage=DriveRetrievalStage.START, + completed_until=0, + ) + user_retrieval_gens = [ + self._impersonate_user_for_retrieval( + email, + is_slim, + checkpoint, + drive_id_iterator, + folder_ids_to_retrieve, + start, + end, + ) + for email in all_org_emails + ] + yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS) + + remaining_folders = ( + drive_ids_to_retrieve | folder_ids_to_retrieve + ) - self._retrieved_ids + if remaining_folders: + logger.warning( + f"Some folders/drives were not retrieved. IDs: {remaining_folders}" + ) + assert all( + checkpoint.completion_map[user_email].stage == DriveRetrievalStage.DONE + for user_email in all_org_emails + ), "some users did not complete retrieval" + checkpoint.completion_stage = DriveRetrievalStage.DONE + + def _determine_retrieval_ids( + self, + checkpoint: GoogleDriveCheckpoint, + is_slim: bool, + next_stage: DriveRetrievalStage, + ) -> tuple[set[str], set[str]]: + all_drive_ids = self.get_all_drive_ids() + drive_ids_to_retrieve: set[str] = set() + folder_ids_to_retrieve: set[str] = set() + if checkpoint.completion_stage == DriveRetrievalStage.DRIVE_IDS: + if self._requested_shared_drive_ids or self._requested_folder_ids: + ( drive_ids_to_retrieve, folder_ids_to_retrieve, - start, - end, - ): email - for email in all_org_emails - } + ) = _clean_requested_drive_ids( + requested_drive_ids=self._requested_shared_drive_ids, + requested_folder_ids=self._requested_folder_ids, + all_drive_ids_available=all_drive_ids, + ) + elif self.include_shared_drives: + drive_ids_to_retrieve = all_drive_ids - # Yield results as they complete - for future in as_completed(future_to_email): - yield from future.result() + if not is_slim: + checkpoint.drive_ids_to_retrieve = list(drive_ids_to_retrieve) + checkpoint.folder_ids_to_retrieve = list(folder_ids_to_retrieve) + checkpoint.completion_stage = next_stage + else: + assert ( + checkpoint.drive_ids_to_retrieve is not None + ), "drive ids to retrieve not set" + assert ( + checkpoint.folder_ids_to_retrieve is not None + ), "folder ids to retrieve not set" + # When loading from a checkpoint, load the previously cached drive and folder ids + drive_ids_to_retrieve = set(checkpoint.drive_ids_to_retrieve) + folder_ids_to_retrieve = set(checkpoint.folder_ids_to_retrieve) + + return drive_ids_to_retrieve, folder_ids_to_retrieve + + def _oauth_retrieval_all_files( + self, + is_slim: bool, + drive_service: GoogleDriveService, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> Iterator[RetrievedDriveFile]: + if not self.include_files_shared_with_me and not self.include_my_drives: + return + + logger.info( + f"Getting shared files/my drive files for OAuth " + f"with include_files_shared_with_me={self.include_files_shared_with_me}, " + f"include_my_drives={self.include_my_drives}, " + f"include_shared_drives={self.include_shared_drives}." + f"Using '{self.primary_admin_email}' as the account." + ) + yield from add_retrieval_info( + get_all_files_for_oauth( + service=drive_service, + include_files_shared_with_me=self.include_files_shared_with_me, + include_my_drives=self.include_my_drives, + include_shared_drives=self.include_shared_drives, + is_slim=is_slim, + start=start, + end=end, + ), + self.primary_admin_email, + DriveRetrievalStage.OAUTH_FILES, + ) + + def _oauth_retrieval_drives( + self, + is_slim: bool, + drive_service: GoogleDriveService, + drive_ids_to_retrieve: set[str], + checkpoint: GoogleDriveCheckpoint, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> Iterator[RetrievedDriveFile]: + def _yield_from_drive( + drive_id: str, drive_start: SecondsSinceUnixEpoch | None + ) -> Iterator[RetrievedDriveFile]: + yield from add_retrieval_info( + get_files_in_shared_drive( + service=drive_service, + drive_id=drive_id, + is_slim=is_slim, + update_traversed_ids_func=self._update_traversed_parent_ids, + start=drive_start, + end=end, + ), + self.primary_admin_email, + DriveRetrievalStage.SHARED_DRIVE_FILES, + parent_id=drive_id, + ) + + # If we are resuming from a checkpoint, we need to finish retrieving the files from the last drive we retrieved + if ( + checkpoint.completion_map[self.primary_admin_email].stage + == DriveRetrievalStage.SHARED_DRIVE_FILES + ): + drive_id = checkpoint.completion_map[ + self.primary_admin_email + ].completed_until_parent_id + assert drive_id is not None, "drive id not set in checkpoint" + resume_start = checkpoint.completion_map[ + self.primary_admin_email + ].completed_until + yield from _yield_from_drive(drive_id, resume_start) + + for drive_id in drive_ids_to_retrieve: + if drive_id in self._retrieved_ids: + logger.info( + f"Skipping drive '{drive_id}' as it has already been retrieved" + ) + continue + logger.info( + f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'" + ) + yield from _yield_from_drive(drive_id, start) + + def _oauth_retrieval_folders( + self, + is_slim: bool, + drive_service: GoogleDriveService, + drive_ids_to_retrieve: set[str], + folder_ids_to_retrieve: set[str], + checkpoint: GoogleDriveCheckpoint, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> Iterator[RetrievedDriveFile]: + """ + If there are any remaining folder ids to retrieve found earlier in the + retrieval process, we recursively descend the file tree and retrieve all + files in the folder(s). + """ + # Even if no folders were requested, we still check if any drives were requested + # that could be folders. + remaining_folders = folder_ids_to_retrieve - self._retrieved_ids + + def _yield_from_folder_crawl( + folder_id: str, folder_start: SecondsSinceUnixEpoch | None + ) -> Iterator[RetrievedDriveFile]: + yield from crawl_folders_for_files( + service=drive_service, + parent_id=folder_id, + is_slim=is_slim, + user_email=self.primary_admin_email, + traversed_parent_ids=self._retrieved_ids, + update_traversed_ids_func=self._update_traversed_parent_ids, + start=folder_start, + end=end, + ) + + # resume from a checkpoint + if ( + checkpoint.completion_map[self.primary_admin_email].stage + == DriveRetrievalStage.FOLDER_FILES + ): + folder_id = checkpoint.completion_map[ + self.primary_admin_email + ].completed_until_parent_id + assert folder_id is not None, "folder id not set in checkpoint" + resume_start = checkpoint.completion_map[ + self.primary_admin_email + ].completed_until + yield from _yield_from_folder_crawl(folder_id, resume_start) + + # the times stored in the completion_map aren't used due to the crawling behavior + # instead, the traversed_parent_ids are used to determine what we have left to retrieve + for folder_id in remaining_folders: + logger.info( + f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'" + ) + yield from _yield_from_folder_crawl(folder_id, start) remaining_folders = ( drive_ids_to_retrieve | folder_ids_to_retrieve @@ -416,31 +767,64 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): f"Some folders/drives were not retrieved. IDs: {remaining_folders}" ) + def _checkpointed_retrieval( + self, + retrieval_method: CredentialedRetrievalMethod, + is_slim: bool, + checkpoint: GoogleDriveCheckpoint, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> Iterator[RetrievedDriveFile]: + drive_files = retrieval_method( + is_slim=is_slim, + checkpoint=checkpoint, + start=start, + end=end, + ) + if is_slim: + yield from drive_files + return + + for file in drive_files: + if file.error is not None: + checkpoint.completion_map[file.user_email].update( + stage=file.completion_stage, + completed_until=file.drive_file[GoogleFields.MODIFIED_TIME.value], + completed_until_parent_id=file.parent_id, + ) + yield file + def _manage_oauth_retrieval( self, is_slim: bool, + checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, - ) -> Iterator[GoogleDriveFileType]: + ) -> Iterator[RetrievedDriveFile]: + if checkpoint.completion_stage == DriveRetrievalStage.START: + checkpoint.completion_stage = DriveRetrievalStage.OAUTH_FILES + checkpoint.completion_map[self.primary_admin_email] = StageCompletion( + stage=DriveRetrievalStage.START, + completed_until=0, + completed_until_parent_id=None, + ) + drive_service = get_drive_service(self.creds, self.primary_admin_email) - if self.include_files_shared_with_me or self.include_my_drives: - logger.info( - f"Getting shared files/my drive files for OAuth " - f"with include_files_shared_with_me={self.include_files_shared_with_me}, " - f"include_my_drives={self.include_my_drives}, " - f"include_shared_drives={self.include_shared_drives}." - f"Using '{self.primary_admin_email}' as the account." - ) - yield from get_all_files_for_oauth( - service=drive_service, - include_files_shared_with_me=self.include_files_shared_with_me, - include_my_drives=self.include_my_drives, - include_shared_drives=self.include_shared_drives, + if checkpoint.completion_stage == DriveRetrievalStage.OAUTH_FILES: + completion = checkpoint.completion_map[self.primary_admin_email] + all_files_start = start + # if resuming from a checkpoint + if completion.stage == DriveRetrievalStage.OAUTH_FILES: + all_files_start = completion.completed_until + + yield from self._oauth_retrieval_all_files( + drive_service=drive_service, is_slim=is_slim, - start=start, + start=all_files_start, end=end, ) + checkpoint.completion_stage = DriveRetrievalStage.DRIVE_IDS all_requested = ( self.include_files_shared_with_me @@ -449,96 +833,110 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): ) if all_requested: # If all 3 are true, we already yielded from get_all_files_for_oauth + checkpoint.completion_stage = DriveRetrievalStage.DONE return - all_drive_ids = self.get_all_drive_ids() - drive_ids_to_retrieve: set[str] = set() - folder_ids_to_retrieve: set[str] = set() - if self._requested_shared_drive_ids or self._requested_folder_ids: - drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids( - requested_drive_ids=self._requested_shared_drive_ids, - requested_folder_ids=self._requested_folder_ids, - all_drive_ids_available=all_drive_ids, - ) - elif self.include_shared_drives: - drive_ids_to_retrieve = all_drive_ids + drive_ids_to_retrieve, folder_ids_to_retrieve = self._determine_retrieval_ids( + checkpoint, is_slim, DriveRetrievalStage.SHARED_DRIVE_FILES + ) - for drive_id in drive_ids_to_retrieve: - logger.info( - f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'" - ) - yield from get_files_in_shared_drive( - service=drive_service, - drive_id=drive_id, + if checkpoint.completion_stage == DriveRetrievalStage.SHARED_DRIVE_FILES: + yield from self._oauth_retrieval_drives( is_slim=is_slim, - update_traversed_ids_func=self._update_traversed_parent_ids, + drive_service=drive_service, + drive_ids_to_retrieve=drive_ids_to_retrieve, + checkpoint=checkpoint, start=start, end=end, ) - # Even if no folders were requested, we still check if any drives were requested - # that could be folders. - remaining_folders = folder_ids_to_retrieve - self._retrieved_ids - for folder_id in remaining_folders: - logger.info( - f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'" - ) - yield from crawl_folders_for_files( - service=drive_service, - parent_id=folder_id, - traversed_parent_ids=self._retrieved_ids, - update_traversed_ids_func=self._update_traversed_parent_ids, + checkpoint.completion_stage = DriveRetrievalStage.FOLDER_FILES + + if checkpoint.completion_stage == DriveRetrievalStage.FOLDER_FILES: + yield from self._oauth_retrieval_folders( + is_slim=is_slim, + drive_service=drive_service, + drive_ids_to_retrieve=drive_ids_to_retrieve, + folder_ids_to_retrieve=folder_ids_to_retrieve, + checkpoint=checkpoint, start=start, end=end, ) - remaining_folders = ( - drive_ids_to_retrieve | folder_ids_to_retrieve - ) - self._retrieved_ids - if remaining_folders: - logger.warning( - f"Some folders/drives were not retrieved. IDs: {remaining_folders}" - ) + checkpoint.completion_stage = DriveRetrievalStage.DONE def _fetch_drive_items( self, is_slim: bool, + checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, - ) -> Iterator[GoogleDriveFileType]: + ) -> Iterator[RetrievedDriveFile]: retrieval_method = ( self._manage_service_account_retrieval if isinstance(self.creds, ServiceAccountCredentials) else self._manage_oauth_retrieval ) - drive_files = retrieval_method( + + return self._checkpointed_retrieval( + retrieval_method=retrieval_method, is_slim=is_slim, + checkpoint=checkpoint, start=start, end=end, ) - return drive_files - def _extract_docs_from_google_drive( self, + checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, - ) -> GenerateDocumentsOutput: - # Create a larger process pool for file conversion - with ThreadPoolExecutor(max_workers=8) as executor: - # Prepare a partial function with the credentials and admin email - convert_func = partial( - _convert_single_file, - self.creds, - self.primary_admin_email, - ) + ) -> Iterator[list[Document | ConnectorFailure]]: + try: + # Create a larger process pool for file conversion + with ThreadPoolExecutor(max_workers=8) as executor: + # Prepare a partial function with the credentials and admin email + convert_func = partial( + _convert_single_file, + self.creds, + self.primary_admin_email, + ) - # Fetch files in batches - files_batch: list[GoogleDriveFileType] = [] - for file in self._fetch_drive_items(is_slim=False, start=start, end=end): - files_batch.append(file) + # Fetch files in batches + batches_complete = 0 + files_batch: list[GoogleDriveFileType] = [] + for retrieved_file in self._fetch_drive_items( + is_slim=False, + checkpoint=checkpoint, + start=start, + end=end, + ): + if retrieved_file.error is not None: + failure_stage = retrieved_file.completion_stage.value + failure_message = ( + f"retrieval failure during stage: {failure_stage}," + ) + failure_message += f"user: {retrieved_file.user_email}," + failure_message += ( + f"parent drive/folder: {retrieved_file.parent_id}," + ) + failure_message += f"error: {retrieved_file.error}" + logger.error(failure_message) + yield [ + ConnectorFailure( + failed_entity=EntityFailure( + entity_id=failure_stage, + ), + failure_message=failure_message, + exception=retrieved_file.error, + ) + ] + continue + files_batch.append(retrieved_file.drive_file) + + if len(files_batch) < self.batch_size: + continue - if len(files_batch) >= self.batch_size: # Process the batch futures = [ executor.submit(convert_func, file) for file in files_batch @@ -550,44 +948,92 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): if doc is not None: documents.append(doc) except Exception as e: - logger.error(f"Error converting file: {e}") + error_str = f"Error converting file: {e}" + logger.error(error_str) + yield [ + ConnectorFailure( + failed_document=DocumentFailure( + document_id=retrieved_file.drive_file["id"], + document_link=retrieved_file.drive_file[ + "webViewLink" + ], + ), + failure_message=error_str, + exception=e, + ) + ] if documents: yield documents + batches_complete += 1 files_batch = [] - # Process any remaining files - if files_batch: - futures = [executor.submit(convert_func, file) for file in files_batch] - documents = [] - for future in as_completed(futures): - try: - doc = future.result() - if doc is not None: - documents.append(doc) - except Exception as e: - logger.error(f"Error converting file: {e}") + if batches_complete > BATCHES_PER_CHECKPOINT: + checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids + return # create a new checkpoint - if documents: - yield documents + # Process any remaining files + if files_batch: + futures = [ + executor.submit(convert_func, file) for file in files_batch + ] + documents = [] + for future in as_completed(futures): + try: + doc = future.result() + if doc is not None: + documents.append(doc) + except Exception as e: + error_str = f"Error converting file: {e}" + logger.error(error_str) + yield [ + ConnectorFailure( + failed_document=DocumentFailure( + document_id=retrieved_file.drive_file["id"], + document_link=retrieved_file.drive_file[ + "webViewLink" + ], + ), + failure_message=error_str, + exception=e, + ) + ] - def load_from_state(self) -> GenerateDocumentsOutput: + if documents: + yield documents + except Exception as e: + logger.exception(f"Error extracting documents from Google Drive: {e}") + raise e + + def load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: GoogleDriveCheckpoint, + ) -> Generator[Document | ConnectorFailure, None, GoogleDriveCheckpoint]: + """ + Entrypoint for the connector; first run is with an empty checkpoint. + """ + if self._creds is None or self._primary_admin_email is None: + raise RuntimeError( + "Credentials missing, should not call this method before calling load_credentials" + ) + + checkpoint = copy.deepcopy(checkpoint) + self._retrieved_ids = checkpoint.retrieved_folder_and_drive_ids try: - yield from self._extract_docs_from_google_drive() - except Exception as e: - if MISSING_SCOPES_ERROR_STR in str(e): - raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e - raise e - - def poll_source( - self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> GenerateDocumentsOutput: - try: - yield from self._extract_docs_from_google_drive(start, end) + for doc_list in self._extract_docs_from_google_drive( + checkpoint, start, end + ): + yield from doc_list except Exception as e: if MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise e + checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids + if checkpoint.completion_stage == DriveRetrievalStage.DONE: + checkpoint.has_more = False + return checkpoint def _extract_slim_docs_from_google_drive( self, @@ -597,11 +1043,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): ) -> GenerateSlimDocumentOutput: slim_batch = [] for file in self._fetch_drive_items( + checkpoint=self.build_dummy_checkpoint(), is_slim=True, start=start, end=end, ): - if doc := build_slim_document(file): + if file.error is not None: + raise file.error + if doc := build_slim_document(file.drive_file): slim_batch.append(doc) if len(slim_batch) >= SLIM_BATCH_SIZE: yield slim_batch @@ -677,3 +1126,16 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): raise ConnectorValidationError( f"Unexpected error during Google Drive validation: {e}" ) + + @override + def build_dummy_checkpoint(self) -> GoogleDriveCheckpoint: + return GoogleDriveCheckpoint( + retrieved_folder_and_drive_ids=set(), + completion_stage=DriveRetrievalStage.START, + completion_map=ThreadSafeDict(), + has_more=True, + ) + + @override + def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoint: + return GoogleDriveCheckpoint.model_validate_json(checkpoint_json) diff --git a/backend/onyx/connectors/google_drive/doc_conversion.py b/backend/onyx/connectors/google_drive/doc_conversion.py index b1c3d8c1a..bc4b83677 100644 --- a/backend/onyx/connectors/google_drive/doc_conversion.py +++ b/backend/onyx/connectors/google_drive/doc_conversion.py @@ -1,4 +1,5 @@ import io +from collections.abc import Callable from datetime import datetime from typing import cast @@ -13,7 +14,9 @@ from onyx.connectors.google_drive.models import GoogleDriveFileType from onyx.connectors.google_drive.section_extraction import get_document_sections from onyx.connectors.google_utils.resources import GoogleDocsService from onyx.connectors.google_utils.resources import GoogleDriveService +from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document +from onyx.connectors.models import DocumentFailure from onyx.connectors.models import ImageSection from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection @@ -202,12 +205,15 @@ def _extract_sections_basic( def convert_drive_item_to_document( file: GoogleDriveFileType, - drive_service: GoogleDriveService, - docs_service: GoogleDocsService, -) -> Document | None: + drive_service: Callable[[], GoogleDriveService], + docs_service: Callable[[], GoogleDocsService], +) -> Document | ConnectorFailure | None: """ Main entry point for converting a Google Drive file => Document object. """ + doc_id = "" + sections: list[TextSection | ImageSection] = [] + try: # skip shortcuts or folders if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]: @@ -215,13 +221,11 @@ def convert_drive_item_to_document( return None # If it's a Google Doc, we might do advanced parsing - sections: list[TextSection | ImageSection] = [] - - # Try to get sections using the advanced method first if file.get("mimeType") == GDriveMimeType.DOC.value: try: + # get_document_sections is the advanced approach for Google Docs doc_sections = get_document_sections( - docs_service=docs_service, doc_id=file.get("id", "") + docs_service=docs_service(), doc_id=file.get("id", "") ) if doc_sections: sections = cast(list[TextSection | ImageSection], doc_sections) @@ -232,7 +236,7 @@ def convert_drive_item_to_document( # If we don't have sections yet, use the basic extraction method if not sections: - sections = _extract_sections_basic(file, drive_service) + sections = _extract_sections_basic(file, drive_service()) # If we still don't have any sections, skip this file if not sections: @@ -257,8 +261,19 @@ def convert_drive_item_to_document( ), ) except Exception as e: - logger.error(f"Error converting file {file.get('name')}: {e}") - return None + error_str = f"Error converting file '{file.get('name')}' to Document: {e}" + logger.exception(error_str) + return ConnectorFailure( + failed_document=DocumentFailure( + document_id=doc_id, + document_link=sections[0].link + if sections + else None, # TODO: see if this is the best way to get a link + ), + failed_entity=None, + failure_message=error_str, + exception=e, + ) def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None: diff --git a/backend/onyx/connectors/google_drive/file_retrieval.py b/backend/onyx/connectors/google_drive/file_retrieval.py index 4e459bd3b..3fcda064c 100644 --- a/backend/onyx/connectors/google_drive/file_retrieval.py +++ b/backend/onyx/connectors/google_drive/file_retrieval.py @@ -1,17 +1,22 @@ from collections.abc import Callable from collections.abc import Iterator from datetime import datetime -from typing import Any from googleapiclient.discovery import Resource # type: ignore from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE +from onyx.connectors.google_drive.models import DriveRetrievalStage from onyx.connectors.google_drive.models import GoogleDriveFileType +from onyx.connectors.google_drive.models import RetrievedDriveFile from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval +from onyx.connectors.google_utils.google_utils import GoogleFields +from onyx.connectors.google_utils.google_utils import ORDER_BY_KEY +from onyx.connectors.google_utils.resources import GoogleDriveService from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.utils.logger import setup_logger + logger = setup_logger() FILE_FIELDS = ( @@ -32,10 +37,12 @@ def _generate_time_range_filter( time_range_filter = "" if start is not None: time_start = datetime.utcfromtimestamp(start).isoformat() + "Z" - time_range_filter += f" and modifiedTime >= '{time_start}'" + time_range_filter += ( + f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'" + ) if end is not None: time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z" - time_range_filter += f" and modifiedTime <= '{time_stop}'" + time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'" return time_range_filter @@ -66,9 +73,9 @@ def _get_folders_in_parent( def _get_files_in_parent( service: Resource, parent_id: str, + is_slim: bool, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, - is_slim: bool = False, ) -> Iterator[GoogleDriveFileType]: query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents" query += " and trashed = false" @@ -83,6 +90,7 @@ def _get_files_in_parent( includeItemsFromAllDrives=True, fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, q=query, + **({} if is_slim else {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}), ): yield file @@ -90,30 +98,50 @@ def _get_files_in_parent( def crawl_folders_for_files( service: Resource, parent_id: str, + is_slim: bool, + user_email: str, traversed_parent_ids: set[str], update_traversed_ids_func: Callable[[str], None], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, -) -> Iterator[GoogleDriveFileType]: +) -> Iterator[RetrievedDriveFile]: """ This function starts crawling from any folder. It is slower though. """ - if parent_id in traversed_parent_ids: - logger.info(f"Skipping subfolder since already traversed: {parent_id}") - return - - found_files = False - for file in _get_files_in_parent( - service=service, - start=start, - end=end, - parent_id=parent_id, - ): - found_files = True - yield file - - if found_files: - update_traversed_ids_func(parent_id) + logger.info("Entered crawl_folders_for_files with parent_id: " + parent_id) + if parent_id not in traversed_parent_ids: + logger.info("Parent id not in traversed parent ids, getting files") + found_files = False + file = {} + try: + for file in _get_files_in_parent( + service=service, + parent_id=parent_id, + is_slim=is_slim, + start=start, + end=end, + ): + found_files = True + logger.info(f"Found file: {file['name']}") + yield RetrievedDriveFile( + drive_file=file, + user_email=user_email, + parent_id=parent_id, + completion_stage=DriveRetrievalStage.FOLDER_FILES, + ) + except Exception as e: + logger.error(f"Error getting files in parent {parent_id}: {e}") + yield RetrievedDriveFile( + drive_file=file, + user_email=user_email, + parent_id=parent_id, + completion_stage=DriveRetrievalStage.FOLDER_FILES, + error=e, + ) + if found_files: + update_traversed_ids_func(parent_id) + else: + logger.info(f"Skipping subfolder files since already traversed: {parent_id}") for subfolder in _get_folders_in_parent( service=service, @@ -123,6 +151,8 @@ def crawl_folders_for_files( yield from crawl_folders_for_files( service=service, parent_id=subfolder["id"], + is_slim=is_slim, + user_email=user_email, traversed_parent_ids=traversed_parent_ids, update_traversed_ids_func=update_traversed_ids_func, start=start, @@ -133,16 +163,19 @@ def crawl_folders_for_files( def get_files_in_shared_drive( service: Resource, drive_id: str, - is_slim: bool = False, + is_slim: bool, update_traversed_ids_func: Callable[[str], None] = lambda _: None, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: + kwargs = {} + if not is_slim: + kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value + # If we know we are going to folder crawl later, we can cache the folders here # Get all folders being queried and add them to the traversed set folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" folder_query += " and trashed = false" - found_folders = False for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", @@ -155,15 +188,13 @@ def get_files_in_shared_drive( q=folder_query, ): update_traversed_ids_func(file["id"]) - found_folders = True - if found_folders: - update_traversed_ids_func(drive_id) # Get all files in the shared drive file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" file_query += " and trashed = false" file_query += _generate_time_range_filter(start, end) - yield from execute_paginated_retrieval( + + for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", continue_on_404_or_403=True, @@ -173,16 +204,26 @@ def get_files_in_shared_drive( includeItemsFromAllDrives=True, fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, q=file_query, - ) + **kwargs, + ): + # If we found any files, mark this drive as traversed. When a user has access to a drive, + # they have access to all the files in the drive. Also not a huge deal if we re-traverse + # empty drives. + update_traversed_ids_func(drive_id) + yield file def get_all_files_in_my_drive( - service: Any, + service: GoogleDriveService, update_traversed_ids_func: Callable, - is_slim: bool = False, + is_slim: bool, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: + kwargs = {} + if not is_slim: + kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value + # If we know we are going to folder crawl later, we can cache the folders here # Get all folders being queried and add them to the traversed set folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" @@ -196,7 +237,7 @@ def get_all_files_in_my_drive( fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, q=folder_query, ): - update_traversed_ids_func(file["id"]) + update_traversed_ids_func(file[GoogleFields.ID]) found_folders = True if found_folders: update_traversed_ids_func(get_root_folder_id(service)) @@ -209,22 +250,28 @@ def get_all_files_in_my_drive( yield from execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", + continue_on_404_or_403=False, corpora="user", fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, q=file_query, + **kwargs, ) def get_all_files_for_oauth( - service: Any, + service: GoogleDriveService, include_files_shared_with_me: bool, include_my_drives: bool, # One of the above 2 should be true include_shared_drives: bool, - is_slim: bool = False, + is_slim: bool, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: + kwargs = {} + if not is_slim: + kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value + should_get_all = ( include_shared_drives and include_my_drives and include_files_shared_with_me ) @@ -243,11 +290,13 @@ def get_all_files_for_oauth( yield from execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", + continue_on_404_or_403=False, corpora=corpora, includeItemsFromAllDrives=should_get_all, supportsAllDrives=should_get_all, fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, q=file_query, + **kwargs, ) @@ -255,4 +304,8 @@ def get_all_files_for_oauth( def get_root_folder_id(service: Resource) -> str: # we dont paginate here because there is only one root folder per user # https://developers.google.com/drive/api/guides/v2-to-v3-reference - return service.files().get(fileId="root", fields="id").execute()["id"] + return ( + service.files() + .get(fileId="root", fields=GoogleFields.ID.value) + .execute()[GoogleFields.ID.value] + ) diff --git a/backend/onyx/connectors/google_drive/models.py b/backend/onyx/connectors/google_drive/models.py index 7cf32450a..9887296e5 100644 --- a/backend/onyx/connectors/google_drive/models.py +++ b/backend/onyx/connectors/google_drive/models.py @@ -1,6 +1,15 @@ from enum import Enum from typing import Any +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import field_serializer +from pydantic import field_validator + +from onyx.connectors.interfaces import ConnectorCheckpoint +from onyx.connectors.interfaces import SecondsSinceUnixEpoch +from onyx.utils.threadpool_concurrency import ThreadSafeDict + class GDriveMimeType(str, Enum): DOC = "application/vnd.google-apps.document" @@ -20,3 +29,128 @@ class GDriveMimeType(str, Enum): GoogleDriveFileType = dict[str, Any] + + +TOKEN_EXPIRATION_TIME = 3600 # 1 hour + + +# These correspond to The major stages of retrieval for google drive. +# The stages for the oauth flow are: +# get_all_files_for_oauth(), +# get_all_drive_ids(), +# get_files_in_shared_drive(), +# crawl_folders_for_files() +# +# The stages for the service account flow are roughly: +# get_all_user_emails(), +# get_all_drive_ids(), +# get_files_in_shared_drive(), +# Then for each user: +# get_files_in_my_drive() +# get_files_in_shared_drive() +# crawl_folders_for_files() +class DriveRetrievalStage(str, Enum): + START = "start" + DONE = "done" + # OAuth specific stages + OAUTH_FILES = "oauth_files" + + # Service account specific stages + USER_EMAILS = "user_emails" + MY_DRIVE_FILES = "my_drive_files" + + # Used for both oauth and service account flows + DRIVE_IDS = "drive_ids" + SHARED_DRIVE_FILES = "shared_drive_files" + FOLDER_FILES = "folder_files" + + +class StageCompletion(BaseModel): + """ + Describes the point in the retrieval+indexing process that the + connector is at. completed_until is the timestamp of the latest + file that has been retrieved or error that has been yielded. + Optional fields are used for retrieval stages that need more information + for resuming than just the timestamp of the latest file. + """ + + stage: DriveRetrievalStage + completed_until: SecondsSinceUnixEpoch + completed_until_parent_id: str | None = None + + # only used for shared drives + processed_drive_ids: set[str] = set() + + def update( + self, + stage: DriveRetrievalStage, + completed_until: SecondsSinceUnixEpoch, + completed_until_parent_id: str | None = None, + ) -> None: + self.stage = stage + self.completed_until = completed_until + self.completed_until_parent_id = completed_until_parent_id + + +class RetrievedDriveFile(BaseModel): + """ + Describes a file that has been retrieved from google drive. + user_email is the email of the user that the file was retrieved + by impersonating. If an error worthy of being reported is encountered, + error should be set and later propagated as a ConnectorFailure. + """ + + # The stage at which this file was retrieved + completion_stage: DriveRetrievalStage + + # The file that was retrieved + drive_file: GoogleDriveFileType + + # The email of the user that the file was retrieved by impersonating + user_email: str + + # The id of the parent folder or drive of the file + parent_id: str | None = None + + # Any unexpected error that occurred while retrieving the file. + # In particular, this is not used for 403/404 errors, which are expected + # in the context of impersonating all the users to try to retrieve all + # files from all their Drives and Folders. + error: Exception | None = None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class GoogleDriveCheckpoint(ConnectorCheckpoint): + # Checkpoint version of _retrieved_ids + retrieved_folder_and_drive_ids: set[str] + + # Describes the point in the retrieval+indexing process that the + # checkpoint is at. when this is set to a given stage, the connector + # has finished yielding all values from the previous stage. + completion_stage: DriveRetrievalStage + + # The latest timestamp of a file that has been retrieved per user email. + # StageCompletion is used to track the completion of each stage, but the + # timestamp part is not used for folder crawling. + completion_map: ThreadSafeDict[str, StageCompletion] + + # cached version of the drive and folder ids to retrieve + drive_ids_to_retrieve: list[str] | None = None + folder_ids_to_retrieve: list[str] | None = None + + # cached user emails + user_emails: list[str] | None = None + + @field_serializer("completion_map") + def serialize_completion_map( + self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any + ) -> dict[str, StageCompletion]: + return completion_map._dict + + @field_validator("completion_map", mode="before") + def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]: + assert isinstance(v, dict) or isinstance(v, ThreadSafeDict) + return ThreadSafeDict( + {k: StageCompletion.model_validate(v) for k, v in v.items()} + ) diff --git a/backend/onyx/connectors/google_utils/google_utils.py b/backend/onyx/connectors/google_utils/google_utils.py index f17ebe6ad..60ee3373c 100644 --- a/backend/onyx/connectors/google_utils/google_utils.py +++ b/backend/onyx/connectors/google_utils/google_utils.py @@ -4,6 +4,7 @@ from collections.abc import Callable from collections.abc import Iterator from datetime import datetime from datetime import timezone +from enum import Enum from typing import Any from googleapiclient.errors import HttpError # type: ignore @@ -20,6 +21,20 @@ logger = setup_logger() # long retry period (~20 minutes of trying every minute) add_retries = retry_builder(tries=50, max_delay=30) +NEXT_PAGE_TOKEN_KEY = "nextPageToken" +PAGE_TOKEN_KEY = "pageToken" +ORDER_BY_KEY = "orderBy" + + +# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more +class GoogleFields(str, Enum): + ID = "id" + CREATED_TIME = "createdTime" + MODIFIED_TIME = "modifiedTime" + NAME = "name" + SIZE = "size" + PARENTS = "parents" + def _execute_with_retry(request: Any) -> Any: max_attempts = 10 @@ -90,11 +105,11 @@ def execute_paginated_retrieval( retrieval_function: The specific list function to call (e.g., service.files().list) **kwargs: Arguments to pass to the list function """ - next_page_token = "" + next_page_token = kwargs.get(PAGE_TOKEN_KEY, "") while next_page_token is not None: request_kwargs = kwargs.copy() if next_page_token: - request_kwargs["pageToken"] = next_page_token + request_kwargs[PAGE_TOKEN_KEY] = next_page_token try: results = retrieval_function(**request_kwargs).execute() @@ -117,7 +132,7 @@ def execute_paginated_retrieval( logger.exception("Error executing request:") raise e - next_page_token = results.get("nextPageToken") + next_page_token = results.get(NEXT_PAGE_TOKEN_KEY) if list_key: for item in results.get(list_key, []): yield item diff --git a/backend/onyx/connectors/interfaces.py b/backend/onyx/connectors/interfaces.py index 683881853..4d8d591c2 100644 --- a/backend/onyx/connectors/interfaces.py +++ b/backend/onyx/connectors/interfaces.py @@ -4,9 +4,11 @@ from collections.abc import Iterator from types import TracebackType from typing import Any from typing import Generic +from typing import TypeAlias from typing import TypeVar from pydantic import BaseModel +from typing_extensions import override from onyx.configs.constants import DocumentSource from onyx.connectors.models import ConnectorCheckpoint @@ -19,7 +21,6 @@ SecondsSinceUnixEpoch = float GenerateDocumentsOutput = Iterator[list[Document]] GenerateSlimDocumentOutput = Iterator[list[SlimDocument]] -CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint] class BaseConnector(abc.ABC): @@ -57,6 +58,9 @@ class BaseConnector(abc.ABC): Default is a no-op (always successful). """ + def build_dummy_checkpoint(self) -> ConnectorCheckpoint: + return ConnectorCheckpoint(has_more=True) + # Large set update or reindex, generally pulling a complete state or from a savestate file class LoadConnector(BaseConnector): @@ -74,6 +78,8 @@ class PollConnector(BaseConnector): raise NotImplementedError +# Slim connectors can retrieve just the ids and +# permission syncing information for connected documents class SlimConnector(BaseConnector): @abc.abstractmethod def retrieve_all_slim_documents( @@ -186,14 +192,21 @@ class EventConnector(BaseConnector): raise NotImplementedError -class CheckpointConnector(BaseConnector): +CT = TypeVar("CT", bound=ConnectorCheckpoint) +# TODO: find a reasonable way to parameterize the return type of the generator +CheckpointOutput: TypeAlias = Generator[ + Document | ConnectorFailure, None, ConnectorCheckpoint +] + + +class CheckpointConnector(BaseConnector, Generic[CT]): @abc.abstractmethod def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, - checkpoint: ConnectorCheckpoint, - ) -> CheckpointOutput: + checkpoint: CT, + ) -> Generator[Document | ConnectorFailure, None, CT]: """Yields back documents or failures. Final return is the new checkpoint. Final return can be access via either: @@ -214,3 +227,15 @@ class CheckpointConnector(BaseConnector): ``` """ raise NotImplementedError + + # Ideally return type should be CT, but that's not possible if + # we want to override build_dummy_checkpoint and have BaseConnector + # return a base ConnectorCheckpoint + @override + def build_dummy_checkpoint(self) -> ConnectorCheckpoint: + raise NotImplementedError + + @abc.abstractmethod + def validate_checkpoint_json(self, checkpoint_json: str) -> CT: + """Validate the checkpoint json and return the checkpoint object""" + raise NotImplementedError diff --git a/backend/onyx/connectors/mock_connector/connector.py b/backend/onyx/connectors/mock_connector/connector.py index 2cd670323..009fee2c8 100644 --- a/backend/onyx/connectors/mock_connector/connector.py +++ b/backend/onyx/connectors/mock_connector/connector.py @@ -1,10 +1,11 @@ +from collections.abc import Generator from typing import Any import httpx from pydantic import BaseModel +from typing_extensions import override from onyx.connectors.interfaces import CheckpointConnector -from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure @@ -15,14 +16,18 @@ from onyx.utils.logger import setup_logger logger = setup_logger() +class MockConnectorCheckpoint(ConnectorCheckpoint): + last_document_id: str | None = None + + class SingleConnectorYield(BaseModel): documents: list[Document] - checkpoint: ConnectorCheckpoint + checkpoint: MockConnectorCheckpoint failures: list[ConnectorFailure] unhandled_exception: str | None = None -class MockConnector(CheckpointConnector): +class MockConnector(CheckpointConnector[MockConnectorCheckpoint]): def __init__( self, mock_server_host: str, @@ -48,7 +53,7 @@ class MockConnector(CheckpointConnector): def _get_mock_server_url(self, endpoint: str) -> str: return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}" - def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None: + def _save_checkpoint(self, checkpoint: MockConnectorCheckpoint) -> None: response = self.client.post( self._get_mock_server_url("add-checkpoint"), json=checkpoint.model_dump(mode="json"), @@ -59,8 +64,8 @@ class MockConnector(CheckpointConnector): self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, - checkpoint: ConnectorCheckpoint, - ) -> CheckpointOutput: + checkpoint: MockConnectorCheckpoint, + ) -> Generator[Document | ConnectorFailure, None, MockConnectorCheckpoint]: if self.connector_yields is None: raise ValueError("No connector yields configured") @@ -84,3 +89,13 @@ class MockConnector(CheckpointConnector): yield failure return current_yield.checkpoint + + @override + def build_dummy_checkpoint(self) -> ConnectorCheckpoint: + return MockConnectorCheckpoint( + has_more=True, + last_document_id=None, + ) + + def validate_checkpoint_json(self, checkpoint_json: str) -> MockConnectorCheckpoint: + return MockConnectorCheckpoint.model_validate_json(checkpoint_json) diff --git a/backend/onyx/connectors/models.py b/backend/onyx/connectors/models.py index 00335cded..6ba15b6b1 100644 --- a/backend/onyx/connectors/models.py +++ b/backend/onyx/connectors/models.py @@ -1,4 +1,3 @@ -import json from datetime import datetime from enum import Enum from typing import Any @@ -232,21 +231,16 @@ class IndexAttemptMetadata(BaseModel): class ConnectorCheckpoint(BaseModel): # TODO: maybe move this to something disk-based to handle extremely large checkpoints? - checkpoint_content: dict has_more: bool - @classmethod - def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint": - return ConnectorCheckpoint(checkpoint_content={}, has_more=True) - def __str__(self) -> str: """String representation of the checkpoint, with truncation for large checkpoint content.""" MAX_CHECKPOINT_CONTENT_CHARS = 1000 - content_str = json.dumps(self.checkpoint_content) + content_str = self.model_dump_json() if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS: content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..." - return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})" + return content_str class DocumentFailure(BaseModel): diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index 66da040d9..e8af11721 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -10,10 +10,11 @@ from datetime import datetime from datetime import timezone from typing import Any from typing import cast -from typing import TypedDict +from pydantic import BaseModel from slack_sdk import WebClient from slack_sdk.errors import SlackApiError +from typing_extensions import override from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS from onyx.configs.app_configs import INDEX_BATCH_SIZE @@ -23,7 +24,6 @@ from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import CheckpointConnector -from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector @@ -56,8 +56,8 @@ MessageType = dict[str, Any] ThreadType = list[MessageType] -class SlackCheckpointContent(TypedDict): - channel_ids: list[str] +class SlackCheckpoint(ConnectorCheckpoint): + channel_ids: list[str] | None channel_completion_map: dict[str, str] current_channel: ChannelType | None seen_thread_ts: list[str] @@ -434,6 +434,12 @@ def _get_all_doc_ids( yield slim_doc_batch +class ProcessedSlackMessage(BaseModel): + doc: Document | None + thread_ts: str | None + failure: ConnectorFailure | None + + def _process_message( message: MessageType, client: WebClient, @@ -442,7 +448,7 @@ def _process_message( user_cache: dict[str, BasicExpertInfo | None], seen_thread_ts: set[str], msg_filter_func: Callable[[MessageType], bool] = default_msg_filter, -) -> tuple[Document | None, str | None, ConnectorFailure | None]: +) -> ProcessedSlackMessage: thread_ts = message.get("thread_ts") try: # causes random failures for testing checkpointing / continue on failure @@ -459,13 +465,13 @@ def _process_message( seen_thread_ts=seen_thread_ts, msg_filter_func=msg_filter_func, ) - return (doc, thread_ts, None) + return ProcessedSlackMessage(doc=doc, thread_ts=thread_ts, failure=None) except Exception as e: logger.exception(f"Error processing message {message['ts']}") - return ( - None, - thread_ts, - ConnectorFailure( + return ProcessedSlackMessage( + doc=None, + thread_ts=thread_ts, + failure=ConnectorFailure( failed_document=DocumentFailure( document_id=_build_doc_id( channel_id=channel["id"], thread_ts=(thread_ts or message["ts"]) @@ -478,7 +484,7 @@ def _process_message( ) -class SlackConnector(SlimConnector, CheckpointConnector): +class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]): MAX_WORKERS = 2 FAST_TIMEOUT = 1 @@ -529,8 +535,8 @@ class SlackConnector(SlimConnector, CheckpointConnector): self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, - checkpoint: ConnectorCheckpoint, - ) -> CheckpointOutput: + checkpoint: SlackCheckpoint, + ) -> Generator[Document | ConnectorFailure, None, SlackCheckpoint]: """Rough outline: Step 1: Get all channels, yield back Checkpoint. @@ -546,49 +552,36 @@ class SlackConnector(SlimConnector, CheckpointConnector): if self.client is None or self.text_cleaner is None: raise ConnectorMissingCredentialError("Slack") - checkpoint_content = cast( - SlackCheckpointContent, - ( - copy.deepcopy(checkpoint.checkpoint_content) - or { - "channel_ids": None, - "channel_completion_map": {}, - "current_channel": None, - "seen_thread_ts": [], - } - ), - ) + checkpoint = cast(SlackCheckpoint, copy.deepcopy(checkpoint)) # if this is the very first time we've called this, need to # get all relevant channels and save them into the checkpoint - if checkpoint_content["channel_ids"] is None: + if checkpoint.channel_ids is None: raw_channels = get_channels(self.client) filtered_channels = filter_channels( raw_channels, self.channels, self.channel_regex_enabled ) + checkpoint.channel_ids = [c["id"] for c in filtered_channels] if len(filtered_channels) == 0: + checkpoint.has_more = False return checkpoint - checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels] - checkpoint_content["current_channel"] = filtered_channels[0] - checkpoint = ConnectorCheckpoint( - checkpoint_content=checkpoint_content, # type: ignore - has_more=True, - ) + checkpoint.current_channel = filtered_channels[0] + checkpoint.has_more = True return checkpoint - final_channel_ids = checkpoint_content["channel_ids"] - channel = checkpoint_content["current_channel"] + final_channel_ids = checkpoint.channel_ids + channel = checkpoint.current_channel if channel is None: - raise ValueError("current_channel key not found in checkpoint") + raise ValueError("current_channel key not set in checkpoint") channel_id = channel["id"] if channel_id not in final_channel_ids: raise ValueError(f"Channel {channel_id} not found in checkpoint") oldest = str(start) if start else None - latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end)) - seen_thread_ts = set(checkpoint_content["seen_thread_ts"]) + latest = checkpoint.channel_completion_map.get(channel_id, str(end)) + seen_thread_ts = set(checkpoint.seen_thread_ts) try: logger.debug( f"Getting messages for channel {channel} within range {oldest} - {latest}" @@ -600,7 +593,7 @@ class SlackConnector(SlimConnector, CheckpointConnector): # Process messages in parallel using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=SlackConnector.MAX_WORKERS) as executor: - futures: list[Future] = [] + futures: list[Future[ProcessedSlackMessage]] = [] for message in message_batch: # Capture the current context so that the thread gets the current tenant ID current_context = contextvars.copy_context() @@ -618,7 +611,10 @@ class SlackConnector(SlimConnector, CheckpointConnector): ) for future in as_completed(futures): - doc, thread_ts, failures = future.result() + processed_slack_message = future.result() + doc = processed_slack_message.doc + thread_ts = processed_slack_message.thread_ts + failure = processed_slack_message.failure if doc: # handle race conditions here since this is single # threaded. Multi-threaded _process_message reads from this @@ -628,36 +624,31 @@ class SlackConnector(SlimConnector, CheckpointConnector): if thread_ts not in seen_thread_ts: yield doc - if thread_ts: - seen_thread_ts.add(thread_ts) - elif failures: - for failure in failures: - yield failure + assert thread_ts, "found non-None doc with None thread_ts" + seen_thread_ts.add(thread_ts) + elif failure: + yield failure - checkpoint_content["seen_thread_ts"] = list(seen_thread_ts) - checkpoint_content["channel_completion_map"][channel["id"]] = new_latest + checkpoint.seen_thread_ts = list(seen_thread_ts) + checkpoint.channel_completion_map[channel["id"]] = new_latest if has_more_in_channel: - checkpoint_content["current_channel"] = channel + checkpoint.current_channel = channel else: new_channel_id = next( ( channel_id for channel_id in final_channel_ids - if channel_id - not in checkpoint_content["channel_completion_map"] + if channel_id not in checkpoint.channel_completion_map ), None, ) if new_channel_id: new_channel = _get_channel_by_id(self.client, new_channel_id) - checkpoint_content["current_channel"] = new_channel + checkpoint.current_channel = new_channel else: - checkpoint_content["current_channel"] = None + checkpoint.current_channel = None - checkpoint = ConnectorCheckpoint( - checkpoint_content=checkpoint_content, # type: ignore - has_more=checkpoint_content["current_channel"] is not None, - ) + checkpoint.has_more = checkpoint.current_channel is not None return checkpoint except Exception as e: @@ -766,6 +757,20 @@ class SlackConnector(SlimConnector, CheckpointConnector): f"Unexpected error during Slack settings validation: {e}" ) + @override + def build_dummy_checkpoint(self) -> SlackCheckpoint: + return SlackCheckpoint( + channel_ids=None, + channel_completion_map={}, + current_channel=None, + seen_thread_ts=[], + has_more=True, + ) + + @override + def validate_checkpoint_json(self, checkpoint_json: str) -> SlackCheckpoint: + return SlackCheckpoint.model_validate_json(checkpoint_json) + if __name__ == "__main__": import os @@ -780,9 +785,11 @@ if __name__ == "__main__": current = time.time() one_day_ago = current - 24 * 60 * 60 # 1 day - checkpoint = ConnectorCheckpoint.build_dummy_checkpoint() + checkpoint = connector.build_dummy_checkpoint() - gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint) + gen = connector.load_from_checkpoint( + one_day_ago, current, cast(SlackCheckpoint, checkpoint) + ) try: for document_or_failure in gen: if isinstance(document_or_failure, Document): diff --git a/backend/onyx/tools/message.py b/backend/onyx/tools/message.py index bb71d56be..761ac2713 100644 --- a/backend/onyx/tools/message.py +++ b/backend/onyx/tools/message.py @@ -21,7 +21,6 @@ def build_tool_message( ) -# TODO: does this NEED to be BaseModel__v1? class ToolCallSummary(BaseModel): tool_call_request: AIMessage tool_call_result: ToolMessage diff --git a/backend/onyx/utils/lazy.py b/backend/onyx/utils/lazy.py new file mode 100644 index 000000000..5df9207ad --- /dev/null +++ b/backend/onyx/utils/lazy.py @@ -0,0 +1,13 @@ +from collections.abc import Callable +from functools import lru_cache +from typing import TypeVar + +R = TypeVar("R") + + +def lazy_eval(func: Callable[[], R]) -> Callable[[], R]: + @lru_cache(maxsize=1) + def lazy_func() -> R: + return func() + + return lazy_func diff --git a/backend/onyx/utils/threadpool_concurrency.py b/backend/onyx/utils/threadpool_concurrency.py index fd8b70174..67d90b02f 100644 --- a/backend/onyx/utils/threadpool_concurrency.py +++ b/backend/onyx/utils/threadpool_concurrency.py @@ -1,18 +1,148 @@ +import collections.abc import contextvars +import copy import threading import uuid from collections.abc import Callable +from collections.abc import Iterator +from collections.abc import MutableMapping from concurrent.futures import as_completed +from concurrent.futures import FIRST_COMPLETED +from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import wait from typing import Any from typing import Generic +from typing import overload from typing import TypeVar +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + from onyx.utils.logger import setup_logger logger = setup_logger() R = TypeVar("R") +KT = TypeVar("KT") # Key type +VT = TypeVar("VT") # Value type +_T = TypeVar("_T") # Default type + + +class ThreadSafeDict(MutableMapping[KT, VT]): + """ + A thread-safe dictionary implementation that uses a lock to ensure thread safety. + Implements the MutableMapping interface to provide a complete dictionary-like interface. + + Example usage: + # Create a thread-safe dictionary + safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict() + + # Basic operations (atomic) + safe_dict["key"] = 1 + value = safe_dict["key"] + del safe_dict["key"] + + # Bulk operations (atomic) + safe_dict.update({"key1": 1, "key2": 2}) + """ + + def __init__(self, input_dict: dict[KT, VT] | None = None) -> None: + self._dict: dict[KT, VT] = input_dict or {} + self.lock = threading.Lock() + + def __getitem__(self, key: KT) -> VT: + with self.lock: + return self._dict[key] + + def __setitem__(self, key: KT, value: VT) -> None: + with self.lock: + self._dict[key] = value + + def __delitem__(self, key: KT) -> None: + with self.lock: + del self._dict[key] + + def __iter__(self) -> Iterator[KT]: + # Return a snapshot of keys to avoid potential modification during iteration + with self.lock: + return iter(list(self._dict.keys())) + + def __len__(self) -> int: + with self.lock: + return len(self._dict) + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.no_info_after_validator_function( + cls.validate, handler(dict[KT, VT]) + ) + + @classmethod + def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]": + if isinstance(v, dict): + return ThreadSafeDict(v) + return v + + def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]": + return ThreadSafeDict(copy.deepcopy(self._dict)) + + def clear(self) -> None: + """Remove all items from the dictionary atomically.""" + with self.lock: + self._dict.clear() + + def copy(self) -> dict[KT, VT]: + """Return a shallow copy of the dictionary atomically.""" + with self.lock: + return self._dict.copy() + + @overload + def get(self, key: KT) -> VT | None: + ... + + @overload + def get(self, key: KT, default: VT | _T) -> VT | _T: + ... + + def get(self, key: KT, default: Any = None) -> Any: + """Get a value with a default, atomically.""" + with self.lock: + return self._dict.get(key, default) + + def pop(self, key: KT, default: Any = None) -> Any: + """Remove and return a value with optional default, atomically.""" + with self.lock: + if default is None: + return self._dict.pop(key) + return self._dict.pop(key, default) + + def setdefault(self, key: KT, default: VT) -> VT: + """Set a default value if key is missing, atomically.""" + with self.lock: + return self._dict.setdefault(key, default) + + def update(self, *args: Any, **kwargs: VT) -> None: + """Update the dictionary atomically from another mapping or from kwargs.""" + with self.lock: + self._dict.update(*args, **kwargs) + + def items(self) -> collections.abc.ItemsView[KT, VT]: + """Return a view of (key, value) pairs atomically.""" + with self.lock: + return collections.abc.ItemsView(self) + + def keys(self) -> collections.abc.KeysView[KT]: + """Return a view of keys atomically.""" + with self.lock: + return collections.abc.KeysView(self) + + def values(self) -> collections.abc.ValuesView[VT]: + """Return a view of values atomically.""" + with self.lock: + return collections.abc.ValuesView(self) def run_functions_tuples_in_parallel( @@ -190,3 +320,27 @@ def wait_on_background(task: TimeoutThread[R]) -> R: raise task.exception return task.result + + +def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]: + return ind, next(g, None) + + +def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_index: dict[Future[tuple[int, R | None]], int] = { + executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens) + } + + next_ind = len(gens) + while future_to_index: + done, _ = wait(future_to_index, return_when=FIRST_COMPLETED) + for future in done: + ind, result = future.result() + if result is not None: + yield result + future_to_index[ + executor.submit(_next_or_none, ind, gens[ind]) + ] = next_ind + next_ind += 1 + del future_to_index[future] diff --git a/backend/tests/daily/connectors/google_drive/consts_and_utils.py b/backend/tests/daily/connectors/google_drive/consts_and_utils.py index 60bbca323..c6dad3d9f 100644 --- a/backend/tests/daily/connectors/google_drive/consts_and_utils.py +++ b/backend/tests/daily/connectors/google_drive/consts_and_utils.py @@ -1,5 +1,9 @@ +import time from collections.abc import Sequence +from onyx.connectors.connector_runner import CheckpointOutputWrapper +from onyx.connectors.google_drive.connector import GoogleDriveCheckpoint +from onyx.connectors.google_drive.connector import GoogleDriveConnector from onyx.connectors.models import Document from onyx.connectors.models import TextSection @@ -21,6 +25,7 @@ FOLDER_2_FILE_IDS = list(range(45, 50)) FOLDER_2_1_FILE_IDS = list(range(50, 55)) FOLDER_2_2_FILE_IDS = list(range(55, 60)) SECTIONS_FILE_IDS = [61] +FOLDER_3_FILE_IDS = list(range(62, 65)) PUBLIC_FOLDER_RANGE = FOLDER_1_2_FILE_IDS PUBLIC_FILE_IDS = list(range(55, 57)) @@ -54,6 +59,8 @@ SECTIONS_FOLDER_URL = ( "https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33" ) +SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA" + ADMIN_EMAIL = "admin@onyx-test.com" TEST_USER_1_EMAIL = "test_user_1@onyx-test.com" TEST_USER_2_EMAIL = "test_user_2@onyx-test.com" @@ -133,17 +140,19 @@ def filter_invalid_prefixes(names: set[str]) -> set[str]: return {name for name in names if name.startswith(_VALID_PREFIX)} -def print_discrepencies( +def print_discrepancies( expected: set[str], retrieved: set[str], ) -> None: if expected != retrieved: - print(expected) - print(retrieved) + expected_list = sorted(expected) + retrieved_list = sorted(retrieved) + print(expected_list) + print(retrieved_list) print("Extra:") - print(retrieved - expected) + print(sorted(retrieved - expected)) print("Missing:") - print(expected - retrieved) + print(sorted(expected - retrieved)) def _get_expected_file_content(file_id: int) -> str: @@ -164,6 +173,8 @@ def assert_retrieved_docs_match_expected( _get_expected_file_content(file_id) for file_id in expected_file_ids } + retrieved_docs.sort(key=lambda x: x.semantic_identifier) + for doc in retrieved_docs: print(f"doc.semantic_identifier: {doc.semantic_identifier}") @@ -190,15 +201,34 @@ def assert_retrieved_docs_match_expected( ) # Check file names - print_discrepencies( + print_discrepancies( expected=expected_file_names, retrieved=valid_retrieved_file_names, ) assert expected_file_names == valid_retrieved_file_names # Check file texts - print_discrepencies( + print_discrepancies( expected=expected_file_texts, retrieved=valid_retrieved_texts, ) assert expected_file_texts == valid_retrieved_texts + + +def load_all_docs(connector: GoogleDriveConnector) -> list[Document]: + retrieved_docs: list[Document] = [] + checkpoint = connector.build_dummy_checkpoint() + while checkpoint.has_more: + for doc, failure, next_checkpoint in CheckpointOutputWrapper()( + connector.load_from_checkpoint(0, time.time(), checkpoint) + ): + assert failure is None + if next_checkpoint is None: + assert isinstance( + doc, Document + ), f"Should not fail with {type(doc)} {doc}" + retrieved_docs.append(doc) + else: + assert isinstance(next_checkpoint, GoogleDriveCheckpoint) + checkpoint = next_checkpoint + return retrieved_docs diff --git a/backend/tests/daily/connectors/google_drive/test_admin_oauth.py b/backend/tests/daily/connectors/google_drive/test_admin_oauth.py index ae854de9e..ca500ce41 100644 --- a/backend/tests/daily/connectors/google_drive/test_admin_oauth.py +++ b/backend/tests/daily/connectors/google_drive/test_admin_oauth.py @@ -1,10 +1,8 @@ -import time from collections.abc import Callable from unittest.mock import MagicMock from unittest.mock import patch from onyx.connectors.google_drive.connector import GoogleDriveConnector -from onyx.connectors.models import Document from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS @@ -23,6 +21,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL +from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL @@ -47,9 +46,7 @@ def test_include_all( my_drive_emails=None, shared_drive_urls=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) # Should get everything in shared and admin's My Drive with oauth expected_file_ids = ( @@ -89,9 +86,7 @@ def test_include_shared_drives_only( my_drive_emails=None, shared_drive_urls=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) # Should only get shared drives expected_file_ids = ( @@ -129,9 +124,7 @@ def test_include_my_drives_only( my_drive_emails=None, shared_drive_urls=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) # Should only get primary_admins My Drive because we are impersonating them expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS @@ -160,9 +153,7 @@ def test_drive_one_only( my_drive_emails=None, shared_drive_urls=",".join([str(url) for url in drive_urls]), ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = ( SHARED_DRIVE_1_FILE_IDS @@ -196,9 +187,7 @@ def test_folder_and_shared_drive( my_drive_emails=None, shared_drive_urls=",".join([str(url) for url in drive_urls]), ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = ( SHARED_DRIVE_1_FILE_IDS @@ -243,9 +232,7 @@ def test_folders_only( my_drive_emails=None, shared_drive_urls=",".join([str(url) for url in shared_drive_urls]), ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = ( FOLDER_1_1_FILE_IDS @@ -281,9 +268,7 @@ def test_personal_folders_only( my_drive_emails=None, shared_drive_urls=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = ADMIN_FOLDER_3_FILE_IDS assert_retrieved_docs_match_expected( diff --git a/backend/tests/daily/connectors/google_drive/test_sections.py b/backend/tests/daily/connectors/google_drive/test_sections.py index 14e858840..6dd22a0e5 100644 --- a/backend/tests/daily/connectors/google_drive/test_sections.py +++ b/backend/tests/daily/connectors/google_drive/test_sections.py @@ -1,11 +1,10 @@ -import time from collections.abc import Callable from unittest.mock import MagicMock from unittest.mock import patch from onyx.connectors.google_drive.connector import GoogleDriveConnector -from onyx.connectors.models import Document from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL +from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_URL @@ -37,9 +36,7 @@ def test_google_drive_sections( my_drive_emails=None, ) for connector in [oauth_connector, service_acct_connector]: - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) # Verify we got the 1 doc with sections assert len(retrieved_docs) == 1 diff --git a/backend/tests/daily/connectors/google_drive/test_service_acct.py b/backend/tests/daily/connectors/google_drive/test_service_acct.py index 69002158b..e610880fa 100644 --- a/backend/tests/daily/connectors/google_drive/test_service_acct.py +++ b/backend/tests/daily/connectors/google_drive/test_service_acct.py @@ -1,10 +1,8 @@ -import time from collections.abc import Callable from unittest.mock import MagicMock from unittest.mock import patch from onyx.connectors.google_drive.connector import GoogleDriveConnector -from onyx.connectors.models import Document from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS @@ -23,6 +21,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL +from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL @@ -52,9 +51,7 @@ def test_include_all( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) # Should get everything expected_file_ids = ( @@ -97,9 +94,7 @@ def test_include_shared_drives_only( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) # Should only get shared drives expected_file_ids = ( @@ -137,9 +132,7 @@ def test_include_my_drives_only( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) # Should only get everyone's My Drives expected_file_ids = ( @@ -174,9 +167,7 @@ def test_drive_one_only( shared_drive_urls=",".join([str(url) for url in urls]), my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) # We ignore shared_drive_urls if include_shared_drives is False expected_file_ids = ( @@ -211,9 +202,7 @@ def test_folder_and_shared_drive( shared_folder_urls=",".join([str(url) for url in folder_urls]), my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) # Should get everything except for the top level files in drive 2 expected_file_ids = ( @@ -259,9 +248,7 @@ def test_folders_only( shared_folder_urls=",".join([str(url) for url in folder_urls]), my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = ( FOLDER_1_1_FILE_IDS @@ -298,9 +285,7 @@ def test_specific_emails( shared_drive_urls=None, my_drive_emails=",".join([str(email) for email in my_drive_emails]), ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS assert_retrieved_docs_match_expected( @@ -330,9 +315,7 @@ def get_specific_folders_in_my_drive( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = ADMIN_FOLDER_3_FILE_IDS assert_retrieved_docs_match_expected( diff --git a/backend/tests/daily/connectors/google_drive/test_slim_docs.py b/backend/tests/daily/connectors/google_drive/test_slim_docs.py index 1248f6d73..45dadb0fe 100644 --- a/backend/tests/daily/connectors/google_drive/test_slim_docs.py +++ b/backend/tests/daily/connectors/google_drive/test_slim_docs.py @@ -22,7 +22,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_I from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS -from tests.daily.connectors.google_drive.consts_and_utils import print_discrepencies +from tests.daily.connectors.google_drive.consts_and_utils import print_discrepancies from tests.daily.connectors.google_drive.consts_and_utils import PUBLIC_RANGE from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS @@ -83,7 +83,7 @@ def assert_correct_access_for_user( expected_file_names = {file_name_template.format(i) for i in all_accessible_ids} filtered_retrieved_file_names = filter_invalid_prefixes(retrieved_file_names) - print_discrepencies(expected_file_names, filtered_retrieved_file_names) + print_discrepancies(expected_file_names, filtered_retrieved_file_names) assert expected_file_names == filtered_retrieved_file_names @@ -175,7 +175,7 @@ def test_all_permissions( # Should get everything filtered_retrieved_file_names = filter_invalid_prefixes(found_file_names) - print_discrepencies(expected_file_names, filtered_retrieved_file_names) + print_discrepancies(expected_file_names, filtered_retrieved_file_names) assert expected_file_names == filtered_retrieved_file_names group_map = get_group_map(google_drive_connector) diff --git a/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py b/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py index d399868ce..cbb474b3a 100644 --- a/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py +++ b/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py @@ -1,10 +1,8 @@ -import time from collections.abc import Callable from unittest.mock import MagicMock from unittest.mock import patch from onyx.connectors.google_drive.connector import GoogleDriveConnector -from onyx.connectors.models import Document from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import ( assert_retrieved_docs_match_expected, @@ -14,6 +12,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL +from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS @@ -37,9 +36,7 @@ def test_all( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = ( # These are the files from my drive @@ -77,9 +74,7 @@ def test_shared_drives_only( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = ( # These are the files from shared drives @@ -112,9 +107,7 @@ def test_shared_with_me_only( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = ( # These are the files shared with me from admin @@ -145,9 +138,7 @@ def test_my_drive_only( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) # These are the files from my drive expected_file_ids = TEST_USER_1_FILE_IDS @@ -175,9 +166,7 @@ def test_shared_my_drive_folder( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = ( # this is a folder from admin's drive that is shared with me @@ -207,9 +196,7 @@ def test_shared_drive_folder( shared_drive_urls=None, my_drive_emails=None, ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) + retrieved_docs = load_all_docs(connector) expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS assert_retrieved_docs_match_expected( diff --git a/backend/tests/integration/tests/indexing/test_checkpointing.py b/backend/tests/integration/tests/indexing/test_checkpointing.py index 6aeac9f53..f1766fd93 100644 --- a/backend/tests/integration/tests/indexing/test_checkpointing.py +++ b/backend/tests/integration/tests/indexing/test_checkpointing.py @@ -7,7 +7,7 @@ import httpx import pytest from onyx.configs.constants import DocumentSource -from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import EntityFailure from onyx.connectors.models import InputType @@ -54,9 +54,9 @@ def test_mock_connector_basic_flow( json=[ { "documents": [test_doc.model_dump(mode="json")], - "checkpoint": ConnectorCheckpoint( - checkpoint_content={}, has_more=False - ).model_dump(mode="json"), + "checkpoint": MockConnectorCheckpoint(has_more=False).model_dump( + mode="json" + ), "failures": [], } ], @@ -128,9 +128,9 @@ def test_mock_connector_with_failures( json=[ { "documents": [doc1.model_dump(mode="json")], - "checkpoint": ConnectorCheckpoint( - checkpoint_content={}, has_more=False - ).model_dump(mode="json"), + "checkpoint": MockConnectorCheckpoint(has_more=False).model_dump( + mode="json" + ), "failures": [doc2_failure.model_dump(mode="json")], } ], @@ -208,9 +208,9 @@ def test_mock_connector_failure_recovery( json=[ { "documents": [doc1.model_dump(mode="json")], - "checkpoint": ConnectorCheckpoint( - checkpoint_content={}, has_more=False - ).model_dump(mode="json"), + "checkpoint": MockConnectorCheckpoint(has_more=False).model_dump( + mode="json" + ), "failures": [ doc2_failure.model_dump(mode="json"), ConnectorFailure( @@ -292,9 +292,9 @@ def test_mock_connector_failure_recovery( doc1.model_dump(mode="json"), doc2.model_dump(mode="json"), ], - "checkpoint": ConnectorCheckpoint( - checkpoint_content={}, has_more=False - ).model_dump(mode="json"), + "checkpoint": MockConnectorCheckpoint(has_more=False).model_dump( + mode="json" + ), "failures": [], } ], @@ -372,23 +372,23 @@ def test_mock_connector_checkpoint_recovery( json=[ { "documents": [doc.model_dump(mode="json") for doc in docs_batch_1], - "checkpoint": ConnectorCheckpoint( - checkpoint_content={}, has_more=True + "checkpoint": MockConnectorCheckpoint( + has_more=True, last_document_id=docs_batch_1[-1].id ).model_dump(mode="json"), "failures": [], }, { "documents": [doc2.model_dump(mode="json")], - "checkpoint": ConnectorCheckpoint( - checkpoint_content={}, has_more=True + "checkpoint": MockConnectorCheckpoint( + has_more=True, last_document_id=doc2.id ).model_dump(mode="json"), "failures": [], }, { "documents": [], # should never hit this, unhandled exception happens first - "checkpoint": ConnectorCheckpoint( - checkpoint_content={}, has_more=False + "checkpoint": MockConnectorCheckpoint( + has_more=False, last_document_id=doc2.id ).model_dump(mode="json"), "failures": [], "unhandled_exception": "Simulated unhandled error", @@ -446,12 +446,16 @@ def test_mock_connector_checkpoint_recovery( initial_checkpoints = response.json() # Verify we got the expected checkpoints in order - assert len(initial_checkpoints) > 0 - assert ( - initial_checkpoints[0]["checkpoint_content"] == {} - ) # Initial empty checkpoint - assert initial_checkpoints[1]["checkpoint_content"] == {} - assert initial_checkpoints[2]["checkpoint_content"] == {} + assert len(initial_checkpoints) == 3 + assert initial_checkpoints[0] == { + "has_more": True, + "last_document_id": None, + } # Initial empty checkpoint + assert initial_checkpoints[1] == { + "has_more": True, + "last_document_id": docs_batch_1[-1].id, + } + assert initial_checkpoints[2] == {"has_more": True, "last_document_id": doc2.id} # Reset the mock server for the next run response = mock_server_client.post("/reset") @@ -463,8 +467,8 @@ def test_mock_connector_checkpoint_recovery( json=[ { "documents": [doc3.model_dump(mode="json")], - "checkpoint": ConnectorCheckpoint( - checkpoint_content={}, has_more=False + "checkpoint": MockConnectorCheckpoint( + has_more=False, last_document_id=doc3.id ).model_dump(mode="json"), "failures": [], } @@ -515,4 +519,4 @@ def test_mock_connector_checkpoint_recovery( # Verify the recovery run started from the last successful checkpoint assert len(recovery_checkpoints) == 1 - assert recovery_checkpoints[0]["checkpoint_content"] == {} + assert recovery_checkpoints[0] == {"has_more": True, "last_document_id": doc2.id} diff --git a/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py b/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py index 8b9505bbc..5ad60d28c 100644 --- a/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py +++ b/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py @@ -1,10 +1,16 @@ import contextvars +import threading import time +from collections.abc import Generator +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor import pytest +from onyx.utils.threadpool_concurrency import parallel_yield from onyx.utils.threadpool_concurrency import run_in_background from onyx.utils.threadpool_concurrency import run_with_timeout +from onyx.utils.threadpool_concurrency import ThreadSafeDict from onyx.utils.threadpool_concurrency import wait_on_background # Create a context variable for testing @@ -148,3 +154,237 @@ def test_multiple_background_tasks() -> None: # Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s) assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations + + +def test_thread_safe_dict_basic_operations() -> None: + """Test basic operations of ThreadSafeDict""" + d = ThreadSafeDict[str, int]() + + # Test setting and getting + d["a"] = 1 + assert d["a"] == 1 + + # Test get with default + assert d.get("a", None) == 1 + assert d.get("b", 2) == 2 + + # Test deletion + del d["a"] + assert "a" not in d + + # Test length + d["x"] = 10 + d["y"] = 20 + assert len(d) == 2 + + # Test iteration + keys = sorted(d.keys()) + assert keys == ["x", "y"] + + # Test items and values + assert dict(d.items()) == {"x": 10, "y": 20} + assert sorted(d.values()) == [10, 20] + + +def test_thread_safe_dict_concurrent_access() -> None: + """Test ThreadSafeDict with concurrent access from multiple threads""" + d = ThreadSafeDict[str, int]() + num_threads = 10 + iterations = 1000 + + def increment_values() -> None: + for i in range(iterations): + key = str(i % 5) # Use 5 different keys + # Get current value or 0 if not exists, increment, then store + d[key] = d.get(key, 0) + 1 + + # Create and start threads + threads = [] + for _ in range(num_threads): + t = threading.Thread(target=increment_values) + threads.append(t) + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # Verify results + # Each key should have been incremented (num_threads * iterations) / 5 times + expected_value = (num_threads * iterations) // 5 + for i in range(5): + assert d[str(i)] == expected_value + + +def test_thread_safe_dict_bulk_operations() -> None: + """Test bulk operations of ThreadSafeDict""" + d = ThreadSafeDict[str, int]() + + # Test update with dict + d.update({"a": 1, "b": 2}) + assert dict(d.items()) == {"a": 1, "b": 2} + + # Test update with kwargs + d.update(c=3, d=4) + assert dict(d.items()) == {"a": 1, "b": 2, "c": 3, "d": 4} + + # Test clear + d.clear() + assert len(d) == 0 + + +def test_thread_safe_dict_concurrent_bulk_operations() -> None: + """Test ThreadSafeDict with concurrent bulk operations""" + d = ThreadSafeDict[str, int]() + num_threads = 5 + + def bulk_update(start: int) -> None: + # Each thread updates with its own range of numbers + updates = {str(i): i for i in range(start, start + 20)} + d.update(updates) + time.sleep(0.01) # Add some delay to increase chance of thread overlap + + # Run updates concurrently + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(bulk_update, i * 20) for i in range(num_threads)] + for future in futures: + future.result() + + # Verify results + assert len(d) == num_threads * 20 + # Verify all numbers from 0 to (num_threads * 20) are present + for i in range(num_threads * 20): + assert d[str(i)] == i + + +def test_thread_safe_dict_atomic_operations() -> None: + """Test atomic operations with ThreadSafeDict's lock""" + d = ThreadSafeDict[str, list[int]]() + d["numbers"] = [] + + def append_numbers(start: int) -> None: + numbers = d["numbers"] + with d.lock: + for i in range(start, start + 5): + numbers.append(i) + time.sleep(0.001) # Add delay to increase chance of thread overlap + d["numbers"] = numbers + + # Run concurrent append operations + threads = [] + for i in range(4): # 4 threads, each adding 5 numbers + t = threading.Thread(target=append_numbers, args=(i * 5,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # Verify results + numbers = d["numbers"] + assert len(numbers) == 20 # 4 threads * 5 numbers each + assert sorted(numbers) == list(range(20)) # All numbers 0-19 should be present + + +def test_parallel_yield_basic() -> None: + """Test that parallel_yield correctly yields values from multiple generators.""" + + def make_gen(values: list[int], delay: float) -> Generator[int, None, None]: + for v in values: + time.sleep(delay) + yield v + + # Create generators with different delays + gen1 = make_gen([1, 4, 7], 0.1) # Slower generator + gen2 = make_gen([2, 5, 8], 0.05) # Faster generator + gen3 = make_gen([3, 6, 9], 0.15) # Slowest generator + + # Collect results with timestamps + results: list[tuple[float, int]] = [] + start_time = time.time() + + for value in parallel_yield([gen1, gen2, gen3]): + results.append((time.time() - start_time, value)) + + # Verify all values were yielded + assert sorted(v for _, v in results) == list(range(1, 10)) + + # Verify that faster generators yielded earlier + # Group results by generator (values 1,4,7 are gen1, 2,5,8 are gen2, 3,6,9 are gen3) + gen1_times = [t for t, v in results if v in (1, 4, 7)] + gen2_times = [t for t, v in results if v in (2, 5, 8)] + gen3_times = [t for t, v in results if v in (3, 6, 9)] + + # Average times for each generator + avg_gen1 = sum(gen1_times) / len(gen1_times) + avg_gen2 = sum(gen2_times) / len(gen2_times) + avg_gen3 = sum(gen3_times) / len(gen3_times) + + # Verify gen2 (fastest) has lowest average time + assert avg_gen2 < avg_gen1 + assert avg_gen2 < avg_gen3 + + +def test_parallel_yield_empty_generators() -> None: + """Test parallel_yield with empty generators.""" + + def empty_gen() -> Iterator[int]: + if False: + yield 0 # Makes this a generator function + + gens = [empty_gen() for _ in range(3)] + results = list(parallel_yield(gens)) + assert len(results) == 0 + + +def test_parallel_yield_different_lengths() -> None: + """Test parallel_yield with generators of different lengths.""" + + def make_gen(count: int) -> Iterator[int]: + for i in range(count): + yield i + time.sleep(0.01) # Small delay to ensure concurrent execution + + gens = [ + make_gen(1), # Yields: [0] + make_gen(3), # Yields: [0, 1, 2] + make_gen(2), # Yields: [0, 1] + ] + + results = list(parallel_yield(gens)) + assert len(results) == 6 # Total number of items from all generators + assert sorted(results) == [0, 0, 0, 1, 1, 2] + + +def test_parallel_yield_exception_handling() -> None: + """Test parallel_yield handles exceptions in generators properly.""" + + def failing_gen() -> Iterator[int]: + yield 1 + raise ValueError("Generator failure") + + def normal_gen() -> Iterator[int]: + yield 2 + yield 3 + + gens = [failing_gen(), normal_gen()] + + with pytest.raises(ValueError, match="Generator failure"): + list(parallel_yield(gens)) + + +def test_parallel_yield_non_blocking() -> None: + """Test parallel_yield with non-blocking generators (simple ranges).""" + + def range_gen(start: int, end: int) -> Iterator[int]: + for i in range(start, end): + yield i + + # Create three overlapping ranges + gens = [range_gen(0, 100), range_gen(100, 200), range_gen(200, 300)] + + results = list(parallel_yield(gens)) + + # Verify no values are missing + assert len(results) == 300 # Should have all values from 0 to 299 + assert sorted(results) == list(range(300)) From 0b87549f3519d12412f24209e29274011732a28b Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Wed, 19 Mar 2025 13:08:44 -0700 Subject: [PATCH 04/18] Feature/email whitelabeling (#4260) * work in progress * work in progress * WIP * refactor, use inline attachment for image (base64 encoding doesn't work) * pretty sure this belongs behind a multi_tenant check * code review / refactor --------- Co-authored-by: Richard Kuo (Danswer) --- .../ee/onyx/server/enterprise_settings/api.py | 44 +++-- .../onyx/server/enterprise_settings/store.py | 38 ++++- backend/onyx/auth/email_utils.py | 157 +++++++++++++++--- backend/onyx/configs/app_configs.py | 4 + backend/onyx/configs/constants.py | 5 + backend/onyx/file_store/file_store.py | 15 ++ backend/onyx/server/manage/users.py | 8 +- backend/onyx/server/runtime/onyx_runtime.py | 89 ++++++++++ backend/onyx/utils/file.py | 36 ++++ backend/requirements/default.txt | 1 + backend/static/images/logo.png | Bin 0 -> 6723 bytes backend/static/images/logotype.png | Bin 0 -> 45427 bytes 12 files changed, 357 insertions(+), 40 deletions(-) create mode 100644 backend/onyx/server/runtime/onyx_runtime.py create mode 100644 backend/onyx/utils/file.py create mode 100644 backend/static/images/logo.png create mode 100644 backend/static/images/logotype.png diff --git a/backend/ee/onyx/server/enterprise_settings/api.py b/backend/ee/onyx/server/enterprise_settings/api.py index dbc89acd2..d849ab399 100644 --- a/backend/ee/onyx/server/enterprise_settings/api.py +++ b/backend/ee/onyx/server/enterprise_settings/api.py @@ -15,8 +15,8 @@ from sqlalchemy.orm import Session from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload from ee.onyx.server.enterprise_settings.models import EnterpriseSettings -from ee.onyx.server.enterprise_settings.store import _LOGO_FILENAME -from ee.onyx.server.enterprise_settings.store import _LOGOTYPE_FILENAME +from ee.onyx.server.enterprise_settings.store import get_logo_filename +from ee.onyx.server.enterprise_settings.store import get_logotype_filename from ee.onyx.server.enterprise_settings.store import load_analytics_script from ee.onyx.server.enterprise_settings.store import load_settings from ee.onyx.server.enterprise_settings.store import store_analytics_script @@ -28,7 +28,7 @@ from onyx.auth.users import get_user_manager from onyx.auth.users import UserManager from onyx.db.engine import get_session from onyx.db.models import User -from onyx.file_store.file_store import get_default_file_store +from onyx.file_store.file_store import PostgresBackedFileStore from onyx.utils.logger import setup_logger admin_router = APIRouter(prefix="/admin/enterprise-settings") @@ -131,31 +131,49 @@ def put_logo( upload_logo(file=file, db_session=db_session, is_logotype=is_logotype) -def fetch_logo_or_logotype(is_logotype: bool, db_session: Session) -> Response: +def fetch_logo_helper(db_session: Session) -> Response: try: - file_store = get_default_file_store(db_session) - filename = _LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME - file_io = file_store.read_file(filename, mode="b") - # NOTE: specifying "image/jpeg" here, but it still works for pngs - # TODO: do this properly - return Response(content=file_io.read(), media_type="image/jpeg") + file_store = PostgresBackedFileStore(db_session) + onyx_file = file_store.get_file_with_mime_type(get_logo_filename()) + if not onyx_file: + raise ValueError("get_onyx_file returned None!") except Exception: raise HTTPException( status_code=404, - detail=f"No {'logotype' if is_logotype else 'logo'} file found", + detail="No logo file found", ) + else: + return Response(content=onyx_file.data, media_type=onyx_file.mime_type) + + +def fetch_logotype_helper(db_session: Session) -> Response: + try: + file_store = PostgresBackedFileStore(db_session) + onyx_file = file_store.get_file_with_mime_type(get_logotype_filename()) + if not onyx_file: + raise ValueError("get_onyx_file returned None!") + except Exception: + raise HTTPException( + status_code=404, + detail="No logotype file found", + ) + else: + return Response(content=onyx_file.data, media_type=onyx_file.mime_type) @basic_router.get("/logotype") def fetch_logotype(db_session: Session = Depends(get_session)) -> Response: - return fetch_logo_or_logotype(is_logotype=True, db_session=db_session) + return fetch_logotype_helper(db_session) @basic_router.get("/logo") def fetch_logo( is_logotype: bool = False, db_session: Session = Depends(get_session) ) -> Response: - return fetch_logo_or_logotype(is_logotype=is_logotype, db_session=db_session) + if is_logotype: + return fetch_logotype_helper(db_session) + + return fetch_logo_helper(db_session) @admin_router.put("/custom-analytics-script") diff --git a/backend/ee/onyx/server/enterprise_settings/store.py b/backend/ee/onyx/server/enterprise_settings/store.py index 65a4dd5bf..f7d0e8535 100644 --- a/backend/ee/onyx/server/enterprise_settings/store.py +++ b/backend/ee/onyx/server/enterprise_settings/store.py @@ -13,6 +13,7 @@ from ee.onyx.server.enterprise_settings.models import EnterpriseSettings from onyx.configs.constants import FileOrigin from onyx.configs.constants import KV_CUSTOM_ANALYTICS_SCRIPT_KEY from onyx.configs.constants import KV_ENTERPRISE_SETTINGS_KEY +from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME from onyx.file_store.file_store import get_default_file_store from onyx.key_value_store.factory import get_kv_store from onyx.key_value_store.interface import KvKeyNotFoundError @@ -21,8 +22,18 @@ from onyx.utils.logger import setup_logger logger = setup_logger() +_LOGO_FILENAME = "__logo__" +_LOGOTYPE_FILENAME = "__logotype__" + def load_settings() -> EnterpriseSettings: + """Loads settings data directly from DB. This should be used primarily + for checking what is actually in the DB, aka for editing and saving back settings. + + Runtime settings actually used by the application should be checked with + load_runtime_settings as defaults may be applied at runtime. + """ + dynamic_config_store = get_kv_store() try: settings = EnterpriseSettings( @@ -36,9 +47,24 @@ def load_settings() -> EnterpriseSettings: def store_settings(settings: EnterpriseSettings) -> None: + """Stores settings directly to the kv store / db.""" + get_kv_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump()) +def load_runtime_settings() -> EnterpriseSettings: + """Loads settings from DB and applies any defaults or transformations for use + at runtime. + + Should not be stored back to the DB. + """ + enterprise_settings = load_settings() + if not enterprise_settings.application_name: + enterprise_settings.application_name = ONYX_DEFAULT_APPLICATION_NAME + + return enterprise_settings + + _CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY") @@ -60,10 +86,6 @@ def store_analytics_script(analytics_script_upload: AnalyticsScriptUpload) -> No get_kv_store().store(KV_CUSTOM_ANALYTICS_SCRIPT_KEY, analytics_script_upload.script) -_LOGO_FILENAME = "__logo__" -_LOGOTYPE_FILENAME = "__logotype__" - - def is_valid_file_type(filename: str) -> bool: valid_extensions = (".png", ".jpg", ".jpeg") return filename.endswith(valid_extensions) @@ -116,3 +138,11 @@ def upload_logo( file_type=file_type, ) return True + + +def get_logo_filename() -> str: + return _LOGO_FILENAME + + +def get_logotype_filename() -> str: + return _LOGOTYPE_FILENAME diff --git a/backend/onyx/auth/email_utils.py b/backend/onyx/auth/email_utils.py index 15d3fc8ec..1c84eb9c6 100644 --- a/backend/onyx/auth/email_utils.py +++ b/backend/onyx/auth/email_utils.py @@ -1,5 +1,6 @@ import smtplib from datetime import datetime +from email.mime.image import MIMEImage from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.utils import formatdate @@ -13,8 +14,13 @@ from onyx.configs.app_configs import SMTP_SERVER from onyx.configs.app_configs import SMTP_USER from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import AuthType +from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME +from onyx.configs.constants import ONYX_SLACK_URL from onyx.configs.constants import TENANT_ID_COOKIE_NAME from onyx.db.models import User +from onyx.server.runtime.onyx_runtime import OnyxRuntime +from onyx.utils.file import FileWithMimeType +from onyx.utils.variable_functionality import fetch_versioned_implementation from shared_configs.configs import MULTI_TENANT HTML_EMAIL_TEMPLATE = """\ @@ -97,8 +103,8 @@ HTML_EMAIL_TEMPLATE = """\ Onyx Logo @@ -113,9 +119,8 @@ HTML_EMAIL_TEMPLATE = """\ - © {year} Onyx. All rights reserved. -
- Have questions? Join our Slack community here. + © {year} {application_name}. All rights reserved. + {slack_fragment} @@ -125,17 +130,27 @@ HTML_EMAIL_TEMPLATE = """\ def build_html_email( - heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None + application_name: str | None, + heading: str, + message: str, + cta_text: str | None = None, + cta_link: str | None = None, ) -> str: + slack_fragment = "" + if application_name == ONYX_DEFAULT_APPLICATION_NAME: + slack_fragment = f'
Have questions? Join our Slack community here.' + if cta_text and cta_link: cta_block = f'{cta_text}' else: cta_block = "" return HTML_EMAIL_TEMPLATE.format( + application_name=application_name, title=heading, heading=heading, message=message, cta_block=cta_block, + slack_fragment=slack_fragment, year=datetime.now().year, ) @@ -146,6 +161,7 @@ def send_email( html_body: str, text_body: str, mail_from: str = EMAIL_FROM, + inline_png: tuple[str, bytes] | None = None, ) -> None: if not EMAIL_CONFIGURED: raise ValueError("Email is not configured.") @@ -164,6 +180,12 @@ def send_email( msg.attach(part_text) msg.attach(part_html) + if inline_png: + img = MIMEImage(inline_png[1], _subtype="png") + img.add_header("Content-ID", inline_png[0]) # CID reference + img.add_header("Content-Disposition", "inline", filename=inline_png[0]) + msg.attach(img) + try: with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s: s.starttls() @@ -174,8 +196,21 @@ def send_email( def send_subscription_cancellation_email(user_email: str) -> None: + """This is templated but isn't meaningful for whitelabeling.""" + # Example usage of the reusable HTML - subject = "Your Onyx Subscription Has Been Canceled" + try: + load_runtime_settings_fn = fetch_versioned_implementation( + "onyx.server.enterprise_settings.store", "load_runtime_settings" + ) + settings = load_runtime_settings_fn() + application_name = settings.application_name + except ModuleNotFoundError: + application_name = ONYX_DEFAULT_APPLICATION_NAME + + onyx_file = OnyxRuntime.get_emailable_logo() + + subject = f"Your {application_name} Subscription Has Been Canceled" heading = "Subscription Canceled" message = ( "

We're sorry to see you go.

" @@ -184,23 +219,48 @@ def send_subscription_cancellation_email(user_email: str) -> None: ) cta_text = "Renew Subscription" cta_link = "https://www.onyx.app/pricing" - html_content = build_html_email(heading, message, cta_text, cta_link) + html_content = build_html_email( + application_name, + heading, + message, + cta_text, + cta_link, + ) text_content = ( "We're sorry to see you go.\n" "Your subscription has been canceled and will end on your next billing date.\n" "If you change your mind, visit https://www.onyx.app/pricing" ) - send_email(user_email, subject, html_content, text_content) + send_email( + user_email, + subject, + html_content, + text_content, + inline_png=("logo.png", onyx_file.data), + ) def send_user_email_invite( user_email: str, current_user: User, auth_type: AuthType ) -> None: - subject = "Invitation to Join Onyx Organization" + onyx_file: FileWithMimeType | None = None + + try: + load_runtime_settings_fn = fetch_versioned_implementation( + "onyx.server.enterprise_settings.store", "load_runtime_settings" + ) + settings = load_runtime_settings_fn() + application_name = settings.application_name + except ModuleNotFoundError: + application_name = ONYX_DEFAULT_APPLICATION_NAME + + onyx_file = OnyxRuntime.get_emailable_logo() + + subject = f"Invitation to Join {application_name} Organization" heading = "You've Been Invited!" # the exact action taken by the user, and thus the message, depends on the auth type - message = f"

You have been invited by {current_user.email} to join an organization on Onyx.

" + message = f"

You have been invited by {current_user.email} to join an organization on {application_name}.

" if auth_type == AuthType.CLOUD: message += ( "

To join the organization, please click the button below to set a password " @@ -226,19 +286,32 @@ def send_user_email_invite( cta_text = "Join Organization" cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}" - html_content = build_html_email(heading, message, cta_text, cta_link) + + html_content = build_html_email( + application_name, + heading, + message, + cta_text, + cta_link, + ) # text content is the fallback for clients that don't support HTML # not as critical, so not having special cases for each auth type text_content = ( - f"You have been invited by {current_user.email} to join an organization on Onyx.\n" + f"You have been invited by {current_user.email} to join an organization on {application_name}.\n" "To join the organization, please visit the following link:\n" f"{WEB_DOMAIN}/auth/signup?email={user_email}\n" ) if auth_type == AuthType.CLOUD: text_content += "You'll be asked to set a password or login with Google to complete your registration." - send_email(user_email, subject, html_content, text_content) + send_email( + user_email, + subject, + html_content, + text_content, + inline_png=("logo.png", onyx_file.data), + ) def send_forgot_password_email( @@ -248,14 +321,36 @@ def send_forgot_password_email( mail_from: str = EMAIL_FROM, ) -> None: # Builds a forgot password email with or without fancy HTML - subject = "Onyx Forgot Password" + try: + load_runtime_settings_fn = fetch_versioned_implementation( + "onyx.server.enterprise_settings.store", "load_runtime_settings" + ) + settings = load_runtime_settings_fn() + application_name = settings.application_name + except ModuleNotFoundError: + application_name = ONYX_DEFAULT_APPLICATION_NAME + + onyx_file = OnyxRuntime.get_emailable_logo() + + subject = f"{application_name} Forgot Password" link = f"{WEB_DOMAIN}/auth/reset-password?token={token}" if MULTI_TENANT: link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}" message = f"

Click the following link to reset your password:

{link}

" - html_content = build_html_email("Reset Your Password", message) + html_content = build_html_email( + application_name, + "Reset Your Password", + message, + ) text_content = f"Click the following link to reset your password: {link}" - send_email(user_email, subject, html_content, text_content, mail_from) + send_email( + user_email, + subject, + html_content, + text_content, + mail_from, + inline_png=("logo.png", onyx_file.data), + ) def send_user_verification_email( @@ -264,11 +359,33 @@ def send_user_verification_email( mail_from: str = EMAIL_FROM, ) -> None: # Builds a verification email - subject = "Onyx Email Verification" + try: + load_runtime_settings_fn = fetch_versioned_implementation( + "onyx.server.enterprise_settings.store", "load_runtime_settings" + ) + settings = load_runtime_settings_fn() + application_name = settings.application_name + except ModuleNotFoundError: + application_name = ONYX_DEFAULT_APPLICATION_NAME + + onyx_file = OnyxRuntime.get_emailable_logo() + + subject = f"{application_name} Email Verification" link = f"{WEB_DOMAIN}/auth/verify-email?token={token}" message = ( f"

Click the following link to verify your email address:

{link}

" ) - html_content = build_html_email("Verify Your Email", message) + html_content = build_html_email( + application_name, + "Verify Your Email", + message, + ) text_content = f"Click the following link to verify your email address: {link}" - send_email(user_email, subject, html_content, text_content, mail_from) + send_email( + user_email, + subject, + html_content, + text_content, + mail_from, + inline_png=("logo.png", onyx_file.data), + ) diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 211ed3fc1..17e992582 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -33,6 +33,10 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int( ) # 1 day DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true" +# Controls whether to allow admin query history reports with: +# 1. associated user emails +# 2. anonymized user emails +# 3. no queries ONYX_QUERY_HISTORY_TYPE = QueryHistoryType( (os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower() ) diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index 4fe52c8e7..80545c064 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -3,6 +3,10 @@ import socket from enum import auto from enum import Enum +ONYX_DEFAULT_APPLICATION_NAME = "Onyx" +ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" +ONYX_EMAILABLE_LOGO_MAX_DIM = 512 + SOURCE_TYPE = "source_type" # stored in the `metadata` of a chunk. Used to signify that this chunk should # not be used for QA. For example, Google Drive file types which can't be parsed @@ -40,6 +44,7 @@ DISABLED_GEN_AI_MSG = ( "You can still use Onyx as a search engine." ) + DEFAULT_PERSONA_ID = 0 DEFAULT_CC_PAIR_ID = 1 diff --git a/backend/onyx/file_store/file_store.py b/backend/onyx/file_store/file_store.py index 9d86602dd..b042c8680 100644 --- a/backend/onyx/file_store/file_store.py +++ b/backend/onyx/file_store/file_store.py @@ -1,7 +1,9 @@ from abc import ABC from abc import abstractmethod +from typing import cast from typing import IO +import puremagic from sqlalchemy.orm import Session from onyx.configs.constants import FileOrigin @@ -12,6 +14,7 @@ from onyx.db.pg_file_store import delete_pgfilestore_by_file_name from onyx.db.pg_file_store import get_pgfilestore_by_file_name from onyx.db.pg_file_store import read_lobj from onyx.db.pg_file_store import upsert_pgfilestore +from onyx.utils.file import FileWithMimeType class FileStore(ABC): @@ -140,6 +143,18 @@ class PostgresBackedFileStore(FileStore): self.db_session.rollback() raise + def get_file_with_mime_type(self, filename: str) -> FileWithMimeType | None: + mime_type: str = "application/octet-stream" + try: + file_io = self.read_file(filename, mode="b") + file_content = file_io.read() + matches = puremagic.magic_string(file_content) + if matches: + mime_type = cast(str, matches[0].mime_type) + return FileWithMimeType(data=file_content, mime_type=mime_type) + except Exception: + return None + def get_default_file_store(db_session: Session) -> FileStore: # The only supported file store now is the Postgres File Store diff --git a/backend/onyx/server/manage/users.py b/backend/onyx/server/manage/users.py index 0a2e358ba..26be15f70 100644 --- a/backend/onyx/server/manage/users.py +++ b/backend/onyx/server/manage/users.py @@ -351,9 +351,11 @@ def remove_invited_user( user_emails = get_invited_users() remaining_users = [user for user in user_emails if user != user_email.user_email] - fetch_ee_implementation_or_noop( - "onyx.server.tenants.user_mapping", "remove_users_from_tenant", None - )([user_email.user_email], tenant_id) + if MULTI_TENANT: + fetch_ee_implementation_or_noop( + "onyx.server.tenants.user_mapping", "remove_users_from_tenant", None + )([user_email.user_email], tenant_id) + number_of_invited_users = write_invited_users(remaining_users) try: diff --git a/backend/onyx/server/runtime/onyx_runtime.py b/backend/onyx/server/runtime/onyx_runtime.py new file mode 100644 index 000000000..a77c27881 --- /dev/null +++ b/backend/onyx/server/runtime/onyx_runtime.py @@ -0,0 +1,89 @@ +import io + +from PIL import Image + +from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM +from onyx.db.engine import get_session_with_shared_schema +from onyx.file_store.file_store import PostgresBackedFileStore +from onyx.utils.file import FileWithMimeType +from onyx.utils.file import OnyxStaticFileManager +from onyx.utils.variable_functionality import ( + fetch_ee_implementation_or_noop, +) + + +class OnyxRuntime: + """Used by the application to get the final runtime value of a setting. + + Rationale: Settings and overrides may be persisted in multiple places, including the + DB, Redis, env vars, and default constants, etc. The logic to present a final + setting to the application should be centralized and in one place. + + Example: To get the logo for the application, one must check the DB for an override, + use the override if present, fall back to the filesystem if not present, and worry + about enterprise or not enterprise. + """ + + @staticmethod + def _get_with_static_fallback( + db_filename: str | None, static_filename: str + ) -> FileWithMimeType: + onyx_file: FileWithMimeType | None = None + + if db_filename: + with get_session_with_shared_schema() as db_session: + file_store = PostgresBackedFileStore(db_session) + onyx_file = file_store.get_file_with_mime_type(db_filename) + + if not onyx_file: + onyx_file = OnyxStaticFileManager.get_static(static_filename) + + if not onyx_file: + raise RuntimeError( + f"Resource not found: db={db_filename} static={static_filename}" + ) + + return onyx_file + + @staticmethod + def get_logo() -> FileWithMimeType: + STATIC_FILENAME = "static/images/logo.png" + + db_filename: str | None = fetch_ee_implementation_or_noop( + "onyx.server.enterprise_settings.store", "get_logo_filename", None + ) + + return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME) + + @staticmethod + def get_emailable_logo() -> FileWithMimeType: + onyx_file = OnyxRuntime.get_logo() + + # check dimensions and resize downwards if necessary or if not PNG + image = Image.open(io.BytesIO(onyx_file.data)) + if ( + image.size[0] > ONYX_EMAILABLE_LOGO_MAX_DIM + or image.size[1] > ONYX_EMAILABLE_LOGO_MAX_DIM + or image.format != "PNG" + ): + image.thumbnail( + (ONYX_EMAILABLE_LOGO_MAX_DIM, ONYX_EMAILABLE_LOGO_MAX_DIM), + Image.LANCZOS, + ) # maintains aspect ratio + output_buffer = io.BytesIO() + image.save(output_buffer, format="PNG") + onyx_file = FileWithMimeType( + data=output_buffer.getvalue(), mime_type="image/png" + ) + + return onyx_file + + @staticmethod + def get_logotype() -> FileWithMimeType: + STATIC_FILENAME = "static/images/logotype.png" + + db_filename: str | None = fetch_ee_implementation_or_noop( + "onyx.server.enterprise_settings.store", "get_logotype_filename", None + ) + + return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME) diff --git a/backend/onyx/utils/file.py b/backend/onyx/utils/file.py new file mode 100644 index 000000000..f62077063 --- /dev/null +++ b/backend/onyx/utils/file.py @@ -0,0 +1,36 @@ +from typing import cast + +import puremagic +from pydantic import BaseModel + +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +class FileWithMimeType(BaseModel): + data: bytes + mime_type: str + + +class OnyxStaticFileManager: + """Retrieve static resources with this class. Currently, these should all be located + in the static directory ... e.g. static/images/logo.png""" + + @staticmethod + def get_static(filename: str) -> FileWithMimeType | None: + try: + mime_type: str = "application/octet-stream" + with open(filename, "rb") as f: + file_content = f.read() + matches = puremagic.magic_string(file_content) + if matches: + mime_type = cast(str, matches[0].mime_type) + except (OSError, FileNotFoundError, PermissionError) as e: + logger.error(f"Failed to read file {filename}: {e}") + return None + except Exception as e: + logger.error(f"Unexpected exception reading file {filename}: {e}") + return None + + return FileWithMimeType(data=file_content, mime_type=mime_type) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 9bae3b12e..b014d340f 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -52,6 +52,7 @@ openpyxl==3.1.2 playwright==1.41.2 psutil==5.9.5 psycopg2-binary==2.9.9 +puremagic==1.28 pyairtable==3.0.1 pycryptodome==3.19.1 pydantic==2.8.2 diff --git a/backend/static/images/logo.png b/backend/static/images/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..0a5aecd7c820d80019f185d3419b1e3f44dbb824 GIT binary patch literal 6723 zcmYkBc|6oz^!O`J}9B}eyFscp8+2Ja>VFj4jrn9 zW8bwtdgu`EI8GB|?00y1vM>2dM{>l>tyzbvM3uk&uLS%SES>c9P4n-}4T1G|E(v?&N?)_C<|C@n59!HZ zI1(4oE)=)3lD3!j^l9$yTL_LKC_$cUmENE2>F_K}2vQLh#?edK29I}KiZyboNWjuc z?IJ#Q3~=T)>^okg$eojQo%72cP!LO?_W8M9ynL&>sp###uMXvrqT|KOjtD1Dy|XxB zWGP?Fz$rcw(&U`%YLIPV(7Vt&q=_)+3(w5jhLJNkv*FgD7hE%RzV>k~N<6YJFQVmV zCYjo(A&MpO>X#zr`b5V!(5e(J4CWw!x$|rp?hRVS<>;r%@J~jyy-dG9+?2gWPJ7 z$jwKpQr=0^Z##E0Z>q|VqrA$WcY)8L(Hq23E_WjD7mP3Rwv?(Ihr}fyUOpIG9XGCU zJ@{Y-M!Nl;bPqq}W?eBv%gDhvF9`JGs3t1MJJdLf@H`tTs&bU_kjS$V6m*e%y-)EdG2>oL@BX5kBOd@aaBNNX^Ll~_0|)h2pBF0vlTRS2!>Lum=}3gN%uYZ zDvg2dCZ3NQFSE|k{9&CF0vo>kZwUDuunzKDITUQ78Uc*%xSsW6<1b#P7Y7>V_2RrJ z?FuaEW3BEnsl3|iG(*^NS;`NO*`}fME@w78E=rBeJY?DaIDRE>b5nYi*GE7ORv`6y z9Bce6W&>*boZ1>%cP_rdlZO)xrm1WeRos*G$tc9tfFAKQc84AIX<#eRoAg6V()~@7 z%<)BJyO(e81ECG5TwL@UmUJzL@ncQehRZ8F5e(nOb8_)cxwdZQ3D{~0<(MDW-=7#) z5Vd7V*PeYXfK#WLU7Hhcy7yCcDla1wY8TQcEd*NXC|&$6=fY19gEajh^iD9=J|)o0 zd)G8wlGtf!e>2KG&-a^~F5wP>cf(CL`IGVM;{*5b)^hQiBK&W37Jz?QgP+}>^KBRg zsLYrJ7!g@~Z3SsFK^jWtih_G7rl;$qjdEK%ge_6l2kKZ5hWdGeFtc~Pn)?>kn><`mRcDbpQ0f>jWc-pmJ`O(dl+QDkCb42^qh=I&Y? z6P?LyD}F*#aPw#s4j;XXknZ^^3+u@3>r(KL-^gv+n+nruTC>chCpR|D^e#+oS#_pyua>^jG0-hyZ(H} z4sGUc5EMF^Q{IodE?}n9q1a_XM!cv_8rz@PG(0MkYU#mSB4Le4KZ|Ni9ZRWzgquj^ z=x~>5ps~kK_-_rHq!D5_$hB3bMD=SHzQUr$i4Z;#8IjUNjw{?{%svb1^ zuU!R!y*$klGhnGmY175vm-%k#!2>v@dVQ zuCQ@_tmv`2CTC*zufOweb7Z5=bK2_BW&%;Q5TN$!k<6ucH-kvO;JCM)mW`zh-@Vt* zhK~^puB7{4V8+y(t=MaxZ27rmSB$Z)?yd6^pg)hW=fmyzCSPzESE$KvhjSU)q(`eg zD{)Tq{cN@Ok=NDa9h~cCk4*H!(2-BK98SI-SMbi}CKzBz_S~gykpWu>Pk)3zPmEX< zD;qFC$kc89n%iJDG9vxzz?x2R&BenAQF}23c~87~q{8ai4odm=_l~dKY@E$ab=QDZa{aG$%)9 zNhB;~XG)x627Y6P)?UKW6Y0+JcNhW^ zH*PIlSbrwyq{h*Tc83ONkd@=Nhx5<>tLgplt>O|4Vsm(8r)cl5GfwwLa}u0Cg?xc%aH22; z!O|g$aMIxb>(9&vH6@00#px8oho9X++Mp&hcCVFHK}n|73&N!2ku~(xiMB5qsx&#x z5)4#1_3_BMP$w&U?oW=*2lrpo;Z<8Oc4s4w+6W89uwI1{T~eA*yNG(Tt3gHAB&G8g1!WadUx zBW^#YE}0Ugyf)WG$H0QJc%vf%7OHm%gfN(UQ*Dasy+>c&9La>?#(rKeCQKd7#m3Ph z6fZ9IeOv9Q{4HYX+we_P#yeJ|<*aJPyIHjO$)eU#!KNmuEU}YI@G{kOc8gk;orihD zXBhmWIZh1=M^c0S^~7ic#J)ZZjdwmhlB45u#%GVQbNAV$?2R#or7^L8GBoJ{ZijB6 zZsc|PEZN9Mh)s9-JKTE(M%^v-RrmUP<4HR$3SNxyQVcwOZw1!YuU5taJZ$KIUn;AN zl~;5;n3F!#3^AOhbI?q(F5b7d8Nrf*i=-NLSD&1L*FaPgL!b89bYTO|b1uQBu0r8~ z>`SUHgS~c01Y6p&L$jQ0=uH?V`s)xF~)uQ%C!Z3JZeR~%6By9euEj& z3LG_WJP05P%eRSWyldHtIuMQTEO|7~=~n5F<|y&|J%cns5TODxY)FHGu=Tbso9Cv(;2p4VK z=($)BKHovm{dtLwDbab03JKo)ej!9vCT4F}cdpK8T88G~YC`i67h{lEhxz5hKN&S{ug6uYEJt&GP5yWZ*&!NMy!kyiDpHI=gOFsaVNLhm2^0>Y zBAy?SYV^FKvaO%n6w=fM)10@)5~@tEBve2 z!&;zg%9Qx+-Ve#qrogXFgLCVmXob&2B3+DNAONd3AX+Jt$MhTSuO+5!b+fWBt=T{W zTv=x|RF{-LuW0g>)ZkPRwRJeA@S&5$}ENt=w6Q)}arwv@~2GW+0NRuFH<3sTaw7Fn%c#LhEeb0Gn^akh5!HDD!~kRw*|F0lk)j>#u$(9_l>1w zOIpZiKIp}a-rrSv?kx}}|F++xY+B^bTyX!Ku9R0ZD{HPqlKv1|)>3madXWwc%U|sI z*@er9@Ku*Gx(Yp47x#*toFJ7k;@hR8;ep?}%PW}sp+fIls*=4VT~;9T_w9!O>?vEc~aZH2u5dGPuF=5D?BZ2?31*R%|wG8N99 z#{QjJ-faOhWx8?H4DYr=VPWUNObuq|mfZUv;~WK2qdC02JLsv-EhfR56$Gzv$O@eG z>jIOlKvs{KbC7?8yPOw1l4z;lwV@fO}Du^MTI3E zT`#7C9h!$2eOI3h!zCmUX{1dFA! z5C*Z;%U7b{4=R6h#~qK~dfs9A9yOXX#NiWyTLLAy`qE3=8MA*Y_44Zi(F+NOX_xK4 zy~OZZQf(*0|EukrL2YkEq^kBDG{5aQ6M$s9YztK?=qOvi{CaeZ5#-@9Ez`P8q(GKG zn#G+mC5o1JXsC|(hV(tJA8VXriuM1<#t#aDYeAj$$M)Z_Wt_x6LOi#SzkWB#KT~~c z{Sod>6s=^6Jq+7ABi-nkc`bT1xuG0#bX&FOyhU`4lnr zhpE-g)CWGqnsXqYB^EheQlR0y^~I^m-u8W2*u?8%kk)QGOttsmCHr|}20mgh&x&=U zx(+-@bDUACI!bzmAKMCqq2wqf;?>|bW$cjMPB5@b575&OE=sWDy4?ZiG8?F3AA9jq zMjGC^kra7~kIN(u+#UJyEbAdxUv=CKDrFvDZBl!V?fyLg=Ne8rvB?yVfz7MS|6nZ zu8#In=Jhe7#y)X@@R)&QZoV$M%xKP_Qos2dSc%A7kZr8DY5b_M)xOv0l5C;Ye0@7% zkR~G+D?*$59SG$m7~pN)!Iv*lgqE zoeb`Gs0@F2p=cjrV?Xy&7Wkz|F2Q66-Ynp=x>FAbLd*sk%kcPoIhD_^l2l~q9()9g zzqVTLsi_zD+6a`v$5yz%@5KH zde1w5yJmm53;%o%fO%@SI^D~%9!WgS=$=I{q5v}-2T~zP*?M3{*JS4sXh#Erd=2ZJ zwgo3%);WjTs^iltwA5k@J?23D9i@6pgsUEr+r00`DNSKu1MmFAh!;u z0m5ip64WbZbLv!OQ`Jv{EMqe0_y1Voe{7Od?SDz(agnCh)l`Bx?6_L{S}S0N8Q6C% zZ$9dq0jhPY2k^%4@&NedR@%Gpcc@TX6-mrOnGfQXwglGs0`iF~j{6hvt-Hx+qO z`@Nb&>wKgHHcqn6bpX`KxXqpEUYVum7&)G=!%8#MH)Z(WV;tj)=chFn7$}_t$RSwj+%U6il#2$ zjYWR!OZNmqq7rI8=iBbx>unmk9TF#HGUf5W-fsgJwGcZ4--A5qygCtj9U!Y5@ipWcrYwFY$1#t-tPX)ROGN%!lhu|MQXV7+ecLwFF! zWLmBfOroJ>p~B%S%uW&7T2c~2eLErjT_Hr1p401bL9Sa&(e^C0^@L-z7wfC+Fcx11 z+&=Gy4;(J0v~{|u!+8q57u+> zr77hJX~9&LrW+1?cV2u`)Ql6EX}0|I<*tR90YHji=JuGCfJ?Kkh+6&22F15y+({;y zy?X%g#TO=n-nBn2^5|oa2}@6yuaHb!dVOX|Z)lK95F8s6=WQ3OCiix(w&AeCzo>XS^!5b-9he*TA0|MCV$RDEvmzZmwRG5zU4hJ6c zMItSQ|0RjxgE0fGK?q@-EZ~$Mo#z0HbdrY3QQk8E6m_!aWnCG9Yim)?Q9M3dd0hrp z8ov!=**2RE(lEsSive-9hD5R*$$Pu3nyf*aw~Gi|zj5pywId;cDu7k+czPBk2g4~R zdwdRY12mL|$mNk6$GY%Syu%@UU2~ut3F$=bZCpSCK7?ePBL$-=+`v5DbJDvoU;W~U zi=sg8BnN|7cKH|`)LUKf}F&0Oj>X((p# zCxz~P$O0dQ%eeIwK#0OPz7H2y^;c6KY$+pgPmxo{`5eMs)zd6hzx(Ka00;%j A_y7O^ literal 0 HcmV?d00001 diff --git a/backend/static/images/logotype.png b/backend/static/images/logotype.png new file mode 100644 index 0000000000000000000000000000000000000000..509e2c714bf72ce7fb68aa74b8e8dd63d7127cd2 GIT binary patch literal 45427 zcmZ5n2UHVT+a5qHbQJ}pExIaEih%{BqY+e)8VemYgcf=a-Bm#8O6W+FqSAW{MHHeT z5CjD2Aclza8bkdryQ{n3Kj(PVxtTj}ecpEGiJs0?c9s(?5CpMbzjnm{f>`t*h;iiL zesIT#6TkP@A?IsGZV=>HxA%vkZ?u#L++=VwxOy2X>J*p;|Ji4KN&6B6l}58}-D3uS zAH06$(oJuM*^!R8Q&*EEzR^Yx-}OFi<8GC~N?OGITeR!@w>YQL{Vf>P&bOfgRICC? z?j$bo>c6ibok)UxG`;@-EErdOf8}6(vjEX|J~v4J(R@W4=)pk{=0EeLBw=` zKS|ZCq2+_Fy2$BYn>f4VI4@5nj+NU_d<`A2nt${!jn1-&?_#N2z1|D3Uq6S)7p&&z z4c$I|O<;kA{(D=gJ5iSPtm=*FL@a+W?$`ZhYd<&A3~7YlL^R6J&hbUekzY$ScL1wW zUKsmKuI1VRp%t@VD?ztTMA6|hFd%|WRPL|aP(FBd%je5aCo5TC7Omb+zgE6=^9ILS zv^#JcF#f!!w$dm+*1MU;dckR~G}6Bl2sX z<~#m8!Lbb#Sr!<>&zs|?#w_P=+uZQl{qGy^Qn+=Br2qS_W@1EA{o!+|re5p+vBYK; zaF?@6Z}@C~`%a>8G~d7c2_|JZ8we%**N!sojS09NMb%L%hOOoe{}r9K-bEOdq$=%o z0l3-USJ1i0r273;28m}K(8B*QiJO*6Rw-3sOW7l{(mdZeez6h6d_(^ttbOXgtkwMK z6jdufMXR`FbJm|qB|X>s_4DxpZXHe50qgmQf2kl#Sqa4sUtUUmk62G)g=tMl{v-xc z8uTW~IvneG;+N<@#W?=2L>KvTOvbo_W4o>YYn?$fT+4?PRDiAtv&+dm%aq!;=YDN< zKF8UhFV%l4D*xB5f_L5|vTgon?XLy{))LtQzhtPE&_#qJgOvuSW%(6t@l}aZfhQ@vJ|Ixsw@mWJ6lW&29H=S3F|sGSMu8wg=jKSB)q$t}7bOayFAE&{Al|EDXl^ zlhW~7ZoSJxM4Nfo%FB5XwdnaMtS`yw*Qz~OP0^gNgeCj=?@D(yjlfmW5-)$UGM)sM zO+rY)u}E&67GpuF==H{51f4~y9dh7}uC3iYnsHHJWh&8PA@CzMM&3iX!TtS3-#i=}$=Pzo)!BfgYEN3mocG;X?n*k>4e?GW) zokjE>Rx%2kR6Cj7a(z6NBk8K|*iYvOQq@;G)Xw)8G2_oTuyUiDzhVZrUJE0jex<)#1+1dcrX0oTOlXtNaqXpWU&`@1ZZ{g ze`KQYBItG9*UMibN+kbdhZI-087REe@P?Ai-WNBtoA%bgH799Rj{_moF#}`|VooPkpPk0zuxG_`MnXKX$1W%&k}EiJF3sAfcK^ zdv@8Kv>g5OT{i*o7@5C}7E17Xuhh30x|#k{-N$o+o>eAEd8G<4H5<%l{R(AJsz*5JUgW?62i%$K^ zFi&>Sv$P1p@bK;-wfg_DjUZM%wd*s$Il66^(LmX&-z0v;fxp)Bfot&fZ10h=QPp0* zH2RYVSafNwqp7(N4Xicj!j;9x<>**y=I{3*mAY|OacyYEde z%BxTO4pH6S4?kW01z^H@Y9F4pp&aLT_XTOmTS{aP{~P7U(-<}5FKB>=`e|3EN=I`N zf7!v^PuzO0;=na<0N}VF!@hqJtob!8<7@XXz<_!!y-2!AtmpnAC+L{I+J_vrzY%O- zMnEV@5Vr8srwaaL)a-Onnwr|tg={(|66FY6Cw@U{28KmcJ^?@&!?$*`sUqC3RA>b3 z=!W}SGzcg|*NcExo~r(-hY`TxE08C4nqD~sVHVrrW5O#}|5*&0+O~MmBM?R3mExIb zJP57kbKlr$B5bMZEB(yLM!Ft6bK{K~5Ll6dx}x{MZ(iw@^Us_LIOa3A6H!T0JO&&L zL9%mE^a()v-&Y2bMYI+-otOfW!VJNjp07gy?EYQ~x&?^J(FSl5#SpXsNYAlZ){dC| z2OEQK0fs8^7eO+0UlTG+Zn^%?j)wBgIMX3#(ivcahHdkS=~2IYuf&c;)QIRz0=(Dl z@<9@|UL;Oh;7_@VW@NMUz~w*=(o7!#e*NpQW?ybSgIL&#%>o>9u2bE!o0NauPbCZ3 zT*E(fc#8$c{zFIe3t(_oZ{aDZ9Z!-RLy+zeTBzq=Y`JQncAYhZglce_2iAE1e<)FU z0!&{TWW&va|4YadZoTf;3sKm(TIti!_U2?_H344>IQjb;Hv-Fb6Nw#gNhY;u;D`U* z=gACuW)6%Aw!3F1J^%+U(zX6sgCqdhHtz0v94DG#$qtYLWRowi!&1dibz{T3KkEW#86fG{ze#C+#__j<|MTQ=U=A`T zfHCaIwD~fb_RJ>=-@hUE&!Qd!Gp=AD#%r=66mo{u|1t*X96)y~B)0)$DiV;=<~}C4 zV)B!Ry|o>ClIb-7#Ym7K{f|^KG(iaJ2|>vc@N7`D3vh8PNLBxNRE-IsaRNvc{+5jj z`hWEB2Li;5PC0^K0u~P(*b^YmKhOPf1FVd22jiL*0NkhwuAEtU@A%JJkQE3UJ&^=F zaX9HnhBbg3hpCnisz0^S1vp&v+p`Pv{TY&>0BZL_;6Km0Yk<&4b|%fBc8*UrJgE9# zbx44X+Y_T|fu=IV|0A2*AUv4%xsyN$-j|^R07sXA|DaUH|Iem(j{znQb0a+Kt90m#i-RJ!Ltilzg%}Ox|c+VEGPiat-%Aw2H^kCt|dNunZ(}?W7;MG z5a*u-;1EZ?!_o6V8{=3I6snU5?!rI0Z@vn06aO~+`P_ZpVWBM_P4UUJGUYprJS(L~9f95sc>y zV}H8!1fn`=fwNUrO87hkafZ-#1XO-l?|SZ1<*?O9k$BR$=<;%Jjt1-IQm$ z*oCIR-E`qE@;pJFY)K2D3%f>JL%zHI%S-;!!do;>j`Sa6_+7FJJ|Vz+*YC$sK&kgh zW}F9rLsccwV5wrJbh#-r`01{*!LdqhG1#MR4-r$5yil`(61OA1W|uy;T%@J^yi-n)q z6=Ih58~F|TS`VVTIp6hE66b$!0x;^|C6f1S$=<;8X?wR|7WN|I;%lPIqS3bQJ zb?ijwA8b}=7&BXTatZW*Z8qxh+OH-7Seu#`kifn7!}8zVc3em)&9=Fu7Yzic64tup zByBNyjOg}S*CXApBpg_VUJB>c;Hc6iheP$EEDYZ|Hu;aVk!Z%9)Xegg;-rNaePw{W zpmLbmi=DOLStjzG<7^Ad6qNRrqS&$Mfh+w?X->rc=xPzD`594q%$pQVu$NY4?c2fE zwQWhSY}|n1`54t$N1F)qJ$?Rpq4v_nPLX>??HwW#Z4$xDM1_N#Vzpc6>&AeUg!e-? znB+#2=dc4KWIE(c(w?F=8K-Mg+ObU;-EC)seAo#1C|`Sr0qThVo-rlge?(h$T&`&e zF^d@giVwm2zLMKlll`}^Ql+c{?ybt8v-_F|HsVlCcIea0z;?gr+s^&zW>Zgdn%>no z{dgwDpw93KJy+Ja7QB?$WHpLYcr+Nd7V6sMSL4*3zFe@W(0iX}rnBpz$O-2XN3YV+ z%ex^pef?iz=yeEnVke0(|3!c)0p6O(|5WR%Y|)6Rj(nq#ub3kvlsfb_eA3~2phVn? zRXWymmUkS198^d*A)T4IU^)7uLDwWbdP{!OPWHOQ@Z=;pWS3L6*iT& z9||xg$XPnxF_~f3&IV=9*JQ_)Ra7sJE_87uv34>;M{zl4_C?~6+zOA;szmJ#WS@cK z7nS;4dz8D`3=um-T&fDOXRIjTzEthrak%y-k@i5OD_w+zg886w)0${2JeAVCwikoj z)wBCpQ{Q(FwEZdBFY~nmwFVOTG`k21Ls#Q<0MVh%bvJBhOda4UtLVp-1R3jedL8`H zafhD1r9QU+&aSfMou(WN;Gmp+l`8A-JhVaMsb@}={5I34kq8XGl%(YxlE$q2HF-e8 znlIiZwkLWB4mG20ENx%yOelMb4r5m6yQgVFwGJQ17Dhk!O|J^rb*&w4YbcXJ4Jx$x zka;t1ZB}?R9IkbjeGmXh*_x5hmAP$0T2blMzAKMZdf)DY4EyYaf`_5!+&18Nc4+bR zSSwAi*4}Y)bDuvDs{Fux*=YnUO^kU;e0!dvP#fusK2OA&_Gy{OSNV`-)xskdj?>+V zwG4&)_X}&|?N^TNDvC{@)2O%HGcD!42wR)*F`;tSg15qitr+^}U6m;RcxeiDgtyK6 z5CfDhWo3SaF0UZ$VuYMRIS8A<9(bJZzbvUv$H7*X^rb?Q<79fOGmoS2ijR(KGpOVg ze5@`0SeIhBng2$_E%HDdpKIE8!$S2}tueJP%W}WFJax;N%@}>hq_FmxN++zzLd~1z zxuaBlqpOZ6c-&tyz`#=~7Gxc^JJAF7FJj896!JdBX$SSq_Vjt=R&Nb|9`kab!@G)L z4$qyGtyyl>bzoJ-1NVQI+N%}vMp<9HN0oOdH6FH;qWBSNvhG5O4L6>XUzBy)qb^0S zYunD?_g+wl-{&w*69;a<&tZ2o;v7s(Ufq#bA$t@!F+zCoj_i}0qS@JIz$?a#NoVHj zuA4MrUP*(NuTe_b;fAL!0DFnaGiWt4OXv)Z|^|{%Db#(`TsL zCDjk=w3nJ;Yk3jV5*7ySI;t-}C9Vk58l4$Stm>4mfo_+%)(OVB#O^+9-+fuz9__&m zWvy+Xc1v;M38WFmj`De0*@>~@etK1rWCLaUx;v^hP!~>2sfqyh^<5tMvQLG&+e}c~ zZsVO4S<4ao?vV93LUwp+i`sk!D$&XD?q)T&6|dI+{W3prK*!2N?LoKLssKW_&iax; z!STA6pN{h9WwOh9N(2^e#rPVH@G&`j&MS1w?;j4a)2=0jHnLZsKu2oLObKl^aCPE$ z2R_#%w>S6CL6~fB17kAi&>B<&QMoichn`}q59k02 zOB3toXD6Q@U#q`vLcK-q4;Q(a@28fqu^6hh*WSS;C`n#XA<254{g^vsI*(6egl6DZ zmaQM+ubV6gRbxWhIa_?v9Vq28r%qcY}A^^rah*8lxndb0?}^k z$VVLyA2@mx=gl8{QtdgB)=}T=p`DaYcONrpD+WIC$I+l<_N23xB@qaf-NurwmFmpf zW39g8-z(W=MEe)(Q`HzK9;O>7cg~qJe#}*svcT^sv%utwUG3viQ1VsWI^AVjs)&fv zL>9H(5E82)B1089HMw~v$|XZ6=eKc*k_gbNM%fP?nU^YVHwax|>?P7`3r6njY-;1n ziw07>+g_;jQNBB0GeIie&J#Os_JVo(Ojjt}ZF4U0NotJLGnb?4lvPxEd%D~Zq}CK- zN)0$?fLMp;JfiHmi7h;(x3C~IVJ#^>Ag|j?6P{DqG6Pd*7Z(&(&NB@c+|)ZeS)b66 zSj^9}qex6BXYWN_Gok8|`BvBz^vp-|O+zJctDx8#llJ1+f-PPQMeRJL*9^SVvEH~q zQK*k=`*!(|T6)Mgr4My%NmqRqK=a!WIJoIs$;{&vW2OBsJGj8R z+}}^;_QmZ+>W{9SkXD_%#Lm6iHxZhsXYN8zW z#K`3t?db}oeeoW#Rh=p3XGCtLKHrA=+><+}#f zvx`|@Yv32?pBhSiz({Q8bKQ=LLAl?5;pBzAbB%6dm?9NlU(qf}2x__RQ~mB_n9BiZ zRc+~aKYOB|x%ViO#J$J9k`8=QJo@JK3~kB7X!G_Qjx>f$Aq4l5`Swqp`Be0Sl8e3g`uz5ohMvcpew71M=Y4jl3^7?=06tBki$JV zissHZJU`W7l`h@=VA<6i&4&V^M%i<(!Ld1M`2A*=LLQo?Q>^B^YsVd@cprm6}1h33=Mn9Oj!0;-{ z{_@dvqn*3(KVvIXvSP|A?O!+_aV+1ey~qp=8ob>`W>aQgJQScx=M+vhwc3~Ahbq<- zk7DhK?Oh_N;g@553}m%^oes7qclUX`sJkD%)aXg?F^f5S+}g%$eJiwhV`Vv7kDOk(KAGPKiku8p*qP|||js~)ZtM6CdGorh@uIO64KOx<9- zBcq?qd()dnJ+*bYI8P#a>El~wKl^l_&QA_dhNIjoY9p*0<;a+q)ZJG)dSBM4O?z<5 z|B2FBxry4ginSzK8oYTg3oQI5j_h8ll4^!k~z}GnEX0p zyhdc~6r6lJWQA@|Xyx{Lje%v12{26m1MfJ3ZfU+>uS!Kc?j`c>2WS(A>FFI zU_h4-D;bBczgpO8V266CXVylf)qTj?<4b^2zv@KoqM0)eyNy>mpGw^#X#MDM=cbZv^;>c6x@3c(=Tp{W+}xn#`=`8J8M%?F%yGb7>Vu9 z?lTYSkSZx6&dDOViG4BtsRZpL0~$jmyK|Tb3ryHI5{wRy=q#P|NFQt4fb<^Td9@h; zE33l64bu+OQ(+E*WB|kM=f`d}90VEa`V!*;H)?Pr)fZF_{TH$eF$kyH4v&cIJEuQc z@n7r!!Mo4519zT2HX-0>Pz|rsYk_ejNxt6=0DZ;XBri__u;Y5v&cVf-#Rk>O_lFII zq4?G1L^%dKgEY>_)sTR=tYimn&9_vxB=TWq8{KlnscJpeQ`!>>Z6(01N?DO-8Q@IuTalZWH0=YJy{P-<4#sxwuTsn7!Aj7uy3YTJTBRBk`ziO-y zPy{7NzpG1lk3#loaqFveoN@4PYY=W~rhr~G~k;eCznKP9M+(XZ2rmq^_`JG1E(s%1;z%TiLk!4Nca0_6p(7$@*f zvp4WcHGLIkRMM!W#kpx5rc6$YSwy>TMKMU&K3=iEUT=B_8@=KG@>cD_$bO+be3M{H zH&JFS=DdL#T`8w^JI3=H3+TBgVBOaO1y+z7YkC)r!ZhT-5T_G;&E$ov`F^M<5C3RK zUhRCuYREb}Xgue6i)IP5pm3EGvzO}UQ|dEsp$xl^EI4|>iS<^f-u6WgKj#7G>&U^nGq?zbp!V~vv6Om@7VX;MBxL7E zyzC(tK9H_C0wGdm-g!$srCJHAh2 z8O!wGG<}Z@rIDC*mP5o#xvzyJCS&NQFHQIUh&HW#zQ{Kx<<%!v%jyi1v(*g`w1x(y zECb$*pBAckm0F+UTR_#YA-u&Z?(>ckQP5ZElgJhL^iB~0Xg+a7%r=3-OEIesT~J+w zp_MTQwa$B-3-`0oQ^VHX=6PtrT$BA44vqmB1f+Yiz!sK^oLzawpXFB2A(gf)PAEb8 z%F^$v_uWVisbszS(2s8=8B^bK0gljkTV?b8HILw8wojHu7KU0IimWJ`$I7!x-Hi9HLZQj`sgX-c(UfDwHQzBr6W+#%84u_=c?V~ zc~8CkA<~%cgO&5NhM~pNXk{NByNTx_sr_8RAG9`PJVa@AO`*wg?C4Q1ecVQ|h}hha z*bN3s54-Jdrgg*tIl~qax22B=5lSU1k5>{QY&RW@ICJoto8||gPSoQs*T0R#(3NQS zg}&>Xb7(V2yy6z!ednnXvb!~*aFV7SMjdN4vKXv*(^`Oc&map%jGzy^fv}4f^+R>+~c?Fql z6G-BUnpQe`fsaXPFHmldN@bhpz|^&x+AAr95?REcf`)I$nPVRxEZaJEC%}#uWXjf) zZms8>r|9bbXgaT$5Kyz*H1tmA%u8)2G8ctDQ4clu;5G^Ci98H_An6n@Im6t)bWTCH z<(W|0gAtm_?Y^HncLr!Q%D$_>U9LSzO)a^|vv%8XDO!>4Wb zo$Z~Dq33i18<>#Ztv8$MTi87djsV<7Y;1j5I|nU#hn;;ui^E)+)=&g}wS**#KF0PF zS>+$8TU>Vf^L#A+6Yj|zVlyyp+xTUFJfU!a>>_L$-o0vdQ0rWO%L0;*85iAeLVbB4 z=f|5!=(`7)w;I&}KlKX6)o%vgUe*dqj+|b(M*m-rQ(zeVQIS~(8ct!AK%m&7hwo){g%@>ZC2F-FSVyH z%ZG3_uvcDi*{@^mM|1Uj(m3obNgCiqUW;KaoGHwvJ~U zLv94{DYdnzLKe5f&gMaWqX1K!br4f-#9on?NUVxyz|KI5UN}5T_hRtUNW0y=zc)=; z>fBx)kXVQbQ(QWiu~jFs-88$TcK}K#4dIBCWP;Lf&p$*$jgF5`YF-Q#2|JtViHHc@ z(u18G2sZwPXV#@EZsD|dmVTR34fK+@Xx4>;;8Vvn=pU<{KzP)`N@aiS`|^A?NT<5b`munp$;gy??&ETW?Z_gRRe3eXBy*+H4k60!BK~v@mVpDnufBTFF2o2?Aqdh`m*r9*RVFP;i5QI5AZUsz878a^5Oe1_|TxjiHX%O890c1b`u9K z81xc)K9vL@@OcGqs%>K!iP7>&R=_QkTDc|5!r&|I4$#J}q80ObMlIgS4JNgwHk(EL z*>3B5H^McazULlCg6g+ZVtcSvbRn5tU-zb>*+{}24KLyD-pG&UWlB@6Im0^Gr_^uty+Bjh9fPPwS2U5SXJM8X~Rt5^aD z2$eWFm5_mf;Uu>Oe6@77vYACSp z`x46CpCLoH_)Bd)cl7G$y@Jzewc})WE3-~x1?~l@=r=h4Z)6UAzvoU{PX60j6)6*e z2Zar|7$FPUcOegTOFwgVMT@vmKZrmcFf-v0NX>%i6XqrT0d#VT8SzeRTD9#y_EB{4 zUb`^%{gsLL-D%w%w%+bECT85i52+1XRK?VRAju`RC0lf9APZE;pXGhz@~JZ^SsyyOJ(G&{&Ggh>b*!Usl2;@)phatqex<0B z$e?vLkYra-Jqj!sy3#9)^KoB3$ao6m$61T_IYv}I%bD%I9$J2sk76cg##QX}nEEwG zP|6~uE;Fdx<^$ne(^m5<pr`%*$X7zk*M^m|?8}^azLn)&Q;D)$vLrDhXcK z@&sLW@AP3VS_!9pANo81K@`6G2X&WZJ;xY&3%GTo?uqZbuV<#w0t3e6X!J#E8};eV zrD$gdB%YMPDLg97nNA;Lrv=vCakZ@{Em-%xPPN#)NVd}}iUf06r2<`S+w+zqnjgo_ zy!A%Qw|MQ=hTyx0C2|BdM`a}PhcL^@tfJ2i^8jHM-3)h0rj{)|hNjFf6B72?dIN~5 z*1_G{UHU}2^43D~2Fvbh~Hwn=zJt0ou5BKm4bWTl~G`24P7VB?*g zr_{0Y;UeixcN!0-I?l065T_2tc7NBCeT6?%5T47@4qG_JG2t(JxxdhzI`qhme{y@a zdtQo1&(pSp87JhblgD}*!mO{r$AXbRx)16Jx5myXSnfFPSYYq`fqlOmfua+oq*Fp`s`EBk3W!$>}+Lj@88?aUFyzGtBlvBjlDDU?b^z)4wjVdMsN&#X z1nP=3OnQ~d%4k`%b)%z^ccm<0_AL@8F<0nUg^AD~&Eo@otzndGrU{rL-z`uE3028) zP~u;$j)6m_J?-Fi-{1?;f|NgyQwbG?ep3e4nwO3z>@`U*eQ%eYkJFWD+jfbBk>fDA z-Zg37y4T2U6&@<UEkZKAG-b!}_LYf>c;~Y8B-S_6RrkDa#$$r2_prgnTqYiB~{se8u9)vC17_lIG`2;;QpVsp4q6UB(HmI%t~ErsAc*K}6rD z*at;9n+QoTO5T;Zp??jgOO`by0+BHr$#MqAn*W~G&r55ow2WOl*B=?LlApHXn_&Tn zRat-Le=!bpVX0J*We#qb`jXDLgErs+6o!up+Lk)<&DChL1{S2nf4?;FiT$MS@{u-P zi*vf+ug)dFfvnWxwdg7@otMjUy^qfVqc9x;I?1O-JlKh_ol%;FRZra9lNRQ%MKzf3 z+d2(|7L8;VGJz&F7_Re%F9^uBtwi6Rw31r|(gs4@)yQMPAT>xm-n4(U}K z?~hCGhc?_&s^$seR!5=ELcD1ICA!Z*o7a~U>_9>mWZBdqAR&fY^L3Xd({8SX>pmU8 z5yQBQ=>*YVc?V zt!&CWjIbJasH`M2dGLQpjgch$C|ChKb*g0lhck_|Pxk@!f?)9N7XUi0)_*A(_hG;I z^FU>a_C=nQ{F$iA$I+l2XZ_ivD$^IJMz1P$FR~3)cnCYnH;OW*lEG;9Kt-a#R@0Nq zF?|uG8o#GUu`5$ScY$mnk+R%f#r>OJf{3}!dc1aD{_tCqjRG)D?BX2RjBn@aR2Oj` zffF~$=HS=lrE6{g9rUV%6fZ;HS^*Dl3aRM~;{w3^)|y&}+wL38fRlc>=oQq(1ek;W zx@)cB`67z}lFgQjFkN{Dj3d(TkolPW_vWg_3EsNS0TgU;4(hJMK*;^a3#YZNzDPL+ zTGkKC);^bVs!0-Ad*kQdC|j~Yy&>3mjYWDo5Ky0qBZ3B;QE&a3Hk9EHM72A& z=W?$_Jg3d2Qa+1)&YwuEukcVJQh@F*0EW#vQkk(BBaC4VSzMRe;{je7!Q7A4+@8O|Nu)?gY~y7xtWS5X^Jk!F`DpQ6d0>+~mnv z=@4^V$0@)j8?pX|QL%x0S2wVqst2ARCpXP?wgnb)qd;8M%HI(Pr`N>s?KvE0T`*Hw z|CTh=l?!tiaRaac34yl1*`&#Pl`yqF$Hy&SEAt|Ln{r1EZr~}{0fJYJ5WD`B0lO-r z>g9#87HUPvx39YSpmvZz)Ku)%)@kq=;P2$z>e{$0v>ka$=B>b35bec3Cc!8KAuUB^^!)Jshw-GN} zy&lMJ%ZT0MA>;a;AN!)`9LuAuYen2pP&~4TIkc<6HS3e%oBApM77lzpeGOIpIKdc< z`RoDj2oVRfA(QuK7EAw9wOBj0j57~-o3(fzMWV5~@De@31MMhTT<)lV`Sf&*X%2(~ zY#%b$SsO^puRt8bOYsL0K6ccVU*#2n1BZP1hPQr=!KS@k@8R;FR&(LokcHH{kV?0W zQ*NO-+T*BD(#(xI{US0KIeSrF72PPibJ&Nrn3{bQZL$;Z;mdEQ*GFqpGjZ1eEhbZ> z(d43-MRD12c&-Jq*Ij&?>$%Mxl%`2sLxVLxWb+nyTxr8qfEr`~ml&Up0qcPMK?L3` ze#+pNq}4xwQc@|l1CM07pOh6AZ13FK88O+w?>X~vS{-^>%diE&R%|6YT_v3wLfUl{ z_uJoK#g_ibIV4M5F{KK%;B632&ldyHw{n*zfj*pS94n*l! z3x^hUrVUCDVfHxaYs@u>>Nf*2NYS?;IFrIl_A%qyYr@C;J3?0Ft!fKO+3vw|ked<( z;LYcJ`}h3L5u{wz;7c)Rv2poGMI6R^4tvx)O{DHk5e(rS23)#vxyGJRp{)WpR4@bp zyEv`qLB2xUW?Q(SKET z?)<_=NC~d&Ou#{lPg)F=cD?dDAA^qsDQGWsnU}+l8!?Docpm!Yx&);f)!1Qbd8NsI z_Fg1$ZEmxfKrCHCyfo82C9FF!P#{#y=wY(C<_Zg~x!r&NXzMK*{vOY_WpG&vFdWR6 ze5vfhzL%It&^+k10hp18A(0sji_0F5b<5QE*)#T+f&(ZaO5!MQ5H9_TmZxV?Ob0Az z07&{V?W%8O9sI2tawCL4Sq3Kqz_{Uv&<5DKRUlp%wJbr~^&i|Yqs5CT8FLDXI$dT3 z)RI+Tr^v+{d?#fRa-4#OT5qlcE_97&2ZLW~x-lh14&x4YH$ zvB+P!^;)+AZqk{h*Y@$!h!!2zYso|FNUl^M))%}0wxi)yOp2*`?F(Q8(u;N&tcxy^ zPk{=^;9s}v1Dh2Qvn2~Ms2|CIf!t-6h|;gH5R#yau|iuv4hAAJyFj{!CI8DF&@avb z@mQu?ykw{1kF`qfyz15ZIq}U!%R;Y=io~Ay6i7|5&6_^(?9N9S2Ezo`k(3rUaDZd) z)OqzZ@l*Pc+8A+rC82pzFp7Tls8yC1zPH9AT-Rmu0IVa>o#x2Ru2B|W)PidfI>!*LIWemEpj08%vL zY@$Hi6eQ+w&4i|twBid6?|`#W6u9Xb6@i_B}B!m zyjc`Grg?^tRiL=iH0pd8zedp|O?5YVxgGIq}#`=4OKs(H2aSx_E{)BBH}D z3L!RcP^~R?fBjvqn2!Mqp7R0q{EKH^nX-Y0ypk+L%kg&}D6r{w)pN%p2K8iNLkEqk zhkE;KZny=Z`9`*H6du>}!H!lfGhgS?gyQe^`wz7lerDW2d(BT1+x4E8+q^NGLVaT0 zT8~^Z)fDC^eeiJQTnh6r^}AGv>x~_5y9;fbm%d*JIgQK0BWzmFtXOM259zI}b$h%J zuT@viN$$GOKneGLX=_Gz%?1^$R!MuUf892uT&76F62mtT8SM3t!(QVXs`NTu?U|YI0M$ z!V?(25oA+eyFVS+H+Zw{?(~%bl`*)@Cz1I~Q=Qsse_^}jliW=qCvdcfsJrqg!xo9$ z(2DQcQ{-p`b|X&*khTHAgnt~d?lES_QUA~=8UT$O=YAgf918;-T7S_omr&=hxA5JA z721cAn$W&{_d{l{u}M=< zyZpOv^W}E0f?OX*1vl`a5;&p$d-~a`#Y|mQBA!&0tFWz$`v&tKm-!Nk;l*fUxkFw%J;8VRmf24Hm_0{KvcfZ;*mglD(IU=fChe#Z`(MJ3mekGaZu% z`(&3HwEWxx!G44N%V5aGsral-CsmHpOSzu?y!mFM9g+87EJ>2tag3Uh2^+GlFHb=d zZ>6t{xo_}*-soYyH{ZvT%uq#hWpK{L_3w)bi{ldJ3jHmX`-^PZv_#cs zifaAA%5OV_{w=|9touWnkMN(#$}vh7%-xWon0`k-_PSa7$r{*TEb(4e25TJ=Om@<9KRY$$peypPHiMTNua3+Q(}!))Cd< z_=wUf^@icAMuS5o6|T`IT9}@ff4^JD)-Uz$#PShM>WBKS7s~?y!gijqE2^8?AC-Ns zkA;N42Cc)#p-GjZ7SAMZF>^L6j6AAbG-47@{IWUsDq`NMkf`v%Df4t+)yt0}=BDL% zg90l*>azKx1NFRthF~5l;l!q;S1?+=OdtFmS!@uv;?`;X*vmz$OeKK1l1ER0pMhnXO@@+D`K7sg616x;@U~?91m)-cmk(%_Iv7Ix($sT^xVcVw7u?h5<=dYlG=% zd&xfJv5o2{GU3T_n=@yqE>EK3l{!VJn*#JdL7OksW~fotc#X+fBJ5@#0qiTydLvId z`B9~!r9@V5Wj;HF@{acz^5%jW!-#DZX#3{wN_u2Kj_@k-v93^3~e zwAn^2|FPVclSa?hmd?Mb4KkQ`s{jM;T}%9Ph^xiBtyP5WH-cymuY( zBZiu<@Qv&J>1}bbH@Qzd2Ch84(-D$tzR%XdS3ZAf4lXwgajF^`^YIWd=2?l$MR^22wI3p1o8 zRp-kfb6ihniJ61715z5{8hL%8MerqoiTz5ihY_(oYkf z729v;%&+6Z^PFWI@O#O*AyN5L z+qb&74wlI~J95YZTQLe3wd=m!`3!-{cK(IWw)>k=?4`UdipNbJg zF_In#PM-FB`z+v4ZwI_CHIOy^i|;i-%{N257daPLVEQmQ*-Gd$Xu1$cxh~-e5lf6hVk>Ud9~rpC%wbUPK_Mw^@&^(3`P&PRzBE> z>%3WC@@V&{3Yw?C)A%!QNA2gX3O1BAO!niG3M+(4Q-Fgz&2@hM?Pa@j83JF_!Nj)ddyL%@iQspXB82f?73cE7x_yP~5qbmFa+de;PSpi9K zguw-`gQN(NB9mH;0K;(UHib3=j?%X3nq(HyhObk|**}=ZF}b2QCr~K+PKNl2`Z?SA z`%eXR^y_?i9~w)VIn`7l+}BWooxNN*Dw+I8ov*xCX>9>ZCAZ!=zi};Y zRcKKjrPQ^8Mdy_!FgDB39etDQMR)$Bib`%qw!Kc6motrB8mVaGT|l!>A|*t+L_mwSEVq6g#6j{`f0ciL8blMuUiFugqdWVEqp(Bw_oPm*sngjA?pn83;v zAqROhXwvUOlzx(?_G>8c*4vz$v~z z30=WQYC=RBi&Hg287~ThIoZv$tAu9CwsnSf)W{5fO9K9o%pzOsEBm{3jb`>yl$nU- z5HF4-xy#6>Ijo|2KeE~;8`vOH2#I`nk++!9*(huSKzV7Tvm=_iA<^5WWm`{%7-4_H;2^VSKtD=dEbuLaQ7$l%z*OnbN*t$2bP zDi_UT^1A5T8${*dwP$%H)3Gv&ApNovH`Sp@1vu|Df zXIs{UiYIObvgT8+*Xz}W0$BUM3bs#vjxjiZ!ehl^*%DlMfLnoY_zYh-Au8kd91!dV z3SKDX!o@h$fs55;7gZl2x46f<@i8yM;iIPA5?!({RExD9e4GT5FL8NFkZRoA(B!#9 zbJkhX5>PCN{x}k6@*2$LFllccrGwUzv-E;{B9jPb8B@a-s`BKoWF9bmE74h*9Pyz2 zGxMO=WBsW9?4V_r|0C(T1L54B_Pe^M5z)(aa|xoCRU*rcgd253qIW^`9(7p}C9z0Z zC5RHyq7%Js2%8XH^s>lC@7+dx$M5^kKYN@r@0^)uo+;;lu4sEE9Y3&hcFKy|;c*yV z5~yGV?!vM~@RD%M_f96SkGsTKbt#;(sn}veK9qsy03YB~9#?~9XSrmsD)j{w*>l{l z#;B`?Knx@54N|7~EjZHA^|J^YjP)K2tC2gg+8!Blp z&7hc=0*LMT&j&mSLEJ9VRDGW8%N?HcX(jP0eG>58kp+LkUdf=voIs{}P^)E$>3}LI z1M}V>>)wrWqgR>Zz_Z{c$De&x!YFLsg=9{g2HFdtF}cO}HR8yWXc^8+LvHvSpTPrd<+D5qrr0$Se5N#h!X0^4`?2 zQ{sFv^AEaiem+PlyVs>ZYD=GCERs3G8E?h25SXvm_Kd4J)=326jya7KG?|;X!0hH5 zCv0?CG@%VXb1%s`{%NOBJ6BQn(gP`zeqyVivLwGPXTYGq2P$>9pYzRsZVn{wKHXt1 z>pWZD(W(R&gV2P-W42mb)h*bS$}Rmu^Hnth6n>>qVhz{-5Ps)F15E;DM;TdB@`MNW zp>?;wjj0f`6kd~}>T_aW?g@`}y{~p6_d21wSM0dE2ne zRINyzI)%`~^1p4-s*9Du3htv(%93AfO`8G-A4W=hA=)TUeP~9U*{}U_Mb1*|MFez> z|7@hVOSE|Ejo)8jPESquD+T6<&#IpOnj-8J{)J9MLepEq>xFbi_pygbCgn8oOwbE) zNqD7K!lH*c^F6EoZ&sJ>I^0cogBAFU|2AY;vMvuHz-ROa3mHfBh~QIzUqf-(|C`do znag}BFl}q0)$gNvGjj~r`htbUWE?K&_vM*^<$eLABzOC&R97(H7G{2u^^LD{UQqJi zmS+9|{9K@F03_zLh3Z)?g<+?8>VKvD;vmnw?Gy>F_{eScZigiQ5`pbdE774`1J=-s z6VAB;eUq~`>9!>7f=_oh#>Cfk613R2V1IEnf2CONXQlhgVwMBXxviK-==_BnK665w ztuGMIhq`ckXU5t3FCvR3%5$iM_zeOTLQzCM6pe+I(kScuK`wB(fEBPG-yL#`N#?kI zsHKuN(sl3~S0H`ooBnD#-;J65W!VUDCi&RUK=Nfghb62bv&#GgbOTJxwHNsbqfgft?tXb4joVup(k$nh zx=5abpq=oJ&=|j?)AG!c(AX{P$<1CkBLEi-m>;6V+7p7i6n{a%w)Avk4_;gcPQMlK zaZ5WaJgbkq1C+uURI=Eg?{`Luh)N z9>IvL*0_4a%|2*S;m)f|W5zF@h&PgDEw=8fv{4lfN+`#Q5z_+?f1bD@vyBtnUFzbT z3eCouTdeHs;WGiI-RbpP2IY0$p-&cHgZOa5q@5$G2`b&A|IE@&w=M!@IffqX1=2De z41Kdzk|Zne^~`gMQ9E;EDUKA;Sd;5&A^=z30}8FVc~$H!=Swh_y3M|? zG&PW?azP;`THdo6O5o0?0bx z%sD?Y&fY8U=P2tvNZ{fi>gsC!yCvX9WC6ln&vrB=67zw6Z&cc+atSX`UKH)@t&@$E zH9s5pcjYu1w*u^d0$x%4SP)oF-9!-l`0leh_BgP>CVadK{VX_n5+$U)Fj7H2WLO zk2@}OTeAcY&gGvli51}*l`?rTIKz800G5~1JM!}(^th9iQgp-~vv@f*eqIR_#Y#{P zL|cOGfMF_n#6MbFJd2+eBofBBFgs@`1~;FURv(mMp+-E|4K467Gc~?!NXEgP#76p^ zDBA3+FJ_B+*Pvy)9gB0`Y(eEXi@7+QGh;+6FPXb=C%0-8S15b1LX@ka%DvL=vm+9y z`Lgd@-t2o5Q=n-19shtCh$wM-N`&-x3{ zEP$(gm{B(-PobAvkGT-@w@rjtaS1WVct#Sa&=&I;e6{ZRd$@#W!1wfl7T>ENY8klQ z0-Rmp7;~;dq47$E$C7eM3cF-~0L?%ef4e+GM8Qpp(*T-J#4#i}_W?;h%5{oz)d{PR z@+8*YO}_7;T9H~j2&VpGK;GzQzPQ+Wu+i1>`3AqL$iwybYh&_f<`9~1x>NQ6L)SR{ z%nh)O{$mfb%DDTp772qhH+&KRJMv<=Zzc0;zsAP8; zo-Rz!o2H~gU?9Oj1w^IWhzixN0P8t{GUrbYSNan&=4MbP&wW{uE-`v^-QF(uxw0iL z4L|<;GU!OTLc%_#!}`$0}as4l&{+KHmDlb*IH(aYXtdv#C)H*Zu)A84RA_! z&*9jGdrf`N>aJxrgGO)1DQ3Qb5oT5scNse-N0Ip}oNSF?b;lLtR@jn_W~mc$u7ho4 zJ@qpC6lb5ZC*2qlm1(Y4MEXZuq^a*hrX8%3mz?n&bmbL*2JJtXgPA#Ejt)$U7mC7| zN-ky`?YkC6^bVX!0i48mpI)yCet*AfZ>Dp3^I@HQqkC%4*0~?sOoj27#74pDvU>eKm6bSMfO^2; zj~9W9=(n4ztG;eTW|DT7$xNQo1QfNEr>8$P}^PV~W z>Hr9tpD+f}eDb{F%B5ntf(v0cS{otLiembtbMydB#N7X?*Q9mbmG^0OcsK)aE(49$ z9rAzJCV_mM6PTJtj;W`~yb@WrN;dz{3LuLl=K>RScC3J6Q#98dMoxmz*nXg;2WP%22VGaVUefOg_7*Hw4(YN@phJgyiTZIf58PtTTO5m>!nm5C=xl0-LH|nwH{$k|Z0jWU0jlreCWH3#L$6Op<(cAL zQ#NzPjtsN)mFac`fhp;2xi~^)p{8m+s2EYV2v~{Pksgo)0^L!C1@D;G%UQkx^O?G$ zM}Z*kj#>kEwmdAb`7Ly|E+RxbiUx|{4^4f0|z%$C#` zQqPL^Vg#0qotSqn06`fG?trl}vkN^sm{7z6KhaBKlx|FCQUM2F5)KCBIgY4%|K;uQ zE|buxS?m_2G7l8GuG%4JnnH#SYXK3LM1^=)nRV9105VY67ldE-K1Ztg}Db*CUgV?KV|Eu zkhMR4p+5UM0T0nVZ{L&tITKxc)Vv)X7Wmc&!8V&C2_}bmiO!U*5-bC_0c>ug8YOO3 zK51}l>jeLO5e%|X_gl{OB;H-&fF%BXrQ8w%&Z`a`W7V52AV~3tL}nhyAZRSJ1!j3! z5+@&pD!3XZkhVIwlKTH_Mf3;i!1AcCj^g@c(RdRYt_|bCHtx!I{uOd!Tn}=Usuq)QSq}QYMiGD$YAsPqR6= z2z!6Pq?wFW-NMhs$ZBbebTXOIlIAg`XEnlit26P(@5_OKt|U)8*zTkyghqo1A{aC& zqXI8O2u*4&2o16~ubL`8RfQ=;n>l~=vW%&hsPxbXD#y>KhEudIxKd)fxT=Nb?(UG| z8Umwjbi4OD0$^tOfh>d`Ffovz6h% z{AzGU<=fIFVQb1!ZIELh*1i9%JvhE{oVxT>8#+z{0|WyT=H9QFGjy4B6o2C=+mOJ; zz$-Tg-!#>a?%c@&k?nh5R*-eP-bM!%3`8W>$a+vsJ|e%08{;K|aK0;E52?3cYqqxz z8X4clHAJX_NdWv8mzfp;oZmJ~rGz}g?#1RCsk;JfpB~dJ{Zytw?Y?G!2q#b8#>tr& z1u-o*t?fpGhHdv3_49iGHhXgCCN0>ygx6ChI^kQw?Kc}M$y8?JOxDrV zpyi9X8&F9cS-$dPnhKSy&;Cj8E7Z~Dphr$O#vUGV1!bakB8rQY zeh~rj^~h2*8(yI7XHSF!sREnT>p?vLR3Xu{obVu08vYQp@*P!-Im$_^=!wApXg=<( zt88tTZMFtjR{ihL|NdphjWN8hIBphr@rB?kMs}N;QL{mo(GJ?xf{c4jvA~Y;84$F)=^ldk4fIvgAIs2H zAlt&}mf#f2KLj$;-#>s%$}L$PfM4!S8au4>(rxPsPzq!O0kF|*2SbFDC?(cvwS7#4 zne2F+mFdH_yuF(d8FZR6zV-XJ9?WJYj6r_jKVq;}_C!TukUnbotFkVF}QxLoO>wkGD6R@j*7d-Ax5q?L%7CXlaW` zyB}+faRJj_AHmSA#{lN5XoCf=DCiz@SzEynwM|UhL6_ol!Y&E-d(uFOdE|nS%~>I% z@-?f>otd#CSG1wK{YJeY=#Lc3LszltgT9=sBXd?7SZE*9{JH0rRrybNm7i@}!Pv>U zv)JQLly!&r*DFAqEXF1YRtoB&7v`;zL>>jVXU zKb41*CYWjP`9h)*DLqz+b9o4}_CH+$BP6hmJ8Z4`l744&guuX^u}D6#MJ$!gNLv!v}0 z5?HBSD~LIpqMj*b075ghFuL9l;NmULf7lg3mnY#A`<+?^J2*pn(AwW05pvw48P@Mk z@n7RI1Z^Yx$>}-AF5}?i%~|?apw^%65cH=oPMb~sI3)EokTc7d=on|o1N-sf-FKUI z6xREdkOMoN?*NL0FyDRp`7V7z9~W8eU9=BE7dn1@&IsZ7lWvV7N=oBzMT*>F&|W2* zHojt?nMfErwpyb5v-^Bb;mONuH|l(%A+LA|FN=;mD@ydj962rp;O1l5rBD6?@orq8 zdC4_a?`WGDbjYY*naMkOs!u80L!)OYtR#BvF|2G1&)}R;oG34c6OXwDI!Y)YLjNg^ z+7CjE6eak=f@p@LfJTvo91F~2@~rJ5$U*7&fY^PR;qq@-px8Bxl4v;!eN#|vf&$%t z83REGqQ8lGP?>v&JxEH3U3pCiOI!Xg@Q#B;Q`~vzh&5ub-*0{~mVzJRD^HYC#t+}6 z7CG!`K2#V2|d1{ES zpX+I^H!V4#KMmvwl#ygmmO-C$>DAFO4QN{$)c7U)a@J=VVm1*nxy>hqZK3m70TGML zp-`$GyTS*h1Q`RM7yPG3;6DQ2iRf?majF%#%Ub3TdQ^1^Sr6TzzKz-ab*+^(dwnr3 z*`8yehEb}7`m&GCzS|o2o;hW-D@o6zGd-;oA;ES)j#uHO95$^(mA! zY_ZC7d<~q{`iK_E<{?($Wa(x`{=JimKLq~BLKEC%chE>v2lS%K=Uxkw>x~&!O0muI zc)BF*U|y%0Z9&fk-wah(JFJ4jCU^F_9mRE0o-ZpdwEJ}ap-onafDK7ysc-8Zy2YCy z#Q@WS2(9b?mX?=9=+C!aNL!ZUxS8dpD;YnCn$P=DP&%fEnbxF= zTK@KwPqrfJfa1`124`S$!X8E3l2bEfTWWna5h-AQ3_}x}>DME*IUuHCs&TsXufrVR z3wj-X7&RX4a6iVv8(YMSn6{;f)sxE}W%sRaO12AJEWokY%NO{S!o;fKSyF$X2+k zH%G%pk?F_a)b|8xlQ`)x4g$5sF#&FRr{Wj(9r_){~4=sGY z)1%XI@oV^nJ(rmhgy<{yEQKdzTxrm2DDMCL4h2DF#)rnF3V45gIN%nmfZ$YR8`DKV z%psw}^R))n-JT+2K5{pP6J6sUvB9}eLl+|dOc2Y6?aIU*%FkC~wy6N36P!Cpa+-!W zmosQ{{{0^Vzw!*L&OPBHHNGp^p;xadPVg}$L2KL#^O+l!(de^XizW<9rd$_y?N<4` zG|-kKMu6f_B_b}{a?>u?Y^pwt!Dh!E(Kmh`!R#Vl;wgmV36ES)mnF!NbhdJ}k75TT z*C)bb@)RweZVcZ$^!Cf5fbthGPf!kzY7Tb30@?e+W(&0c@^!Jt9to%b>8T`g>$$~L zFsj`;rTI|Fi?V1(-{NjJ;3>Q~gS6iZ7?i&&T2lJ#x_-{%4ujdQPO-o_5k%bEte2ZG z6}VavouIk&9VB%&{ON{U)R^dcx;CKbW3=7TrKh6G>nh2SdXSB%gXLgM4T^h`CP{3? z0d1}e3P3o)r^nfoV>|cM&VZp^9md1{K3IEylXv|cM68;G%n$xi>0(~ff^5vvrc27> zmn?>blJiJn(;^YILKPUeM$pzIE3DgR{AXNR8Utqs6EuQ%zvr^VLbaYHbza3&f_nn? ze+of5#fjPx@_}jgt{n!a^;pU}V7~a1f%YfpUd|Sq_ImI5lnFXmwR@8od{8MssYwWV zY^v)f^Vw)c?L@Jh9Fi*CRapZPB)%sYE9LKE2A<_T{tuQbBt)!Y%xw;E`v^Jk*A*5S z1_iE@tChO#-j_|0B5*`kd@@9fpSlLp^CrU3O8d6%f;!$KJsJfS9}sFgLKT@82N-$) zln5p(FS?cH!`VLr7tWoekaS;S)bhP`S#Dn z=1UMRw6UJUBr-RCPbEt50EXGj7jnn!&L$O+V=cu7iemciLsQaUsGx(2X!Igy`w(U) z+Wm5Nv(?0n4In`ND`BhK>IvZH6R{H2>V$EOPQHd!%1s%6bd}gZIBiLJV$?slQJT{G zBa+R5 zqCf&$qWhBFCghZ+tg^L=Q=URGtGre!kx zLoyeQ0K@7krzec*3YusGcCAtPmtS%?qm|p=CY!0|Q===Q;fq4T$huv`Rd5o12<_0! z=<4c=S+W3tWAS@>uI%4$#f8faL!7w50#(aoYnMZz@iN41v66zgF@ByOThjMz#Z{>a z-!Tx@4k}wIpyc7s>WxX@!m%h>03DtV@=Oem-+Xk_>X~l(4N$?e;W8?H2hDD?=2zt? zE@6wSWUyt>w_58L>RZzC~H&Q}-n}0((8iEiRNv z&Xu_lHr^`Xp;g_CXgBooJLeolHu@@tMl_yj)*D?MgVwyXyth?S06?D%gqMXzK=s`` zdlzoFb05%T>s7J-NCDg-K$&VP+MuUqHo68-9saWIBxRJv01QeV!=Jb@90uR7c)k=V zxFXzh|DDIB+Ct3kA5piy3|RMVcCF-*&MrBel?oX6E9Mld&4Gr9X_NJD>1tf;hmXt{ zZ^Fqw&Lz?3Af{08OD|uZ&E@gi9IX+<5Omh&*_Y`clmzB>H%`?46I|bst7B%V`%g6;rbz zAO~hl_{ilgJf5XE>=)i*Fp&Z3{r5X(2Jd2if^ZnEPLJ!QV{CYI4WO6SeAR;AAt1=V zlDf7PJ{8vG4mi)(7GD?s)50~=@s>&N%WyqR(59KcJ zn;7sbQsp)&0fc{}MgtgY0(au|SUN8ihd2E>|0Fsvz(Vhxv<^m9F}m}N;J7BYk?Iyh zFld#iRL7`aeaqcMi=XpSLe9l5%)w<>_aCJ}vT1nja#)J}1xxryXWc?$ESzB~_;x4G@*Zo9pFv90vIT;DSwrN0J z$+&_@hinE02e(P;3SpT?&qq8Ri9;pfS|jsGp{mTgVB zX>up;Dj(*Iw+*@X3$FMdhZY%5Y=vF^SZRc=3gq*Az^8X~4?3k8-t)5LaU(x6&U`o* zZHDdrC%4%)Y}RS4(wGu*>)flTg-_86t%A?Pem>Fk%Rn_oo<8MNXk^tH%zx-g1DN11 zhsakk95T`TkQ%!dU$-k&VBuJf5Hb4n4dm<6q0aEXg*27YH##P3nX!)-PHaSos0qr$ z-sQ~$?DivPRsyEPDv!?Licy^-2#|)K$86^?;z|*Xv_9dM%~t4`;fN}P$H$IPy-L%N zc%Zve3nlKTyy`{cB6Ds~B>YbZ2u72QTlZqpv$E{lnv67pjO{icA~SdVG%zWi(DAi* zB@L40R|#CGE(vHPw|D$GOX>w(b6E|7YV#AFW@?hM@1!>Z6U}o4IbJPv$UQ3ah+emK z%d`aM@}sNuXmu1AIyhGY(`RZ3@j%(se0fh0WibvbbGp)ma2T#R)Z?v?clXO`)?l~n zgQ49|1xq$!Mv9*^VzoH|*-LWtI|^<+`MPy6Jzy&(=?622A(82lYh+OPgZ4LY(lJYl z6dJ#*?v9a*SXrU+=3Uc)Y>G2uYHEs(d~CRrho}$NSaX}kcKZG&X`g) zUMwWxTITRIOEVxG%;X5Jk?szWub~ho(R|%;CaZ-)wM+?9j|he(0%*KX;e&D??lDmJ zr9#a$wnI+_xxGVGAN)}4YPzuIoOb=s&qPU2bWJ(}j*bJOO{hn}fSRNc5No-q>`Gfc zbA=tN%zQph6=>8a8xG`k8M(7f$5;H^+iZ-(?j0@UX3alSFlr=0SzDz2zk?P%>zJyD zww5?J%sejxvomD6Lj>h10rLK7PXX3?jbJgvB|!Aagi?r2w90c*_^0JGFVm-}z&on1*TKPgeNZliQ^F2u$*rhqB#Cq!B(7@v7cl6RiMup zj0ZVtce#XQYeeoM-sJ+G3DGV-y&{7+hUpNqa*Q<6aQyw0=+N%|=IBx9o-p zZ~_Ka)Q6hmuD2}9+SZ-H#@0h5OKKqp(?Yu5`}E$q{OFj$Z)xU?wGl-JIga}FA81v+ ztQhFG|Nli=Pegq>K&2~>?9O>pG`3YK=<*>DY0mUPzXKY8uw3ltLZoXbh2ta$5e(>< zkgIf|ASn4YhsKWBm@6?&VE^3Q3UrELeX? zACw_*sN&){7a|}}1sIkbCQbOAmu%>@<=f9ae&H-aTSrQH4u1#lUE6k6{gS^XA3a`7 z1|9ZfuAYuEN?y{Jz?0%zE^nIYr>q=`YB^0C=V)jINq*>p90X=U-4Lm;dLvUplQhZ# zW$^{}YdC5$>D3qIwa*3e(uYfChoUg_EnkVc)~)KzC^dlnhN%xPm`GPnnKP@4F)=Ro z`kqf}sDw!-Dm-t~6(EaZv2|1;Ps*kjE+aZnZ{{g*Q7Clc;ME0X^=Hd4G)k7&3ni!$ z)U6Ci2|YkiP9*M5fXOMr`5ZH6IX+#dVk4_d7zT>ed!+zQ9nYsWnkbTOD>fqehMB3S zaKEi$#D!fWgm9#^l7PRO{DBffngoj{C|Cao)=Y~8Dro(=cG}vC%v_Mkiq=N;M=On> z{=#37Ly-L?Rh~p$XNa%3grCA90nbHda+|%Da7gt#ZLvQa-WKgjzcaxE@u>5I%5K+c zJtvX>KkYp0;*5V}u(~a;t^}_q6*E2fATz~=cxXG){&i@6{$)(ej`7)}D<@X15vK<|^JfF*V2)GHT znflrGgD^T(o1Hr>Wc04nUpBm+xjj3ha-8Z|I9BlzHTRf#ZZSEeCf;rJQF~>LB)tb> z5KiK)AC$EltXCvPKkb8$@y^S~>VSbgx>F0xuAj(s z0RoaLmkHBn?>-wmV=yu2pOZ9yz!>XxTIHlm5yfx6QN{>C^2_#f?(Re~oAeW`ll>L0 z| zWr3Ux?V!RGBJ<_(-E78I>)JJ*bglJ%rWEBUS9<|HEsbgs}b@aA89ShF7Og{DF zL*{0T<^zdMrXp=q_E2^gz{*{77gw9;v_t8;>5zjS)iAHi_Nu^VnPIn5jBShiI2Dos zyF}x!s$pskpgSS?Y~&zm`=`h-_!vsF>TVP170Chz(Z+LKoR+#WC`UQ+=zLb?Z~07k zqLziiqrE2V5(M!NW5AK~r-rZS~9x-IgQV0M$IKEAm>YdmR(u4<P`Az zZmi2PRbID}&228P=!6_Vrlt1eFG?G6)U)Hdk~ngYY&}#cV&h+S0j=6Rp0Lc0n&{&_ za*%hBP}V6eRVmQB`cw*Nu{x&CAstML8yl4fN}e|2cNzI2yhT0b0_cSi>u7N<8H5`Q zXI3Im+apMy?Wjz*u70s#&bPqyjJ1pRJ6gZIT?f>9Rj=G+1;rL;?eaYf84JS{_g{}Q zJMYGsaT3qe%o&humg9c|v^6@~X%`mPU7sxgSfd4!gJdIyTL)PzwD_NRZeSU}#}MwU zY?dOHN zPD{1%^)v_O`1SXxAx%v-Xcft}*ZQW`J~l?uhC8%vy+`RbcBH^97(OqJ*`$I3>w&FG ze1`>jrh1E>(F{A5p0vjGT-Q?I%z}drLqCs#bQ4d$=X2Z*sAfRW%^WCj>GoNb>Jnx~ zIEM<*OdWgP7Yo(vv44)Ox#frn;AXD`kgzuQc~6I{=v;A#(mB_HSh=G{&|C*tth1+r z($Wn{j%2@=}-~WJ#bNU*X6o*0d zbxg1UoLN(+oeBD5i%SFm)9~~bg>aK%#w3}uFXUo4$|PYo($7UU&SfclROEL?5{D<5 zYAOvV#^5<*65fzOgy2m7u>kjidbrr!RaV?FbHI<9KuQX?Rz00s$yIB zUt6qYfc#f)_eN{CT*4hL-)_+F&6~Xw``Uq4U$5griScW)R0xy)H!7jZx4HW(wO7ce z20)R24Jrl2^O}s=Z;ZBdTNv?GC1HwWVpd#7#AdT@drc zy?QzbG7RJ$1#r7S1#@9h7KhlT1n$(7wOH?porw3|Uwwo*G2sjR8fi-_`?HO+7Cs&r>(smpKNKn0;;bTfYf8Y0DimxRCQRn z%TXT}f?dTeB0Rp4h7Sh%MPpT%Hl~y`f@KqZuKec7-_~^-Cx>*3N=h+1 zR>tY1n7@14#-0f-*LJbVaY|BFXL|b3JG5+zjcBn?CXO3|yAPXLFYIBy_8JOl^R7oVYMMyzP?f z1nDA(uxYs0?l03#x^~E#E!K#AaG!*|Rx=PykeCkHfZL#O z5|2J^f^2SfM`xKM54>o}$KN0_;}aG#c7AzW2}~8T58S<%+Rl6pP+h+qCH*Y{P&;jJ z8}8a$VdNP{(_TR?{@Oo{rs0B- z7L^VZ^*AHF+zDZm+GlL(Nrob8C!TheOpuhqOzYUu2K%1|@ZFg+AJNS|lTXs7RtG80 zICjvePP!FBk395Vwj!r~2DyO=1pXD#&|zWQNRJn*ZdbhXwU(3DyIdmWNhHkdVp!M< zx)uo177i&wnHZ+`2UL1?u&f^;Z}R>cmA`Y;zraJ12h^4Y;8SR(4Q1w|}K=uB&>PC7OoC zuukZZZ`7#Ec9t`#YbWh{Y@Rd-45CQpk} zM|3nB{22!2e%kj{s#+DpsriX5W01pr#{B7EvrI`CO@%j#;$(CHuYRQ>TqB9c{EhUT zlfU5=iTU3+GQ>h2T1SBa_&+@DcxfS!QL=UC*Hd}C_|}(JJ>wd(T;^S}+Qq;>Y=>k&a%jEu|S=4RO^8N~)_=8u(k3VuWgSM0a&vfuf40xKu_iZCn6<7(z z)(HC(5g{QS%QxP#AN@Y4MBQqZ*R(TV&B1aH1I@vXWR_Z!nlfd&O@ph}9a!a1l{hK${m4m6W#0ZOZ*} zGE(Fz(B`$FvmycicEd*&;EJYyJo==|LPYXhx0oTw#i2a$nwOQD`$cg`EkapZwTsy` zHtj4eavjAdsAk$s@4IZb$Aw0%wev(@l$C&{G;5!3MZQyAKUR}BgrC9P8jaFHfG=Ik zjbtT_QLh3_cDlZSu%gXA%+b(gWEnz$tbW{G63;)MVdG5#Y#YA7c+?vzBTp@UA|`x6D$3*;09671Mhc!V=>E>d(_}Iu>e>kK?jC0>cb+Y5&I+Ce1>BoonF8f z$xoaJZ{<=7`{!izG1q;C6xr%B^B4PB2#e9jn4iJAn~es^ggFuT&{pwpr5eujX)Ui^ zp71BG+zWq^0b(novq}>sN14qh_+=zoo^$5bMH^$y+WlEgI)M99pQ_)267mycazy-g zmT0z>JSJqK;P$Ci&tT#4eb}raN1gzGZe@3M22l5LFP#YJtr5Ig!kKhyx|y6bJlI|R zH}7}nSQUfbCsSvKv9N9d##}zv*uO*}{LjPv#h_)_u}q!Y+o9p#m9h>EBESE^`%O7l zJrnZO^gHL65po68pFW5dY>IU#9wq8Heqi%#pVaxDnk7Vj;8HaVHc5|UW8CD&XYoUt zKkjA|rB4;L!Q+UEDaL zDNZ(FjwL>~vwRHo-U~6`L?Xr$ZOB;E@vp%FDjcf4k4W*VoBR|E@j?|6^iJl=vseJiM19wfPKnn8TxXQJt`^ zpR!{!vnOUQq;ih|NY0x5UDBArVf>_t{z#!5TA1DObRGAK9DA2D=FqpkTmtYlYTi>( zuSuLTW!C@}7g!V;ERRse_I4lsTFPoXJ0npaRE9s{ei?dxzPVAJ?!`*nD+PWoRCTuE9}JW)Yq8o? zPeJ3Ba-q6AemguphxCJL5eIUdP5M<5xlL=scALY|{Vt@9)WS*Z#NkU=)l|IBk3^V^r*$acgwd8LrZ8j`orJ!th$3b0H1} zA@I5d`g=~~Qi(pa$XJ%~*FH92PB5f-AX*`AX~PSi2Dzoa-ZBb!=0^7?*gA;F%a7*2 zehlUdk9E?dC9cRmxkqlpx1@+BtIYwW6d>~ddivP0;|YCbU9U3m1eV~rg0(TteWPrQiQOs_h6G{3gwqQ z^TKyMe(WP==Z`e^HMJAXM&zIS?EH$LNYMfNQ34wf9f+eaxqY&EB9u9WRtBO`fI;Z{ z>>-72kyPp>^D*`4dU$S5UFV$f44j>YJ^Obqf-!UUI`R^9WXQInEHPUd4VmTjr zUn=Z*DXp>srjnq|50Mj3eJ;Dvh!uYu*G=?K+N@R|e~oCQ8guc@ZKf1H|DNtT(wPE= zW`Eg3(hs6oY@8i@&kLXt+gR}FApU}C{+c&eZ5RX}?LRabW)>JN70rD=7tK{OF(B*- z>sDQ}i>&h^r-%3)zXlGkEW1g;N^;|v1uI$w>;rv3S@v&J{*c0vQYNutOoKZ}+R>U{ z49JGdf^WCm5Ay8G)aNKMMi_xh!?@`#V@3WD!$iPk5Eit*fj~#(n%cY5z)V|H-HQj1 zCiHpRlam;+hn;JmK%poZmfx81UBYyf0l=X3dRg|?x0e?LuGqLQ)w9xwdH!$vN5h?R zYO6mn8F?+VuH1irX=GmluY(lBxjDcWJn?oI3>%>wdb#tN5~}Ze<*C>(Ly3&3;3t zYfsLGR_L=4#IDjxL0kraQcK{0Xr>={<=2hPm&5=Cd%$qnZl-uhtei>S#c=0B@clQ< zWK6d~OnIMeNz>)>9KM4_E#Bj1hZ2_T*R~kK8iML)*VI4|ZVDRw`~%qzOMahm5iJW` zRLLoYRm=M;+aXH`ErPv!@~Uf~{uvA@x#%CEgpFU8043!uxt6yJtSp<^Of|SInOu>5 zTc;IZbk%Oz!Cr5ptXpCrJyXs;8j1-;))_)fOB2)ou zb7R(VZezYw0IGGW9~t9p%Y`;UX&QuXSib!&ZQWb8PRPDM9>i%cH$e?W+1Q&0pXP+s zf>Z@n8pb+9`1#q8O)ME=bZY@A!LN3r4NgwdwQFE5bxYJ$5|G>dpf2%u!coEzBxaya zHhTT!Ya9z187G`GmejdTi5IvissHQsB2a*FwoyZuXD*K{7Yg|N&A00>B54vM(7oHQ z57LQFS?N-iDsE(M6YjB9Ot^w1R4xbrY!JkgPfS27zOS)5)yprx7x+5J0NCV_mAj&U zsRe~Q?6lTZ%rjfFB6HyZpTGM4;mP=oe6yoXg#R}|QP|6O)0EpEszB8VXtA%M6GgHZ z0CvOuUYD<;A_;si0tpm(j)j}igT?9Z5qadqe$%%QvBPhI9juoExA)qe9Gun#+e}aA zPlsxR>2`1VS8amKDlRi);e`1(x9;-AiIq^8(MP80e%X(#z?*NMmpx{|*zlIH@BW-` zztJEOIuk0J)lBMsS_e z0fFH#coD?Nn!gP~^-JCtuKicf=M#Y~0dKCoc0U>N8=U|PPg^l!Dfn$HiIHv7Dxw1T zp@idmzq}hy=2As7Xf3va4!w(u62ZbD{$OFAnG1Uf2WevM5Iz$<_EgO`-ih)4NvL?I z2-UU4A&t^B3}@Oyu&hNE>KRod-hi zbV^s>UiQv~A{}7;AM>Zqy{z0re7F6<=<-`Kx9KFqw;AD5GYJ3|ciDY}+mxFz zJ3m#$C%C;?Yj0y3hVE#0&H&MKdaQyLOfnyF@ar(*&x`>#)r$g<^uNy2yw7OX=fJj= zfIvp~hVwFijkhK`?KG^YbXmA4(m75~4g5&$Z`eSBynEqK)7Av-+4P-q?X!N3V%iEU ze39L=G>e*8y?!@T0{6Ocg!FbhPni;Z%~|G&)Li!w)?l?Un7#j*&JxzWDE0 zVl-sfiB>AgXD3WMOVGSjNUB9>Dwsbov4i8&%#bVTzXM-3$7e?ewZXD^aUNt+b-QSD zGrBe~GuzFue0MIm`*`m3Z#HbjGl?m=J4g4>xtuby`=M8D%E3dD;JNgg#~stBfq!iI zJZ}S#_Di5|C%EXx!iB*pH-7!q+>QNI24|tgv9rb55i59k?vc4(109MFb%*MEQCm?d z@P3Nu!=c5InTH zt1)TD-d#80XhZG-YBqI}x964UK=t8=mXUJK^K&K2)DAXvVM|D{Y7gMy3ib=bj^ytZ z|20OP|KrZ%a<(}NzOw!&*O|90-$-Y-(!jd$6K3pMp24X`>6Pa8v|Pz-?dD8+VTIpP zS0&HS&zKboY_w3dHTQ=*=ar#F+3mT7o3O)dzho46g+7}Nn3Ou-{K@^wg3tSy&9X?) zt<|MurCMX=c40XevD-h_+&Y-w<=Tk4cXe*3i|yOCR0%EG2GJU{#{MO!j&Rk@m{WIz8t??*$tsXaN-qLNCq|$Vy{ia}sXM5c0Tt9^4%`&{1n+KYpL=yF z%=7o3zTx-nwtTsJ4^$RZ#p{h)PB)Sf5~H@M00V@ z-BMj5>U>+?RHm;S1Cj+syyQZ!3I+={y&~Y%Ug7RtWiP&*3u7iLn|(c`SiSGOy&f#pN!<$ty66PLh)i`V#IfvHtN79PutFc*Of z{r3k6rZCHAC8-*3DqqpBtu4^UonWh{5F^Xjw-T8&NWFq-vm)>Lha@lx+%V`b@wjQ0 z!z&pX|IcFJ3a|$qJtH7tjCS<{r8e-?Y${^a7GMq3dcl6@)^Kt<^Ti#^PA1>WH59_6eA#t&!ZqHB2rbl5Rkg4XeeT$ zg3^VEO7BVy7F-tuR)Vn7Y?Po>rG^q)=&(YhDh2^jK&k=h-?@o??=ycXcQSY8%&EU~ zX0G3%;e6VN0=>?k+O%0~9=wQn)x(#hw{{YhzKSZWM-aB#>(nUGPc0f2KCNe?i$`#L zUk<)m4a~#$j~tPeHPp^)XOv{|`Q}L*IUV?(8QyB=-G1x_sYI-{zA0ct3i@2O2u zA_x)ZFzQwG5No<`pn(Xx)LZgy$k{WDHI2I3(MmhWg%S|ovo*HH&1`#+dIo!<6X5wj zwbVXb_UpnF`Li~s+FGF-b!Of>JIltbB;fseM_N>j6uvCLNzCi${kKH((Gss> z!FQ#}gNl-OhqwJ(tHgHG4fTA>0hzpQ4}OSZZ${|0QD2W@{lwrC%V=!lxUcu(e$_Jl zDV#aO^yRaHXhFd-G~?kRd!Hd~s2!&?<}3Pk%g6O{rW+ z-4}v1PfC20cfuPpTG0|@avr4G!Wluc)|u1M>Lct#KXozDPSoMXTjVrIV=Y@Bw*`M2 z93T|%)V;f4TeZ<*!Zd|e+eKtqhc^Vf%VP>DQk5rP{U-_yn(5n0oGgs}1qOZayR(H3 zs#K=z5G+cvdwTrnRRPOZ*E}r?Nvdn?!xqUy|AldguouH^s;$?Y&7T7UzX$Z`S1s#`96JafqS|K6@Keu;M2jrquS`|WwFxXI^#st*k-;d-JT zUsMYpNZdUf-EHmo34g^G5cbZ^$TL!Xd27TrBbD+3PKZzl$En&11Z*eGbKYu2R)TY1 zU~&g#n6CDn8kOBR*9?46!;I5>0o+cYDbd5(i^3D8pR&Z!`LP3%7IpqunUWzT%+b1e zd4PbIU#Ir>|HpRlBt}A(>Y} z=K5ONoK7?@IyFR#!W5a0Liy9xW7*;VWf>3a5h|5&f^eAE!LrO0i6TiTtd-mH1^~|bwuT5)B-Ixrd*g%q8Q)9TGQd`Q~`_9$hu2vHCcEQyb-?C^R|G9 zk`n@nL`KR;O&KmXquBPAFSx0@JW|g#xBT?Xf5#~QdJfX#F9Uy^^;A?9?)VakE+SiaX`nwxqi3G-q@Tc*GZDnUkj(Z|9lzb=T#QJ}7>=8-_-Z zv(N3e+hRUjLr{EWxMg0`|@MXG*BnV{q&A#Jvsdf@85fQ`(D z4Wds%{brrF=2&4*v=)8UTrWbd7Tt5KVU;78HKL#NYW2M%%o!62%z^CY4l}-czOjsh zC5sD$6E#1S;c$`LsOR5dRTwJqo6^-rJMUs4Cbi+`xdS^D)+tcWVh>p?uyAiKh#*Vq z=rrY%Tp{j4=AH=osrlT#rB+GXWyiSWUPyvfe~;#D(0kY3LNv3M#V6NvepWN4bA9O` zWx8bF57RqY6Mw*QPI|gfg#ckygkRp>6uY)V5-i11`%z0b&sSEuZxm$ZI?w;9zHjbv z+{u2laHlO}X*-{MlGf?%F-n^vbLy5}Hq5pZI6dVAJi+?wMpfU}e@Dy0kvY$L(Zlu7lILfP4@6!3`7)e6y_kRc zY*N?@fg)sBr*T7To%>%QkBg%2zET>T3~v(M>&9Ex?h=}#h;G_h(~P=xyO+TI2i(xd zV?H2e{(|00X82zyNHgA}(0nPk!F=2y!ROJQ0MOLc5PxWxhv6rvD^ptQaHcO*x%=pf#3W^-s z3GW;{@ROd}8}stG^zDIQG?ev1ck`OGg`D@0_bB^>VR|FFFxu5+f(8a6JV(k!T@UV&6LhVGuy8y!ELlz;VF zf=b#0;CzS4N5QULQ#@ut#QRZ0#B{$ zN-@^C4D&MZ`3Z}4_ z*7Mqpg;7$2LJ3M!S4Kh?&X`>i_ymVG{z#Z!SnW%Z)C>B4B7z$6#ff-?b*$GkOcGM= z;CCzqRWc)PRQ%i@xlkS58st|z*^-p1%Cn(ppLMZ)0j22tfpxK!-+Qu$OsjY1%o!Y$((5^TLMJIaB9|Wh$sz&45 zbF6Llo$ET8rIyN{pQn#iH)K250J>2k-?n-}eZAtHrF>ey%lN5xpr_!_^txwxx@+!5 z{{OB+x~YEp)t7{YPi`7+yxv1|A2BjfIM^9&s)yAzONpyZ??0ll9l3la{bO&;62@z0 zAv*DLP33o$1zkoyFM@P!?b{3}-K zqA3)2k#A!yJF{CvA2Gx3&L{HbORB}rc<`u%n^ZNIJ0V9F^5s+fWR5v*#UejbWe0l) zo(hNCkFJGGjx02{Nq&AEKRESra&11uU`;P5vf*Mm<1S`ZR^#F%~mJF8^&o0`e zLi(4Uuh)W$Dx2%G!&5!pN5_=egOvp@eW9wMv{kq?uoHu*i0_nl(jzsG3X6F?5v$nP zOl|7^8@^wKHkHgTgvnR7HMeJl5XYlDltSH?$D!TQA*2{{XjF8PH&!&qI+*+=)p(MEHU>XDsa<)aGpe3<(}YMID1rj)Dnkc`36^8HU^$O_`uCGYz7; z!qF%)BBN^&LY-8P9b4c1_O~~1l8W{XbGHqbE(>A85y%rbp472?l@u00 zt#uf65Tg~AU~r{G#+`=_gX&v2V+~$t4u~@K3r0d z9rfg;I}0X_sz)i@;*J$w_2<$3=-my7^-V;^#OkP@0Y^{~;@F;JSkCM8DbL>qX&I=c zp@~Uk-L6Mndv1(jZj8!UbQNCgmF(0J1@lGqQ#a=@gbr#R7YMJ09^pB!fH!v8H{?Bb)PA!enio&aLFj!(X)c#aSer}_~I(HhE^{JzHfq41ct zgdnF5lkc88D0V#|8{RoZGG@e5K&e0jL;rO`)-038d=9;ilJEAu04aR&il3!p3?eQ0 ziwfly+6B6|sDh9ug>r(N7(}MzqBIUm_YZHwb<%G9=9(2Wk*b!~ezh0~aTBsUM4){m zE|-Qynjh1c%1Ob1Cqjcup=sHt@9zxV4hzLv7Dqv*;$efN7i30FzbdoUJS#YBo$@&v zbixhp!b1#)VK731D(>oEti8yXm|E=I1ebdr&k?*{QP2X=TsgxAk=~fHl3v76FUnM} zmN{0w!l8w?i#cP;WouSDNf2%|IAu=0t6j1smCsSUGRE%L)V(!o?M9C$H*S7R!;tU5 z?py}ie@~u(4;>!(m}zgH%he$!QvkRM5*H`HXFv)uJQOns;1*Dpaco)yF@rP}wDHsb zWCd*^28UBKpw|{2itd?9q_r?cY{In!WV|Kb=S2`#SI!;2;u{RbXf@194AQ*;gU^3| zAvEVAFS4ZNB6xCDMBIQgo{gNYc3ds#!wG(*JEoQ|AJW(vMaX(MD<1PH1jT-#^~pjX zSD!#Z#65l!#cY5wZS)w6j5j@vv6BM^fL2Iyu!G|vt|?|un?xUgt(L&xv-~TW^QxWV zWRHDzvZfbjuKK5P>_zwWFN9iqAV_b#4_cQji?vj{HN%>AW~j&@?l)l%pwU-fBn5Bj z`0>SZuDup6f<)+1&@KB5VFY;9vTT*0tR5@M6(Y&T`iWXs_vKf*H8i*Z2Z8juj424x z4}!s)7A0i#ox`Z4mH}r+3ne}L(^C6Rh{@?-5cm7g?+~f#+j~Bw6A%Gh4;X&Xt$LT# z(~p!oS$xb)UxzsI19LQam^5l4Dz%?bh2B?)ItZM8(zpuB03kvEp<94NmBQq}=|hs! ze=R{jEYWB2W7(EeTYQh}nvV0IB8VfzfYc)is>jam@~+v4=vrj>nEl?bN2${Hjt6CG zHznY#Yyo$z5kly(AcQHhK?mki5c=yMNsDzYap=1dcbkOi@u^Q<;h`d(1o)&Y?64RH ze>w0kFo5$N|F@O{KNW*^?oF8=EfBkY*W4ez2!YxR8yVW=$fhajzxnoXw3s2{t)p8 zT9EG^fYgR(r4Ru!+sDAi38F?sbLoIIU#c&cGl0YxTny3kNXj|B*grFLTXZ0~dQ3h1 z3sfa?hH|vXj5O*=fCB}EelWitZ>o|rim{Nx6jIdfAk+%ubR)~4rs|fIVS~LtU^{I1m_#H)}ZQ!AG-AafXGbd%B&gyqwzFAd!6GUMFP#*q9=Qj2l3v` zLm7YXO(jCDm@$%bGK0oi(gry~TVa$)x-LJ$a)e3K&`IC$OFt4X@?cvSqwR~=c{43e zuV=@>6yWw_@UiFXyFF|Caa*78S(|)QT4>6wiT$;HrXku#1N!7Ni9XpskY`08%KyRTpFI!kb1$XF&s_(}J(wKFI-3!;mb}uMU-z;c3&6P8$ICauu`)_nYi5GzUCwXHHHt%FYiqhAnuVfF5`I?mKMvo;7W`DNw!_8~< z>&Z4GzhOx+1L=sELV(+$G2QF0e1Wtac}>O;JKy0dF===0k82eA-$F@?3WFm9+MZPI z#J+QI(NGF_K+W-hpDItsVCx723u9(^6x!}H)t7^y|zTsG8Y8}uj zZGzz7DkxqpNS~vyh78mIU{3ld#Tkdb4N6ZyaZH_se$}gl>=?{mBfO&koJGk{3-?43 zU5=(9dnmSu04hFp1o~x2K_z9xVL7b_`eBUxip-FS-^8bPCi*o1iOaWxBlszPS}K!I zApHNyNUD~l+$gE$ReHZ zhVz&DnvmuZG-FZ8;*gTwPnl)FYp$rk5LOXXBEvCZw~R`rzH?b$A8QDdKHBE8QdQF- ztUk>r|8SEru=(1bn|$c(#Sa1H-}m6SPE80KJztHGpJFZlu#1XG`;3pDIm@RIxdfu$ z$H6RMAV(-&DNe;+d;;s=z(cVtP=fXABk&oDBlxb~2pThJ`YiXACMqtv+tJkjmAx2z zjzbbH-&dK}dwc4BsWCtIc+RZoCk4gYgl@r57s1Kjxsqui#!I|L@iqfezQ`OL`YnUr zdyUbYfh;gppZbu}u?3>;?w4DBRE!<&1B?b5?mh<-PXHJHJNQ4kfS;b#-T!h|006aI z4-tvYF?+J8)XEFP$Qfdn{vB32*2CXjT#TKG10S|Jo4XtB7YE?Zv1Aawanqz}%;=bD zwP2x4*~Gwxz-n1{s;6i{TfJUxjku|Z1sww z%-cKw@Nd5>z`xumCbU@@txm88(NnR7tjg};&8#BO8}_Dk4SwKrq7#S0t`UEf6^>vn zdCE}VzC96Tg#kDbK5t32;%@NI;xe(eZ=giM&`v!ZAlQM&ca1Vd%A~m@(SGv_c}ZcQ z!3#J~kxxFd?VB28MP9MmpuSbILO^_;J5YHo6oZm{UnE!8MysJ5ka2l?kgKHmv;QTb#Y45rk zwwS}%(H<@Z;y8-z!47y88jz?3LbZj<3CNbfA?kqN6u3m0y*2sSObFrhJ5?EDgm+Re zedcbvEIsJ9P-j{##MxoGMx*xq=lmQunTI&0uD_g!^fp0N4=`T(Z-s$m5wa@K1S&Yc zBFRm^>#v{mA5YxwWfZgU2IV^B}0 zabfOLK#n}QzJHev3^bUVK8k9gIRC$PX9inW-~Oe!J8ra56ad*08rX_O`eH()8FGzI_Lzk6ol%jg7B2hf*a{ppiW^2--BImhEHLfE^hfy~1mw?(*SAK^yK7ZMoW@4OhvH4IIs(G%H#PeqDH%$mk2RDZiPs>x`FLtQSfZ1g@$m&Rg*> z2?d-k0YOr-d1VUu?HH{$+SCFSUHX7i+U;H}r`@wj;rb%T=*diAEwn!rMp;eR4p-DFxM2_L+N() zU|0akfV_u}#Lv{Q_x$ps#_JUNdb6!MIt$R%tY#1yt$*2E&2h+uDf-`pKo)J|Ezs$? zRmXcTACB6NnAtm69(6!(s13?9xII2X51Lg6lZIP^Kakrl`!+)%R*d*j&f~%$Ad;|y zUQItq7qXfh7B7z^?8+&dVnujVSDfN904^3-Dyd+X)#M(|hFBdUMHn@1hky+og_R7X zNUy1^=|2K$9zkTMHJs}h2i!M=KKtKYmFR-0e6s4wh}J-tB>Ubrx2aTfL+h_m$e$_S z$ETl;wH-w{14%aLxcjf%K1f_i@Y%`*2+IBNX-kWjD=(u&-T$beyk|x!zy z7}uL->040u{m9Ouf^tR3xf-DER%FS)er`KL(cvYJE^>51&Ka2Bf>K-`a>b9zwT4K3 z0KVGDkUhoe(0xO8kkL%j9UoXTXB4quJebm8Dj`DN@?!zGLm(*rjU(47B(wQhz>Lv9 zAS#K((k+|d5Fz}#FvNqgDFM~!zXhAlwCFiwX}MAnE&wI3FW2BaHSpCXhIj%uw1PUv pf0ln}Kz@NWydQ4cZChN!_|A#{oFJ}1@iKxOKVoq>{}0D&{|~@Tcgz3) literal 0 HcmV?d00001 From f421c6010bb39451c4c7d1ff67797379ff4df2e5 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Wed, 19 Mar 2025 13:41:01 -0700 Subject: [PATCH 05/18] Checkpointed Jira connector (#4286) * Checkpointed Jira connector * nit Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * typing improvements and test fixes * cleaner typing * remove default because it is from the future * mypy * Address EL comments --------- Co-authored-by: evan-danswer Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- backend/onyx/connectors/connector_runner.py | 28 +- .../onyx/connectors/google_drive/connector.py | 4 +- backend/onyx/connectors/interfaces.py | 24 +- .../connectors/mock_connector/connector.py | 6 +- .../onyx/connectors/onyx_jira/connector.py | 306 ++++++------ backend/onyx/connectors/slack/connector.py | 3 +- .../google_drive/consts_and_utils.py | 24 +- .../daily/connectors/jira/test_jira_basic.py | 17 +- backend/tests/daily/connectors/utils.py | 70 +++ .../jira/test_jira_checkpointing.py | 436 ++++++++++++++++++ .../jira/test_large_ticket_handling.py | 51 +- backend/tests/unit/onyx/connectors/utils.py | 55 +++ 12 files changed, 810 insertions(+), 214 deletions(-) create mode 100644 backend/tests/daily/connectors/utils.py create mode 100644 backend/tests/unit/onyx/connectors/jira/test_jira_checkpointing.py create mode 100644 backend/tests/unit/onyx/connectors/utils.py diff --git a/backend/onyx/connectors/connector_runner.py b/backend/onyx/connectors/connector_runner.py index 6cb3272b1..6acb88b10 100644 --- a/backend/onyx/connectors/connector_runner.py +++ b/backend/onyx/connectors/connector_runner.py @@ -2,6 +2,8 @@ import sys import time from collections.abc import Generator from datetime import datetime +from typing import Generic +from typing import TypeVar from onyx.connectors.interfaces import BaseConnector from onyx.connectors.interfaces import CheckpointConnector @@ -19,8 +21,10 @@ logger = setup_logger() TimeRange = tuple[datetime, datetime] +CT = TypeVar("CT", bound=ConnectorCheckpoint) -class CheckpointOutputWrapper: + +class CheckpointOutputWrapper(Generic[CT]): """ Wraps a CheckpointOutput generator to give things back in a more digestible format. The connector format is easier for the connector implementor (e.g. it enforces exactly @@ -29,20 +33,20 @@ class CheckpointOutputWrapper: """ def __init__(self) -> None: - self.next_checkpoint: ConnectorCheckpoint | None = None + self.next_checkpoint: CT | None = None def __call__( self, - checkpoint_connector_generator: CheckpointOutput, + checkpoint_connector_generator: CheckpointOutput[CT], ) -> Generator[ - tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None], + tuple[Document | None, ConnectorFailure | None, CT | None], None, None, ]: # grabs the final return value and stores it in the `next_checkpoint` variable def _inner_wrapper( - checkpoint_connector_generator: CheckpointOutput, - ) -> CheckpointOutput: + checkpoint_connector_generator: CheckpointOutput[CT], + ) -> CheckpointOutput[CT]: self.next_checkpoint = yield from checkpoint_connector_generator return self.next_checkpoint # not used @@ -64,7 +68,7 @@ class CheckpointOutputWrapper: yield None, None, self.next_checkpoint -class ConnectorRunner: +class ConnectorRunner(Generic[CT]): """ Handles: - Batching @@ -85,11 +89,9 @@ class ConnectorRunner: self.doc_batch: list[Document] = [] def run( - self, checkpoint: ConnectorCheckpoint + self, checkpoint: CT ) -> Generator[ - tuple[ - list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None - ], + tuple[list[Document] | None, ConnectorFailure | None, CT | None], None, None, ]: @@ -105,9 +107,9 @@ class ConnectorRunner: end=self.time_range[1].timestamp(), checkpoint=checkpoint, ) - next_checkpoint: ConnectorCheckpoint | None = None + next_checkpoint: CT | None = None # this is guaranteed to always run at least once with next_checkpoint being non-None - for document, failure, next_checkpoint in CheckpointOutputWrapper()( + for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()( checkpoint_connector_generator ): if document is not None: diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index dcc14df06..496d193ec 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -1,7 +1,6 @@ import copy import threading from collections.abc import Callable -from collections.abc import Generator from collections.abc import Iterator from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor @@ -51,6 +50,7 @@ from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTION from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE from onyx.connectors.google_utils.shared_constants import USER_FIELDS from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector @@ -1010,7 +1010,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: GoogleDriveCheckpoint, - ) -> Generator[Document | ConnectorFailure, None, GoogleDriveCheckpoint]: + ) -> CheckpointOutput[GoogleDriveCheckpoint]: """ Entrypoint for the connector; first run is with an empty checkpoint. """ diff --git a/backend/onyx/connectors/interfaces.py b/backend/onyx/connectors/interfaces.py index 4d8d591c2..ae2109829 100644 --- a/backend/onyx/connectors/interfaces.py +++ b/backend/onyx/connectors/interfaces.py @@ -22,8 +22,10 @@ SecondsSinceUnixEpoch = float GenerateDocumentsOutput = Iterator[list[Document]] GenerateSlimDocumentOutput = Iterator[list[SlimDocument]] +CT = TypeVar("CT", bound=ConnectorCheckpoint) -class BaseConnector(abc.ABC): + +class BaseConnector(abc.ABC, Generic[CT]): REDIS_KEY_PREFIX = "da_connector_data:" # Common image file extensions supported across connectors IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"} @@ -58,8 +60,9 @@ class BaseConnector(abc.ABC): Default is a no-op (always successful). """ - def build_dummy_checkpoint(self) -> ConnectorCheckpoint: - return ConnectorCheckpoint(has_more=True) + def build_dummy_checkpoint(self) -> CT: + # TODO: find a way to make this work without type: ignore + return ConnectorCheckpoint(has_more=True) # type: ignore # Large set update or reindex, generally pulling a complete state or from a savestate file @@ -192,21 +195,17 @@ class EventConnector(BaseConnector): raise NotImplementedError -CT = TypeVar("CT", bound=ConnectorCheckpoint) -# TODO: find a reasonable way to parameterize the return type of the generator -CheckpointOutput: TypeAlias = Generator[ - Document | ConnectorFailure, None, ConnectorCheckpoint -] +CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT] -class CheckpointConnector(BaseConnector, Generic[CT]): +class CheckpointConnector(BaseConnector[CT]): @abc.abstractmethod def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: CT, - ) -> Generator[Document | ConnectorFailure, None, CT]: + ) -> CheckpointOutput[CT]: """Yields back documents or failures. Final return is the new checkpoint. Final return can be access via either: @@ -228,11 +227,8 @@ class CheckpointConnector(BaseConnector, Generic[CT]): """ raise NotImplementedError - # Ideally return type should be CT, but that's not possible if - # we want to override build_dummy_checkpoint and have BaseConnector - # return a base ConnectorCheckpoint @override - def build_dummy_checkpoint(self) -> ConnectorCheckpoint: + def build_dummy_checkpoint(self) -> CT: raise NotImplementedError @abc.abstractmethod diff --git a/backend/onyx/connectors/mock_connector/connector.py b/backend/onyx/connectors/mock_connector/connector.py index 009fee2c8..efc0d8477 100644 --- a/backend/onyx/connectors/mock_connector/connector.py +++ b/backend/onyx/connectors/mock_connector/connector.py @@ -1,4 +1,3 @@ -from collections.abc import Generator from typing import Any import httpx @@ -6,6 +5,7 @@ from pydantic import BaseModel from typing_extensions import override from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure @@ -65,7 +65,7 @@ class MockConnector(CheckpointConnector[MockConnectorCheckpoint]): start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: MockConnectorCheckpoint, - ) -> Generator[Document | ConnectorFailure, None, MockConnectorCheckpoint]: + ) -> CheckpointOutput[MockConnectorCheckpoint]: if self.connector_yields is None: raise ValueError("No connector yields configured") @@ -91,7 +91,7 @@ class MockConnector(CheckpointConnector[MockConnectorCheckpoint]): return current_yield.checkpoint @override - def build_dummy_checkpoint(self) -> ConnectorCheckpoint: + def build_dummy_checkpoint(self) -> MockConnectorCheckpoint: return MockConnectorCheckpoint( has_more=True, last_document_id=None, diff --git a/backend/onyx/connectors/onyx_jira/connector.py b/backend/onyx/connectors/onyx_jira/connector.py index 30caf3ea5..f085c5362 100644 --- a/backend/onyx/connectors/onyx_jira/connector.py +++ b/backend/onyx/connectors/onyx_jira/connector.py @@ -6,6 +6,7 @@ from typing import Any from jira import JIRA from jira.resources import Issue +from typing_extensions import override from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP @@ -15,14 +16,16 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError -from onyx.connectors.interfaces import GenerateDocumentsOutput +from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput -from onyx.connectors.interfaces import LoadConnector -from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document +from onyx.connectors.models import DocumentFailure from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info @@ -42,121 +45,112 @@ _JIRA_SLIM_PAGE_SIZE = 500 _JIRA_FULL_PAGE_SIZE = 50 -def _paginate_jql_search( +def _perform_jql_search( jira_client: JIRA, jql: str, + start: int, max_results: int, fields: str | None = None, ) -> Iterable[Issue]: - start = 0 - while True: - logger.debug( - f"Fetching Jira issues with JQL: {jql}, " - f"starting at {start}, max results: {max_results}" - ) - issues = jira_client.search_issues( - jql_str=jql, - startAt=start, - maxResults=max_results, - fields=fields, - ) + logger.debug( + f"Fetching Jira issues with JQL: {jql}, " + f"starting at {start}, max results: {max_results}" + ) + issues = jira_client.search_issues( + jql_str=jql, + startAt=start, + maxResults=max_results, + fields=fields, + ) - for issue in issues: - if isinstance(issue, Issue): - yield issue - else: - raise Exception(f"Found Jira object not of type Issue: {issue}") - - if len(issues) < max_results: - break - - start += max_results + for issue in issues: + if isinstance(issue, Issue): + yield issue + else: + raise RuntimeError(f"Found Jira object not of type Issue: {issue}") -def fetch_jira_issues_batch( +def process_jira_issue( jira_client: JIRA, - jql: str, - batch_size: int, + issue: Issue, comment_email_blacklist: tuple[str, ...] = (), labels_to_skip: set[str] | None = None, -) -> Iterable[Document]: - for issue in _paginate_jql_search( - jira_client=jira_client, - jql=jql, - max_results=batch_size, - ): - if labels_to_skip: - if any(label in issue.fields.labels for label in labels_to_skip): - logger.info( - f"Skipping {issue.key} because it has a label to skip. Found " - f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}." - ) - continue - - description = ( - issue.fields.description - if JIRA_API_VERSION == "2" - else extract_text_from_adf(issue.raw["fields"]["description"]) - ) - comments = get_comment_strs( - issue=issue, - comment_email_blacklist=comment_email_blacklist, - ) - ticket_content = f"{description}\n" + "\n".join( - [f"Comment: {comment}" for comment in comments if comment] - ) - - # Check ticket size - if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE: +) -> Document | None: + if labels_to_skip: + if any(label in issue.fields.labels for label in labels_to_skip): logger.info( - f"Skipping {issue.key} because it exceeds the maximum size of " - f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes." + f"Skipping {issue.key} because it has a label to skip. Found " + f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}." ) - continue + return None - page_url = f"{jira_client.client_info()}/browse/{issue.key}" + description = ( + issue.fields.description + if JIRA_API_VERSION == "2" + else extract_text_from_adf(issue.raw["fields"]["description"]) + ) + comments = get_comment_strs( + issue=issue, + comment_email_blacklist=comment_email_blacklist, + ) + ticket_content = f"{description}\n" + "\n".join( + [f"Comment: {comment}" for comment in comments if comment] + ) - people = set() - try: - creator = best_effort_get_field_from_issue(issue, "creator") - if basic_expert_info := best_effort_basic_expert_info(creator): - people.add(basic_expert_info) - except Exception: - # Author should exist but if not, doesn't matter - pass - - try: - assignee = best_effort_get_field_from_issue(issue, "assignee") - if basic_expert_info := best_effort_basic_expert_info(assignee): - people.add(basic_expert_info) - except Exception: - # Author should exist but if not, doesn't matter - pass - - metadata_dict = {} - if priority := best_effort_get_field_from_issue(issue, "priority"): - metadata_dict["priority"] = priority.name - if status := best_effort_get_field_from_issue(issue, "status"): - metadata_dict["status"] = status.name - if resolution := best_effort_get_field_from_issue(issue, "resolution"): - metadata_dict["resolution"] = resolution.name - if labels := best_effort_get_field_from_issue(issue, "labels"): - metadata_dict["label"] = labels - - yield Document( - id=page_url, - sections=[TextSection(link=page_url, text=ticket_content)], - source=DocumentSource.JIRA, - semantic_identifier=f"{issue.key}: {issue.fields.summary}", - title=f"{issue.key} {issue.fields.summary}", - doc_updated_at=time_str_to_utc(issue.fields.updated), - primary_owners=list(people) or None, - # TODO add secondary_owners (commenters) if needed - metadata=metadata_dict, + # Check ticket size + if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE: + logger.info( + f"Skipping {issue.key} because it exceeds the maximum size of " + f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes." ) + return None + + page_url = build_jira_url(jira_client, issue.key) + + people = set() + try: + creator = best_effort_get_field_from_issue(issue, "creator") + if basic_expert_info := best_effort_basic_expert_info(creator): + people.add(basic_expert_info) + except Exception: + # Author should exist but if not, doesn't matter + pass + + try: + assignee = best_effort_get_field_from_issue(issue, "assignee") + if basic_expert_info := best_effort_basic_expert_info(assignee): + people.add(basic_expert_info) + except Exception: + # Author should exist but if not, doesn't matter + pass + + metadata_dict = {} + if priority := best_effort_get_field_from_issue(issue, "priority"): + metadata_dict["priority"] = priority.name + if status := best_effort_get_field_from_issue(issue, "status"): + metadata_dict["status"] = status.name + if resolution := best_effort_get_field_from_issue(issue, "resolution"): + metadata_dict["resolution"] = resolution.name + if labels := best_effort_get_field_from_issue(issue, "labels"): + metadata_dict["labels"] = labels + + return Document( + id=page_url, + sections=[TextSection(link=page_url, text=ticket_content)], + source=DocumentSource.JIRA, + semantic_identifier=f"{issue.key}: {issue.fields.summary}", + title=f"{issue.key} {issue.fields.summary}", + doc_updated_at=time_str_to_utc(issue.fields.updated), + primary_owners=list(people) or None, + metadata=metadata_dict, + ) -class JiraConnector(LoadConnector, PollConnector, SlimConnector): +class JiraConnectorCheckpoint(ConnectorCheckpoint): + offset: int | None = None + + +class JiraConnector(CheckpointConnector[JiraConnectorCheckpoint], SlimConnector): def __init__( self, jira_base_url: str, @@ -200,33 +194,10 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector): ) return None - def _get_jql_query(self) -> str: - """Get the JQL query based on whether a specific project is set""" - if self.jira_project: - return f"project = {self.quoted_jira_project}" - return "" # Empty string means all accessible projects - - def load_from_state(self) -> GenerateDocumentsOutput: - jql = self._get_jql_query() - - document_batch = [] - for doc in fetch_jira_issues_batch( - jira_client=self.jira_client, - jql=jql, - batch_size=_JIRA_FULL_PAGE_SIZE, - comment_email_blacklist=self.comment_email_blacklist, - labels_to_skip=self.labels_to_skip, - ): - document_batch.append(doc) - if len(document_batch) >= self.batch_size: - yield document_batch - document_batch = [] - - yield document_batch - - def poll_source( + def _get_jql_query( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch - ) -> GenerateDocumentsOutput: + ) -> str: + """Get the JQL query based on whether a specific project is set and time range""" start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime( "%Y-%m-%d %H:%M" ) @@ -234,25 +205,61 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector): "%Y-%m-%d %H:%M" ) - base_jql = self._get_jql_query() - jql = ( - f"{base_jql} AND " if base_jql else "" - ) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'" + time_jql = f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'" - document_batch = [] - for doc in fetch_jira_issues_batch( + if self.jira_project: + base_jql = f"project = {self.quoted_jira_project}" + return f"{base_jql} AND {time_jql}" + + return time_jql + + def load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: JiraConnectorCheckpoint, + ) -> CheckpointOutput[JiraConnectorCheckpoint]: + jql = self._get_jql_query(start, end) + + # Get the current offset from checkpoint or start at 0 + starting_offset = checkpoint.offset or 0 + current_offset = starting_offset + + for issue in _perform_jql_search( jira_client=self.jira_client, jql=jql, - batch_size=_JIRA_FULL_PAGE_SIZE, - comment_email_blacklist=self.comment_email_blacklist, - labels_to_skip=self.labels_to_skip, + start=current_offset, + max_results=_JIRA_FULL_PAGE_SIZE, ): - document_batch.append(doc) - if len(document_batch) >= self.batch_size: - yield document_batch - document_batch = [] + issue_key = issue.key + try: + if document := process_jira_issue( + jira_client=self.jira_client, + issue=issue, + comment_email_blacklist=self.comment_email_blacklist, + labels_to_skip=self.labels_to_skip, + ): + yield document - yield document_batch + except Exception as e: + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=issue_key, + document_link=build_jira_url(self.jira_client, issue_key), + ), + failure_message=f"Failed to process Jira issue: {str(e)}", + exception=e, + ) + + current_offset += 1 + + # Update checkpoint + checkpoint = JiraConnectorCheckpoint( + offset=current_offset, + # if we didn't retrieve a full batch, we're done + has_more=current_offset - starting_offset == _JIRA_FULL_PAGE_SIZE, + ) + return checkpoint def retrieve_all_slim_documents( self, @@ -260,12 +267,13 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector): end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: - jql = self._get_jql_query() + jql = self._get_jql_query(start or 0, end or float("inf")) slim_doc_batch = [] - for issue in _paginate_jql_search( + for issue in _perform_jql_search( jira_client=self.jira_client, jql=jql, + start=0, max_results=_JIRA_SLIM_PAGE_SIZE, fields="key", ): @@ -334,6 +342,16 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector): raise RuntimeError(f"Unexpected Jira error during validation: {e}") + @override + def validate_checkpoint_json(self, checkpoint_json: str) -> JiraConnectorCheckpoint: + return JiraConnectorCheckpoint.model_validate_json(checkpoint_json) + + @override + def build_dummy_checkpoint(self) -> JiraConnectorCheckpoint: + return JiraConnectorCheckpoint( + has_more=True, + ) + if __name__ == "__main__": import os @@ -350,5 +368,7 @@ if __name__ == "__main__": "jira_api_token": os.environ["JIRA_API_TOKEN"], } ) - document_batches = connector.load_from_state() + document_batches = connector.load_from_checkpoint( + 0, float("inf"), JiraConnectorCheckpoint(has_more=True) + ) print(next(document_batches)) diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index e8af11721..d4890e2e3 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -24,6 +24,7 @@ from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector @@ -536,7 +537,7 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]): start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: SlackCheckpoint, - ) -> Generator[Document | ConnectorFailure, None, SlackCheckpoint]: + ) -> CheckpointOutput[SlackCheckpoint]: """Rough outline: Step 1: Get all channels, yield back Checkpoint. diff --git a/backend/tests/daily/connectors/google_drive/consts_and_utils.py b/backend/tests/daily/connectors/google_drive/consts_and_utils.py index c6dad3d9f..570aaad14 100644 --- a/backend/tests/daily/connectors/google_drive/consts_and_utils.py +++ b/backend/tests/daily/connectors/google_drive/consts_and_utils.py @@ -1,11 +1,10 @@ import time from collections.abc import Sequence -from onyx.connectors.connector_runner import CheckpointOutputWrapper -from onyx.connectors.google_drive.connector import GoogleDriveCheckpoint from onyx.connectors.google_drive.connector import GoogleDriveConnector from onyx.connectors.models import Document from onyx.connectors.models import TextSection +from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector ALL_FILES = list(range(0, 60)) SHARED_DRIVE_FILES = list(range(20, 25)) @@ -216,19 +215,8 @@ def assert_retrieved_docs_match_expected( def load_all_docs(connector: GoogleDriveConnector) -> list[Document]: - retrieved_docs: list[Document] = [] - checkpoint = connector.build_dummy_checkpoint() - while checkpoint.has_more: - for doc, failure, next_checkpoint in CheckpointOutputWrapper()( - connector.load_from_checkpoint(0, time.time(), checkpoint) - ): - assert failure is None - if next_checkpoint is None: - assert isinstance( - doc, Document - ), f"Should not fail with {type(doc)} {doc}" - retrieved_docs.append(doc) - else: - assert isinstance(next_checkpoint, GoogleDriveCheckpoint) - checkpoint = next_checkpoint - return retrieved_docs + return load_all_docs_from_checkpoint_connector( + connector, + 0, + time.time(), + ) diff --git a/backend/tests/daily/connectors/jira/test_jira_basic.py b/backend/tests/daily/connectors/jira/test_jira_basic.py index cf7d14fbd..885d4f2ca 100644 --- a/backend/tests/daily/connectors/jira/test_jira_basic.py +++ b/backend/tests/daily/connectors/jira/test_jira_basic.py @@ -5,6 +5,7 @@ import pytest from onyx.configs.constants import DocumentSource from onyx.connectors.onyx_jira.connector import JiraConnector +from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector @pytest.fixture @@ -24,15 +25,13 @@ def jira_connector() -> JiraConnector: def test_jira_connector_basic(jira_connector: JiraConnector) -> None: - doc_batch_generator = jira_connector.poll_source(0, time.time()) - - doc_batch = next(doc_batch_generator) - with pytest.raises(StopIteration): - next(doc_batch_generator) - - assert len(doc_batch) == 1 - - doc = doc_batch[0] + docs = load_all_docs_from_checkpoint_connector( + connector=jira_connector, + start=0, + end=time.time(), + ) + assert len(docs) == 1 + doc = docs[0] assert doc.id == "https://danswerai.atlassian.net/browse/AS-2" assert doc.semantic_identifier == "AS-2: test123small" diff --git a/backend/tests/daily/connectors/utils.py b/backend/tests/daily/connectors/utils.py new file mode 100644 index 000000000..00f64fa0f --- /dev/null +++ b/backend/tests/daily/connectors/utils.py @@ -0,0 +1,70 @@ +from typing import cast +from typing import TypeVar + +from onyx.connectors.connector_runner import CheckpointOutputWrapper +from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import SecondsSinceUnixEpoch +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import Document + +_ITERATION_LIMIT = 100_000 + +CT = TypeVar("CT", bound=ConnectorCheckpoint) + + +def load_all_docs_from_checkpoint_connector( + connector: CheckpointConnector[CT], + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, +) -> list[Document]: + num_iterations = 0 + + checkpoint = cast(CT, connector.build_dummy_checkpoint()) + documents: list[Document] = [] + while checkpoint.has_more: + doc_batch_generator = CheckpointOutputWrapper[CT]()( + connector.load_from_checkpoint(start, end, checkpoint) + ) + for document, failure, next_checkpoint in doc_batch_generator: + if failure is not None: + raise RuntimeError(f"Failed to load documents: {failure}") + if document is not None: + documents.append(document) + if next_checkpoint is not None: + checkpoint = next_checkpoint + + num_iterations += 1 + if num_iterations > _ITERATION_LIMIT: + raise RuntimeError("Too many iterations. Infinite loop?") + + return documents + + +def load_everything_from_checkpoint_connector( + connector: CheckpointConnector[CT], + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, +) -> list[Document | ConnectorFailure]: + """Like load_all_docs_from_checkpoint_connector but returns both documents and failures""" + num_iterations = 0 + + checkpoint = connector.build_dummy_checkpoint() + outputs: list[Document | ConnectorFailure] = [] + while checkpoint.has_more: + doc_batch_generator = CheckpointOutputWrapper[CT]()( + connector.load_from_checkpoint(start, end, checkpoint) + ) + for document, failure, next_checkpoint in doc_batch_generator: + if failure is not None: + outputs.append(failure) + if document is not None: + outputs.append(document) + if next_checkpoint is not None: + checkpoint = next_checkpoint + + num_iterations += 1 + if num_iterations > _ITERATION_LIMIT: + raise RuntimeError("Too many iterations. Infinite loop?") + + return outputs diff --git a/backend/tests/unit/onyx/connectors/jira/test_jira_checkpointing.py b/backend/tests/unit/onyx/connectors/jira/test_jira_checkpointing.py new file mode 100644 index 000000000..001d5ad66 --- /dev/null +++ b/backend/tests/unit/onyx/connectors/jira/test_jira_checkpointing.py @@ -0,0 +1,436 @@ +import time +from collections.abc import Callable +from collections.abc import Generator +from datetime import datetime +from datetime import timezone +from typing import Any +from typing import cast +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from jira import JIRA +from jira import JIRAError +from jira.resources import Issue + +from onyx.configs.constants import DocumentSource +from onyx.connectors.exceptions import ConnectorValidationError +from onyx.connectors.exceptions import CredentialExpiredError +from onyx.connectors.exceptions import InsufficientPermissionsError +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import Document +from onyx.connectors.models import SlimDocument +from onyx.connectors.onyx_jira.connector import JiraConnector +from onyx.connectors.onyx_jira.connector import JiraConnectorCheckpoint +from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector + +PAGE_SIZE = 2 + + +@pytest.fixture +def jira_base_url() -> str: + return "https://jira.example.com" + + +@pytest.fixture +def project_key() -> str: + return "TEST" + + +@pytest.fixture +def mock_jira_client() -> MagicMock: + """Create a mock JIRA client with proper typing""" + mock = MagicMock(spec=JIRA) + # Add proper return typing for search_issues method + mock.search_issues = MagicMock() + # Add proper return typing for project method + mock.project = MagicMock() + # Add proper return typing for projects method + mock.projects = MagicMock() + return mock + + +@pytest.fixture +def jira_connector( + jira_base_url: str, project_key: str, mock_jira_client: MagicMock +) -> Generator[JiraConnector, None, None]: + connector = JiraConnector( + jira_base_url=jira_base_url, + project_key=project_key, + comment_email_blacklist=["blacklist@example.com"], + labels_to_skip=["secret", "sensitive"], + ) + connector._jira_client = mock_jira_client + connector._jira_client.client_info.return_value = jira_base_url + with patch("onyx.connectors.onyx_jira.connector._JIRA_FULL_PAGE_SIZE", 2): + yield connector + + +@pytest.fixture +def create_mock_issue() -> Callable[..., MagicMock]: + def _create_mock_issue( + key: str = "TEST-123", + summary: str = "Test Issue", + updated: str = "2023-01-01T12:00:00.000+0000", + description: str = "Test Description", + labels: list[str] | None = None, + ) -> MagicMock: + """Helper to create a mock Issue object""" + mock_issue = MagicMock(spec=Issue) + # Create fields attribute first + mock_issue.fields = MagicMock() + mock_issue.key = key + mock_issue.fields.summary = summary + mock_issue.fields.updated = updated + mock_issue.fields.description = description + mock_issue.fields.labels = labels or [] + + # Set up creator and assignee for testing owner extraction + mock_issue.fields.creator = MagicMock() + mock_issue.fields.creator.displayName = "Test Creator" + mock_issue.fields.creator.emailAddress = "creator@example.com" + + mock_issue.fields.assignee = MagicMock() + mock_issue.fields.assignee.displayName = "Test Assignee" + mock_issue.fields.assignee.emailAddress = "assignee@example.com" + + # Set up priority, status, and resolution + mock_issue.fields.priority = MagicMock() + mock_issue.fields.priority.name = "High" + + mock_issue.fields.status = MagicMock() + mock_issue.fields.status.name = "In Progress" + + mock_issue.fields.resolution = MagicMock() + mock_issue.fields.resolution.name = "Fixed" + + # Add raw field for accessing through API version check + mock_issue.raw = {"fields": {"description": description}} + + return mock_issue + + return _create_mock_issue + + +def test_load_credentials(jira_connector: JiraConnector) -> None: + """Test loading credentials""" + with patch( + "onyx.connectors.onyx_jira.connector.build_jira_client" + ) as mock_build_client: + mock_build_client.return_value = jira_connector._jira_client + credentials = { + "jira_user_email": "user@example.com", + "jira_api_token": "token123", + } + + result = jira_connector.load_credentials(credentials) + + mock_build_client.assert_called_once_with( + credentials=credentials, jira_base=jira_connector.jira_base + ) + assert result is None + assert jira_connector._jira_client == mock_build_client.return_value + + +def test_get_jql_query_with_project(jira_connector: JiraConnector) -> None: + """Test JQL query generation with project specified""" + start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp() + end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp() + + query = jira_connector._get_jql_query(start, end) + + # Check that the project part and time part are both in the query + assert f'project = "{jira_connector.jira_project}"' in query + assert "updated >= '2023-01-01 00:00'" in query + assert "updated <= '2023-01-02 00:00'" in query + assert " AND " in query + + +def test_get_jql_query_without_project(jira_base_url: str) -> None: + """Test JQL query generation without project specified""" + # Create connector without project key + connector = JiraConnector(jira_base_url=jira_base_url) + + start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp() + end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp() + + query = connector._get_jql_query(start, end) + + # Check that only time part is in the query + assert "project =" not in query + assert "updated >= '2023-01-01 00:00'" in query + assert "updated <= '2023-01-02 00:00'" in query + + +def test_load_from_checkpoint_happy_path( + jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock] +) -> None: + """Test loading from checkpoint - happy path""" + # Set up mocked issues + mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1") + mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2") + mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3") + + # Only mock the search_issues method + jira_client = cast(JIRA, jira_connector._jira_client) + search_issues_mock = cast(MagicMock, jira_client.search_issues) + search_issues_mock.side_effect = [ + [mock_issue1, mock_issue2], + [mock_issue3], + [], + ] + + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time) + + # Check that the documents were returned + assert len(outputs) == 2 + + checkpoint_output1 = outputs[0] + assert len(checkpoint_output1.items) == 2 + document1 = checkpoint_output1.items[0] + assert isinstance(document1, Document) + assert document1.id == "https://jira.example.com/browse/TEST-1" + document2 = checkpoint_output1.items[1] + assert isinstance(document2, Document) + assert document2.id == "https://jira.example.com/browse/TEST-2" + assert checkpoint_output1.next_checkpoint == JiraConnectorCheckpoint( + offset=2, + has_more=True, + ) + + checkpoint_output2 = outputs[1] + assert len(checkpoint_output2.items) == 1 + document3 = checkpoint_output2.items[0] + assert isinstance(document3, Document) + assert document3.id == "https://jira.example.com/browse/TEST-3" + assert checkpoint_output2.next_checkpoint == JiraConnectorCheckpoint( + offset=3, + has_more=False, + ) + + # Check that search_issues was called with the right parameters + assert search_issues_mock.call_count == 2 + args, kwargs = search_issues_mock.call_args_list[0] + assert kwargs["startAt"] == 0 + assert kwargs["maxResults"] == PAGE_SIZE + + args, kwargs = search_issues_mock.call_args_list[1] + assert kwargs["startAt"] == 2 + assert kwargs["maxResults"] == PAGE_SIZE + + +def test_load_from_checkpoint_with_issue_processing_error( + jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock] +) -> None: + """Test loading from checkpoint with a mix of successful and failed issue processing across multiple batches""" + # Set up mocked issues for first batch + mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1") + mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2") + # Set up mocked issues for second batch + mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3") + mock_issue4 = create_mock_issue(key="TEST-4", summary="Issue 4") + + # Mock search_issues to return our mock issues in batches + jira_client = cast(JIRA, jira_connector._jira_client) + search_issues_mock = cast(MagicMock, jira_client.search_issues) + search_issues_mock.side_effect = [ + [mock_issue1, mock_issue2], # First batch + [mock_issue3, mock_issue4], # Second batch + [], # Empty batch to indicate end + ] + + # Mock process_jira_issue to succeed for some issues and fail for others + def mock_process_side_effect( + jira_client: JIRA, issue: Issue, *args: Any, **kwargs: Any + ) -> Document | None: + if issue.key in ["TEST-1", "TEST-3"]: + return Document( + id=f"https://jira.example.com/browse/{issue.key}", + sections=[], + source=DocumentSource.JIRA, + semantic_identifier=f"{issue.key}: {issue.fields.summary}", + title=f"{issue.key} {issue.fields.summary}", + metadata={}, + ) + else: + raise Exception(f"Processing error for {issue.key}") + + with patch( + "onyx.connectors.onyx_jira.connector.process_jira_issue" + ) as mock_process: + mock_process.side_effect = mock_process_side_effect + + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time) + + assert len(outputs) == 3 + + # Check first batch + first_batch = outputs[0] + assert len(first_batch.items) == 2 + # First item should be successful + assert isinstance(first_batch.items[0], Document) + assert first_batch.items[0].id == "https://jira.example.com/browse/TEST-1" + # Second item should be a failure + assert isinstance(first_batch.items[1], ConnectorFailure) + assert first_batch.items[1].failed_document is not None + assert first_batch.items[1].failed_document.document_id == "TEST-2" + assert "Failed to process Jira issue" in first_batch.items[1].failure_message + # Check checkpoint indicates more items (full batch) + assert first_batch.next_checkpoint.has_more is True + assert first_batch.next_checkpoint.offset == 2 + + # Check second batch + second_batch = outputs[1] + assert len(second_batch.items) == 2 + # First item should be successful + assert isinstance(second_batch.items[0], Document) + assert second_batch.items[0].id == "https://jira.example.com/browse/TEST-3" + # Second item should be a failure + assert isinstance(second_batch.items[1], ConnectorFailure) + assert second_batch.items[1].failed_document is not None + assert second_batch.items[1].failed_document.document_id == "TEST-4" + assert "Failed to process Jira issue" in second_batch.items[1].failure_message + # Check checkpoint indicates more items + assert second_batch.next_checkpoint.has_more is True + assert second_batch.next_checkpoint.offset == 4 + + # Check third, empty batch + third_batch = outputs[2] + assert len(third_batch.items) == 0 + assert third_batch.next_checkpoint.has_more is False + assert third_batch.next_checkpoint.offset == 4 + + +def test_load_from_checkpoint_with_skipped_issue( + jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock] +) -> None: + """Test loading from checkpoint with an issue that should be skipped due to labels""" + LABEL_TO_SKIP = "secret" + jira_connector.labels_to_skip = {LABEL_TO_SKIP} + + # Set up mocked issue with a label to skip + mock_issue = create_mock_issue( + key="TEST-1", summary="Issue 1", labels=[LABEL_TO_SKIP] + ) + + # Mock search_issues to return our mock issue + jira_client = cast(JIRA, jira_connector._jira_client) + search_issues_mock = cast(MagicMock, jira_client.search_issues) + search_issues_mock.return_value = [mock_issue] + + # Call load_from_checkpoint + end_time = time.time() + outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time) + + assert len(outputs) == 1 + checkpoint_output = outputs[0] + # Check that no documents were returned + assert len(checkpoint_output.items) == 0 + + +def test_retrieve_all_slim_documents( + jira_connector: JiraConnector, create_mock_issue: Any +) -> None: + """Test retrieving all slim documents""" + # Set up mocked issues + mock_issue1 = create_mock_issue(key="TEST-1") + mock_issue2 = create_mock_issue(key="TEST-2") + + # Mock search_issues to return our mock issues + jira_client = cast(JIRA, jira_connector._jira_client) + search_issues_mock = cast(MagicMock, jira_client.search_issues) + search_issues_mock.return_value = [mock_issue1, mock_issue2] + + # Mock best_effort_get_field_from_issue to return the keys + with patch( + "onyx.connectors.onyx_jira.connector.best_effort_get_field_from_issue" + ) as mock_field: + mock_field.side_effect = ["TEST-1", "TEST-2"] + + # Mock build_jira_url to return URLs + with patch("onyx.connectors.onyx_jira.connector.build_jira_url") as mock_url: + mock_url.side_effect = [ + "https://jira.example.com/browse/TEST-1", + "https://jira.example.com/browse/TEST-2", + ] + + # Call retrieve_all_slim_documents + batches = list(jira_connector.retrieve_all_slim_documents(0, 100)) + + # Check that a batch with 2 documents was returned + assert len(batches) == 1 + assert len(batches[0]) == 2 + assert isinstance(batches[0][0], SlimDocument) + assert batches[0][0].id == "https://jira.example.com/browse/TEST-1" + assert batches[0][1].id == "https://jira.example.com/browse/TEST-2" + + # Check that search_issues was called with the right parameters + search_issues_mock.assert_called_once() + args, kwargs = search_issues_mock.call_args + assert kwargs["fields"] == "key" + + +@pytest.mark.parametrize( + "status_code,expected_exception,expected_message", + [ + ( + 401, + CredentialExpiredError, + "Jira credential appears to be expired or invalid", + ), + ( + 403, + InsufficientPermissionsError, + "Your Jira token does not have sufficient permissions", + ), + (404, ConnectorValidationError, "Jira project not found"), + ( + 429, + ConnectorValidationError, + "Validation failed due to Jira rate-limits being exceeded", + ), + ], +) +def test_validate_connector_settings_errors( + jira_connector: JiraConnector, + status_code: int, + expected_exception: type[Exception], + expected_message: str, +) -> None: + """Test validation with various error scenarios""" + error = JIRAError(status_code=status_code) + + jira_client = cast(JIRA, jira_connector._jira_client) + project_mock = cast(MagicMock, jira_client.project) + project_mock.side_effect = error + + with pytest.raises(expected_exception) as excinfo: + jira_connector.validate_connector_settings() + assert expected_message in str(excinfo.value) + + +def test_validate_connector_settings_with_project_success( + jira_connector: JiraConnector, +) -> None: + """Test successful validation with project specified""" + jira_client = cast(JIRA, jira_connector._jira_client) + project_mock = cast(MagicMock, jira_client.project) + project_mock.return_value = MagicMock() + jira_connector.validate_connector_settings() + project_mock.assert_called_once_with(jira_connector.jira_project) + + +def test_validate_connector_settings_without_project_success( + jira_base_url: str, +) -> None: + """Test successful validation without project specified""" + connector = JiraConnector(jira_base_url=jira_base_url) + connector._jira_client = MagicMock() + connector._jira_client.projects.return_value = [MagicMock()] + + connector.validate_connector_settings() + connector._jira_client.projects.assert_called_once() diff --git a/backend/tests/unit/onyx/connectors/jira/test_large_ticket_handling.py b/backend/tests/unit/onyx/connectors/jira/test_large_ticket_handling.py index c8ae925ce..badb72aef 100644 --- a/backend/tests/unit/onyx/connectors/jira/test_large_ticket_handling.py +++ b/backend/tests/unit/onyx/connectors/jira/test_large_ticket_handling.py @@ -7,7 +7,8 @@ import pytest from jira.resources import Issue from pytest_mock import MockFixture -from onyx.connectors.onyx_jira.connector import fetch_jira_issues_batch +from onyx.connectors.onyx_jira.connector import _perform_jql_search +from onyx.connectors.onyx_jira.connector import process_jira_issue @pytest.fixture @@ -79,14 +80,22 @@ def test_fetch_jira_issues_batch_small_ticket( ) -> None: mock_jira_client.search_issues.return_value = [mock_issue_small] - docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50)) + # First get the issues via pagination + issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50)) + assert len(issues) == 1 + + # Then process each issue + docs = [process_jira_issue(mock_jira_client, issue) for issue in issues] + docs = [doc for doc in docs if doc is not None] # Filter out None values assert len(docs) == 1 - assert docs[0].id.endswith("/SMALL-1") - assert docs[0].sections[0].text is not None - assert "Small description" in docs[0].sections[0].text - assert "Small comment 1" in docs[0].sections[0].text - assert "Small comment 2" in docs[0].sections[0].text + doc = docs[0] + assert doc is not None # Type assertion for mypy + assert doc.id.endswith("/SMALL-1") + assert doc.sections[0].text is not None + assert "Small description" in doc.sections[0].text + assert "Small comment 1" in doc.sections[0].text + assert "Small comment 2" in doc.sections[0].text def test_fetch_jira_issues_batch_large_ticket( @@ -96,7 +105,13 @@ def test_fetch_jira_issues_batch_large_ticket( ) -> None: mock_jira_client.search_issues.return_value = [mock_issue_large] - docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50)) + # First get the issues via pagination + issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50)) + assert len(issues) == 1 + + # Then process each issue + docs = [process_jira_issue(mock_jira_client, issue) for issue in issues] + docs = [doc for doc in docs if doc is not None] # Filter out None values assert len(docs) == 0 # The large ticket should be skipped @@ -109,10 +124,18 @@ def test_fetch_jira_issues_batch_mixed_tickets( ) -> None: mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large] - docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50)) + # First get the issues via pagination + issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50)) + assert len(issues) == 2 + + # Then process each issue + docs = [process_jira_issue(mock_jira_client, issue) for issue in issues] + docs = [doc for doc in docs if doc is not None] # Filter out None values assert len(docs) == 1 # Only the small ticket should be included - assert docs[0].id.endswith("/SMALL-1") + doc = docs[0] + assert doc is not None # Type assertion for mypy + assert doc.id.endswith("/SMALL-1") @patch("onyx.connectors.onyx_jira.connector.JIRA_CONNECTOR_MAX_TICKET_SIZE", 50) @@ -124,6 +147,12 @@ def test_fetch_jira_issues_batch_custom_size_limit( ) -> None: mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large] - docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50)) + # First get the issues via pagination + issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50)) + assert len(issues) == 2 + + # Then process each issue + docs = [process_jira_issue(mock_jira_client, issue) for issue in issues] + docs = [doc for doc in docs if doc is not None] # Filter out None values assert len(docs) == 0 # Both tickets should be skipped due to the low size limit diff --git a/backend/tests/unit/onyx/connectors/utils.py b/backend/tests/unit/onyx/connectors/utils.py new file mode 100644 index 000000000..023a347a4 --- /dev/null +++ b/backend/tests/unit/onyx/connectors/utils.py @@ -0,0 +1,55 @@ +from typing import cast +from typing import Generic +from typing import TypeVar + +from pydantic import BaseModel + +from onyx.connectors.connector_runner import CheckpointOutputWrapper +from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import SecondsSinceUnixEpoch +from onyx.connectors.models import ConnectorCheckpoint +from onyx.connectors.models import ConnectorFailure +from onyx.connectors.models import Document + +_ITERATION_LIMIT = 100_000 + + +CT = TypeVar("CT", bound=ConnectorCheckpoint) + + +class SingleConnectorCallOutput(BaseModel, Generic[CT]): + items: list[Document | ConnectorFailure] + next_checkpoint: CT + + +def load_everything_from_checkpoint_connector( + connector: CheckpointConnector[CT], + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, +) -> list[SingleConnectorCallOutput[CT]]: + num_iterations = 0 + + checkpoint = cast(CT, connector.build_dummy_checkpoint()) + outputs: list[SingleConnectorCallOutput[CT]] = [] + while checkpoint.has_more: + items: list[Document | ConnectorFailure] = [] + doc_batch_generator = CheckpointOutputWrapper[CT]()( + connector.load_from_checkpoint(start, end, checkpoint) + ) + for document, failure, next_checkpoint in doc_batch_generator: + if failure is not None: + items.append(failure) + if document is not None: + items.append(document) + if next_checkpoint is not None: + checkpoint = next_checkpoint + + outputs.append( + SingleConnectorCallOutput(items=items, next_checkpoint=checkpoint) + ) + + num_iterations += 1 + if num_iterations > _ITERATION_LIMIT: + raise RuntimeError("Too many iterations. Infinite loop?") + + return outputs From 72bf427cc22905ba5281a624f7387b5dbdbbe495 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Wed, 19 Mar 2025 14:15:06 -0700 Subject: [PATCH 06/18] Address invalid connector state (#4304) * Address invalid connector state * Fixes * Address mypy * Address RK comment --- .../tasks/doc_permission_syncing/tasks.py | 8 +--- .../tasks/external_group_syncing/tasks.py | 17 ++----- .../onyx/background/indexing/run_indexing.py | 44 ++++++++++++++++--- backend/onyx/db/constants.py | 2 + 4 files changed, 44 insertions(+), 27 deletions(-) diff --git a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py index ba2b68aa1..6ec3f1a05 100644 --- a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py @@ -46,7 +46,6 @@ from onyx.configs.constants import OnyxRedisSignals from onyx.connectors.factory import validate_ccpair_for_user from onyx.db.connector import mark_cc_pair_as_permissions_synced from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id -from onyx.db.connector_credential_pair import update_connector_credential_pair from onyx.db.document import upsert_document_by_connector_credential_pair from onyx.db.engine import get_session_with_current_tenant from onyx.db.enums import AccessType @@ -420,12 +419,7 @@ def connector_permission_sync_generator_task( task_logger.exception( f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}" ) - update_connector_credential_pair( - db_session=db_session, - connector_id=cc_pair.connector.id, - credential_id=cc_pair.credential.id, - status=ConnectorCredentialPairStatus.INVALID, - ) + # TODO: add some notification to the admins here raise source_type = cc_pair.connector.source diff --git a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py index 1599e0ae1..00cd6ea91 100644 --- a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py @@ -41,7 +41,6 @@ from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.factory import validate_ccpair_for_user from onyx.db.connector import mark_cc_pair_as_external_group_synced from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id -from onyx.db.connector_credential_pair import update_connector_credential_pair from onyx.db.engine import get_session_with_current_tenant from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus @@ -402,12 +401,7 @@ def connector_external_group_sync_generator_task( task_logger.exception( f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}" ) - update_connector_credential_pair( - db_session=db_session, - connector_id=cc_pair.connector.id, - credential_id=cc_pair.credential.id, - status=ConnectorCredentialPairStatus.INVALID, - ) + # TODO: add some notification to the admins here raise source_type = cc_pair.connector.source @@ -425,12 +419,9 @@ def connector_external_group_sync_generator_task( try: external_user_groups = ext_group_sync_func(tenant_id, cc_pair) except ConnectorValidationError as e: - msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}" - update_connector_credential_pair( - db_session=db_session, - connector_id=cc_pair.connector.id, - credential_id=cc_pair.credential.id, - status=ConnectorCredentialPairStatus.INVALID, + # TODO: add some notification to the admins here + logger.exception( + f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}" ) raise e diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index 54345bf09..7802e3d20 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -31,8 +31,11 @@ from onyx.connectors.models import TextSection from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.connector_credential_pair import get_last_successful_attempt_time from onyx.db.connector_credential_pair import update_connector_credential_pair +from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX from onyx.db.engine import get_session_with_current_tenant from onyx.db.enums import ConnectorCredentialPairStatus +from onyx.db.enums import IndexingStatus +from onyx.db.enums import IndexModelStatus from onyx.db.index_attempt import create_index_attempt_error from onyx.db.index_attempt import get_index_attempt from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair @@ -45,8 +48,6 @@ from onyx.db.index_attempt import transition_attempt_to_in_progress from onyx.db.index_attempt import update_docs_indexed from onyx.db.models import IndexAttempt from onyx.db.models import IndexAttemptError -from onyx.db.models import IndexingStatus -from onyx.db.models import IndexModelStatus from onyx.document_index.factory import get_default_document_index from onyx.httpx.httpx_pool import HttpxPool from onyx.indexing.embedder import DefaultIndexingEmbedder @@ -386,6 +387,7 @@ def _run_indexing( net_doc_change = 0 document_count = 0 chunk_count = 0 + index_attempt: IndexAttempt | None = None try: with get_session_with_current_tenant() as db_session_temp: index_attempt = get_index_attempt(db_session_temp, index_attempt_id) @@ -596,16 +598,44 @@ def _run_indexing( mark_attempt_canceled( index_attempt_id, db_session_temp, - reason=str(e), + reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}", ) if ctx.is_primary: - update_connector_credential_pair( + if not index_attempt: + # should always be set by now + raise RuntimeError("Should never happen.") + + VALIDATION_ERROR_THRESHOLD = 5 + + recent_index_attempts = get_recent_completed_attempts_for_cc_pair( + cc_pair_id=ctx.cc_pair_id, + search_settings_id=index_attempt.search_settings_id, + limit=VALIDATION_ERROR_THRESHOLD, db_session=db_session_temp, - connector_id=ctx.connector_id, - credential_id=ctx.credential_id, - status=ConnectorCredentialPairStatus.INVALID, ) + num_validation_errors = len( + [ + index_attempt + for index_attempt in recent_index_attempts + if index_attempt.error_msg + and index_attempt.error_msg.startswith( + CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX + ) + ] + ) + + if num_validation_errors >= VALIDATION_ERROR_THRESHOLD: + logger.warning( + f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation" + f" errors. Marking the CC Pair as invalid." + ) + update_connector_credential_pair( + db_session=db_session_temp, + connector_id=ctx.connector_id, + credential_id=ctx.credential_id, + status=ConnectorCredentialPairStatus.INVALID, + ) memory_tracer.stop() raise e diff --git a/backend/onyx/db/constants.py b/backend/onyx/db/constants.py index 58573d342..61d7149c3 100644 --- a/backend/onyx/db/constants.py +++ b/backend/onyx/db/constants.py @@ -1,2 +1,4 @@ SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__" DEFAULT_PERSONA_SLACK_CHANNEL_NAME = "DEFAULT_SLACK_CHANNEL" + +CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX = "ConnectorValidationError:" From 5dda53eec37e8b64a94fecf2d2232da4b395fce2 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Wed, 19 Mar 2025 16:16:05 -0700 Subject: [PATCH 07/18] Notion improvement (#4306) * Notion connector improvements * Enable recursive index by default * Small tweak --- backend/onyx/configs/app_configs.py | 4 +- backend/onyx/connectors/notion/connector.py | 86 ++++++++++++--------- 2 files changed, 53 insertions(+), 37 deletions(-) diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 17e992582..c3a0b4e80 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -347,8 +347,8 @@ HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get( HtmlBasedConnectorTransformLinksStrategy.STRIP, ) -NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = ( - os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower() +NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = ( + os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower() == "true" ) diff --git a/backend/onyx/connectors/notion/connector.py b/backend/onyx/connectors/notion/connector.py index 678dd32e4..f5521fa2c 100644 --- a/backend/onyx/connectors/notion/connector.py +++ b/backend/onyx/connectors/notion/connector.py @@ -1,16 +1,16 @@ from collections.abc import Generator -from dataclasses import dataclass -from dataclasses import fields from datetime import datetime from datetime import timezone from typing import Any +from typing import cast from typing import Optional import requests +from pydantic import BaseModel from retry import retry from onyx.configs.app_configs import INDEX_BATCH_SIZE -from onyx.configs.app_configs import NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP +from onyx.configs.app_configs import NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rl_requests, @@ -25,6 +25,7 @@ from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document +from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.utils.batching import batch_generator from onyx.utils.logger import setup_logger @@ -38,8 +39,7 @@ _NOTION_CALL_TIMEOUT = 30 # 30 seconds # TODO: Tables need to be ingested, Pages need to have their metadata ingested -@dataclass -class NotionPage: +class NotionPage(BaseModel): """Represents a Notion Page object""" id: str @@ -49,17 +49,10 @@ class NotionPage: properties: dict[str, Any] url: str - database_name: str | None # Only applicable to the database type page (wiki) - - def __init__(self, **kwargs: dict[str, Any]) -> None: - names = set([f.name for f in fields(self)]) - for k, v in kwargs.items(): - if k in names: - setattr(self, k, v) + database_name: str | None = None # Only applicable to the database type page (wiki) -@dataclass -class NotionBlock: +class NotionBlock(BaseModel): """Represents a Notion Block object""" id: str # Used for the URL @@ -69,20 +62,13 @@ class NotionBlock: prefix: str -@dataclass -class NotionSearchResponse: +class NotionSearchResponse(BaseModel): """Represents the response from the Notion Search API""" results: list[dict[str, Any]] next_cursor: Optional[str] has_more: bool = False - def __init__(self, **kwargs: dict[str, Any]) -> None: - names = set([f.name for f in fields(self)]) - for k, v in kwargs.items(): - if k in names: - setattr(self, k, v) - class NotionConnector(LoadConnector, PollConnector): """Notion Page connector that reads all Notion pages @@ -95,7 +81,7 @@ class NotionConnector(LoadConnector, PollConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, - recursive_index_enabled: bool = NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP, + recursive_index_enabled: bool = not NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP, root_page_id: str | None = None, ) -> None: """Initialize with parameters.""" @@ -464,23 +450,53 @@ class NotionConnector(LoadConnector, PollConnector): page_blocks, child_page_ids = self._read_blocks(page.id) all_child_page_ids.extend(child_page_ids) - if not page_blocks: - continue + # okay to mark here since there's no way for this to not succeed + # without a critical failure + self.indexed_pages.add(page.id) - page_title = ( - self._read_page_title(page) or f"Untitled Page with ID {page.id}" - ) + raw_page_title = self._read_page_title(page) + page_title = raw_page_title or f"Untitled Page with ID {page.id}" + + if not page_blocks: + if not raw_page_title: + logger.warning( + f"No blocks OR title found for page with ID '{page.id}'. Skipping." + ) + continue + + logger.debug(f"No blocks found for page with ID '{page.id}'") + """ + Something like: + + TITLE + + PROP1: PROP1_VALUE + PROP2: PROP2_VALUE + """ + text = page_title + if page.properties: + text += "\n\n" + "\n".join( + [f"{key}: {value}" for key, value in page.properties.items()] + ) + sections = [ + TextSection( + link=f"{page.url}", + text=text, + ) + ] + else: + sections = [ + TextSection( + link=f"{page.url}#{block.id.replace('-', '')}", + text=block.prefix + block.text, + ) + for block in page_blocks + ] yield ( Document( id=page.id, - sections=[ - TextSection( - link=f"{page.url}#{block.id.replace('-', '')}", - text=block.prefix + block.text, - ) - for block in page_blocks - ], + sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.NOTION, semantic_identifier=page_title, doc_updated_at=datetime.fromisoformat( From 85ebadc8ebab9a96a11f5376d429ee3a567aeb68 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Wed, 19 Mar 2025 18:13:02 -0700 Subject: [PATCH 08/18] sanitize llm keys and handle updates properly (#4270) * sanitize llm keys and handle updates properly * fix llm provider testing * fix test * mypy * fix default model editing --------- Co-authored-by: Richard Kuo (Danswer) Co-authored-by: Richard Kuo --- .../ee/onyx/server/tenants/provisioning.py | 6 ++- backend/onyx/db/llm.py | 30 +++++++---- backend/onyx/llm/factory.py | 10 ++-- backend/onyx/server/manage/llm/api.py | 51 ++++++++++++++----- backend/onyx/server/manage/llm/models.py | 10 ++-- backend/onyx/setup.py | 1 + .../common_utils/managers/llm_provider.py | 14 ++--- .../tests/llm_provider/test_llm_provider.py | 32 ++++++++++++ .../app/admin/assistants/AssistantEditor.tsx | 4 +- web/src/app/admin/assistants/lib.ts | 4 +- .../llm/ConfiguredLLMProviderDisplay.tsx | 8 +-- .../llm/CustomLLMProviderUpdateForm.tsx | 6 +-- .../configuration/llm/LLMConfiguration.tsx | 8 +-- .../llm/LLMProviderUpdateForm.tsx | 10 ++-- .../app/admin/configuration/llm/interfaces.ts | 4 +- .../components/initialSetup/welcome/lib.ts | 4 +- .../assistants/fetchPersonaEditorInfoSS.ts | 6 +-- 17 files changed, 146 insertions(+), 62 deletions(-) diff --git a/backend/ee/onyx/server/tenants/provisioning.py b/backend/ee/onyx/server/tenants/provisioning.py index 19880339d..9a1b425b7 100644 --- a/backend/ee/onyx/server/tenants/provisioning.py +++ b/backend/ee/onyx/server/tenants/provisioning.py @@ -271,6 +271,7 @@ def configure_default_api_keys(db_session: Session) -> None: fast_default_model_name="claude-3-5-sonnet-20241022", model_names=ANTHROPIC_MODEL_NAMES, display_model_names=["claude-3-5-sonnet-20241022"], + api_key_changed=True, ) try: full_provider = upsert_llm_provider(anthropic_provider, db_session) @@ -283,7 +284,7 @@ def configure_default_api_keys(db_session: Session) -> None: ) if OPENAI_DEFAULT_API_KEY: - open_provider = LLMProviderUpsertRequest( + openai_provider = LLMProviderUpsertRequest( name="OpenAI", provider=OPENAI_PROVIDER_NAME, api_key=OPENAI_DEFAULT_API_KEY, @@ -291,9 +292,10 @@ def configure_default_api_keys(db_session: Session) -> None: fast_default_model_name="gpt-4o-mini", model_names=OPEN_AI_MODEL_NAMES, display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"], + api_key_changed=True, ) try: - full_provider = upsert_llm_provider(open_provider, db_session) + full_provider = upsert_llm_provider(openai_provider, db_session) update_default_provider(full_provider.id, db_session) except Exception as e: logger.error(f"Failed to configure OpenAI provider: {e}") diff --git a/backend/onyx/db/llm.py b/backend/onyx/db/llm.py index e5b1602b7..7a70462ad 100644 --- a/backend/onyx/db/llm.py +++ b/backend/onyx/db/llm.py @@ -16,8 +16,8 @@ from onyx.db.models import User__UserGroup from onyx.llm.utils import model_supports_image_input from onyx.server.manage.embedding.models import CloudEmbeddingProvider from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest -from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView from shared_configs.enums import EmbeddingProvider @@ -67,7 +67,7 @@ def upsert_cloud_embedding_provider( def upsert_llm_provider( llm_provider: LLMProviderUpsertRequest, db_session: Session, -) -> FullLLMProvider: +) -> LLMProviderView: existing_llm_provider = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name) ) @@ -98,7 +98,7 @@ def upsert_llm_provider( group_ids=llm_provider.groups, db_session=db_session, ) - full_llm_provider = FullLLMProvider.from_model(existing_llm_provider) + full_llm_provider = LLMProviderView.from_model(existing_llm_provider) db_session.commit() @@ -132,6 +132,16 @@ def fetch_existing_llm_providers( return list(db_session.scalars(stmt).all()) +def fetch_existing_llm_provider( + provider_name: str, db_session: Session +) -> LLMProviderModel | None: + provider_model = db_session.scalar( + select(LLMProviderModel).where(LLMProviderModel.name == provider_name) + ) + + return provider_model + + def fetch_existing_llm_providers_for_user( db_session: Session, user: User | None = None, @@ -177,7 +187,7 @@ def fetch_embedding_provider( ) -def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: +def fetch_default_provider(db_session: Session) -> LLMProviderView | None: provider_model = db_session.scalar( select(LLMProviderModel).where( LLMProviderModel.is_default_provider == True # noqa: E712 @@ -185,10 +195,10 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: ) if not provider_model: return None - return FullLLMProvider.from_model(provider_model) + return LLMProviderView.from_model(provider_model) -def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None: +def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None: provider_model = db_session.scalar( select(LLMProviderModel).where( LLMProviderModel.is_default_vision_provider == True # noqa: E712 @@ -196,16 +206,18 @@ def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None ) if not provider_model: return None - return FullLLMProvider.from_model(provider_model) + return LLMProviderView.from_model(provider_model) -def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None: +def fetch_llm_provider_view( + db_session: Session, provider_name: str +) -> LLMProviderView | None: provider_model = db_session.scalar( select(LLMProviderModel).where(LLMProviderModel.name == provider_name) ) if not provider_model: return None - return FullLLMProvider.from_model(provider_model) + return LLMProviderView.from_model(provider_model) def remove_embedding_provider( diff --git a/backend/onyx/llm/factory.py b/backend/onyx/llm/factory.py index 3d0bb6b3b..c77518f51 100644 --- a/backend/onyx/llm/factory.py +++ b/backend/onyx/llm/factory.py @@ -9,14 +9,14 @@ from onyx.db.engine import get_session_with_current_tenant from onyx.db.llm import fetch_default_provider from onyx.db.llm import fetch_default_vision_provider from onyx.db.llm import fetch_existing_llm_providers -from onyx.db.llm import fetch_provider +from onyx.db.llm import fetch_llm_provider_view from onyx.db.models import Persona from onyx.llm.chat_llm import DefaultMultiLLM from onyx.llm.exceptions import GenAIDisabledException from onyx.llm.interfaces import LLM from onyx.llm.override_models import LLMOverride from onyx.llm.utils import model_supports_image_input -from onyx.server.manage.llm.models import FullLLMProvider +from onyx.server.manage.llm.models import LLMProviderView from onyx.utils.headers import build_llm_extra_headers from onyx.utils.logger import setup_logger from onyx.utils.long_term_log import LongTermLogger @@ -62,7 +62,7 @@ def get_llms_for_persona( ) with get_session_context_manager() as db_session: - llm_provider = fetch_provider(db_session, provider_name) + llm_provider = fetch_llm_provider_view(db_session, provider_name) if not llm_provider: raise ValueError("No LLM provider found") @@ -106,7 +106,7 @@ def get_default_llm_with_vision( if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() - def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM: + def create_vision_llm(provider: LLMProviderView, model: str) -> LLM: """Helper to create an LLM if the provider supports image input.""" return get_llm( provider=provider.provider, @@ -148,7 +148,7 @@ def get_default_llm_with_vision( provider.default_vision_model, provider.provider ): return create_vision_llm( - FullLLMProvider.from_model(provider), provider.default_vision_model + LLMProviderView.from_model(provider), provider.default_vision_model ) return None diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index ceafca2e3..0a5ceb036 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -9,9 +9,9 @@ from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user from onyx.auth.users import current_chat_accessible_user from onyx.db.engine import get_session +from onyx.db.llm import fetch_existing_llm_provider from onyx.db.llm import fetch_existing_llm_providers from onyx.db.llm import fetch_existing_llm_providers_for_user -from onyx.db.llm import fetch_provider from onyx.db.llm import remove_llm_provider from onyx.db.llm import update_default_provider from onyx.db.llm import update_default_vision_provider @@ -24,9 +24,9 @@ from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor from onyx.llm.utils import litellm_exception_to_error_msg from onyx.llm.utils import model_supports_image_input from onyx.llm.utils import test_llm -from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderDescriptor from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView from onyx.server.manage.llm.models import TestLLMRequest from onyx.server.manage.llm.models import VisionProviderResponse from onyx.utils.logger import setup_logger @@ -49,11 +49,27 @@ def fetch_llm_options( def test_llm_configuration( test_llm_request: TestLLMRequest, _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), ) -> None: + """Test regular llm and fast llm settings""" + + # the api key is sanitized if we are testing a provider already in the system + + test_api_key = test_llm_request.api_key + if test_llm_request.name: + # NOTE: we are querying by name. we probably should be querying by an invariant id, but + # as it turns out the name is not editable in the UI and other code also keys off name, + # so we won't rock the boat just yet. + existing_provider = fetch_existing_llm_provider( + test_llm_request.name, db_session + ) + if existing_provider: + test_api_key = existing_provider.api_key + llm = get_llm( provider=test_llm_request.provider, model=test_llm_request.default_model_name, - api_key=test_llm_request.api_key, + api_key=test_api_key, api_base=test_llm_request.api_base, api_version=test_llm_request.api_version, custom_config=test_llm_request.custom_config, @@ -69,7 +85,7 @@ def test_llm_configuration( fast_llm = get_llm( provider=test_llm_request.provider, model=test_llm_request.fast_default_model_name, - api_key=test_llm_request.api_key, + api_key=test_api_key, api_base=test_llm_request.api_base, api_version=test_llm_request.api_version, custom_config=test_llm_request.custom_config, @@ -119,11 +135,17 @@ def test_default_provider( def list_llm_providers( _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), -) -> list[FullLLMProvider]: - return [ - FullLLMProvider.from_model(llm_provider_model) - for llm_provider_model in fetch_existing_llm_providers(db_session) - ] +) -> list[LLMProviderView]: + llm_provider_list: list[LLMProviderView] = [] + for llm_provider_model in fetch_existing_llm_providers(db_session): + full_llm_provider = LLMProviderView.from_model(llm_provider_model) + if full_llm_provider.api_key: + full_llm_provider.api_key = ( + full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:] + ) + llm_provider_list.append(full_llm_provider) + + return llm_provider_list @admin_router.put("/provider") @@ -135,11 +157,11 @@ def put_llm_provider( ), _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), -) -> FullLLMProvider: +) -> LLMProviderView: # validate request (e.g. if we're intending to create but the name already exists we should throw an error) # NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache # the result - existing_provider = fetch_provider(db_session, llm_provider.name) + existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session) if existing_provider and is_creation: raise HTTPException( status_code=400, @@ -161,6 +183,11 @@ def put_llm_provider( llm_provider.fast_default_model_name ) + # the llm api key is sanitized when returned to clients, so the only time we + # should get a real key is when it is explicitly changed + if existing_provider and not llm_provider.api_key_changed: + llm_provider.api_key = existing_provider.api_key + try: return upsert_llm_provider( llm_provider=llm_provider, @@ -234,7 +261,7 @@ def get_vision_capable_providers( # Only include providers with at least one vision-capable model if vision_models: - provider_dict = FullLLMProvider.from_model(provider).model_dump() + provider_dict = LLMProviderView.from_model(provider).model_dump() provider_dict["vision_models"] = vision_models logger.info( f"Vision provider: {provider.provider} with models: {vision_models}" diff --git a/backend/onyx/server/manage/llm/models.py b/backend/onyx/server/manage/llm/models.py index 3172f5adf..9d5544d96 100644 --- a/backend/onyx/server/manage/llm/models.py +++ b/backend/onyx/server/manage/llm/models.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: class TestLLMRequest(BaseModel): # provider level + name: str | None = None provider: str api_key: str | None = None api_base: str | None = None @@ -76,16 +77,19 @@ class LLMProviderUpsertRequest(LLMProvider): # should only be used for a "custom" provider # for default providers, the built-in model names are used model_names: list[str] | None = None + api_key_changed: bool = False -class FullLLMProvider(LLMProvider): +class LLMProviderView(LLMProvider): + """Stripped down representation of LLMProvider for display / limited access info only""" + id: int is_default_provider: bool | None = None is_default_vision_provider: bool | None = None model_names: list[str] @classmethod - def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider": + def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView": return cls( id=llm_provider_model.id, name=llm_provider_model.name, @@ -111,7 +115,7 @@ class FullLLMProvider(LLMProvider): ) -class VisionProviderResponse(FullLLMProvider): +class VisionProviderResponse(LLMProviderView): """Response model for vision providers endpoint, including vision-specific fields.""" vision_models: list[str] diff --git a/backend/onyx/setup.py b/backend/onyx/setup.py index 1dff601ef..750b35d8d 100644 --- a/backend/onyx/setup.py +++ b/backend/onyx/setup.py @@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None: groups=[], display_model_names=OPEN_AI_MODEL_NAMES, model_names=OPEN_AI_MODEL_NAMES, + api_key_changed=True, ) new_llm_provider = upsert_llm_provider( llm_provider=model_req, db_session=db_session diff --git a/backend/tests/integration/common_utils/managers/llm_provider.py b/backend/tests/integration/common_utils/managers/llm_provider.py index 44d4ce501..33d29e42e 100644 --- a/backend/tests/integration/common_utils/managers/llm_provider.py +++ b/backend/tests/integration/common_utils/managers/llm_provider.py @@ -3,8 +3,8 @@ from uuid import uuid4 import requests -from onyx.server.manage.llm.models import FullLLMProvider from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.server.manage.llm.models import LLMProviderView from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.test_models import DATestLLMProvider @@ -39,6 +39,7 @@ class LLMProviderManager: groups=groups or [], display_model_names=None, model_names=None, + api_key_changed=True, ) llm_response = requests.put( @@ -90,7 +91,7 @@ class LLMProviderManager: @staticmethod def get_all( user_performing_action: DATestUser | None = None, - ) -> list[FullLLMProvider]: + ) -> list[LLMProviderView]: response = requests.get( f"{API_SERVER_URL}/admin/llm/provider", headers=user_performing_action.headers @@ -98,7 +99,7 @@ class LLMProviderManager: else GENERAL_HEADERS, ) response.raise_for_status() - return [FullLLMProvider(**ug) for ug in response.json()] + return [LLMProviderView(**ug) for ug in response.json()] @staticmethod def verify( @@ -111,18 +112,19 @@ class LLMProviderManager: if llm_provider.id == fetched_llm_provider.id: if verify_deleted: raise ValueError( - f"User group {llm_provider.id} found but should be deleted" + f"LLM Provider {llm_provider.id} found but should be deleted" ) fetched_llm_groups = set(fetched_llm_provider.groups) llm_provider_groups = set(llm_provider.groups) + + # NOTE: returned api keys are sanitized and should not match if ( fetched_llm_groups == llm_provider_groups and llm_provider.provider == fetched_llm_provider.provider - and llm_provider.api_key == fetched_llm_provider.api_key and llm_provider.default_model_name == fetched_llm_provider.default_model_name and llm_provider.is_public == fetched_llm_provider.is_public ): return if not verify_deleted: - raise ValueError(f"User group {llm_provider.id} not found") + raise ValueError(f"LLM Provider {llm_provider.id} not found") diff --git a/backend/tests/integration/tests/llm_provider/test_llm_provider.py b/backend/tests/integration/tests/llm_provider/test_llm_provider.py index 1b7d4207e..7c72382f1 100644 --- a/backend/tests/integration/tests/llm_provider/test_llm_provider.py +++ b/backend/tests/integration/tests/llm_provider/test_llm_provider.py @@ -34,6 +34,7 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None: json={ "name": str(uuid.uuid4()), "provider": "openai", + "api_key": "sk-000000000000000000000000000000000000000000000000", "default_model_name": _DEFAULT_MODELS[0], "model_names": _DEFAULT_MODELS, "is_public": True, @@ -49,6 +50,9 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None: assert provider_data["model_names"] == _DEFAULT_MODELS assert provider_data["default_model_name"] == _DEFAULT_MODELS[0] assert provider_data["display_model_names"] is None + assert ( + provider_data["api_key"] == "sk-0****0000" + ) # test that returned key is sanitized def test_update_llm_provider_model_names(reset: None) -> None: @@ -64,10 +68,12 @@ def test_update_llm_provider_model_names(reset: None) -> None: json={ "name": name, "provider": "openai", + "api_key": "sk-000000000000000000000000000000000000000000000000", "default_model_name": _DEFAULT_MODELS[0], "model_names": [_DEFAULT_MODELS[0]], "is_public": True, "groups": [], + "api_key_changed": True, }, ) assert response.status_code == 200 @@ -81,6 +87,7 @@ def test_update_llm_provider_model_names(reset: None) -> None: "id": created_provider["id"], "name": name, "provider": created_provider["provider"], + "api_key": "sk-000000000000000000000000000000000000000000000001", "default_model_name": _DEFAULT_MODELS[0], "model_names": _DEFAULT_MODELS, "is_public": True, @@ -93,6 +100,30 @@ def test_update_llm_provider_model_names(reset: None) -> None: provider_data = _get_provider_by_id(admin_user, created_provider["id"]) assert provider_data is not None assert provider_data["model_names"] == _DEFAULT_MODELS + assert ( + provider_data["api_key"] == "sk-0****0000" + ) # test that key was NOT updated due to api_key_changed not being set + + # Update with api_key_changed properly set + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "id": created_provider["id"], + "name": name, + "provider": created_provider["provider"], + "api_key": "sk-000000000000000000000000000000000000000000000001", + "default_model_name": _DEFAULT_MODELS[0], + "model_names": _DEFAULT_MODELS, + "is_public": True, + "groups": [], + "api_key_changed": True, + }, + ) + assert response.status_code == 200 + provider_data = _get_provider_by_id(admin_user, created_provider["id"]) + assert provider_data is not None + assert provider_data["api_key"] == "sk-0****0001" # test that key was updated def test_delete_llm_provider(reset: None) -> None: @@ -107,6 +138,7 @@ def test_delete_llm_provider(reset: None) -> None: json={ "name": "test-provider-delete", "provider": "openai", + "api_key": "sk-000000000000000000000000000000000000000000000000", "default_model_name": _DEFAULT_MODELS[0], "model_names": _DEFAULT_MODELS, "is_public": True, diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 4ed8c3be9..7137102f5 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -61,7 +61,7 @@ import { import { buildImgUrl } from "@/app/chat/files/images/utils"; import { useAssistants } from "@/components/context/AssistantsContext"; import { debounce } from "lodash"; -import { FullLLMProvider } from "../configuration/llm/interfaces"; +import { LLMProviderView } from "../configuration/llm/interfaces"; import StarterMessagesList from "./StarterMessageList"; import { Switch, SwitchField } from "@/components/ui/switch"; @@ -123,7 +123,7 @@ export function AssistantEditor({ documentSets: DocumentSet[]; user: User | null; defaultPublic: boolean; - llmProviders: FullLLMProvider[]; + llmProviders: LLMProviderView[]; tools: ToolSnapshot[]; shouldAddAssistantToUserPreferences?: boolean; admin?: boolean; diff --git a/web/src/app/admin/assistants/lib.ts b/web/src/app/admin/assistants/lib.ts index a6494782f..70dc8035b 100644 --- a/web/src/app/admin/assistants/lib.ts +++ b/web/src/app/admin/assistants/lib.ts @@ -1,4 +1,4 @@ -import { FullLLMProvider } from "../configuration/llm/interfaces"; +import { LLMProviderView } from "../configuration/llm/interfaces"; import { Persona, StarterMessage } from "./interfaces"; interface PersonaUpsertRequest { @@ -319,7 +319,7 @@ export function checkPersonaRequiresImageGeneration(persona: Persona) { } export function providersContainImageGeneratingSupport( - providers: FullLLMProvider[] + providers: LLMProviderView[] ) { return providers.some((provider) => provider.provider === "openai"); } diff --git a/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx b/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx index 3146d1da7..16b3e7863 100644 --- a/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx +++ b/web/src/app/admin/configuration/llm/ConfiguredLLMProviderDisplay.tsx @@ -1,5 +1,5 @@ import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; -import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; +import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces"; import { Modal } from "@/components/Modal"; import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm"; import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm"; @@ -19,7 +19,7 @@ function LLMProviderUpdateModal({ }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined; onClose: () => void; - existingLlmProvider?: FullLLMProvider; + existingLlmProvider?: LLMProviderView; shouldMarkAsDefault?: boolean; setPopup?: (popup: PopupSpec) => void; }) { @@ -61,7 +61,7 @@ function LLMProviderDisplay({ shouldMarkAsDefault, }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined; - existingLlmProvider: FullLLMProvider; + existingLlmProvider: LLMProviderView; shouldMarkAsDefault?: boolean; }) { const [formIsVisible, setFormIsVisible] = useState(false); @@ -146,7 +146,7 @@ export function ConfiguredLLMProviderDisplay({ existingLlmProviders, llmProviderDescriptors, }: { - existingLlmProviders: FullLLMProvider[]; + existingLlmProviders: LLMProviderView[]; llmProviderDescriptors: WellKnownLLMProviderDescriptor[]; }) { existingLlmProviders = existingLlmProviders.sort((a, b) => { diff --git a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx index 0b175554d..1bdef47e4 100644 --- a/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.tsx @@ -21,7 +21,7 @@ import { } from "@/components/admin/connectors/Field"; import { useState } from "react"; import { useSWRConfig } from "swr"; -import { FullLLMProvider } from "./interfaces"; +import { LLMProviderView } from "./interfaces"; import { PopupSpec } from "@/components/admin/connectors/Popup"; import * as Yup from "yup"; import isEqual from "lodash/isEqual"; @@ -43,7 +43,7 @@ export function CustomLLMProviderUpdateForm({ hideSuccess, }: { onClose: () => void; - existingLlmProvider?: FullLLMProvider; + existingLlmProvider?: LLMProviderView; shouldMarkAsDefault?: boolean; setPopup?: (popup: PopupSpec) => void; hideSuccess?: boolean; @@ -165,7 +165,7 @@ export function CustomLLMProviderUpdateForm({ } if (shouldMarkAsDefault) { - const newLlmProvider = (await response.json()) as FullLLMProvider; + const newLlmProvider = (await response.json()) as LLMProviderView; const setDefaultResponse = await fetch( `${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`, { diff --git a/web/src/app/admin/configuration/llm/LLMConfiguration.tsx b/web/src/app/admin/configuration/llm/LLMConfiguration.tsx index 933efa597..56645df4f 100644 --- a/web/src/app/admin/configuration/llm/LLMConfiguration.tsx +++ b/web/src/app/admin/configuration/llm/LLMConfiguration.tsx @@ -9,7 +9,7 @@ import Text from "@/components/ui/text"; import Title from "@/components/ui/title"; import { Button } from "@/components/ui/button"; import { ThreeDotsLoader } from "@/components/Loading"; -import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; +import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces"; import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm"; import { LLM_PROVIDERS_ADMIN_URL } from "./constants"; @@ -25,7 +25,7 @@ function LLMProviderUpdateModal({ }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor | null; onClose: () => void; - existingLlmProvider?: FullLLMProvider; + existingLlmProvider?: LLMProviderView; shouldMarkAsDefault?: boolean; setPopup?: (popup: PopupSpec) => void; }) { @@ -99,7 +99,7 @@ function DefaultLLMProviderDisplay({ function AddCustomLLMProvider({ existingLlmProviders, }: { - existingLlmProviders: FullLLMProvider[]; + existingLlmProviders: LLMProviderView[]; }) { const [formIsVisible, setFormIsVisible] = useState(false); @@ -130,7 +130,7 @@ export function LLMConfiguration() { const { data: llmProviderDescriptors } = useSWR< WellKnownLLMProviderDescriptor[] >("/api/admin/llm/built-in/options", errorHandlingFetcher); - const { data: existingLlmProviders } = useSWR( + const { data: existingLlmProviders } = useSWR( LLM_PROVIDERS_ADMIN_URL, errorHandlingFetcher ); diff --git a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx index f8cb5b6cb..cb2881a31 100644 --- a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx @@ -14,7 +14,7 @@ import { import { useState } from "react"; import { useSWRConfig } from "swr"; import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks"; -import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; +import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces"; import { PopupSpec } from "@/components/admin/connectors/Popup"; import * as Yup from "yup"; import isEqual from "lodash/isEqual"; @@ -31,7 +31,7 @@ export function LLMProviderUpdateForm({ }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor; onClose: () => void; - existingLlmProvider?: FullLLMProvider; + existingLlmProvider?: LLMProviderView; shouldMarkAsDefault?: boolean; hideAdvanced?: boolean; setPopup?: (popup: PopupSpec) => void; @@ -73,6 +73,7 @@ export function LLMProviderUpdateForm({ defaultModelsByProvider[llmProviderDescriptor.name] || [], deployment_name: existingLlmProvider?.deployment_name, + api_key_changed: false, }; // Setup validation schema if required @@ -113,6 +114,7 @@ export function LLMProviderUpdateForm({ is_public: Yup.boolean().required(), groups: Yup.array().of(Yup.number()), display_model_names: Yup.array().of(Yup.string()), + api_key_changed: Yup.boolean(), }); return ( @@ -122,6 +124,8 @@ export function LLMProviderUpdateForm({ onSubmit={async (values, { setSubmitting }) => { setSubmitting(true); + values.api_key_changed = values.api_key !== initialValues.api_key; + // test the configuration if (!isEqual(values, initialValues)) { setIsTesting(true); @@ -180,7 +184,7 @@ export function LLMProviderUpdateForm({ } if (shouldMarkAsDefault) { - const newLlmProvider = (await response.json()) as FullLLMProvider; + const newLlmProvider = (await response.json()) as LLMProviderView; const setDefaultResponse = await fetch( `${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`, { diff --git a/web/src/app/admin/configuration/llm/interfaces.ts b/web/src/app/admin/configuration/llm/interfaces.ts index 641a372d2..80971e0cc 100644 --- a/web/src/app/admin/configuration/llm/interfaces.ts +++ b/web/src/app/admin/configuration/llm/interfaces.ts @@ -53,14 +53,14 @@ export interface LLMProvider { is_default_vision_provider: boolean | null; } -export interface FullLLMProvider extends LLMProvider { +export interface LLMProviderView extends LLMProvider { id: number; is_default_provider: boolean | null; model_names: string[]; icon?: React.FC<{ size?: number; className?: string }>; } -export interface VisionProvider extends FullLLMProvider { +export interface VisionProvider extends LLMProviderView { vision_models: string[]; } diff --git a/web/src/components/initialSetup/welcome/lib.ts b/web/src/components/initialSetup/welcome/lib.ts index 5cbe54cc3..822b9f1ea 100644 --- a/web/src/components/initialSetup/welcome/lib.ts +++ b/web/src/components/initialSetup/welcome/lib.ts @@ -1,5 +1,5 @@ import { - FullLLMProvider, + LLMProviderView, WellKnownLLMProviderDescriptor, } from "@/app/admin/configuration/llm/interfaces"; import { User } from "@/lib/types"; @@ -36,7 +36,7 @@ export async function checkLlmProvider(user: User | null) { const [providerResponse, optionsResponse, defaultCheckResponse] = await Promise.all(tasks); - let providers: FullLLMProvider[] = []; + let providers: LLMProviderView[] = []; if (providerResponse?.ok) { providers = await providerResponse.json(); } diff --git a/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts b/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts index 2c00e0dfc..3d4a85845 100644 --- a/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts +++ b/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts @@ -3,7 +3,7 @@ import { CCPairBasicInfo, DocumentSet, User } from "../types"; import { getCurrentUserSS } from "../userSS"; import { fetchSS } from "../utilsSS"; import { - FullLLMProvider, + LLMProviderView, getProviderIcon, } from "@/app/admin/configuration/llm/interfaces"; import { ToolSnapshot } from "../tools/interfaces"; @@ -16,7 +16,7 @@ export async function fetchAssistantEditorInfoSS( { ccPairs: CCPairBasicInfo[]; documentSets: DocumentSet[]; - llmProviders: FullLLMProvider[]; + llmProviders: LLMProviderView[]; user: User | null; existingPersona: Persona | null; tools: ToolSnapshot[]; @@ -83,7 +83,7 @@ export async function fetchAssistantEditorInfoSS( ]; } - const llmProviders = (await llmProvidersResponse.json()) as FullLLMProvider[]; + const llmProviders = (await llmProvidersResponse.json()) as LLMProviderView[]; if (personaId && personaResponse && !personaResponse.ok) { return [null, `Failed to fetch Persona - ${await personaResponse.text()}`]; From 2a01c854a0ece32878c490044fe927a5d488905a Mon Sep 17 00:00:00 2001 From: Weves Date: Wed, 19 Mar 2025 17:16:52 -0700 Subject: [PATCH 09/18] Fix cases where the bot is disabled --- backend/onyx/onyxbot/slack/listener.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/backend/onyx/onyxbot/slack/listener.py b/backend/onyx/onyxbot/slack/listener.py index 75f662fd6..e6ac548a2 100644 --- a/backend/onyx/onyxbot/slack/listener.py +++ b/backend/onyx/onyxbot/slack/listener.py @@ -41,6 +41,7 @@ from onyx.db.engine import get_session_with_current_tenant from onyx.db.engine import get_session_with_tenant from onyx.db.models import SlackBot from onyx.db.search_settings import get_current_search_settings +from onyx.db.slack_bot import fetch_slack_bot from onyx.db.slack_bot import fetch_slack_bots from onyx.key_value_store.interface import KvKeyNotFoundError from onyx.natural_language_processing.search_nlp_models import EmbeddingModel @@ -519,6 +520,25 @@ class SlackbotHandler: def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool: """True to keep going, False to ignore this Slack request""" + + # skip cases where the bot is disabled in the web UI + bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client) + with get_session_with_current_tenant() as db_session: + slack_bot = fetch_slack_bot( + db_session=db_session, slack_bot_id=client.slack_bot_id + ) + if not slack_bot: + logger.error( + f"Slack bot with ID '{client.slack_bot_id}' not found. Skipping request." + ) + return False + + if not slack_bot.enabled: + logger.info( + f"Slack bot with ID '{client.slack_bot_id}' is disabled. Skipping request." + ) + return False + if req.type == "events_api": # Verify channel is valid event = cast(dict[str, Any], req.payload.get("event", {})) From 91c9be37c0dd548a8a5f0b615201c64b9f48dd87 Mon Sep 17 00:00:00 2001 From: Weves Date: Wed, 19 Mar 2025 17:34:15 -0700 Subject: [PATCH 10/18] Fix loader --- web/src/app/admin/bots/[bot-id]/page.tsx | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/web/src/app/admin/bots/[bot-id]/page.tsx b/web/src/app/admin/bots/[bot-id]/page.tsx index aea206d0e..e56558762 100644 --- a/web/src/app/admin/bots/[bot-id]/page.tsx +++ b/web/src/app/admin/bots/[bot-id]/page.tsx @@ -6,11 +6,9 @@ import { ErrorCallout } from "@/components/ErrorCallout"; import { ThreeDotsLoader } from "@/components/Loading"; import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; import { usePopup } from "@/components/admin/connectors/Popup"; -import Link from "next/link"; import { SlackChannelConfigsTable } from "./SlackChannelConfigsTable"; import { useSlackBot, useSlackChannelConfigsByBot } from "./hooks"; import { ExistingSlackBotForm } from "../SlackBotUpdateForm"; -import { FiPlusSquare } from "react-icons/fi"; import { Separator } from "@/components/ui/separator"; function SlackBotEditPage({ @@ -37,7 +35,11 @@ function SlackBotEditPage({ } = useSlackChannelConfigsByBot(Number(unwrappedParams["bot-id"])); if (isSlackBotLoading || isSlackChannelConfigsLoading) { - return ; + return ( +
+ +
+ ); } if (slackBotError || !slackBot) { @@ -67,7 +69,7 @@ function SlackBotEditPage({ } return ( -
+ <> @@ -86,8 +88,18 @@ function SlackBotEditPage({ setPopup={setPopup} />
- + ); } -export default SlackBotEditPage; +export default function Page({ + params, +}: { + params: Promise<{ "bot-id": string }>; +}) { + return ( +
+ +
+ ); +} From 15dd1e72cade63ab0b423fcbbf9f96ac00047fc6 Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 20 Mar 2025 08:34:03 -0700 Subject: [PATCH 11/18] Remove slack channel validation --- backend/onyx/connectors/slack/connector.py | 39 ++++++++++++---------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index d4890e2e3..d220020b9 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -704,25 +704,28 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]): ) # 3) If channels are specified and regex is not enabled, verify each is accessible - if self.channels and not self.channel_regex_enabled: - accessible_channels = get_channels( - client=self.fast_client, - exclude_archived=True, - get_public=True, - get_private=True, - ) - # For quick lookups by name or ID, build a map: - accessible_channel_names = {ch["name"] for ch in accessible_channels} - accessible_channel_ids = {ch["id"] for ch in accessible_channels} + # NOTE: removed this for now since it may be too slow for large workspaces which may + # have some automations which create a lot of channels (100k+) - for user_channel in self.channels: - if ( - user_channel not in accessible_channel_names - and user_channel not in accessible_channel_ids - ): - raise ConnectorValidationError( - f"Channel '{user_channel}' not found or inaccessible in this workspace." - ) + # if self.channels and not self.channel_regex_enabled: + # accessible_channels = get_channels( + # client=self.fast_client, + # exclude_archived=True, + # get_public=True, + # get_private=True, + # ) + # # For quick lookups by name or ID, build a map: + # accessible_channel_names = {ch["name"] for ch in accessible_channels} + # accessible_channel_ids = {ch["id"] for ch in accessible_channels} + + # for user_channel in self.channels: + # if ( + # user_channel not in accessible_channel_names + # and user_channel not in accessible_channel_ids + # ): + # raise ConnectorValidationError( + # f"Channel '{user_channel}' not found or inaccessible in this workspace." + # ) except SlackApiError as e: slack_error = e.response.get("error", "") From 0292ca2445a8d0fb8089d1f1e9e67d9cadb826d6 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Thu, 20 Mar 2025 09:56:05 -0700 Subject: [PATCH 12/18] Add option to control # of slack threads (#4310) --- backend/onyx/configs/app_configs.py | 3 +++ backend/onyx/connectors/slack/connector.py | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index c3a0b4e80..14da10664 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -420,6 +420,9 @@ EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET") LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID") LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET") +# Slack specific configs +SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2) + DASK_JOB_CLIENT_ENABLED = ( os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true" ) diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index d220020b9..83e52c410 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -18,6 +18,7 @@ from typing_extensions import override from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS from onyx.configs.app_configs import INDEX_BATCH_SIZE +from onyx.configs.app_configs import SLACK_NUM_THREADS from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError @@ -486,7 +487,6 @@ def _process_message( class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]): - MAX_WORKERS = 2 FAST_TIMEOUT = 1 def __init__( @@ -496,10 +496,12 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]): # regexes, and will only index channels that fully match the regexes channel_regex_enabled: bool = False, batch_size: int = INDEX_BATCH_SIZE, + num_threads: int = SLACK_NUM_THREADS, ) -> None: self.channels = channels self.channel_regex_enabled = channel_regex_enabled self.batch_size = batch_size + self.num_threads = num_threads self.client: WebClient | None = None self.fast_client: WebClient | None = None # just used for efficiency @@ -593,7 +595,7 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]): new_latest = message_batch[-1]["ts"] if message_batch else latest # Process messages in parallel using ThreadPoolExecutor - with ThreadPoolExecutor(max_workers=SlackConnector.MAX_WORKERS) as executor: + with ThreadPoolExecutor(max_workers=self.num_threads) as executor: futures: list[Future[ProcessedSlackMessage]] = [] for message in message_batch: # Capture the current context so that the thread gets the current tenant ID From 6d330131fd1383128bfcc90d34c2ac0948a99249 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Thu, 20 Mar 2025 16:10:28 -0700 Subject: [PATCH 13/18] =?UTF-8?q?wire=20off=20image=20downloading=20for=20?= =?UTF-8?q?confluence=20and=20gdrive=20if=20not=20enabled=20i=E2=80=A6=20(?= =?UTF-8?q?#4305)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * wire off image downloading for confluence and gdrive if not enabled in settings * fix partial func * fix confluence basic test * add test for skipping/allowing images * review comments * skip allow images test * mock function using the db * mock at the proper level --------- Co-authored-by: Richard Kuo (Onyx) --- .../onyx/connectors/confluence/connector.py | 19 ++-- backend/onyx/connectors/confluence/utils.py | 15 +++- backend/onyx/connectors/factory.py | 3 + .../onyx/connectors/google_drive/connector.py | 7 ++ .../connectors/google_drive/doc_conversion.py | 8 +- backend/onyx/connectors/interfaces.py | 4 + .../confluence/test_confluence_basic.py | 86 +++++++++++++++---- 7 files changed, 117 insertions(+), 25 deletions(-) diff --git a/backend/onyx/connectors/confluence/connector.py b/backend/onyx/connectors/confluence/connector.py index ce2a81fca..76f694868 100644 --- a/backend/onyx/connectors/confluence/connector.py +++ b/backend/onyx/connectors/confluence/connector.py @@ -114,6 +114,7 @@ class ConfluenceConnector( self.timezone_offset = timezone_offset self._confluence_client: OnyxConfluence | None = None self._fetched_titles: set[str] = set() + self.allow_images = False # Remove trailing slash from wiki_base if present self.wiki_base = wiki_base.rstrip("/") @@ -158,6 +159,9 @@ class ConfluenceConnector( "max_backoff_seconds": 60, } + def set_allow_images(self, value: bool) -> None: + self.allow_images = value + @property def confluence_client(self) -> OnyxConfluence: if self._confluence_client is None: @@ -233,7 +237,9 @@ class ConfluenceConnector( # Extract basic page information page_id = page["id"] page_title = page["title"] - page_url = f"{self.wiki_base}{page['_links']['webui']}" + page_url = build_confluence_document_id( + self.wiki_base, page["_links"]["webui"], self.is_cloud + ) # Get the page content page_content = extract_text_from_confluence_html( @@ -264,6 +270,7 @@ class ConfluenceConnector( self.confluence_client, attachment, page_id, + self.allow_images, ) if result and result.text: @@ -304,13 +311,14 @@ class ConfluenceConnector( if "version" in page and "by" in page["version"]: author = page["version"]["by"] display_name = author.get("displayName", "Unknown") - primary_owners.append(BasicExpertInfo(display_name=display_name)) + email = author.get("email", "unknown@domain.invalid") + primary_owners.append( + BasicExpertInfo(display_name=display_name, email=email) + ) # Create the document return Document( - id=build_confluence_document_id( - self.wiki_base, page["_links"]["webui"], self.is_cloud - ), + id=page_url, sections=sections, source=DocumentSource.CONFLUENCE, semantic_identifier=page_title, @@ -373,6 +381,7 @@ class ConfluenceConnector( confluence_client=self.confluence_client, attachment=attachment, page_id=page["id"], + allow_images=self.allow_images, ) if response is None: continue diff --git a/backend/onyx/connectors/confluence/utils.py b/backend/onyx/connectors/confluence/utils.py index fcd98ead0..931bba845 100644 --- a/backend/onyx/connectors/confluence/utils.py +++ b/backend/onyx/connectors/confluence/utils.py @@ -112,6 +112,7 @@ def process_attachment( confluence_client: "OnyxConfluence", attachment: dict[str, Any], parent_content_id: str | None, + allow_images: bool, ) -> AttachmentProcessingResult: """ Processes a Confluence attachment. If it's a document, extracts text, @@ -119,7 +120,7 @@ def process_attachment( """ try: # Get the media type from the attachment metadata - media_type = attachment.get("metadata", {}).get("mediaType", "") + media_type: str = attachment.get("metadata", {}).get("mediaType", "") # Validate the attachment type if not validate_attachment_filetype(attachment): return AttachmentProcessingResult( @@ -138,7 +139,14 @@ def process_attachment( attachment_size = attachment["extensions"]["fileSize"] - if not media_type.startswith("image/"): + if media_type.startswith("image/"): + if not allow_images: + return AttachmentProcessingResult( + text=None, + file_name=None, + error="Image downloading is not enabled", + ) + else: if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD: logger.warning( f"Skipping {attachment_link} due to size. " @@ -294,6 +302,7 @@ def convert_attachment_to_content( confluence_client: "OnyxConfluence", attachment: dict[str, Any], page_id: str, + allow_images: bool, ) -> tuple[str | None, str | None] | None: """ Facade function which: @@ -309,7 +318,7 @@ def convert_attachment_to_content( ) return None - result = process_attachment(confluence_client, attachment, page_id) + result = process_attachment(confluence_client, attachment, page_id, allow_images) if result.error is not None: logger.warning( f"Attachment {attachment['title']} encountered error: {result.error}" diff --git a/backend/onyx/connectors/factory.py b/backend/onyx/connectors/factory.py index 2f0b10743..01a329f82 100644 --- a/backend/onyx/connectors/factory.py +++ b/backend/onyx/connectors/factory.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import Session from onyx.configs.app_configs import INTEGRATION_TESTS_MODE from onyx.configs.constants import DocumentSource +from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled from onyx.connectors.airtable.airtable_connector import AirtableConnector from onyx.connectors.asana.connector import AsanaConnector from onyx.connectors.axero.connector import AxeroConnector @@ -184,6 +185,8 @@ def instantiate_connector( if new_credentials is not None: backend_update_credential_json(credential, new_credentials, db_session) + connector.set_allow_images(get_image_extraction_and_analysis_enabled()) + return connector diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index 496d193ec..07993a1d4 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -86,6 +86,7 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]: def _convert_single_file( creds: Any, primary_admin_email: str, + allow_images: bool, file: dict[str, Any], ) -> Document | ConnectorFailure | None: user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email @@ -101,6 +102,7 @@ def _convert_single_file( file=file, drive_service=user_drive_service, docs_service=docs_service, + allow_images=allow_images, ) @@ -234,6 +236,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo self._creds: OAuthCredentials | ServiceAccountCredentials | None = None self._retrieved_ids: set[str] = set() + self.allow_images = False + + def set_allow_images(self, value: bool) -> None: + self.allow_images = value @property def primary_admin_email(self) -> str: @@ -900,6 +906,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo _convert_single_file, self.creds, self.primary_admin_email, + self.allow_images, ) # Fetch files in batches diff --git a/backend/onyx/connectors/google_drive/doc_conversion.py b/backend/onyx/connectors/google_drive/doc_conversion.py index bc4b83677..cfa246c82 100644 --- a/backend/onyx/connectors/google_drive/doc_conversion.py +++ b/backend/onyx/connectors/google_drive/doc_conversion.py @@ -79,6 +79,7 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool: def _extract_sections_basic( file: dict[str, str], service: GoogleDriveService, + allow_images: bool, ) -> list[TextSection | ImageSection]: """Extract text and images from a Google Drive file.""" file_id = file["id"] @@ -87,6 +88,10 @@ def _extract_sections_basic( link = file.get("webViewLink", "") try: + # skip images if not explicitly enabled + if not allow_images and is_gdrive_image_mime_type(mime_type): + return [] + # For Google Docs, Sheets, and Slides, export as plain text if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT: export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type] @@ -207,6 +212,7 @@ def convert_drive_item_to_document( file: GoogleDriveFileType, drive_service: Callable[[], GoogleDriveService], docs_service: Callable[[], GoogleDocsService], + allow_images: bool, ) -> Document | ConnectorFailure | None: """ Main entry point for converting a Google Drive file => Document object. @@ -236,7 +242,7 @@ def convert_drive_item_to_document( # If we don't have sections yet, use the basic extraction method if not sections: - sections = _extract_sections_basic(file, drive_service()) + sections = _extract_sections_basic(file, drive_service(), allow_images) # If we still don't have any sections, skip this file if not sections: diff --git a/backend/onyx/connectors/interfaces.py b/backend/onyx/connectors/interfaces.py index ae2109829..daa4b07c6 100644 --- a/backend/onyx/connectors/interfaces.py +++ b/backend/onyx/connectors/interfaces.py @@ -60,6 +60,10 @@ class BaseConnector(abc.ABC, Generic[CT]): Default is a no-op (always successful). """ + def set_allow_images(self, value: bool) -> None: + """Implement if the underlying connector wants to skip/allow image downloading + based on the application level image analysis setting.""" + def build_dummy_checkpoint(self) -> CT: # TODO: find a way to make this work without type: ignore return ConnectorCheckpoint(has_more=True) # type: ignore diff --git a/backend/tests/daily/connectors/confluence/test_confluence_basic.py b/backend/tests/daily/connectors/confluence/test_confluence_basic.py index b675f3035..499907d87 100644 --- a/backend/tests/daily/connectors/confluence/test_confluence_basic.py +++ b/backend/tests/daily/connectors/confluence/test_confluence_basic.py @@ -1,5 +1,6 @@ import os import time +from typing import Any from unittest.mock import MagicMock from unittest.mock import patch @@ -7,15 +8,16 @@ import pytest from onyx.configs.constants import DocumentSource from onyx.connectors.confluence.connector import ConfluenceConnector +from onyx.connectors.confluence.utils import AttachmentProcessingResult from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider from onyx.connectors.models import Document @pytest.fixture -def confluence_connector() -> ConfluenceConnector: +def confluence_connector(space: str) -> ConfluenceConnector: connector = ConfluenceConnector( wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"], - space=os.environ["CONFLUENCE_TEST_SPACE"], + space=space, is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true", page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""), ) @@ -32,14 +34,15 @@ def confluence_connector() -> ConfluenceConnector: return connector +@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]]) @patch( "onyx.file_processing.extract_file_text.get_unstructured_api_key", return_value=None, ) -@pytest.mark.skip(reason="Skipping this test") def test_confluence_connector_basic( mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector ) -> None: + confluence_connector.set_allow_images(False) doc_batch_generator = confluence_connector.poll_source(0, time.time()) doc_batch = next(doc_batch_generator) @@ -50,15 +53,14 @@ def test_confluence_connector_basic( page_within_a_page_doc: Document | None = None page_doc: Document | None = None - txt_doc: Document | None = None for doc in doc_batch: if doc.semantic_identifier == "DailyConnectorTestSpace Home": page_doc = doc - elif ".txt" in doc.semantic_identifier: - txt_doc = doc elif doc.semantic_identifier == "Page Within A Page": page_within_a_page_doc = doc + else: + pass assert page_within_a_page_doc is not None assert page_within_a_page_doc.semantic_identifier == "Page Within A Page" @@ -79,7 +81,7 @@ def test_confluence_connector_basic( assert page_doc.metadata["labels"] == ["testlabel"] assert page_doc.primary_owners assert page_doc.primary_owners[0].email == "hagen@danswer.ai" - assert len(page_doc.sections) == 1 + assert len(page_doc.sections) == 2 # page text + attachment text page_section = page_doc.sections[0] assert page_section.text == "test123 " + page_within_a_page_text @@ -88,13 +90,65 @@ def test_confluence_connector_basic( == "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview" ) - assert txt_doc is not None - assert txt_doc.semantic_identifier == "small-file.txt" - assert len(txt_doc.sections) == 1 - assert txt_doc.sections[0].text == "small" - assert txt_doc.primary_owners - assert txt_doc.primary_owners[0].email == "chris@onyx.app" - assert ( - txt_doc.sections[0].link - == "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt" + text_attachment_section = page_doc.sections[1] + assert text_attachment_section.text == "small" + assert text_attachment_section.link + assert text_attachment_section.link.endswith("small-file.txt") + + +@pytest.mark.parametrize("space", ["MI"]) +@patch( + "onyx.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_confluence_connector_skip_images( + mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector +) -> None: + confluence_connector.set_allow_images(False) + doc_batch_generator = confluence_connector.poll_source(0, time.time()) + + doc_batch = next(doc_batch_generator) + with pytest.raises(StopIteration): + next(doc_batch_generator) + + assert len(doc_batch) == 8 + assert sum(len(doc.sections) for doc in doc_batch) == 8 + + +def mock_process_image_attachment( + *args: Any, **kwargs: Any +) -> AttachmentProcessingResult: + """We need this mock to bypass DB access happening in the connector. Which shouldn't + be done as a rule to begin with, but life is not perfect. Fix it later""" + + return AttachmentProcessingResult( + text="Hi_text", + file_name="Hi_filename", + error=None, ) + + +@pytest.mark.parametrize("space", ["MI"]) +@patch( + "onyx.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +@patch( + "onyx.connectors.confluence.utils._process_image_attachment", + side_effect=mock_process_image_attachment, +) +def test_confluence_connector_allow_images( + mock_get_api_key: MagicMock, + mock_process_image_attachment: MagicMock, + confluence_connector: ConfluenceConnector, +) -> None: + confluence_connector.set_allow_images(True) + + doc_batch_generator = confluence_connector.poll_source(0, time.time()) + + doc_batch = next(doc_batch_generator) + with pytest.raises(StopIteration): + next(doc_batch_generator) + + assert len(doc_batch) == 8 + assert sum(len(doc.sections) for doc in doc_batch) == 12 From 775c847f822c526eea3f34328356319c90f267e5 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Thu, 20 Mar 2025 17:23:55 -0700 Subject: [PATCH 14/18] Reduce drive retries (#4312) * Reduce drive retries * timestamp format fix --------- Co-authored-by: Evan Lohn --- backend/onyx/connectors/google_drive/connector.py | 4 +++- .../onyx/connectors/google_drive/file_retrieval.py | 5 +++-- .../onyx/connectors/google_utils/google_utils.py | 13 ++++++++----- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index 07993a1d4..a9f1f9469 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -1104,7 +1104,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo drive_service.files().list(pageSize=1, fields="files(id)").execute() if isinstance(self._creds, ServiceAccountCredentials): - retry_builder()(get_root_folder_id)(drive_service) + # default is ~17mins of retries, don't do that here since this is called from + # the UI + retry_builder(tries=3, delay=0.1)(get_root_folder_id)(drive_service) except HttpError as e: status_code = e.resp.status if e.resp else None diff --git a/backend/onyx/connectors/google_drive/file_retrieval.py b/backend/onyx/connectors/google_drive/file_retrieval.py index 3fcda064c..c4d8e6257 100644 --- a/backend/onyx/connectors/google_drive/file_retrieval.py +++ b/backend/onyx/connectors/google_drive/file_retrieval.py @@ -1,6 +1,7 @@ from collections.abc import Callable from collections.abc import Iterator from datetime import datetime +from datetime import timezone from googleapiclient.discovery import Resource # type: ignore @@ -36,12 +37,12 @@ def _generate_time_range_filter( ) -> str: time_range_filter = "" if start is not None: - time_start = datetime.utcfromtimestamp(start).isoformat() + "Z" + time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat() time_range_filter += ( f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'" ) if end is not None: - time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z" + time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat() time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'" return time_range_filter diff --git a/backend/onyx/connectors/google_utils/google_utils.py b/backend/onyx/connectors/google_utils/google_utils.py index 60ee3373c..4ad6cbe7a 100644 --- a/backend/onyx/connectors/google_utils/google_utils.py +++ b/backend/onyx/connectors/google_utils/google_utils.py @@ -17,9 +17,12 @@ logger = setup_logger() # Google Drive APIs are quite flakey and may 500 for an -# extended period of time. Trying to combat here by adding a very -# long retry period (~20 minutes of trying every minute) -add_retries = retry_builder(tries=50, max_delay=30) +# extended period of time. This is now addressed by checkpointing. +# +# NOTE: We previously tried to combat this here by adding a very +# long retry period (~20 minutes of trying, one request a minute.) +# This is no longer necessary due to checkpointing. +add_retries = retry_builder(tries=5, max_delay=10) NEXT_PAGE_TOKEN_KEY = "nextPageToken" PAGE_TOKEN_KEY = "pageToken" @@ -37,14 +40,14 @@ class GoogleFields(str, Enum): def _execute_with_retry(request: Any) -> Any: - max_attempts = 10 + max_attempts = 6 attempt = 1 while attempt < max_attempts: # Note for reasons unknown, the Google API will sometimes return a 429 # and even after waiting the retry period, it will return another 429. # It could be due to a few possibilities: - # 1. Other things are also requesting from the Gmail API with the same key + # 1. Other things are also requesting from the Drive/Gmail API with the same key # 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly # 3. The retry-after has a maximum and we've already hit the limit for the day # or it's something else... From d123713c006f4a24cc931a3b6bdc2b9980c03e0e Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 21 Mar 2025 11:11:00 -0700 Subject: [PATCH 15/18] Fix GPU status request in sync flow (#4318) * Fix GPU status request in sync flow * tweak * Fix test * Fix more tests --- backend/onyx/chat/answer.py | 6 ++++-- backend/onyx/setup.py | 2 +- backend/onyx/utils/gpu_utils.py | 16 ++++++++++++++-- backend/tests/unit/onyx/chat/test_answer.py | 4 ++-- backend/tests/unit/onyx/chat/test_skip_gen_ai.py | 2 +- 5 files changed, 22 insertions(+), 8 deletions(-) diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index eb9b2130d..0bf937b6c 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -30,7 +30,7 @@ from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.utils import explicit_tool_calling_supported -from onyx.utils.gpu_utils import gpu_status_request +from onyx.utils.gpu_utils import fast_gpu_status_request from onyx.utils.logger import setup_logger logger = setup_logger() @@ -88,7 +88,9 @@ class Answer: rerank_settings is not None and rerank_settings.rerank_provider_type is not None ) - allow_agent_reranking = gpu_status_request() or using_cloud_reranking + allow_agent_reranking = ( + fast_gpu_status_request(indexing=False) or using_cloud_reranking + ) # TODO: this is a hack to force the query to be used for the search tool # this should be removed once we fully unify graph inputs (i.e. diff --git a/backend/onyx/setup.py b/backend/onyx/setup.py index 750b35d8d..b1d2a4c04 100644 --- a/backend/onyx/setup.py +++ b/backend/onyx/setup.py @@ -324,7 +324,7 @@ def update_default_multipass_indexing(db_session: Session) -> None: logger.info( "No existing docs or connectors found. Checking GPU availability for multipass indexing." ) - gpu_available = gpu_status_request() + gpu_available = gpu_status_request(indexing=True) logger.info(f"GPU available: {gpu_available}") current_settings = get_current_search_settings(db_session) diff --git a/backend/onyx/utils/gpu_utils.py b/backend/onyx/utils/gpu_utils.py index 75acc0232..c348e40b1 100644 --- a/backend/onyx/utils/gpu_utils.py +++ b/backend/onyx/utils/gpu_utils.py @@ -1,3 +1,5 @@ +from functools import lru_cache + import requests from retry import retry @@ -10,8 +12,7 @@ from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() -@retry(tries=5, delay=5) -def gpu_status_request(indexing: bool = True) -> bool: +def _get_gpu_status_from_model_server(indexing: bool) -> bool: if indexing: model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}" else: @@ -28,3 +29,14 @@ def gpu_status_request(indexing: bool = True) -> bool: except requests.RequestException as e: logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}") raise # Re-raise exception to trigger a retry + + +@retry(tries=5, delay=5) +def gpu_status_request(indexing: bool) -> bool: + return _get_gpu_status_from_model_server(indexing) + + +@lru_cache(maxsize=1) +def fast_gpu_status_request(indexing: bool) -> bool: + """For use in sync flows, where we don't want to retry / we want to cache this.""" + return gpu_status_request(indexing=indexing) diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index 8e2a5f448..34f46fff9 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -50,7 +50,7 @@ def answer_instance( mocker: MockerFixture, ) -> Answer: mocker.patch( - "onyx.chat.answer.gpu_status_request", + "onyx.chat.answer.fast_gpu_status_request", return_value=True, ) return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config) @@ -400,7 +400,7 @@ def test_no_slow_reranking( mocker: MockerFixture, ) -> None: mocker.patch( - "onyx.chat.answer.gpu_status_request", + "onyx.chat.answer.fast_gpu_status_request", return_value=gpu_enabled, ) rerank_settings = ( diff --git a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py index 146a08f60..c1c17e362 100644 --- a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py +++ b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py @@ -39,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag( mocker: MockerFixture, ) -> None: mocker.patch( - "onyx.chat.answer.gpu_status_request", + "onyx.chat.answer.fast_gpu_status_request", return_value=True, ) question = config["question"] From 52b96854a26f99925dc2bbde106643210b5e2ced Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 21 Mar 2025 11:11:12 -0700 Subject: [PATCH 16/18] Handle move errors (#4317) * Handle move errors * Make a warning --- backend/model_server/main.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/backend/model_server/main.py b/backend/model_server/main.py index 3a6a56297..bbec7a933 100644 --- a/backend/model_server/main.py +++ b/backend/model_server/main.py @@ -65,11 +65,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: app.state.gpu_type = gpu_type - if TEMP_HF_CACHE_PATH.is_dir(): - logger.notice("Moving contents of temp_huggingface to huggingface cache.") - _move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH) - shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True) - logger.notice("Moved contents of temp_huggingface to huggingface cache.") + try: + if TEMP_HF_CACHE_PATH.is_dir(): + logger.notice("Moving contents of temp_huggingface to huggingface cache.") + _move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH) + shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True) + logger.notice("Moved contents of temp_huggingface to huggingface cache.") + except Exception as e: + logger.warning( + f"Error moving contents of temp_huggingface to huggingface cache: {e}. " + "This is not a critical error and the model server will continue to run." + ) torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads())) logger.notice(f"Torch Threads: {torch.get_num_threads()}") From 61facfb0a8d62590ef5bc6d9a7050777a8dde583 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 21 Mar 2025 14:30:03 -0700 Subject: [PATCH 17/18] Fix slack connector (#4326) --- backend/onyx/connectors/slack/connector.py | 25 +++++++++++++++------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index 83e52c410..b38f216b3 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -438,7 +438,11 @@ def _get_all_doc_ids( class ProcessedSlackMessage(BaseModel): doc: Document | None - thread_ts: str | None + # if the message is part of a thread, this is the thread_ts + # otherwise, this is the message_ts. Either way, will be a unique identifier. + # In the future, if the message becomes a thread, then the thread_ts + # will be set to the message_ts. + thread_or_message_ts: str failure: ConnectorFailure | None @@ -452,6 +456,7 @@ def _process_message( msg_filter_func: Callable[[MessageType], bool] = default_msg_filter, ) -> ProcessedSlackMessage: thread_ts = message.get("thread_ts") + thread_or_message_ts = thread_ts or message["ts"] try: # causes random failures for testing checkpointing / continue on failure # import random @@ -467,16 +472,18 @@ def _process_message( seen_thread_ts=seen_thread_ts, msg_filter_func=msg_filter_func, ) - return ProcessedSlackMessage(doc=doc, thread_ts=thread_ts, failure=None) + return ProcessedSlackMessage( + doc=doc, thread_or_message_ts=thread_or_message_ts, failure=None + ) except Exception as e: logger.exception(f"Error processing message {message['ts']}") return ProcessedSlackMessage( doc=None, - thread_ts=thread_ts, + thread_or_message_ts=thread_or_message_ts, failure=ConnectorFailure( failed_document=DocumentFailure( document_id=_build_doc_id( - channel_id=channel["id"], thread_ts=(thread_ts or message["ts"]) + channel_id=channel["id"], thread_ts=thread_or_message_ts ), document_link=get_message_link(message, client, channel["id"]), ), @@ -616,7 +623,7 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]): for future in as_completed(futures): processed_slack_message = future.result() doc = processed_slack_message.doc - thread_ts = processed_slack_message.thread_ts + thread_or_message_ts = processed_slack_message.thread_or_message_ts failure = processed_slack_message.failure if doc: # handle race conditions here since this is single @@ -624,11 +631,13 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]): # but since this is single threaded, we won't run into simul # writes. At worst, we can duplicate a thread, which will be # deduped later on. - if thread_ts not in seen_thread_ts: + if thread_or_message_ts not in seen_thread_ts: yield doc - assert thread_ts, "found non-None doc with None thread_ts" - seen_thread_ts.add(thread_ts) + assert ( + thread_or_message_ts + ), "found non-None doc with None thread_or_message_ts" + seen_thread_ts.add(thread_or_message_ts) elif failure: yield failure From fce81ebb6069775e766f4b28f5b873acf477ae61 Mon Sep 17 00:00:00 2001 From: pablonyx Date: Fri, 21 Mar 2025 14:50:56 -0700 Subject: [PATCH 18/18] Minor ux nits (#4327) * k * quick fix --- web/src/app/chat/message/SourcesDisplay.tsx | 1 + web/src/components/SearchResultIcon.tsx | 2 +- web/src/components/WebResultIcon.tsx | 5 ++-- .../components/chat/sources/SourceCard.tsx | 29 ------------------- 4 files changed, 4 insertions(+), 33 deletions(-) diff --git a/web/src/app/chat/message/SourcesDisplay.tsx b/web/src/app/chat/message/SourcesDisplay.tsx index 15e3cd75a..ce106e926 100644 --- a/web/src/app/chat/message/SourcesDisplay.tsx +++ b/web/src/app/chat/message/SourcesDisplay.tsx @@ -54,6 +54,7 @@ export const SourceCard: React.FC<{
+
{truncatedIdentifier}
diff --git a/web/src/components/SearchResultIcon.tsx b/web/src/components/SearchResultIcon.tsx index a042ee909..00bc31dc3 100644 --- a/web/src/components/SearchResultIcon.tsx +++ b/web/src/components/SearchResultIcon.tsx @@ -49,7 +49,7 @@ export function SearchResultIcon({ url }: { url: string }) { if (!faviconUrl) { return ; } - if (url.includes("docs.onyx.app")) { + if (url.includes("onyx.app")) { return ; } diff --git a/web/src/components/WebResultIcon.tsx b/web/src/components/WebResultIcon.tsx index 221a094e3..0d78b14a3 100644 --- a/web/src/components/WebResultIcon.tsx +++ b/web/src/components/WebResultIcon.tsx @@ -17,12 +17,11 @@ export function WebResultIcon({ try { hostname = new URL(url).hostname; } catch (e) { - // console.log(e); - hostname = "docs.onyx.app"; + hostname = "onyx.app"; } return ( <> - {hostname == "docs.onyx.app" ? ( + {hostname.includes("onyx.app") ? ( ) : !error ? ( void; -// }) { -// return ( -//
openDocument(doc, setPresentingDocument)} -// className="cursor-pointer h-[80px] text-left overflow-hidden flex flex-col gap-0.5 rounded-lg px-3 py-2 bg-accent-background hover:bg-accent-background-hovered w-[200px]" -// > -//
-// {doc.is_internet || doc.source_type === "web" ? ( -// -// ) : ( -// -// )} -//

{truncateString(doc.semantic_identifier || doc.document_id, 20)}

-//
-//
-//
-// {doc.blurb} -//
-//
-// ); -// } - interface SeeMoreBlockProps { toggleDocumentSelection: () => void; docs: OnyxDocument[];