Reworked salesforce connector to use bulk api (#3581)

This commit is contained in:
hagen-danswer 2025-01-02 18:09:02 -08:00 committed by GitHub
parent 3b214133a8
commit d1ec72b5e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 2545 additions and 202 deletions

1
backend/.gitignore vendored
View File

@ -9,3 +9,4 @@ api_keys.py
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule*
onyx/connectors/salesforce/data/

View File

@ -1,11 +1,7 @@
import os
from collections.abc import Iterator
from datetime import datetime
from datetime import UTC
from typing import Any
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@ -20,36 +16,24 @@ from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.salesforce.doc_conversion import extract_sections
from onyx.connectors.salesforce.doc_conversion import extract_section
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
from shared_configs.utils import batch_list
logger = setup_logger()
# max query length is 20,000 characters, leave 5000 characters for slop
_MAX_QUERY_LENGTH = 10000
# There are 22 extra characters per ID so 200 * 22 = 4400 characters which is
# still well under the max query length
_MAX_ID_BATCH_SIZE = 200
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_ID_PREFIX = "SALESFORCE_"
def _build_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
if start is None or end is None:
return ""
start_datetime = datetime.fromtimestamp(start, UTC)
end_datetime = datetime.fromtimestamp(end, UTC)
return (
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
f"AND LastModifiedDate < {end_datetime.isoformat()}"
)
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
@ -64,7 +48,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
else _DEFAULT_PARENT_OBJECT_TYPES
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
def load_credentials(
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
self._sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
@ -78,203 +65,146 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
raise ConnectorMissingCredentialError("Salesforce")
return self._sf_client
def _get_sf_type_object_json(self, type_name: str) -> Any:
sf_object = SFType(
type_name, self.sf_client.session_id, self.sf_client.sf_instance
)
return sf_object.describe()
def _get_name_from_id(self, id: str) -> str:
try:
user_object_info = self.sf_client.query(
f"SELECT Name FROM User WHERE Id = '{id}'"
)
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
return name
except Exception:
logger.warning(f"Couldnt find name for object id: {id}")
return "Null User"
def _extract_primary_owners(
self, sf_object: SalesforceObject
) -> list[BasicExpertInfo] | None:
object_dict = sf_object.data
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
return None
if not (last_modified_by := get_record(last_modified_by_id)):
return None
if not (last_modified_by_name := last_modified_by.data.get("Name")):
return None
primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)]
return primary_owners
def _convert_object_instance_to_document(
self, object_dict: dict[str, Any]
self, sf_object: SalesforceObject
) -> Document:
object_dict = sf_object.data
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
base_url = f"https://{self.sf_client.sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
extracted_primary_owners = [
BasicExpertInfo(
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
)
]
sections = [extract_section(sf_object, base_url)]
for id in get_child_ids(sf_object.id):
if not (child_object := get_record(id)):
continue
sections.append(extract_section(child_object, base_url))
doc = Document(
id=onyx_salesforce_id,
sections=extract_sections(object_dict, base_url),
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=extracted_primary_owners,
primary_owners=self._extract_primary_owners(sf_object),
metadata={},
)
return doc
def _is_valid_child_object(self, child_relationship: dict) -> bool:
if not child_relationship["childSObject"]:
return False
if not child_relationship["relationshipName"]:
return False
sf_type = child_relationship["childSObject"]
object_description = self._get_sf_type_object_json(sf_type)
if not object_description["queryable"]:
return False
try:
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
result = self.sf_client.query(query)
if result["totalSize"] == 0:
return False
except Exception as e:
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
return False
if child_relationship["field"]:
if child_relationship["field"] == "RelatedToId":
return False
else:
return False
return True
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
logger.debug(f"Fetching children for SF type: {sf_type}")
object_description = self._get_sf_type_object_json(sf_type)
children_objects: list[dict] = []
for child_relationship in object_description["childRelationships"]:
if self._is_valid_child_object(child_relationship):
children_objects.append(
{
"relationship_name": child_relationship["relationshipName"],
"object_type": child_relationship["childSObject"],
}
)
return children_objects
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
object_description = self._get_sf_type_object_json(sf_type)
fields = [
field.get("name")
for field in object_description["fields"]
if field.get("type", "base64") != "base64"
]
return fields
def _get_parent_object_ids(
self, parent_sf_type: str, time_filter_query: str
) -> list[str]:
"""Fetch all IDs for a given parent object type."""
logger.debug(f"Fetching IDs for parent type: {parent_sf_type}")
query = f"SELECT Id FROM {parent_sf_type}{time_filter_query}"
query_result = self.sf_client.query_all(query)
ids = [record["Id"] for record in query_result["records"]]
logger.debug(f"Found {len(ids)} IDs for parent type: {parent_sf_type}")
return ids
def _process_id_batch(
self,
id_batch: list[str],
queries: list[str],
) -> dict[str, dict[str, Any]]:
"""Process a batch of IDs using the given queries."""
# Initialize results dictionary for this batch
logger.debug(f"Processing batch of {len(id_batch)} IDs")
query_results: dict[str, dict[str, Any]] = {}
# For each query, fetch and combine results for the batch
for query in queries:
id_filter = f" WHERE Id IN {tuple(id_batch)}"
batch_query = query + id_filter
logger.debug(f"Executing query with length: {len(batch_query)}")
query_result = self.sf_client.query_all(batch_query)
logger.debug(f"Retrieved {len(query_result['records'])} records for query")
for record_dict in query_result["records"]:
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
# Convert results to documents
return query_results
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
"""
parent_sf_type is a string that represents the Salesforce object type.
This function generates queries that will fetch:
- all the fields of the parent object type
- all the fields of the child objects of the parent object type
"""
logger.debug(f"Generating queries for parent type: {parent_sf_type}")
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
logger.debug(f"Found {len(parent_fields)} fields for parent type")
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
logger.debug(f"Found {len(child_sf_types)} child types")
query = f"SELECT {', '.join(parent_fields)}"
for child_object_dict in child_sf_types:
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
if len(query_addition) + len(query) > _MAX_QUERY_LENGTH:
query += f"\n FROM {parent_sf_type}"
yield query
query = "SELECT Id" + query_addition
else:
query += query_addition
query += f"\n FROM {parent_sf_type}"
yield query
def _batch_retrieval(
self,
id_batches: list[list[str]],
queries: list[str],
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
# For each batch of IDs, perform all queries and convert to documents
# so they can be yielded in batches
for id_batch in id_batches:
query_results = self._process_id_batch(id_batch, queries)
for doc in query_results.values():
doc_batch.append(self._convert_object_instance_to_document(doc))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
def _fetch_from_salesforce(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
logger.debug(f"Starting Salesforce fetch from {start} to {end}")
time_filter_query = _build_time_filter_for_salesforce(start, end)
init_db()
all_object_types: set[str] = set(self.parent_object_list)
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
logger.debug(f"Parent object types: {self.parent_object_list}")
# This takes like 20 seconds
for parent_object_type in self.parent_object_list:
logger.info(f"Processing parent object type: {parent_object_type}")
child_types = get_all_children_of_sf_type(
self.sf_client, parent_object_type
)
all_object_types.update(child_types)
logger.debug(
f"Found {len(child_types)} child types for {parent_object_type}"
)
all_ids = self._get_parent_object_ids(parent_object_type, time_filter_query)
logger.info(f"Found {len(all_ids)} IDs for {parent_object_type}")
id_batches = batch_list(all_ids, _MAX_ID_BATCH_SIZE)
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
logger.debug(f"All object types: {all_object_types}")
# Generate all queries we'll need
queries = list(self._generate_query_per_parent_type(parent_object_type))
logger.info(f"Generated {len(queries)} queries for {parent_object_type}")
yield from self._batch_retrieval(id_batches, queries)
# checkpoint - we've found all object types, now time to fetch the data
logger.info("Starting to fetch CSVs for all object types")
# This takes like 30 minutes first time and <2 minutes for updates
object_type_to_csv_path = fetch_all_csvs_in_parallel(
sf_client=self.sf_client,
object_types=all_object_types,
start=start,
end=end,
)
updated_ids: set[str] = set()
# This takes like 10 seconds
# This is for testing the rest of the functionality if data has
# already been fetched and put in sqlite
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
# for object_type in self.parent_object_list:
# updated_ids.update(list(find_ids_by_type(object_type)))
# This takes 10-70 minutes first time (idk why the range is so big)
total_types = len(object_type_to_csv_path)
logger.info(f"Starting to process {total_types} object types")
for i, (object_type, csv_paths) in enumerate(
object_type_to_csv_path.items(), 1
):
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
# If path is None, it means it failed to fetch the csv
if csv_paths is None:
continue
# Go through each csv path and use it to update the db
for csv_path in csv_paths:
logger.debug(f"Updating {object_type} with {csv_path}")
new_ids = update_sf_db_with_csv(
object_type=object_type,
csv_download_path=csv_path,
)
updated_ids.update(new_ids)
logger.debug(
f"Added {len(new_ids)} new/updated records for {object_type}"
)
# Remove the csv file after it has been used
# to successfully update the db
os.remove(csv_path)
logger.info(f"Found {len(updated_ids)} total updated records")
logger.info(
f"Starting to process parent objects of types: {self.parent_object_list}"
)
docs_to_yield: list[Document] = []
docs_processed = 0
# Takes 15-20 seconds per batch
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
updated_ids=list(updated_ids),
parent_types=self.parent_object_list,
):
logger.info(
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
)
for parent_id in parent_id_batch:
if not (parent_object := get_record(parent_id, parent_type)):
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
docs_to_yield.append(
self._convert_object_instance_to_document(parent_object)
)
docs_processed += 1
if len(docs_to_yield) >= self.batch_size:
yield docs_to_yield
docs_to_yield = []
yield docs_to_yield
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_salesforce()
@ -305,9 +235,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
if __name__ == "__main__":
connector = SalesforceConnector(
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
)
import time
connector = SalesforceConnector(requested_objects=["Account"])
connector.load_credentials(
{
@ -316,5 +246,20 @@ if __name__ == "__main__":
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))
start_time = time.time()
doc_count = 0
section_count = 0
text_count = 0
for doc_batch in connector.load_from_state():
doc_count += len(doc_batch)
print(f"doc_count: {doc_count}")
for doc in doc_batch:
section_count += len(doc.sections)
for section in doc.sections:
text_count += len(section.text)
end_time = time.time()
print(f"Doc count: {doc_count}")
print(f"Section count: {section_count}")
print(f"Text count: {text_count}")
print(f"Time taken: {end_time - start_time}")

View File

@ -2,6 +2,7 @@ import re
from collections import OrderedDict
from onyx.connectors.models import Section
from onyx.connectors.salesforce.utils import SalesforceObject
# All of these types of keys are handled by specific fields in the doc
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
@ -102,6 +103,13 @@ def _extract_dict_text(raw_dict: dict) -> str:
return natural_language_for_dict
def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
return Section(
text=_extract_dict_text(salesforce_object.data),
link=f"{base_url}/{salesforce_object.id}",
)
def _field_value_is_child_object(field_value: dict) -> bool:
"""
Checks if the field value is a child object.
@ -115,7 +123,7 @@ def _field_value_is_child_object(field_value: dict) -> bool:
)
def extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
"""
This goes through the salesforce_object and extracts the top level fields as a Section.
It also goes through the child objects and extracts them as Sections.

View File

@ -0,0 +1,210 @@
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any
from pytz import UTC
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from simple_salesforce.bulk2 import SFBulk2Handler
from simple_salesforce.bulk2 import SFBulk2Type
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
from onyx.connectors.salesforce.utils import get_object_type_path
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _build_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
if start is None or end is None:
return ""
start_datetime = datetime.fromtimestamp(start, UTC)
end_datetime = datetime.fromtimestamp(end, UTC)
return (
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
f"AND LastModifiedDate < {end_datetime.isoformat()}"
)
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
return sf_object.describe()
def _is_valid_child_object(
sf_client: Salesforce, child_relationship: dict[str, Any]
) -> bool:
if not child_relationship["childSObject"]:
return False
if not child_relationship["relationshipName"]:
return False
sf_type = child_relationship["childSObject"]
object_description = _get_sf_type_object_json(sf_client, sf_type)
if not object_description["queryable"]:
return False
if child_relationship["field"]:
if child_relationship["field"] == "RelatedToId":
return False
else:
return False
return True
def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]:
object_description = _get_sf_type_object_json(sf_client, sf_type)
child_object_types = set()
for child_relationship in object_description["childRelationships"]:
if _is_valid_child_object(sf_client, child_relationship):
logger.debug(
f"Found valid child object {child_relationship['childSObject']}"
)
child_object_types.add(child_relationship["childSObject"])
return child_object_types
def _get_all_queryable_fields_of_sf_type(
sf_client: Salesforce,
sf_type: str,
) -> list[str]:
object_description = _get_sf_type_object_json(sf_client, sf_type)
fields: list[dict[str, Any]] = object_description["fields"]
valid_fields: set[str] = set()
compound_field_names: set[str] = set()
for field in fields:
if compound_field_name := field.get("compoundFieldName"):
compound_field_names.add(compound_field_name)
if field.get("type", "base64") == "base64":
continue
if field_name := field.get("name"):
valid_fields.add(field_name)
return list(valid_fields - compound_field_names)
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
"""
Send a small query to check if the object type is empty so we don't
perform extra bulk queries
"""
try:
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
result = sf_client.query(query)
if result["totalSize"] == 0:
return False
except Exception as e:
if "OPERATION_TOO_LARGE" not in str(e):
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
return False
return True
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
# Check if the csv already exists
if os.path.exists(get_object_type_path(sf_type)):
existing_csvs = [
os.path.join(get_object_type_path(sf_type), f)
for f in os.listdir(get_object_type_path(sf_type))
if f.endswith(".csv")
]
# If the csv already exists, return the path
# This is likely due to a previous run that failed
# after downloading the csv but before the data was
# written to the db
if existing_csvs:
return existing_csvs
return None
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
return query
def _bulk_retrieve_from_salesforce(
sf_client: Salesforce,
sf_type: str,
time_filter: str,
) -> tuple[str, list[str] | None]:
if not _check_if_object_type_is_empty(sf_client, sf_type):
return sf_type, None
if existing_csvs := _check_for_existing_csvs(sf_type):
return sf_type, existing_csvs
query = _build_bulk_query(sf_client, sf_type, time_filter)
bulk_2_handler = SFBulk2Handler(
session_id=sf_client.session_id,
bulk2_url=sf_client.bulk2_url,
proxies=sf_client.proxies,
session=sf_client.session,
)
bulk_2_type = SFBulk2Type(
object_name=sf_type,
bulk2_url=bulk_2_handler.bulk2_url,
headers=bulk_2_handler.headers,
session=bulk_2_handler.session,
)
logger.info(f"Downloading {sf_type}")
logger.info(f"Query: {query}")
try:
# This downloads the file to a file in the target path with a random name
results = bulk_2_type.download(
query=query,
path=get_object_type_path(sf_type),
max_records=1000000,
)
all_download_paths = [result["file"] for result in results]
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
return sf_type, all_download_paths
except Exception as e:
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
return sf_type, None
def fetch_all_csvs_in_parallel(
sf_client: Salesforce,
object_types: set[str],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
) -> dict[str, list[str] | None]:
"""
Fetches all the csvs in parallel for the given object types
Returns a dict of (sf_type, full_download_path)
"""
time_filter = _build_time_filter_for_salesforce(start, end)
time_filter_for_each_object_type = {}
# We do this outside of the thread pool executor because this requires
# a database connection and we don't want to block the thread pool
# executor from running
for sf_type in object_types:
"""Only add time filter if there is at least one object of the type
in the database. We aren't worried about partially completed object update runs
because this occurs after we check for existing csvs which covers this case"""
if has_at_least_one_object_of_type(sf_type):
time_filter_for_each_object_type[sf_type] = time_filter
else:
time_filter_for_each_object_type[sf_type] = ""
# Run the bulk retrieve in parallel
with ThreadPoolExecutor() as executor:
results = executor.map(
lambda object_type: _bulk_retrieve_from_salesforce(
sf_client=sf_client,
sf_type=object_type,
time_filter=time_filter_for_each_object_type[object_type],
),
object_types,
)
return dict(results)

