mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-11 20:30:07 +02:00
* refactor salesforce sqlite db access * more refactoring * refactor again * refactor again * rename object * add finalizer to ensure db connection is always closed * avoid unnecessarily nesting connections and commit regularly when possible * remove db usage from csv download * dead code * hide deprecation warning in ddtrace * remove unused param --------- Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
420 lines
16 KiB
Python
420 lines
16 KiB
Python
import gc
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from simple_salesforce import Salesforce
|
|
|
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
|
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
|
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
|
from onyx.connectors.interfaces import LoadConnector
|
|
from onyx.connectors.interfaces import PollConnector
|
|
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
|
from onyx.connectors.interfaces import SlimConnector
|
|
from onyx.connectors.models import ConnectorMissingCredentialError
|
|
from onyx.connectors.models import Document
|
|
from onyx.connectors.models import SlimDocument
|
|
from onyx.connectors.models import TextSection
|
|
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
|
|
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
|
|
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
|
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
|
|
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
|
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
|
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
|
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
|
from onyx.utils.logger import setup_logger
|
|
from shared_configs.configs import MULTI_TENANT
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
|
|
|
|
|
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
|
"""Approach outline
|
|
|
|
Initial sync
|
|
- get the first level children object types of parent object types
|
|
- bulk export all object types to CSV
|
|
-- bulk exports of an object type contain parent id's, but not child id's
|
|
- Load all CSV's to the DB
|
|
- generate all parent object types as documents and yield them
|
|
|
|
Delta sync's
|
|
- query all changed records (includes children and parents)
|
|
- extrapolate all changed parent objects
|
|
- for each parent object, construct a query and yield the result back
|
|
|
|
"""
|
|
|
|
MAX_BATCH_BYTES = 1024 * 1024
|
|
|
|
def __init__(
|
|
self,
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
requested_objects: list[str] = [],
|
|
) -> None:
|
|
self.batch_size = batch_size
|
|
self._sf_client: Salesforce | None = None
|
|
self.parent_object_list = (
|
|
[obj.capitalize() for obj in requested_objects]
|
|
if requested_objects
|
|
else _DEFAULT_PARENT_OBJECT_TYPES
|
|
)
|
|
|
|
def load_credentials(
|
|
self,
|
|
credentials: dict[str, Any],
|
|
) -> dict[str, Any] | None:
|
|
domain = "test" if credentials.get("is_sandbox") else None
|
|
self._sf_client = Salesforce(
|
|
username=credentials["sf_username"],
|
|
password=credentials["sf_password"],
|
|
security_token=credentials["sf_security_token"],
|
|
domain=domain,
|
|
)
|
|
return None
|
|
|
|
@property
|
|
def sf_client(self) -> Salesforce:
|
|
if self._sf_client is None:
|
|
raise ConnectorMissingCredentialError("Salesforce")
|
|
return self._sf_client
|
|
|
|
@staticmethod
|
|
def reconstruct_object_types(directory: str) -> dict[str, list[str] | None]:
|
|
"""
|
|
Scans the given directory for all CSV files and reconstructs the available object types.
|
|
Assumes filenames are formatted as "ObjectType.filename.csv" or "ObjectType.csv".
|
|
|
|
Args:
|
|
directory (str): The path to the directory containing CSV files.
|
|
|
|
Returns:
|
|
dict[str, list[str]]: A dictionary mapping object types to lists of file paths.
|
|
"""
|
|
object_types = defaultdict(list)
|
|
|
|
for filename in os.listdir(directory):
|
|
if filename.endswith(".csv"):
|
|
parts = filename.split(".", 1) # Split on the first period
|
|
object_type = parts[0] # Take the first part as the object type
|
|
object_types[object_type].append(os.path.join(directory, filename))
|
|
|
|
return dict(object_types)
|
|
|
|
@staticmethod
|
|
def _download_object_csvs(
|
|
all_types_to_filter: dict[str, bool],
|
|
directory: str,
|
|
sf_client: Salesforce,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> None:
|
|
# checkpoint - we've found all object types, now time to fetch the data
|
|
logger.info("Fetching CSVs for all object types")
|
|
|
|
# This takes like 30 minutes first time and <2 minutes for updates
|
|
object_type_to_csv_path = fetch_all_csvs_in_parallel(
|
|
sf_client=sf_client,
|
|
all_types_to_filter=all_types_to_filter,
|
|
start=start,
|
|
end=end,
|
|
target_dir=directory,
|
|
)
|
|
|
|
# print useful information
|
|
num_csvs = 0
|
|
num_bytes = 0
|
|
for object_type, csv_paths in object_type_to_csv_path.items():
|
|
if not csv_paths:
|
|
continue
|
|
|
|
for csv_path in csv_paths:
|
|
if not csv_path:
|
|
continue
|
|
|
|
file_path = Path(csv_path)
|
|
file_size = file_path.stat().st_size
|
|
num_csvs += 1
|
|
num_bytes += file_size
|
|
logger.info(
|
|
f"CSV info: object_type={object_type} path={csv_path} bytes={file_size}"
|
|
)
|
|
|
|
logger.info(f"CSV info total: total_csvs={num_csvs} total_bytes={num_bytes}")
|
|
|
|
@staticmethod
|
|
def _load_csvs_to_db(csv_directory: str, sf_db: OnyxSalesforceSQLite) -> set[str]:
|
|
updated_ids: set[str] = set()
|
|
|
|
object_type_to_csv_path = SalesforceConnector.reconstruct_object_types(
|
|
csv_directory
|
|
)
|
|
|
|
# NOTE(rkuo): this timing note is meaningless without a reference point in terms
|
|
# of number of records, etc
|
|
# This takes like 10 seconds
|
|
|
|
# This is for testing the rest of the functionality if data has
|
|
# already been fetched and put in sqlite
|
|
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
|
|
# for object_type in self.parent_object_list:
|
|
# updated_ids.update(list(find_ids_by_type(object_type)))
|
|
|
|
# This takes 10-70 minutes first time (idk why the range is so big)
|
|
total_types = len(object_type_to_csv_path)
|
|
logger.info(f"Starting to process {total_types} object types")
|
|
|
|
for i, (object_type, csv_paths) in enumerate(
|
|
object_type_to_csv_path.items(), 1
|
|
):
|
|
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
|
|
# If path is None, it means it failed to fetch the csv
|
|
if csv_paths is None:
|
|
continue
|
|
|
|
# Go through each csv path and use it to update the db
|
|
for csv_path in csv_paths:
|
|
logger.debug(
|
|
f"Processing CSV: object_type={object_type} "
|
|
f"csv={csv_path} "
|
|
f"len={Path(csv_path).stat().st_size}"
|
|
)
|
|
new_ids = sf_db.update_from_csv(
|
|
object_type=object_type,
|
|
csv_download_path=csv_path,
|
|
)
|
|
updated_ids.update(new_ids)
|
|
logger.debug(
|
|
f"Added {len(new_ids)} new/updated records for {object_type}"
|
|
)
|
|
|
|
os.remove(csv_path)
|
|
|
|
return updated_ids
|
|
|
|
@staticmethod
|
|
def _get_all_types(parent_types: list[str], sf_client: Salesforce) -> set[str]:
|
|
all_types: set[str] = set(parent_types)
|
|
|
|
# Step 1 - get all object types
|
|
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
|
|
|
|
# This takes like 20 seconds
|
|
for parent_object_type in parent_types:
|
|
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
|
|
logger.debug(
|
|
f"Found {len(child_types)} child types for {parent_object_type}"
|
|
)
|
|
|
|
all_types.update(child_types)
|
|
|
|
# Always want to make sure user is grabbed for permissioning purposes
|
|
all_types.add("User")
|
|
|
|
logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
|
|
|
# gc.collect()
|
|
return all_types
|
|
|
|
def _fetch_from_salesforce(
|
|
self,
|
|
temp_dir: str,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> GenerateDocumentsOutput:
|
|
logger.info("_fetch_from_salesforce starting.")
|
|
if not self._sf_client:
|
|
raise RuntimeError("self._sf_client is None!")
|
|
|
|
updated_ids: set[str] = set()
|
|
docs_processed = 0
|
|
|
|
sf_db = OnyxSalesforceSQLite(os.path.join(temp_dir, "salesforce_db.sqlite"))
|
|
sf_db.connect()
|
|
|
|
try:
|
|
sf_db.apply_schema()
|
|
sf_db.log_stats()
|
|
|
|
# Step 1.1 - add child object types + "User" type to the list of types
|
|
all_types = SalesforceConnector._get_all_types(
|
|
self.parent_object_list, self._sf_client
|
|
)
|
|
|
|
"""Only add time filter if there is at least one object of the type
|
|
in the database. We aren't worried about partially completed object update runs
|
|
because this occurs after we check for existing csvs which covers this case"""
|
|
all_types_to_filter: dict[str, bool] = {}
|
|
for sf_type in all_types:
|
|
if sf_db.has_at_least_one_object_of_type(sf_type):
|
|
all_types_to_filter[sf_type] = True
|
|
else:
|
|
all_types_to_filter[sf_type] = False
|
|
|
|
# Step 1.2 - bulk download the CSV for each object type
|
|
SalesforceConnector._download_object_csvs(
|
|
all_types_to_filter, temp_dir, self._sf_client, start, end
|
|
)
|
|
gc.collect()
|
|
|
|
# Step 2 - load CSV's to sqlite
|
|
updated_ids = SalesforceConnector._load_csvs_to_db(temp_dir, sf_db)
|
|
gc.collect()
|
|
|
|
logger.info(f"Found {len(updated_ids)} total updated records")
|
|
logger.info(
|
|
f"Starting to process parent objects of types: {self.parent_object_list}"
|
|
)
|
|
|
|
# Step 3 - extract and index docs
|
|
batches_processed = 0
|
|
docs_to_yield: list[Document] = []
|
|
docs_to_yield_bytes = 0
|
|
|
|
# Takes 15-20 seconds per batch
|
|
for parent_type, parent_id_batch in sf_db.get_affected_parent_ids_by_type(
|
|
updated_ids=list(updated_ids),
|
|
parent_types=self.parent_object_list,
|
|
):
|
|
batches_processed += 1
|
|
logger.info(
|
|
f"Processing batch: index={batches_processed} "
|
|
f"object_type={parent_type} "
|
|
f"len={len(parent_id_batch)} "
|
|
f"processed={docs_processed} "
|
|
f"remaining={len(updated_ids) - docs_processed}"
|
|
)
|
|
for parent_id in parent_id_batch:
|
|
parent_object = sf_db.get_record(parent_id, parent_type)
|
|
if not parent_object:
|
|
logger.warning(
|
|
f"Failed to get parent object {parent_id} for {parent_type}"
|
|
)
|
|
continue
|
|
|
|
doc = convert_sf_object_to_doc(
|
|
sf_db,
|
|
sf_object=parent_object,
|
|
sf_instance=self.sf_client.sf_instance,
|
|
)
|
|
doc_sizeof = sys.getsizeof(doc)
|
|
docs_to_yield_bytes += doc_sizeof
|
|
docs_to_yield.append(doc)
|
|
docs_processed += 1
|
|
|
|
# memory usage is sensitive to the input length, so we're yielding immediately
|
|
# if the batch exceeds a certain byte length
|
|
if (
|
|
len(docs_to_yield) >= self.batch_size
|
|
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
|
):
|
|
yield docs_to_yield
|
|
docs_to_yield = []
|
|
docs_to_yield_bytes = 0
|
|
|
|
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
|
gc.collect()
|
|
|
|
yield docs_to_yield
|
|
|
|
finally:
|
|
logger.info(
|
|
f"Final processing stats: "
|
|
f"processed={docs_processed} "
|
|
f"remaining={len(updated_ids) - docs_processed}"
|
|
)
|
|
|
|
sf_db.close()
|
|
|
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
|
if MULTI_TENANT:
|
|
# if multi tenant, we cannot expect the sqlite db to be cached/present
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
return self._fetch_from_salesforce(temp_dir)
|
|
|
|
# nuke the db since we're starting from scratch
|
|
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
|
if os.path.exists(sqlite_db_path):
|
|
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
|
|
os.remove(sqlite_db_path)
|
|
return self._fetch_from_salesforce(BASE_DATA_PATH)
|
|
|
|
def poll_source(
|
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
) -> GenerateDocumentsOutput:
|
|
if MULTI_TENANT:
|
|
# if multi tenant, we cannot expect the sqlite db to be cached/present
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
return self._fetch_from_salesforce(temp_dir, start=start, end=end)
|
|
|
|
if start == 0:
|
|
# nuke the db if we're starting from scratch
|
|
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
|
if os.path.exists(sqlite_db_path):
|
|
logger.info(
|
|
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
|
|
)
|
|
os.remove(sqlite_db_path)
|
|
|
|
return self._fetch_from_salesforce(BASE_DATA_PATH)
|
|
|
|
def retrieve_all_slim_documents(
|
|
self,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
callback: IndexingHeartbeatInterface | None = None,
|
|
) -> GenerateSlimDocumentOutput:
|
|
doc_metadata_list: list[SlimDocument] = []
|
|
for parent_object_type in self.parent_object_list:
|
|
query = f"SELECT Id FROM {parent_object_type}"
|
|
query_result = self.sf_client.query_all(query)
|
|
doc_metadata_list.extend(
|
|
SlimDocument(
|
|
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
|
|
perm_sync_data={},
|
|
)
|
|
for instance_dict in query_result["records"]
|
|
)
|
|
|
|
yield doc_metadata_list
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import time
|
|
|
|
connector = SalesforceConnector(requested_objects=["Account"])
|
|
|
|
connector.load_credentials(
|
|
{
|
|
"sf_username": os.environ["SF_USERNAME"],
|
|
"sf_password": os.environ["SF_PASSWORD"],
|
|
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
|
|
}
|
|
)
|
|
start_time = time.monotonic()
|
|
doc_count = 0
|
|
section_count = 0
|
|
text_count = 0
|
|
for doc_batch in connector.load_from_state():
|
|
doc_count += len(doc_batch)
|
|
print(f"doc_count: {doc_count}")
|
|
for doc in doc_batch:
|
|
section_count += len(doc.sections)
|
|
for section in doc.sections:
|
|
if isinstance(section, TextSection) and section.text is not None:
|
|
text_count += len(section.text)
|
|
end_time = time.monotonic()
|
|
|
|
print(f"Doc count: {doc_count}")
|
|
print(f"Section count: {section_count}")
|
|
print(f"Text count: {text_count}")
|
|
print(f"Time taken: {end_time - start_time}")
|