mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-11 12:20:24 +02:00
Feature/salesforce correctness 2 (#4506)
* refactor salesforce sqlite db access * more refactoring * refactor again * refactor again * rename object * add finalizer to ensure db connection is always closed * avoid unnecessarily nesting connections and commit regularly when possible * remove db usage from csv download * dead code * hide deprecation warning in ddtrace * remove unused param --------- Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
parent
c93cebe1ab
commit
c83ee06062
@ -111,43 +111,19 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _download_object_csvs(
|
def _download_object_csvs(
|
||||||
sf_db: OnyxSalesforceSQLite,
|
all_types_to_filter: dict[str, bool],
|
||||||
directory: str,
|
directory: str,
|
||||||
parent_object_list: list[str],
|
|
||||||
sf_client: Salesforce,
|
sf_client: Salesforce,
|
||||||
start: SecondsSinceUnixEpoch | None = None,
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
end: SecondsSinceUnixEpoch | None = None,
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
all_types: set[str] = set(parent_object_list)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Parent object types: num={len(parent_object_list)} list={parent_object_list}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# This takes like 20 seconds
|
|
||||||
for parent_object_type in parent_object_list:
|
|
||||||
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
|
|
||||||
logger.debug(
|
|
||||||
f"Found {len(child_types)} child types for {parent_object_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
all_types.update(child_types)
|
|
||||||
|
|
||||||
# Always want to make sure user is grabbed for permissioning purposes
|
|
||||||
all_types.add("User")
|
|
||||||
|
|
||||||
logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
|
||||||
|
|
||||||
# gc.collect()
|
|
||||||
|
|
||||||
# checkpoint - we've found all object types, now time to fetch the data
|
# checkpoint - we've found all object types, now time to fetch the data
|
||||||
logger.info("Fetching CSVs for all object types")
|
logger.info("Fetching CSVs for all object types")
|
||||||
|
|
||||||
# This takes like 30 minutes first time and <2 minutes for updates
|
# This takes like 30 minutes first time and <2 minutes for updates
|
||||||
object_type_to_csv_path = fetch_all_csvs_in_parallel(
|
object_type_to_csv_path = fetch_all_csvs_in_parallel(
|
||||||
sf_db=sf_db,
|
|
||||||
sf_client=sf_client,
|
sf_client=sf_client,
|
||||||
object_types=all_types,
|
all_types_to_filter=all_types_to_filter,
|
||||||
start=start,
|
start=start,
|
||||||
end=end,
|
end=end,
|
||||||
target_dir=directory,
|
target_dir=directory,
|
||||||
@ -224,6 +200,30 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
|||||||
|
|
||||||
return updated_ids
|
return updated_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_all_types(parent_types: list[str], sf_client: Salesforce) -> set[str]:
|
||||||
|
all_types: set[str] = set(parent_types)
|
||||||
|
|
||||||
|
# Step 1 - get all object types
|
||||||
|
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
|
||||||
|
|
||||||
|
# This takes like 20 seconds
|
||||||
|
for parent_object_type in parent_types:
|
||||||
|
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(child_types)} child types for {parent_object_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_types.update(child_types)
|
||||||
|
|
||||||
|
# Always want to make sure user is grabbed for permissioning purposes
|
||||||
|
all_types.add("User")
|
||||||
|
|
||||||
|
logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||||
|
|
||||||
|
# gc.collect()
|
||||||
|
return all_types
|
||||||
|
|
||||||
def _fetch_from_salesforce(
|
def _fetch_from_salesforce(
|
||||||
self,
|
self,
|
||||||
temp_dir: str,
|
temp_dir: str,
|
||||||
@ -244,9 +244,24 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
|||||||
sf_db.apply_schema()
|
sf_db.apply_schema()
|
||||||
sf_db.log_stats()
|
sf_db.log_stats()
|
||||||
|
|
||||||
# Step 1 - download
|
# Step 1.1 - add child object types + "User" type to the list of types
|
||||||
|
all_types = SalesforceConnector._get_all_types(
|
||||||
|
self.parent_object_list, self._sf_client
|
||||||
|
)
|
||||||
|
|
||||||
|
"""Only add time filter if there is at least one object of the type
|
||||||
|
in the database. We aren't worried about partially completed object update runs
|
||||||
|
because this occurs after we check for existing csvs which covers this case"""
|
||||||
|
all_types_to_filter: dict[str, bool] = {}
|
||||||
|
for sf_type in all_types:
|
||||||
|
if sf_db.has_at_least_one_object_of_type(sf_type):
|
||||||
|
all_types_to_filter[sf_type] = True
|
||||||
|
else:
|
||||||
|
all_types_to_filter[sf_type] = False
|
||||||
|
|
||||||
|
# Step 1.2 - bulk download the CSV for each object type
|
||||||
SalesforceConnector._download_object_csvs(
|
SalesforceConnector._download_object_csvs(
|
||||||
sf_db, temp_dir, self.parent_object_list, self._sf_client, start, end
|
all_types_to_filter, temp_dir, self._sf_client, start, end
|
||||||
)
|
)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
@ -10,7 +10,6 @@ from simple_salesforce.bulk2 import SFBulk2Handler
|
|||||||
from simple_salesforce.bulk2 import SFBulk2Type
|
from simple_salesforce.bulk2 import SFBulk2Type
|
||||||
|
|
||||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -184,9 +183,8 @@ def _bulk_retrieve_from_salesforce(
|
|||||||
|
|
||||||
|
|
||||||
def fetch_all_csvs_in_parallel(
|
def fetch_all_csvs_in_parallel(
|
||||||
sf_db: OnyxSalesforceSQLite,
|
|
||||||
sf_client: Salesforce,
|
sf_client: Salesforce,
|
||||||
object_types: set[str],
|
all_types_to_filter: dict[str, bool],
|
||||||
start: SecondsSinceUnixEpoch | None,
|
start: SecondsSinceUnixEpoch | None,
|
||||||
end: SecondsSinceUnixEpoch | None,
|
end: SecondsSinceUnixEpoch | None,
|
||||||
target_dir: str,
|
target_dir: str,
|
||||||
@ -219,20 +217,16 @@ def fetch_all_csvs_in_parallel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
time_filter_for_each_object_type = {}
|
time_filter_for_each_object_type = {}
|
||||||
# We do this outside of the thread pool executor because this requires
|
|
||||||
# a database connection and we don't want to block the thread pool
|
for sf_type, apply_filter in all_types_to_filter.items():
|
||||||
# executor from running
|
if not apply_filter:
|
||||||
for sf_type in object_types:
|
|
||||||
"""Only add time filter if there is at least one object of the type
|
|
||||||
in the database. We aren't worried about partially completed object update runs
|
|
||||||
because this occurs after we check for existing csvs which covers this case"""
|
|
||||||
if sf_db.has_at_least_one_object_of_type(sf_type):
|
|
||||||
if sf_type in created_date_types:
|
|
||||||
time_filter_for_each_object_type[sf_type] = created_date_time_filter
|
|
||||||
else:
|
|
||||||
time_filter_for_each_object_type[sf_type] = last_modified_time_filter
|
|
||||||
else:
|
|
||||||
time_filter_for_each_object_type[sf_type] = ""
|
time_filter_for_each_object_type[sf_type] = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
if sf_type in created_date_types:
|
||||||
|
time_filter_for_each_object_type[sf_type] = created_date_time_filter
|
||||||
|
else:
|
||||||
|
time_filter_for_each_object_type[sf_type] = last_modified_time_filter
|
||||||
|
|
||||||
# Run the bulk retrieve in parallel
|
# Run the bulk retrieve in parallel
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
@ -243,6 +237,6 @@ def fetch_all_csvs_in_parallel(
|
|||||||
time_filter=time_filter_for_each_object_type[object_type],
|
time_filter=time_filter_for_each_object_type[object_type],
|
||||||
target_dir=target_dir,
|
target_dir=target_dir,
|
||||||
),
|
),
|
||||||
object_types,
|
all_types_to_filter.keys(),
|
||||||
)
|
)
|
||||||
return dict(results)
|
return dict(results)
|
||||||
|
@ -572,28 +572,3 @@ class OnyxSalesforceSQLite:
|
|||||||
AND json_extract(data, '$.Email') IS NOT NULL
|
AND json_extract(data, '$.Email') IS NOT NULL
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# @contextmanager
|
|
||||||
# def get_db_connection(
|
|
||||||
# directory: str,
|
|
||||||
# isolation_level: str | None = None,
|
|
||||||
# ) -> Iterator[sqlite3.Connection]:
|
|
||||||
# """Get a database connection with proper isolation level and error handling.
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# isolation_level: SQLite isolation level. None = default "DEFERRED",
|
|
||||||
# can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
|
|
||||||
# """
|
|
||||||
# # 60 second timeout for locks
|
|
||||||
# conn = sqlite3.connect(get_sqlite_db_path(directory), timeout=60.0)
|
|
||||||
|
|
||||||
# if isolation_level is not None:
|
|
||||||
# conn.isolation_level = isolation_level
|
|
||||||
# try:
|
|
||||||
# yield conn
|
|
||||||
# except Exception:
|
|
||||||
# conn.rollback()
|
|
||||||
# raise
|
|
||||||
# finally:
|
|
||||||
# conn.close()
|
|
||||||
|
@ -5,4 +5,4 @@ markers =
|
|||||||
filterwarnings =
|
filterwarnings =
|
||||||
ignore::DeprecationWarning
|
ignore::DeprecationWarning
|
||||||
ignore::cryptography.utils.CryptographyDeprecationWarning
|
ignore::cryptography.utils.CryptographyDeprecationWarning
|
||||||
|
ignore::PendingDeprecationWarning:ddtrace.internal.module
|
||||||
|
Loading…
x
Reference in New Issue
Block a user