View File

@ -0,0 +1,209 @@
import csv
import shelve
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
get_child_to_parent_shelf_path,
)
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
get_parent_to_child_shelf_path,
)
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _update_relationship_shelves(
child_id: str,
parent_ids: set[str],
) -> None:
"""Update the relationship shelf when a record is updated."""
try:
# Convert child_id to string once
str_child_id = str(child_id)
# First update child to parent mapping
with shelve.open(
get_child_to_parent_shelf_path(),
flag="c",
protocol=None,
writeback=True,
) as child_to_parent_db:
old_parent_ids = set(child_to_parent_db.get(str_child_id, []))
child_to_parent_db[str_child_id] = list(parent_ids)
# Calculate differences outside the next context manager
parent_ids_to_remove = old_parent_ids - parent_ids
parent_ids_to_add = parent_ids - old_parent_ids
# Only sync once at the end
child_to_parent_db.sync()
# Then update parent to child mapping in a single transaction
if not parent_ids_to_remove and not parent_ids_to_add:
return
with shelve.open(
get_parent_to_child_shelf_path(),
flag="c",
protocol=None,
writeback=True,
) as parent_to_child_db:
# Process all removals first
for parent_id in parent_ids_to_remove:
str_parent_id = str(parent_id)
existing_children = set(parent_to_child_db.get(str_parent_id, []))
if str_child_id in existing_children:
existing_children.remove(str_child_id)
parent_to_child_db[str_parent_id] = list(existing_children)
# Then process all additions
for parent_id in parent_ids_to_add:
str_parent_id = str(parent_id)
existing_children = set(parent_to_child_db.get(str_parent_id, []))
existing_children.add(str_child_id)
parent_to_child_db[str_parent_id] = list(existing_children)
# Single sync at the end
parent_to_child_db.sync()
except Exception as e:
logger.error(f"Error updating relationship shelves: {e}")
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
raise
def get_child_ids(parent_id: str) -> set[str]:
"""Get all child IDs for a given parent ID.
Args:
parent_id: The ID of the parent object
Returns:
A set of child object IDs
"""
with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db:
return set(parent_to_child_db.get(parent_id, []))
def update_sf_db_with_csv(
object_type: str,
csv_download_path: str,
) -> list[str]:
"""Update the SF DB with a CSV file using shelve storage."""
updated_ids = []
shelf_path = get_object_shelf_path(object_type)
# First read the CSV to get all the data
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
id = row["Id"]
parent_ids = set()
field_to_remove: set[str] = set()
# Update relationship shelves for any parent references
for field, value in row.items():
if validate_salesforce_id(value) and field != "Id":
parent_ids.add(value)
field_to_remove.add(field)
if not value:
field_to_remove.add(field)
_update_relationship_shelves(id, parent_ids)
for field in field_to_remove:
# We use this to extract the Primary Owner later
if field != "LastModifiedById":
del row[field]
# Update the main object shelf
with shelve.open(shelf_path) as object_type_db:
object_type_db[id] = row
# Update the ID-to-type mapping shelf
with shelve.open(get_id_type_shelf_path()) as id_type_db:
id_type_db[id] = object_type
updated_ids.append(id)
# os.remove(csv_download_path)
return updated_ids
def get_type_from_id(object_id: str) -> str | None:
"""Get the type of an object from its ID."""
# Look up the object type from the ID-to-type mapping
with shelve.open(get_id_type_shelf_path()) as id_type_db:
if object_id not in id_type_db:
logger.warning(f"Object ID {object_id} not found in ID-to-type mapping")
return None
return id_type_db[object_id]
def get_record(
object_id: str, object_type: str | None = None
) -> SalesforceObject | None:
"""
Retrieve the record and return it as a SalesforceObject.
The object type will be looked up from the ID-to-type mapping shelf.
"""
if object_type is None:
if not (object_type := get_type_from_id(object_id)):
return None
shelf_path = get_object_shelf_path(object_type)
with shelve.open(shelf_path) as db:
if object_id not in db:
logger.warning(f"Object ID {object_id} not found in {shelf_path}")
return None
data = db[object_id]
return SalesforceObject(
id=object_id,
type=object_type,
data=data,
)
def find_ids_by_type(object_type: str) -> list[str]:
"""
Find all object IDs for rows of the specified type.
"""
shelf_path = get_object_shelf_path(object_type)
try:
with shelve.open(shelf_path) as db:
return list(db.keys())
except FileNotFoundError:
return []
def get_affected_parent_ids_by_type(
updated_ids: set[str], parent_types: list[str]
) -> dict[str, set[str]]:
"""Get IDs of objects that are of the specified parent types and are either in the updated_ids
or have children in the updated_ids.
Args:
updated_ids: List of IDs that were updated
parent_types: List of object types to filter by
Returns:
A dictionary of IDs that match the criteria
"""
affected_ids_by_type: dict[str, set[str]] = {}
# Check each updated ID
for updated_id in updated_ids:
# Add the ID itself if it's of a parent type
updated_type = get_type_from_id(updated_id)
if updated_type in parent_types:
affected_ids_by_type.setdefault(updated_type, set()).add(updated_id)
continue
# Get parents of this ID and add them if they're of a parent type
with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db:
parent_ids = child_to_parent_db.get(updated_id, [])
for parent_id in parent_ids:
parent_type = get_type_from_id(parent_id)
if parent_type in parent_types:
affected_ids_by_type.setdefault(parent_type, set()).add(parent_id)
return affected_ids_by_type

