Feature/salesforce correctness 2 (#4506)

* refactor salesforce sqlite db access

* more refactoring

* refactor again

* refactor again

* rename object

* add finalizer to ensure db connection is always closed

* avoid unnecessarily nesting connections and commit regularly when possible

* remove db usage from csv download

* dead code

* hide deprecation warning in ddtrace

* remove unused param

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2025-04-23 18:05:52 -07:00 committed by GitHub
parent c93cebe1ab
commit c83ee06062
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 55 additions and 71 deletions

View File

@ -111,43 +111,19 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
@staticmethod
def _download_object_csvs(
sf_db: OnyxSalesforceSQLite,
all_types_to_filter: dict[str, bool],
directory: str,
parent_object_list: list[str],
sf_client: Salesforce,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> None:
all_types: set[str] = set(parent_object_list)
logger.info(
f"Parent object types: num={len(parent_object_list)} list={parent_object_list}"
)
# This takes like 20 seconds
for parent_object_type in parent_object_list:
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
logger.debug(
f"Found {len(child_types)} child types for {parent_object_type}"
)
all_types.update(child_types)
# Always want to make sure user is grabbed for permissioning purposes
all_types.add("User")
logger.info(f"All object types: num={len(all_types)} list={all_types}")
# gc.collect()
# checkpoint - we've found all object types, now time to fetch the data
logger.info("Fetching CSVs for all object types")
# This takes like 30 minutes first time and <2 minutes for updates
object_type_to_csv_path = fetch_all_csvs_in_parallel(
sf_db=sf_db,
sf_client=sf_client,
object_types=all_types,
all_types_to_filter=all_types_to_filter,
start=start,
end=end,
target_dir=directory,
@ -224,6 +200,30 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
return updated_ids
@staticmethod
def _get_all_types(parent_types: list[str], sf_client: Salesforce) -> set[str]:
all_types: set[str] = set(parent_types)
# Step 1 - get all object types
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
# This takes like 20 seconds
for parent_object_type in parent_types:
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
logger.debug(
f"Found {len(child_types)} child types for {parent_object_type}"
)
all_types.update(child_types)
# Always want to make sure user is grabbed for permissioning purposes
all_types.add("User")
logger.info(f"All object types: num={len(all_types)} list={all_types}")
# gc.collect()
return all_types
def _fetch_from_salesforce(
self,
temp_dir: str,
@ -244,9 +244,24 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
sf_db.apply_schema()
sf_db.log_stats()
# Step 1 - download
# Step 1.1 - add child object types + "User" type to the list of types
all_types = SalesforceConnector._get_all_types(
self.parent_object_list, self._sf_client
)
"""Only add time filter if there is at least one object of the type
in the database. We aren't worried about partially completed object update runs
because this occurs after we check for existing csvs which covers this case"""
all_types_to_filter: dict[str, bool] = {}
for sf_type in all_types:
if sf_db.has_at_least_one_object_of_type(sf_type):
all_types_to_filter[sf_type] = True
else:
all_types_to_filter[sf_type] = False
# Step 1.2 - bulk download the CSV for each object type
SalesforceConnector._download_object_csvs(
sf_db, temp_dir, self.parent_object_list, self._sf_client, start, end
all_types_to_filter, temp_dir, self._sf_client, start, end
)
gc.collect()

View File

@ -10,7 +10,6 @@ from simple_salesforce.bulk2 import SFBulk2Handler
from simple_salesforce.bulk2 import SFBulk2Type
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -184,9 +183,8 @@ def _bulk_retrieve_from_salesforce(
def fetch_all_csvs_in_parallel(
sf_db: OnyxSalesforceSQLite,
sf_client: Salesforce,
object_types: set[str],
all_types_to_filter: dict[str, bool],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
target_dir: str,
@ -219,20 +217,16 @@ def fetch_all_csvs_in_parallel(
)
time_filter_for_each_object_type = {}
# We do this outside of the thread pool executor because this requires
# a database connection and we don't want to block the thread pool
# executor from running
for sf_type in object_types:
"""Only add time filter if there is at least one object of the type
in the database. We aren't worried about partially completed object update runs
because this occurs after we check for existing csvs which covers this case"""
if sf_db.has_at_least_one_object_of_type(sf_type):
if sf_type in created_date_types:
time_filter_for_each_object_type[sf_type] = created_date_time_filter
else:
time_filter_for_each_object_type[sf_type] = last_modified_time_filter
else:
for sf_type, apply_filter in all_types_to_filter.items():
if not apply_filter:
time_filter_for_each_object_type[sf_type] = ""
continue
if sf_type in created_date_types:
time_filter_for_each_object_type[sf_type] = created_date_time_filter
else:
time_filter_for_each_object_type[sf_type] = last_modified_time_filter
# Run the bulk retrieve in parallel
with ThreadPoolExecutor() as executor:
@ -243,6 +237,6 @@ def fetch_all_csvs_in_parallel(
time_filter=time_filter_for_each_object_type[object_type],
target_dir=target_dir,
),
object_types,
all_types_to_filter.keys(),
)
return dict(results)

View File

@ -572,28 +572,3 @@ class OnyxSalesforceSQLite:
AND json_extract(data, '$.Email') IS NOT NULL
"""
)
# @contextmanager
# def get_db_connection(
# directory: str,
# isolation_level: str | None = None,
# ) -> Iterator[sqlite3.Connection]:
# """Get a database connection with proper isolation level and error handling.
# Args:
# isolation_level: SQLite isolation level. None = default "DEFERRED",
# can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
# """
# # 60 second timeout for locks
# conn = sqlite3.connect(get_sqlite_db_path(directory), timeout=60.0)
# if isolation_level is not None:
# conn.isolation_level = isolation_level
# try:
# yield conn
# except Exception:
# conn.rollback()
# raise
# finally:
# conn.close()

View File

@ -5,4 +5,4 @@ markers =
filterwarnings =
ignore::DeprecationWarning
ignore::cryptography.utils.CryptographyDeprecationWarning
ignore::PendingDeprecationWarning:ddtrace.internal.module