From c83ee06062d379e4ac8a0bbaf7cc9da1bf6f7c0b Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Wed, 23 Apr 2025 18:05:52 -0700 Subject: [PATCH] Feature/salesforce correctness 2 (#4506) * refactor salesforce sqlite db access * more refactoring * refactor again * refactor again * rename object * add finalizer to ensure db connection is always closed * avoid unnecessarily nesting connections and commit regularly when possible * remove db usage from csv download * dead code * hide deprecation warning in ddtrace * remove unused param --------- Co-authored-by: Richard Kuo (Onyx) --- .../onyx/connectors/salesforce/connector.py | 71 +++++++++++-------- .../connectors/salesforce/salesforce_calls.py | 28 +++----- .../connectors/salesforce/sqlite_functions.py | 25 ------- backend/pytest.ini | 2 +- 4 files changed, 55 insertions(+), 71 deletions(-) diff --git a/backend/onyx/connectors/salesforce/connector.py b/backend/onyx/connectors/salesforce/connector.py index c9076d4c8..9141e8fb2 100644 --- a/backend/onyx/connectors/salesforce/connector.py +++ b/backend/onyx/connectors/salesforce/connector.py @@ -111,43 +111,19 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector): @staticmethod def _download_object_csvs( - sf_db: OnyxSalesforceSQLite, + all_types_to_filter: dict[str, bool], directory: str, - parent_object_list: list[str], sf_client: Salesforce, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> None: - all_types: set[str] = set(parent_object_list) - - logger.info( - f"Parent object types: num={len(parent_object_list)} list={parent_object_list}" - ) - - # This takes like 20 seconds - for parent_object_type in parent_object_list: - child_types = get_all_children_of_sf_type(sf_client, parent_object_type) - logger.debug( - f"Found {len(child_types)} child types for {parent_object_type}" - ) - - all_types.update(child_types) - - # Always want to make sure user is grabbed for permissioning purposes - all_types.add("User") - - logger.info(f"All object types: num={len(all_types)} list={all_types}") - - # gc.collect() - # checkpoint - we've found all object types, now time to fetch the data logger.info("Fetching CSVs for all object types") # This takes like 30 minutes first time and <2 minutes for updates object_type_to_csv_path = fetch_all_csvs_in_parallel( - sf_db=sf_db, sf_client=sf_client, - object_types=all_types, + all_types_to_filter=all_types_to_filter, start=start, end=end, target_dir=directory, @@ -224,6 +200,30 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector): return updated_ids + @staticmethod + def _get_all_types(parent_types: list[str], sf_client: Salesforce) -> set[str]: + all_types: set[str] = set(parent_types) + + # Step 1 - get all object types + logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}") + + # This takes like 20 seconds + for parent_object_type in parent_types: + child_types = get_all_children_of_sf_type(sf_client, parent_object_type) + logger.debug( + f"Found {len(child_types)} child types for {parent_object_type}" + ) + + all_types.update(child_types) + + # Always want to make sure user is grabbed for permissioning purposes + all_types.add("User") + + logger.info(f"All object types: num={len(all_types)} list={all_types}") + + # gc.collect() + return all_types + def _fetch_from_salesforce( self, temp_dir: str, @@ -244,9 +244,24 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector): sf_db.apply_schema() sf_db.log_stats() - # Step 1 - download + # Step 1.1 - add child object types + "User" type to the list of types + all_types = SalesforceConnector._get_all_types( + self.parent_object_list, self._sf_client + ) + + """Only add time filter if there is at least one object of the type + in the database. We aren't worried about partially completed object update runs + because this occurs after we check for existing csvs which covers this case""" + all_types_to_filter: dict[str, bool] = {} + for sf_type in all_types: + if sf_db.has_at_least_one_object_of_type(sf_type): + all_types_to_filter[sf_type] = True + else: + all_types_to_filter[sf_type] = False + + # Step 1.2 - bulk download the CSV for each object type SalesforceConnector._download_object_csvs( - sf_db, temp_dir, self.parent_object_list, self._sf_client, start, end + all_types_to_filter, temp_dir, self._sf_client, start, end ) gc.collect() diff --git a/backend/onyx/connectors/salesforce/salesforce_calls.py b/backend/onyx/connectors/salesforce/salesforce_calls.py index 233281801..51ff6dc6a 100644 --- a/backend/onyx/connectors/salesforce/salesforce_calls.py +++ b/backend/onyx/connectors/salesforce/salesforce_calls.py @@ -10,7 +10,6 @@ from simple_salesforce.bulk2 import SFBulk2Handler from simple_salesforce.bulk2 import SFBulk2Type from onyx.connectors.interfaces import SecondsSinceUnixEpoch -from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite from onyx.utils.logger import setup_logger logger = setup_logger() @@ -184,9 +183,8 @@ def _bulk_retrieve_from_salesforce( def fetch_all_csvs_in_parallel( - sf_db: OnyxSalesforceSQLite, sf_client: Salesforce, - object_types: set[str], + all_types_to_filter: dict[str, bool], start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None, target_dir: str, @@ -219,20 +217,16 @@ def fetch_all_csvs_in_parallel( ) time_filter_for_each_object_type = {} - # We do this outside of the thread pool executor because this requires - # a database connection and we don't want to block the thread pool - # executor from running - for sf_type in object_types: - """Only add time filter if there is at least one object of the type - in the database. We aren't worried about partially completed object update runs - because this occurs after we check for existing csvs which covers this case""" - if sf_db.has_at_least_one_object_of_type(sf_type): - if sf_type in created_date_types: - time_filter_for_each_object_type[sf_type] = created_date_time_filter - else: - time_filter_for_each_object_type[sf_type] = last_modified_time_filter - else: + + for sf_type, apply_filter in all_types_to_filter.items(): + if not apply_filter: time_filter_for_each_object_type[sf_type] = "" + continue + + if sf_type in created_date_types: + time_filter_for_each_object_type[sf_type] = created_date_time_filter + else: + time_filter_for_each_object_type[sf_type] = last_modified_time_filter # Run the bulk retrieve in parallel with ThreadPoolExecutor() as executor: @@ -243,6 +237,6 @@ def fetch_all_csvs_in_parallel( time_filter=time_filter_for_each_object_type[object_type], target_dir=target_dir, ), - object_types, + all_types_to_filter.keys(), ) return dict(results) diff --git a/backend/onyx/connectors/salesforce/sqlite_functions.py b/backend/onyx/connectors/salesforce/sqlite_functions.py index c3b6409a1..f6e030efe 100644 --- a/backend/onyx/connectors/salesforce/sqlite_functions.py +++ b/backend/onyx/connectors/salesforce/sqlite_functions.py @@ -572,28 +572,3 @@ class OnyxSalesforceSQLite: AND json_extract(data, '$.Email') IS NOT NULL """ ) - - -# @contextmanager -# def get_db_connection( -# directory: str, -# isolation_level: str | None = None, -# ) -> Iterator[sqlite3.Connection]: -# """Get a database connection with proper isolation level and error handling. - -# Args: -# isolation_level: SQLite isolation level. None = default "DEFERRED", -# can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation. -# """ -# # 60 second timeout for locks -# conn = sqlite3.connect(get_sqlite_db_path(directory), timeout=60.0) - -# if isolation_level is not None: -# conn.isolation_level = isolation_level -# try: -# yield conn -# except Exception: -# conn.rollback() -# raise -# finally: -# conn.close() diff --git a/backend/pytest.ini b/backend/pytest.ini index 954a02740..55587c398 100644 --- a/backend/pytest.ini +++ b/backend/pytest.ini @@ -5,4 +5,4 @@ markers = filterwarnings = ignore::DeprecationWarning ignore::cryptography.utils.CryptographyDeprecationWarning - \ No newline at end of file + ignore::PendingDeprecationWarning:ddtrace.internal.module