View File

@ -0,0 +1,29 @@
import os
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
def get_object_shelf_path(object_type: str) -> str:
"""Get the path to the shelf file for a specific object type."""
base_path = get_object_type_path(object_type)
os.makedirs(base_path, exist_ok=True)
return os.path.join(base_path, "data.shelf")
def get_id_type_shelf_path() -> str:
"""Get the path to the ID-to-type mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g")
def get_parent_to_child_shelf_path() -> str:
"""Get the path to the parent-to-child mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g")
def get_child_to_parent_shelf_path() -> str:
"""Get the path to the child-to-parent mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g")

View File

@ -0,0 +1,737 @@
import csv
import os
import shutil
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
get_affected_parent_ids_by_type,
)
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
update_sf_db_with_csv,
)
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
_VALID_SALESFORCE_IDS = [
"001bm00000fd9Z3AAI",
"001bm00000fdYTdAAM",
"001bm00000fdYTeAAM",
"001bm00000fdYTfAAM",
"001bm00000fdYTgAAM",
"001bm00000fdYThAAM",
"001bm00000fdYTiAAM",
"001bm00000fdYTjAAM",
"001bm00000fdYTkAAM",
"001bm00000fdYTlAAM",
"001bm00000fdYTmAAM",
"001bm00000fdYTnAAM",
"001bm00000fdYToAAM",
"500bm00000XoOxtAAF",
"500bm00000XoOxuAAF",
"500bm00000XoOxvAAF",
"500bm00000XoOxwAAF",
"500bm00000XoOxxAAF",
"500bm00000XoOxyAAF",
"500bm00000XoOxzAAF",
"500bm00000XoOy0AAF",
"500bm00000XoOy1AAF",
"500bm00000XoOy2AAF",
"500bm00000XoOy3AAF",
"500bm00000XoOy4AAF",
"500bm00000XoOy5AAF",
"500bm00000XoOy6AAF",
"500bm00000XoOy7AAF",
"500bm00000XoOy8AAF",
"500bm00000XoOy9AAF",
"500bm00000XoOyAAAV",
"500bm00000XoOyBAAV",
"500bm00000XoOyCAAV",
"500bm00000XoOyDAAV",
"500bm00000XoOyEAAV",
"500bm00000XoOyFAAV",
"500bm00000XoOyGAAV",
"500bm00000XoOyHAAV",
"500bm00000XoOyIAAV",
"003bm00000EjHCjAAN",
"003bm00000EjHCkAAN",
"003bm00000EjHClAAN",
"003bm00000EjHCmAAN",
"003bm00000EjHCnAAN",
"003bm00000EjHCoAAN",
"003bm00000EjHCpAAN",
"003bm00000EjHCqAAN",
"003bm00000EjHCrAAN",
"003bm00000EjHCsAAN",
"003bm00000EjHCtAAN",
"003bm00000EjHCuAAN",
"003bm00000EjHCvAAN",
"003bm00000EjHCwAAN",
"003bm00000EjHCxAAN",
"003bm00000EjHCyAAN",
"003bm00000EjHCzAAN",
"003bm00000EjHD0AAN",
"003bm00000EjHD1AAN",
"003bm00000EjHD2AAN",
"550bm00000EXc2tAAD",
"006bm000006kyDpAAI",
"006bm000006kyDqAAI",
"006bm000006kyDrAAI",
"006bm000006kyDsAAI",
"006bm000006kyDtAAI",
"006bm000006kyDuAAI",
"006bm000006kyDvAAI",
"006bm000006kyDwAAI",
"006bm000006kyDxAAI",
"006bm000006kyDyAAI",
"006bm000006kyDzAAI",
"006bm000006kyE0AAI",
"006bm000006kyE1AAI",
"006bm000006kyE2AAI",
"006bm000006kyE3AAI",
"006bm000006kyE4AAI",
"006bm000006kyE5AAI",
"006bm000006kyE6AAI",
"006bm000006kyE7AAI",
"006bm000006kyE8AAI",
"006bm000006kyE9AAI",
"006bm000006kyEAAAY",
"006bm000006kyEBAAY",
"006bm000006kyECAAY",
"006bm000006kyEDAAY",
"006bm000006kyEEAAY",
"006bm000006kyEFAAY",
"006bm000006kyEGAAY",
"006bm000006kyEHAAY",
"006bm000006kyEIAAY",
"006bm000006kyEJAAY",
"005bm000009zy0TAAQ",
"005bm000009zy25AAA",
"005bm000009zy26AAA",
"005bm000009zy28AAA",
"005bm000009zy29AAA",
"005bm000009zy2AAAQ",
"005bm000009zy2BAAQ",
]
def clear_sf_db() -> None:
"""
Clears the SF DB by deleting all files in the data directory.
"""
shutil.rmtree(BASE_DATA_PATH)
def create_csv_file(
object_type: str, records: list[dict], filename: str = "test_data.csv"
) -> None:
"""
Creates a CSV file for the given object type and records.
Args:
object_type: The Salesforce object type (e.g. "Account", "Contact")
records: List of dictionaries containing the record data
filename: Name of the CSV file to create (default: test_data.csv)
"""
if not records:
return
# Get all unique fields from records
fields: set[str] = set()
for record in records:
fields.update(record.keys())
fields = set(sorted(list(fields))) # Sort for consistent order
# Create CSV file
csv_path = os.path.join(get_object_type_path(object_type), filename)
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
for record in records:
writer.writerow(record)
# Update the database with the CSV
update_sf_db_with_csv(object_type, csv_path)
def create_csv_with_example_data() -> None:
"""
Creates CSV files with example data, organized by object type.
"""
example_data: dict[str, list[dict]] = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
{
"Id": _VALID_SALESFORCE_IDS[3],
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
{
"Id": _VALID_SALESFORCE_IDS[4],
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
{
"Id": _VALID_SALESFORCE_IDS[5],
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
{
"Id": _VALID_SALESFORCE_IDS[6],
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
{
"Id": _VALID_SALESFORCE_IDS[7],
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"FirstName": "John",
"LastName": "Doe",
"Email": "john.doe@acme.com",
"Title": "CEO",
},
{
"Id": _VALID_SALESFORCE_IDS[41],
"FirstName": "Jane",
"LastName": "Smith",
"Email": "jane.smith@acme.com",
"Title": "CTO",
},
{
"Id": _VALID_SALESFORCE_IDS[42],
"FirstName": "Bob",
"LastName": "Johnson",
"Email": "bob.j@globex.com",
"Title": "Sales Director",
},
{
"Id": _VALID_SALESFORCE_IDS[43],
"FirstName": "Sarah",
"LastName": "Chen",
"Email": "sarah.chen@techcorp.com",
"Title": "Product Manager",
"Phone": "415-555-0101",
},
{
"Id": _VALID_SALESFORCE_IDS[44],
"FirstName": "Michael",
"LastName": "Rodriguez",
"Email": "m.rodriguez@biomed.com",
"Title": "Research Director",
"Phone": "617-555-0202",
},
{
"Id": _VALID_SALESFORCE_IDS[45],
"FirstName": "Emily",
"LastName": "Green",
"Email": "emily.g@greenenergy.com",
"Title": "Sustainability Lead",
"Phone": "503-555-0303",
},
{
"Id": _VALID_SALESFORCE_IDS[46],
"FirstName": "David",
"LastName": "Kim",
"Email": "david.kim@dataflow.com",
"Title": "Data Scientist",
"Phone": "206-555-0404",
},
{
"Id": _VALID_SALESFORCE_IDS[47],
"FirstName": "Rachel",
"LastName": "Taylor",
"Email": "r.taylor@cloudnine.com",
"Title": "Cloud Architect",
"Phone": "303-555-0505",
},
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"Name": "Acme Server Upgrade",
"Amount": 50000,
"Stage": "Prospecting",
"CloseDate": "2024-06-30",
},
{
"Id": _VALID_SALESFORCE_IDS[63],
"Name": "Globex Manufacturing Line",
"Amount": 150000,
"Stage": "Negotiation",
"CloseDate": "2024-03-15",
},
{
"Id": _VALID_SALESFORCE_IDS[64],
"Name": "Initech Software License",
"Amount": 75000,
"Stage": "Closed Won",
"CloseDate": "2024-01-30",
},
{
"Id": _VALID_SALESFORCE_IDS[65],
"Name": "TechCorp AI Implementation",
"Amount": 250000,
"Stage": "Needs Analysis",
"CloseDate": "2024-08-15",
"Probability": 60,
},
{
"Id": _VALID_SALESFORCE_IDS[66],
"Name": "BioMed Lab Equipment",
"Amount": 500000,
"Stage": "Value Proposition",
"CloseDate": "2024-09-30",
"Probability": 75,
},
{
"Id": _VALID_SALESFORCE_IDS[67],
"Name": "Green Energy Solar Project",
"Amount": 750000,
"Stage": "Proposal",
"CloseDate": "2024-07-15",
"Probability": 80,
},
{
"Id": _VALID_SALESFORCE_IDS[68],
"Name": "DataFlow Analytics Platform",
"Amount": 180000,
"Stage": "Negotiation",
"CloseDate": "2024-05-30",
"Probability": 90,
},
{
"Id": _VALID_SALESFORCE_IDS[69],
"Name": "Cloud Nine Infrastructure",
"Amount": 300000,
"Stage": "Qualification",
"CloseDate": "2024-10-15",
"Probability": 40,
},
],
}
# Create CSV files for each object type
for object_type, records in example_data.items():
create_csv_file(object_type, records)
def test_query() -> None:
"""
Tests querying functionality by verifying:
1. All expected Account IDs are found
2. Each Account's data matches what was inserted
"""
# Expected test data for verification
expected_accounts: dict[str, dict[str, str | int]] = {
_VALID_SALESFORCE_IDS[0]: {
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
_VALID_SALESFORCE_IDS[1]: {
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
_VALID_SALESFORCE_IDS[2]: {
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
_VALID_SALESFORCE_IDS[3]: {
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
_VALID_SALESFORCE_IDS[4]: {
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
_VALID_SALESFORCE_IDS[5]: {
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
_VALID_SALESFORCE_IDS[6]: {
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
_VALID_SALESFORCE_IDS[7]: {
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
}
# Get all Account IDs
account_ids = find_ids_by_type("Account")
# Verify we found all expected accounts
assert len(account_ids) == len(
expected_accounts
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
assert set(account_ids) == set(
expected_accounts.keys()
), "Found account IDs don't match expected IDs"
# Verify each account's data
for acc_id in account_ids:
combined = get_record(acc_id)
assert combined is not None, f"Could not find account {acc_id}"
expected = expected_accounts[acc_id]
# Verify account data matches
for key, value in expected.items():
value = str(value)
assert (
combined.data[key] == value
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
print("All query tests passed successfully!")
def test_upsert() -> None:
"""
Tests upsert functionality by:
1. Updating an existing account
2. Creating a new account
3. Verifying both operations were successful
"""
# Create CSV for updating an existing account and adding a new one
update_data: list[dict[str, str | int]] = [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc. Updated",
"BillingCity": "New York",
"Industry": "Technology",
"Description": "Updated company info",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "New Company Inc.",
"BillingCity": "Miami",
"Industry": "Finance",
"AnnualRevenue": 1000000,
},
]
create_csv_file("Account", update_data, "update_data.csv")
# Verify the update worked
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
assert updated_record is not None, "Updated record not found"
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
assert (
updated_record.data["Description"] == "Updated company info"
), "Description not added"
# Verify the new record was created
new_record = get_record(_VALID_SALESFORCE_IDS[2])
assert new_record is not None, "New record not found"
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
print("All upsert tests passed successfully!")
def test_relationships() -> None:
"""
Tests relationship shelf updates and queries by:
1. Creating test data with relationships
2. Verifying the relationships are correctly stored
3. Testing relationship queries
"""
# Create test data for each object type
test_data: dict[str, list[dict[str, str | int]]] = {
"Case": [
{
"Id": _VALID_SALESFORCE_IDS[13],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 1",
},
{
"Id": _VALID_SALESFORCE_IDS[14],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 2",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[48],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Name": "Test Opportunity",
"Amount": 100000,
}
],
}
# Create and update CSV files for each object type
for object_type, records in test_data.items():
create_csv_file(object_type, records, "relationship_test.csv")
# Test relationship queries
# All these objects should be children of Acme Inc.
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
assert (
_VALID_SALESFORCE_IDS[62] in child_ids
), "Opportunity not found in relationship"
# Test querying relationships for a different account (should be empty)
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert (
len(other_account_children) == 0
), "Expected no children for different account"
print("All relationship tests passed successfully!")
def test_account_with_children() -> None:
"""
Tests querying all accounts and retrieving their child objects.
This test verifies that:
1. All accounts can be retrieved
2. Child objects are correctly linked
3. Child object data is complete and accurate
"""
# First get all account IDs
account_ids = find_ids_by_type("Account")
assert len(account_ids) > 0, "No accounts found"
# For each account, get its children and verify the data
for account_id in account_ids:
account = get_record(account_id)
assert account is not None, f"Could not find account {account_id}"
# Get all child objects
child_ids = get_child_ids(account_id)
# For Acme Inc., verify specific relationships
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
assert (
len(child_ids) == 4
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
# Get all child records
child_records = []
for child_id in child_ids:
child_record = get_record(child_id)
if child_record is not None:
child_records.append(child_record)
# Verify Cases
cases = [r for r in child_records if r.type == "Case"]
assert (
len(cases) == 2
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
case_subjects = {case.data["Subject"] for case in cases}
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
# Verify Contacts
contacts = [r for r in child_records if r.type == "Contact"]
assert (
len(contacts) == 1
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
contact = contacts[0]
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
# Verify Opportunities
opportunities = [r for r in child_records if r.type == "Opportunity"]
assert (
len(opportunities) == 1
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
opportunity = opportunities[0]
assert (
opportunity.data["Name"] == "Test Opportunity"
), "Opportunity name mismatch"
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
print("All account with children tests passed successfully!")
def test_relationship_updates() -> None:
"""
Tests that relationships are properly updated when a child object's parent reference changes.
This test verifies:
1. Initial relationship is created correctly
2. When parent reference is updated, old relationship is removed
3. New relationship is created correctly
"""
# Create initial test data - Contact linked to Acme Inc.
initial_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", initial_contact, "initial_contact.csv")
# Verify initial relationship
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] in acme_children
), "Initial relationship not created"
# Update contact to be linked to Globex Corp instead
updated_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[1],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", updated_contact, "updated_contact.csv")
# Verify old relationship is removed
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] not in acme_children
), "Old relationship not removed"
# Verify new relationship is created
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
print("All relationship update tests passed successfully!")
def test_get_affected_parent_ids() -> None:
"""
Tests get_affected_parent_ids functionality by verifying:
1. IDs that are directly in the parent_types list are included
2. IDs that have children in the updated_ids list are included
3. IDs that are neither of the above are not included
"""
# Create test data with relationships
test_data = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Parent Account 1",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Parent Account 2",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Not Affected Account",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Child",
"LastName": "Contact",
}
],
}
# Create and update CSV files for test data
for object_type, records in test_data.items():
create_csv_file(object_type, records)
# Test Case 1: Account directly in updated_ids and parent_types
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
# Test Case 2: Account with child in updated_ids
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids
), "Parent of updated child not included"
# Test Case 3: Both direct and indirect affects
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
assert (
_VALID_SALESFORCE_IDS[2] not in affected_ids
), "Unaffected ID incorrectly included"
# Test Case 4: No matches
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Opportunity"] # Wrong type
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 0, "Should return empty list when no matches"
print("All get_affected_parent_ids tests passed successfully!")
def main_build() -> None:
clear_sf_db()
create_csv_with_example_data()
test_query()
test_upsert()
test_relationships()
test_account_with_children()
test_relationship_updates()
test_get_affected_parent_ids()
if __name__ == "__main__":
main_build()

View File

@ -0,0 +1,386 @@
import csv
import json
import os
import sqlite3
from collections.abc import Iterator
from contextlib import contextmanager
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
from shared_configs.utils import batch_list
logger = setup_logger()
@contextmanager
def get_db_connection(
isolation_level: str | None = None,
) -> Iterator[sqlite3.Connection]:
"""Get a database connection with proper isolation level and error handling.
Args:
isolation_level: SQLite isolation level. None = default "DEFERRED",
can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
"""
# 60 second timeout for locks
conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0)
if isolation_level is not None:
conn.isolation_level = isolation_level
try:
yield conn
except Exception:
conn.rollback()
raise
finally:
conn.close()
def init_db() -> None:
"""Initialize the SQLite database with required tables if they don't exist."""
if os.path.exists(get_sqlite_db_path()):
return
# Create database directory if it doesn't exist
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
with get_db_connection("EXCLUSIVE") as conn:
cursor = conn.cursor()
# Enable WAL mode for better concurrent access and write performance
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA temp_store=MEMORY")
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
# Main table for storing Salesforce objects
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS salesforce_objects (
id TEXT PRIMARY KEY,
object_type TEXT NOT NULL,
data TEXT NOT NULL, -- JSON serialized data
last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management
) WITHOUT ROWID -- Optimize for primary key lookups
"""
)
# Table for parent-child relationships with covering index
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS relationships (
child_id TEXT NOT NULL,
parent_id TEXT NOT NULL,
PRIMARY KEY (child_id, parent_id)
) WITHOUT ROWID -- Optimize for primary key lookups
"""
)
# New table for caching parent-child relationships with object types
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS relationship_types (
child_id TEXT NOT NULL,
parent_id TEXT NOT NULL,
parent_type TEXT NOT NULL,
PRIMARY KEY (child_id, parent_id, parent_type)
) WITHOUT ROWID
"""
)
# Always recreate indexes to ensure they exist
cursor.execute("DROP INDEX IF EXISTS idx_object_type")
cursor.execute("DROP INDEX IF EXISTS idx_parent_id")
cursor.execute("DROP INDEX IF EXISTS idx_child_parent")
cursor.execute("DROP INDEX IF EXISTS idx_object_type_id")
cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup")
# Create covering indexes for common queries
cursor.execute(
"""
CREATE INDEX idx_object_type
ON salesforce_objects(object_type, id)
WHERE object_type IS NOT NULL
"""
)
cursor.execute(
"""
CREATE INDEX idx_parent_id
ON relationships(parent_id, child_id)
"""
)
cursor.execute(
"""
CREATE INDEX idx_child_parent
ON relationships(child_id)
WHERE child_id IS NOT NULL
"""
)
# New composite index for fast parent type lookups
cursor.execute(
"""
CREATE INDEX idx_relationship_types_lookup
ON relationship_types(parent_type, child_id, parent_id)
"""
)
# Analyze tables to help query planner
cursor.execute("ANALYZE relationships")
cursor.execute("ANALYZE salesforce_objects")
cursor.execute("ANALYZE relationship_types")
conn.commit()
def _update_relationship_tables(
conn: sqlite3.Connection, child_id: str, parent_ids: set[str]
) -> None:
"""Update the relationship tables when a record is updated.
Args:
conn: The database connection to use (must be in a transaction)
child_id: The ID of the child record
parent_ids: Set of parent IDs to link to
"""
try:
cursor = conn.cursor()
# Get existing parent IDs
cursor.execute(
"SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,)
)
old_parent_ids = {row[0] for row in cursor.fetchall()}
# Calculate differences
parent_ids_to_remove = old_parent_ids - parent_ids
parent_ids_to_add = parent_ids - old_parent_ids
# Remove old relationships
if parent_ids_to_remove:
cursor.executemany(
"DELETE FROM relationships WHERE child_id = ? AND parent_id = ?",
[(child_id, pid) for pid in parent_ids_to_remove],
)
# Also remove from relationship_types
cursor.executemany(
"DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?",
[(child_id, pid) for pid in parent_ids_to_remove],
)
# Add new relationships
if parent_ids_to_add:
# First add to relationships table
cursor.executemany(
"INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)",
[(child_id, pid) for pid in parent_ids_to_add],
)
# Then get the types of the parent objects and add to relationship_types
for parent_id in parent_ids_to_add:
cursor.execute(
"SELECT object_type FROM salesforce_objects WHERE id = ?",
(parent_id,),
)
result = cursor.fetchone()
if result:
parent_type = result[0]
cursor.execute(
"""
INSERT INTO relationship_types (child_id, parent_id, parent_type)
VALUES (?, ?, ?)
""",
(child_id, parent_id, parent_type),
)
except Exception as e:
logger.error(f"Error updating relationship tables: {e}")
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
raise
def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]:
"""Update the SF DB with a CSV file using SQLite storage."""
updated_ids = []
# Use IMMEDIATE to get a write lock at the start of the transaction
with get_db_connection("IMMEDIATE") as conn:
cursor = conn.cursor()
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
if "Id" not in row:
logger.warning(
f"Row {row} does not have an Id field in {csv_download_path}"
)
continue
id = row["Id"]
parent_ids = set()
field_to_remove: set[str] = set()
# Process relationships and clean data
for field, value in row.items():
if validate_salesforce_id(value) and field != "Id":
parent_ids.add(value)
field_to_remove.add(field)
if not value:
field_to_remove.add(field)
# Remove unwanted fields
for field in field_to_remove:
if field != "LastModifiedById":
del row[field]
# Update main object data
cursor.execute(
"""
INSERT OR REPLACE INTO salesforce_objects (id, object_type, data)
VALUES (?, ?, ?)
""",
(id, object_type, json.dumps(row)),
)
# Update relationships using the same connection
_update_relationship_tables(conn, id, parent_ids)
updated_ids.append(id)
conn.commit()
return updated_ids
def get_child_ids(parent_id: str) -> set[str]:
"""Get all child IDs for a given parent ID."""
with get_db_connection() as conn:
cursor = conn.cursor()
# Force index usage with INDEXED BY
cursor.execute(
"SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?",
(parent_id,),
)
child_ids = {row[0] for row in cursor.fetchall()}
return child_ids
def get_type_from_id(object_id: str) -> str | None:
"""Get the type of an object from its ID."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,)
)
result = cursor.fetchone()
if not result:
logger.warning(f"Object ID {object_id} not found")
return None
return result[0]
def get_record(
object_id: str, object_type: str | None = None
) -> SalesforceObject | None:
"""Retrieve the record and return it as a SalesforceObject."""
if object_type is None:
object_type = get_type_from_id(object_id)
if not object_type:
return None
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,))
result = cursor.fetchone()
if not result:
logger.warning(f"Object ID {object_id} not found")
return None
data = json.loads(result[0])
return SalesforceObject(id=object_id, type=object_type, data=data)
def find_ids_by_type(object_type: str) -> list[str]:
"""Find all object IDs for rows of the specified type."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,)
)
return [row[0] for row in cursor.fetchall()]
def get_affected_parent_ids_by_type(
updated_ids: list[str],
parent_types: list[str],
batch_size: int = 500,
) -> Iterator[tuple[str, set[str]]]:
"""Get IDs of objects that are of the specified parent types and are either in the
updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids).
"""
# SQLite typically has a limit of 999 variables
updated_ids_batches = batch_list(updated_ids, batch_size)
updated_parent_ids: set[str] = set()
with get_db_connection() as conn:
cursor = conn.cursor()
for batch_ids in updated_ids_batches:
id_placeholders = ",".join(["?" for _ in batch_ids])
for parent_type in parent_types:
affected_ids: set[str] = set()
# Get directly updated objects of parent types - using index on object_type
cursor.execute(
f"""
SELECT id FROM salesforce_objects
WHERE id IN ({id_placeholders})
AND object_type = ?
""",
batch_ids + [parent_type],
)
affected_ids.update(row[0] for row in cursor.fetchall())
# Get parent objects of updated objects - using optimized relationship_types table
cursor.execute(
f"""
SELECT DISTINCT parent_id
FROM relationship_types
INDEXED BY idx_relationship_types_lookup
WHERE parent_type = ?
AND child_id IN ({id_placeholders})
""",
[parent_type] + batch_ids,
)
affected_ids.update(row[0] for row in cursor.fetchall())
# Remove any parent IDs that have already been processed
new_affected_ids = affected_ids - updated_parent_ids
# Add the new affected IDs to the set of updated parent IDs
updated_parent_ids.update(new_affected_ids)
if new_affected_ids:
yield parent_type, new_affected_ids
def has_at_least_one_object_of_type(object_type: str) -> bool:
"""Check if there is at least one object of the specified type in the database.
Args:
object_type: The Salesforce object type to check
Returns:
bool: True if at least one object exists, False otherwise
"""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?",
(object_type,),
)
count = cursor.fetchone()[0]
return count > 0

View File

@ -0,0 +1,72 @@
import os
from dataclasses import dataclass
from typing import Any
@dataclass
class SalesforceObject:
id: str
type: str
data: dict[str, Any]
def to_dict(self) -> dict[str, Any]:
return {
"ID": self.id,
"Type": self.type,
"Data": self.data,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject":
return cls(
id=data["Id"],
type=data["Type"],
data=data,
)
# This defines the base path for all data files relative to this file
# AKA BE CAREFUL WHEN MOVING THIS FILE
BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
def get_sqlite_db_path() -> str:
"""Get the path to the sqlite db file."""
return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite")
def get_object_type_path(object_type: str) -> str:
"""Get the directory path for a specific object type."""
type_dir = os.path.join(BASE_DATA_PATH, object_type)
os.makedirs(type_dir, exist_ok=True)
return type_dir
_CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
_LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)}
def validate_salesforce_id(salesforce_id: str) -> bool:
"""Validate the checksum portion of an 18-character Salesforce ID.
Args:
salesforce_id: An 18-character Salesforce ID
Returns:
bool: True if the checksum is valid, False otherwise
"""
if len(salesforce_id) != 18:
return False
chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]]
checksum = salesforce_id[15:18]
calculated_checksum = ""
for chunk in chunks:
result_string = "".join(
"1" if char.isupper() else "0" for char in reversed(chunk)
)
calculated_checksum += _LOOKUP[result_string]
return checksum == calculated_checksum

View File

@ -0,0 +1,746 @@
import csv
import os
import shutil
from onyx.connectors.salesforce.sqlite_functions import find_ids_by_type
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
_VALID_SALESFORCE_IDS = [
"001bm00000fd9Z3AAI",
"001bm00000fdYTdAAM",
"001bm00000fdYTeAAM",
"001bm00000fdYTfAAM",
"001bm00000fdYTgAAM",
"001bm00000fdYThAAM",
"001bm00000fdYTiAAM",
"001bm00000fdYTjAAM",
"001bm00000fdYTkAAM",
"001bm00000fdYTlAAM",
"001bm00000fdYTmAAM",
"001bm00000fdYTnAAM",
"001bm00000fdYToAAM",
"500bm00000XoOxtAAF",
"500bm00000XoOxuAAF",
"500bm00000XoOxvAAF",
"500bm00000XoOxwAAF",
"500bm00000XoOxxAAF",
"500bm00000XoOxyAAF",
"500bm00000XoOxzAAF",
"500bm00000XoOy0AAF",
"500bm00000XoOy1AAF",
"500bm00000XoOy2AAF",
"500bm00000XoOy3AAF",
"500bm00000XoOy4AAF",
"500bm00000XoOy5AAF",
"500bm00000XoOy6AAF",
"500bm00000XoOy7AAF",
"500bm00000XoOy8AAF",
"500bm00000XoOy9AAF",
"500bm00000XoOyAAAV",
"500bm00000XoOyBAAV",
"500bm00000XoOyCAAV",
"500bm00000XoOyDAAV",
"500bm00000XoOyEAAV",
"500bm00000XoOyFAAV",
"500bm00000XoOyGAAV",
"500bm00000XoOyHAAV",
"500bm00000XoOyIAAV",
"003bm00000EjHCjAAN",
"003bm00000EjHCkAAN",
"003bm00000EjHClAAN",
"003bm00000EjHCmAAN",
"003bm00000EjHCnAAN",
"003bm00000EjHCoAAN",
"003bm00000EjHCpAAN",
"003bm00000EjHCqAAN",
"003bm00000EjHCrAAN",
"003bm00000EjHCsAAN",
"003bm00000EjHCtAAN",
"003bm00000EjHCuAAN",
"003bm00000EjHCvAAN",
"003bm00000EjHCwAAN",
"003bm00000EjHCxAAN",
"003bm00000EjHCyAAN",
"003bm00000EjHCzAAN",
"003bm00000EjHD0AAN",
"003bm00000EjHD1AAN",
"003bm00000EjHD2AAN",
"550bm00000EXc2tAAD",
"006bm000006kyDpAAI",
"006bm000006kyDqAAI",
"006bm000006kyDrAAI",
"006bm000006kyDsAAI",
"006bm000006kyDtAAI",
"006bm000006kyDuAAI",
"006bm000006kyDvAAI",
"006bm000006kyDwAAI",
"006bm000006kyDxAAI",
"006bm000006kyDyAAI",
"006bm000006kyDzAAI",
"006bm000006kyE0AAI",
"006bm000006kyE1AAI",
"006bm000006kyE2AAI",
"006bm000006kyE3AAI",
"006bm000006kyE4AAI",
"006bm000006kyE5AAI",
"006bm000006kyE6AAI",
"006bm000006kyE7AAI",
"006bm000006kyE8AAI",
"006bm000006kyE9AAI",
"006bm000006kyEAAAY",
"006bm000006kyEBAAY",
"006bm000006kyECAAY",
"006bm000006kyEDAAY",
"006bm000006kyEEAAY",
"006bm000006kyEFAAY",
"006bm000006kyEGAAY",
"006bm000006kyEHAAY",
"006bm000006kyEIAAY",
"006bm000006kyEJAAY",
"005bm000009zy0TAAQ",
"005bm000009zy25AAA",
"005bm000009zy26AAA",
"005bm000009zy28AAA",
"005bm000009zy29AAA",
"005bm000009zy2AAAQ",
"005bm000009zy2BAAQ",
]
def _clear_sf_db() -> None:
"""
Clears the SF DB by deleting all files in the data directory.
"""
shutil.rmtree(BASE_DATA_PATH, ignore_errors=True)
def _create_csv_file(
object_type: str, records: list[dict], filename: str = "test_data.csv"
) -> None:
"""
Creates a CSV file for the given object type and records.
Args:
object_type: The Salesforce object type (e.g. "Account", "Contact")
records: List of dictionaries containing the record data
filename: Name of the CSV file to create (default: test_data.csv)
"""
if not records:
return
# Get all unique fields from records
fields: set[str] = set()
for record in records:
fields.update(record.keys())
fields = set(sorted(list(fields))) # Sort for consistent order
# Create CSV file
csv_path = os.path.join(get_object_type_path(object_type), filename)
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
for record in records:
writer.writerow(record)
# Update the database with the CSV
update_sf_db_with_csv(object_type, csv_path)
def _create_csv_with_example_data() -> None:
"""
Creates CSV files with example data, organized by object type.
"""
example_data: dict[str, list[dict]] = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
{
"Id": _VALID_SALESFORCE_IDS[3],
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
{
"Id": _VALID_SALESFORCE_IDS[4],
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
{
"Id": _VALID_SALESFORCE_IDS[5],
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
{
"Id": _VALID_SALESFORCE_IDS[6],
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
{
"Id": _VALID_SALESFORCE_IDS[7],
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"FirstName": "John",
"LastName": "Doe",
"Email": "john.doe@acme.com",
"Title": "CEO",
},
{
"Id": _VALID_SALESFORCE_IDS[41],
"FirstName": "Jane",
"LastName": "Smith",
"Email": "jane.smith@acme.com",
"Title": "CTO",
},
{
"Id": _VALID_SALESFORCE_IDS[42],
"FirstName": "Bob",
"LastName": "Johnson",
"Email": "bob.j@globex.com",
"Title": "Sales Director",
},
{
"Id": _VALID_SALESFORCE_IDS[43],
"FirstName": "Sarah",
"LastName": "Chen",
"Email": "sarah.chen@techcorp.com",
"Title": "Product Manager",
"Phone": "415-555-0101",
},
{
"Id": _VALID_SALESFORCE_IDS[44],
"FirstName": "Michael",
"LastName": "Rodriguez",
"Email": "m.rodriguez@biomed.com",
"Title": "Research Director",
"Phone": "617-555-0202",
},
{
"Id": _VALID_SALESFORCE_IDS[45],
"FirstName": "Emily",
"LastName": "Green",
"Email": "emily.g@greenenergy.com",
"Title": "Sustainability Lead",
"Phone": "503-555-0303",
},
{
"Id": _VALID_SALESFORCE_IDS[46],
"FirstName": "David",
"LastName": "Kim",
"Email": "david.kim@dataflow.com",
"Title": "Data Scientist",
"Phone": "206-555-0404",
},
{
"Id": _VALID_SALESFORCE_IDS[47],
"FirstName": "Rachel",
"LastName": "Taylor",
"Email": "r.taylor@cloudnine.com",
"Title": "Cloud Architect",
"Phone": "303-555-0505",
},
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"Name": "Acme Server Upgrade",
"Amount": 50000,
"Stage": "Prospecting",
"CloseDate": "2024-06-30",
},
{
"Id": _VALID_SALESFORCE_IDS[63],
"Name": "Globex Manufacturing Line",
"Amount": 150000,
"Stage": "Negotiation",
"CloseDate": "2024-03-15",
},
{
"Id": _VALID_SALESFORCE_IDS[64],
"Name": "Initech Software License",
"Amount": 75000,
"Stage": "Closed Won",
"CloseDate": "2024-01-30",
},
{
"Id": _VALID_SALESFORCE_IDS[65],
"Name": "TechCorp AI Implementation",
"Amount": 250000,
"Stage": "Needs Analysis",
"CloseDate": "2024-08-15",
"Probability": 60,
},
{
"Id": _VALID_SALESFORCE_IDS[66],
"Name": "BioMed Lab Equipment",
"Amount": 500000,
"Stage": "Value Proposition",
"CloseDate": "2024-09-30",
"Probability": 75,
},
{
"Id": _VALID_SALESFORCE_IDS[67],
"Name": "Green Energy Solar Project",
"Amount": 750000,
"Stage": "Proposal",
"CloseDate": "2024-07-15",
"Probability": 80,
},
{
"Id": _VALID_SALESFORCE_IDS[68],
"Name": "DataFlow Analytics Platform",
"Amount": 180000,
"Stage": "Negotiation",
"CloseDate": "2024-05-30",
"Probability": 90,
},
{
"Id": _VALID_SALESFORCE_IDS[69],
"Name": "Cloud Nine Infrastructure",
"Amount": 300000,
"Stage": "Qualification",
"CloseDate": "2024-10-15",
"Probability": 40,
},
],
}
# Create CSV files for each object type
for object_type, records in example_data.items():
_create_csv_file(object_type, records)
def _test_query() -> None:
"""
Tests querying functionality by verifying:
1. All expected Account IDs are found
2. Each Account's data matches what was inserted
"""
# Expected test data for verification
expected_accounts: dict[str, dict[str, str | int]] = {
_VALID_SALESFORCE_IDS[0]: {
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
_VALID_SALESFORCE_IDS[1]: {
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
_VALID_SALESFORCE_IDS[2]: {
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
_VALID_SALESFORCE_IDS[3]: {
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
_VALID_SALESFORCE_IDS[4]: {
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
_VALID_SALESFORCE_IDS[5]: {
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
_VALID_SALESFORCE_IDS[6]: {
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
_VALID_SALESFORCE_IDS[7]: {
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
}
# Get all Account IDs
account_ids = find_ids_by_type("Account")
# Verify we found all expected accounts
assert len(account_ids) == len(
expected_accounts
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
assert set(account_ids) == set(
expected_accounts.keys()
), "Found account IDs don't match expected IDs"
# Verify each account's data
for acc_id in account_ids:
combined = get_record(acc_id)
assert combined is not None, f"Could not find account {acc_id}"
expected = expected_accounts[acc_id]
# Verify account data matches
for key, value in expected.items():
value = str(value)
assert (
combined.data[key] == value
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
print("All query tests passed successfully!")
def _test_upsert() -> None:
"""
Tests upsert functionality by:
1. Updating an existing account
2. Creating a new account
3. Verifying both operations were successful
"""
# Create CSV for updating an existing account and adding a new one
update_data: list[dict[str, str | int]] = [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc. Updated",
"BillingCity": "New York",
"Industry": "Technology",
"Description": "Updated company info",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "New Company Inc.",
"BillingCity": "Miami",
"Industry": "Finance",
"AnnualRevenue": 1000000,
},
]
_create_csv_file("Account", update_data, "update_data.csv")
# Verify the update worked
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
assert updated_record is not None, "Updated record not found"
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
assert (
updated_record.data["Description"] == "Updated company info"
), "Description not added"
# Verify the new record was created
new_record = get_record(_VALID_SALESFORCE_IDS[2])
assert new_record is not None, "New record not found"
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
print("All upsert tests passed successfully!")
def _test_relationships() -> None:
"""
Tests relationship shelf updates and queries by:
1. Creating test data with relationships
2. Verifying the relationships are correctly stored
3. Testing relationship queries
"""
# Create test data for each object type
test_data: dict[str, list[dict[str, str | int]]] = {
"Case": [
{
"Id": _VALID_SALESFORCE_IDS[13],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 1",
},
{
"Id": _VALID_SALESFORCE_IDS[14],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 2",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[48],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Name": "Test Opportunity",
"Amount": 100000,
}
],
}
# Create and update CSV files for each object type
for object_type, records in test_data.items():
_create_csv_file(object_type, records, "relationship_test.csv")
# Test relationship queries
# All these objects should be children of Acme Inc.
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
assert (
_VALID_SALESFORCE_IDS[62] in child_ids
), "Opportunity not found in relationship"
# Test querying relationships for a different account (should be empty)
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert (
len(other_account_children) == 0
), "Expected no children for different account"
print("All relationship tests passed successfully!")
def _test_account_with_children() -> None:
"""
Tests querying all accounts and retrieving their child objects.
This test verifies that:
1. All accounts can be retrieved
2. Child objects are correctly linked
3. Child object data is complete and accurate
"""
# First get all account IDs
account_ids = find_ids_by_type("Account")
assert len(account_ids) > 0, "No accounts found"
# For each account, get its children and verify the data
for account_id in account_ids:
account = get_record(account_id)
assert account is not None, f"Could not find account {account_id}"
# Get all child objects
child_ids = get_child_ids(account_id)
# For Acme Inc., verify specific relationships
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
assert (
len(child_ids) == 4
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
# Get all child records
child_records = []
for child_id in child_ids:
child_record = get_record(child_id)
if child_record is not None:
child_records.append(child_record)
# Verify Cases
cases = [r for r in child_records if r.type == "Case"]
assert (
len(cases) == 2
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
case_subjects = {case.data["Subject"] for case in cases}
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
# Verify Contacts
contacts = [r for r in child_records if r.type == "Contact"]
assert (
len(contacts) == 1
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
contact = contacts[0]
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
# Verify Opportunities
opportunities = [r for r in child_records if r.type == "Opportunity"]
assert (
len(opportunities) == 1
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
opportunity = opportunities[0]
assert (
opportunity.data["Name"] == "Test Opportunity"
), "Opportunity name mismatch"
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
print("All account with children tests passed successfully!")
def _test_relationship_updates() -> None:
"""
Tests that relationships are properly updated when a child object's parent reference changes.
This test verifies:
1. Initial relationship is created correctly
2. When parent reference is updated, old relationship is removed
3. New relationship is created correctly
"""
# Create initial test data - Contact linked to Acme Inc.
initial_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
]
_create_csv_file("Contact", initial_contact, "initial_contact.csv")
# Verify initial relationship
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] in acme_children
), "Initial relationship not created"
# Update contact to be linked to Globex Corp instead
updated_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[1],
"FirstName": "Test",
"LastName": "Contact",
}
]
_create_csv_file("Contact", updated_contact, "updated_contact.csv")
# Verify old relationship is removed
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] not in acme_children
), "Old relationship not removed"
# Verify new relationship is created
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
print("All relationship update tests passed successfully!")
def _test_get_affected_parent_ids() -> None:
"""
Tests get_affected_parent_ids functionality by verifying:
1. IDs that are directly in the parent_types list are included
2. IDs that have children in the updated_ids list are included
3. IDs that are neither of the above are not included
"""
# Create test data with relationships
test_data = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Parent Account 1",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Parent Account 2",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Not Affected Account",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Child",
"LastName": "Contact",
}
],
}
# Create and update CSV files for test data
for object_type, records in test_data.items():
_create_csv_file(object_type, records)
# Test Case 1: Account directly in updated_ids and parent_types
updated_ids = [_VALID_SALESFORCE_IDS[1]] # Parent Account 2
parent_types = ["Account"]
affected_ids_by_type = dict(
get_affected_parent_ids_by_type(updated_ids, parent_types)
)
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
assert (
_VALID_SALESFORCE_IDS[1] in affected_ids_by_type["Account"]
), "Direct parent ID not included"
# Test Case 2: Account with child in updated_ids
updated_ids = [_VALID_SALESFORCE_IDS[40]] # Child Contact
parent_types = ["Account"]
affected_ids_by_type = dict(
get_affected_parent_ids_by_type(updated_ids, parent_types)
)
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids_by_type["Account"]
), "Parent of updated child not included"
# Test Case 3: Both direct and indirect affects
updated_ids = [_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]] # Both cases
parent_types = ["Account"]
affected_ids_by_type = dict(
get_affected_parent_ids_by_type(updated_ids, parent_types)
)
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
affected_ids = affected_ids_by_type["Account"]
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
assert (
_VALID_SALESFORCE_IDS[2] not in affected_ids
), "Unaffected ID incorrectly included"
# Test Case 4: No matches
updated_ids = [_VALID_SALESFORCE_IDS[40]] # Child Contact
parent_types = ["Opportunity"] # Wrong type
affected_ids_by_type = dict(
get_affected_parent_ids_by_type(updated_ids, parent_types)
)
assert len(affected_ids_by_type) == 0, "Should return empty dict when no matches"
print("All get_affected_parent_ids tests passed successfully!")
def test_salesforce_sqlite() -> None:
_clear_sf_db()
init_db()
_create_csv_with_example_data()
_test_query()
_test_upsert()
_test_relationships()
_test_account_with_children()
_test_relationship_updates()
_test_get_affected_parent_ids()
_clear_sf_db()