mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-10 11:50:32 +02:00
Feature/salesforce correctness 2 (#4506)
* refactor salesforce sqlite db access * more refactoring * refactor again * refactor again * rename object * add finalizer to ensure db connection is always closed * avoid unnecessarily nesting connections and commit regularly when possible * remove db usage from csv download * dead code * hide deprecation warning in ddtrace * remove unused param --------- Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
parent
c93cebe1ab
commit
c83ee06062
@ -111,43 +111,19 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
@staticmethod
|
||||
def _download_object_csvs(
|
||||
sf_db: OnyxSalesforceSQLite,
|
||||
all_types_to_filter: dict[str, bool],
|
||||
directory: str,
|
||||
parent_object_list: list[str],
|
||||
sf_client: Salesforce,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> None:
|
||||
all_types: set[str] = set(parent_object_list)
|
||||
|
||||
logger.info(
|
||||
f"Parent object types: num={len(parent_object_list)} list={parent_object_list}"
|
||||
)
|
||||
|
||||
# This takes like 20 seconds
|
||||
for parent_object_type in parent_object_list:
|
||||
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
|
||||
logger.debug(
|
||||
f"Found {len(child_types)} child types for {parent_object_type}"
|
||||
)
|
||||
|
||||
all_types.update(child_types)
|
||||
|
||||
# Always want to make sure user is grabbed for permissioning purposes
|
||||
all_types.add("User")
|
||||
|
||||
logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||
|
||||
# gc.collect()
|
||||
|
||||
# checkpoint - we've found all object types, now time to fetch the data
|
||||
logger.info("Fetching CSVs for all object types")
|
||||
|
||||
# This takes like 30 minutes first time and <2 minutes for updates
|
||||
object_type_to_csv_path = fetch_all_csvs_in_parallel(
|
||||
sf_db=sf_db,
|
||||
sf_client=sf_client,
|
||||
object_types=all_types,
|
||||
all_types_to_filter=all_types_to_filter,
|
||||
start=start,
|
||||
end=end,
|
||||
target_dir=directory,
|
||||
@ -224,6 +200,30 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
return updated_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_all_types(parent_types: list[str], sf_client: Salesforce) -> set[str]:
|
||||
all_types: set[str] = set(parent_types)
|
||||
|
||||
# Step 1 - get all object types
|
||||
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
|
||||
|
||||
# This takes like 20 seconds
|
||||
for parent_object_type in parent_types:
|
||||
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
|
||||
logger.debug(
|
||||
f"Found {len(child_types)} child types for {parent_object_type}"
|
||||
)
|
||||
|
||||
all_types.update(child_types)
|
||||
|
||||
# Always want to make sure user is grabbed for permissioning purposes
|
||||
all_types.add("User")
|
||||
|
||||
logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||
|
||||
# gc.collect()
|
||||
return all_types
|
||||
|
||||
def _fetch_from_salesforce(
|
||||
self,
|
||||
temp_dir: str,
|
||||
@ -244,9 +244,24 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
sf_db.apply_schema()
|
||||
sf_db.log_stats()
|
||||
|
||||
# Step 1 - download
|
||||
# Step 1.1 - add child object types + "User" type to the list of types
|
||||
all_types = SalesforceConnector._get_all_types(
|
||||
self.parent_object_list, self._sf_client
|
||||
)
|
||||
|
||||
"""Only add time filter if there is at least one object of the type
|
||||
in the database. We aren't worried about partially completed object update runs
|
||||
because this occurs after we check for existing csvs which covers this case"""
|
||||
all_types_to_filter: dict[str, bool] = {}
|
||||
for sf_type in all_types:
|
||||
if sf_db.has_at_least_one_object_of_type(sf_type):
|
||||
all_types_to_filter[sf_type] = True
|
||||
else:
|
||||
all_types_to_filter[sf_type] = False
|
||||
|
||||
# Step 1.2 - bulk download the CSV for each object type
|
||||
SalesforceConnector._download_object_csvs(
|
||||
sf_db, temp_dir, self.parent_object_list, self._sf_client, start, end
|
||||
all_types_to_filter, temp_dir, self._sf_client, start, end
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
|
@ -10,7 +10,6 @@ from simple_salesforce.bulk2 import SFBulk2Handler
|
||||
from simple_salesforce.bulk2 import SFBulk2Type
|
||||
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -184,9 +183,8 @@ def _bulk_retrieve_from_salesforce(
|
||||
|
||||
|
||||
def fetch_all_csvs_in_parallel(
|
||||
sf_db: OnyxSalesforceSQLite,
|
||||
sf_client: Salesforce,
|
||||
object_types: set[str],
|
||||
all_types_to_filter: dict[str, bool],
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
target_dir: str,
|
||||
@ -219,20 +217,16 @@ def fetch_all_csvs_in_parallel(
|
||||
)
|
||||
|
||||
time_filter_for_each_object_type = {}
|
||||
# We do this outside of the thread pool executor because this requires
|
||||
# a database connection and we don't want to block the thread pool
|
||||
# executor from running
|
||||
for sf_type in object_types:
|
||||
"""Only add time filter if there is at least one object of the type
|
||||
in the database. We aren't worried about partially completed object update runs
|
||||
because this occurs after we check for existing csvs which covers this case"""
|
||||
if sf_db.has_at_least_one_object_of_type(sf_type):
|
||||
if sf_type in created_date_types:
|
||||
time_filter_for_each_object_type[sf_type] = created_date_time_filter
|
||||
else:
|
||||
time_filter_for_each_object_type[sf_type] = last_modified_time_filter
|
||||
else:
|
||||
|
||||
for sf_type, apply_filter in all_types_to_filter.items():
|
||||
if not apply_filter:
|
||||
time_filter_for_each_object_type[sf_type] = ""
|
||||
continue
|
||||
|
||||
if sf_type in created_date_types:
|
||||
time_filter_for_each_object_type[sf_type] = created_date_time_filter
|
||||
else:
|
||||
time_filter_for_each_object_type[sf_type] = last_modified_time_filter
|
||||
|
||||
# Run the bulk retrieve in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
@ -243,6 +237,6 @@ def fetch_all_csvs_in_parallel(
|
||||
time_filter=time_filter_for_each_object_type[object_type],
|
||||
target_dir=target_dir,
|
||||
),
|
||||
object_types,
|
||||
all_types_to_filter.keys(),
|
||||
)
|
||||
return dict(results)
|
||||
|
@ -572,28 +572,3 @@ class OnyxSalesforceSQLite:
|
||||
AND json_extract(data, '$.Email') IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# @contextmanager
|
||||
# def get_db_connection(
|
||||
# directory: str,
|
||||
# isolation_level: str | None = None,
|
||||
# ) -> Iterator[sqlite3.Connection]:
|
||||
# """Get a database connection with proper isolation level and error handling.
|
||||
|
||||
# Args:
|
||||
# isolation_level: SQLite isolation level. None = default "DEFERRED",
|
||||
# can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
|
||||
# """
|
||||
# # 60 second timeout for locks
|
||||
# conn = sqlite3.connect(get_sqlite_db_path(directory), timeout=60.0)
|
||||
|
||||
# if isolation_level is not None:
|
||||
# conn.isolation_level = isolation_level
|
||||
# try:
|
||||
# yield conn
|
||||
# except Exception:
|
||||
# conn.rollback()
|
||||
# raise
|
||||
# finally:
|
||||
# conn.close()
|
||||
|
@ -5,4 +5,4 @@ markers =
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::cryptography.utils.CryptographyDeprecationWarning
|
||||
|
||||
ignore::PendingDeprecationWarning:ddtrace.internal.module
|
||||
|
Loading…
x
Reference in New Issue
Block a user