mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-02 03:01:14 +02:00
Reworked salesforce connector to use bulk api (#3581)
This commit is contained in:
parent
3b214133a8
commit
d1ec72b5e5
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
@ -9,3 +9,4 @@ api_keys.py
|
|||||||
vespa-app.zip
|
vespa-app.zip
|
||||||
dynamic_config_storage/
|
dynamic_config_storage/
|
||||||
celerybeat-schedule*
|
celerybeat-schedule*
|
||||||
|
onyx/connectors/salesforce/data/
|
@ -1,11 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from collections.abc import Iterator
|
|
||||||
from datetime import datetime
|
|
||||||
from datetime import UTC
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from simple_salesforce import Salesforce
|
from simple_salesforce import Salesforce
|
||||||
from simple_salesforce import SFType
|
|
||||||
|
|
||||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
@ -20,36 +16,24 @@ from onyx.connectors.models import BasicExpertInfo
|
|||||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||||
from onyx.connectors.models import Document
|
from onyx.connectors.models import Document
|
||||||
from onyx.connectors.models import SlimDocument
|
from onyx.connectors.models import SlimDocument
|
||||||
from onyx.connectors.salesforce.doc_conversion import extract_sections
|
from onyx.connectors.salesforce.doc_conversion import extract_section
|
||||||
|
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||||
|
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
||||||
|
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from shared_configs.utils import batch_list
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
# max query length is 20,000 characters, leave 5000 characters for slop
|
|
||||||
_MAX_QUERY_LENGTH = 10000
|
|
||||||
# There are 22 extra characters per ID so 200 * 22 = 4400 characters which is
|
|
||||||
# still well under the max query length
|
|
||||||
_MAX_ID_BATCH_SIZE = 200
|
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||||
_ID_PREFIX = "SALESFORCE_"
|
_ID_PREFIX = "SALESFORCE_"
|
||||||
|
|
||||||
|
|
||||||
def _build_time_filter_for_salesforce(
|
|
||||||
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
|
||||||
) -> str:
|
|
||||||
if start is None or end is None:
|
|
||||||
return ""
|
|
||||||
start_datetime = datetime.fromtimestamp(start, UTC)
|
|
||||||
end_datetime = datetime.fromtimestamp(end, UTC)
|
|
||||||
return (
|
|
||||||
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
|
|
||||||
f"AND LastModifiedDate < {end_datetime.isoformat()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -64,7 +48,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
|||||||
else _DEFAULT_PARENT_OBJECT_TYPES
|
else _DEFAULT_PARENT_OBJECT_TYPES
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
def load_credentials(
|
||||||
|
self,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
self._sf_client = Salesforce(
|
self._sf_client = Salesforce(
|
||||||
username=credentials["sf_username"],
|
username=credentials["sf_username"],
|
||||||
password=credentials["sf_password"],
|
password=credentials["sf_password"],
|
||||||
@ -78,203 +65,146 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
|||||||
raise ConnectorMissingCredentialError("Salesforce")
|
raise ConnectorMissingCredentialError("Salesforce")
|
||||||
return self._sf_client
|
return self._sf_client
|
||||||
|
|
||||||
def _get_sf_type_object_json(self, type_name: str) -> Any:
|
def _extract_primary_owners(
|
||||||
sf_object = SFType(
|
self, sf_object: SalesforceObject
|
||||||
type_name, self.sf_client.session_id, self.sf_client.sf_instance
|
) -> list[BasicExpertInfo] | None:
|
||||||
)
|
object_dict = sf_object.data
|
||||||
return sf_object.describe()
|
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
|
||||||
|
return None
|
||||||
def _get_name_from_id(self, id: str) -> str:
|
if not (last_modified_by := get_record(last_modified_by_id)):
|
||||||
try:
|
return None
|
||||||
user_object_info = self.sf_client.query(
|
if not (last_modified_by_name := last_modified_by.data.get("Name")):
|
||||||
f"SELECT Name FROM User WHERE Id = '{id}'"
|
return None
|
||||||
)
|
primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)]
|
||||||
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
|
return primary_owners
|
||||||
return name
|
|
||||||
except Exception:
|
|
||||||
logger.warning(f"Couldnt find name for object id: {id}")
|
|
||||||
return "Null User"
|
|
||||||
|
|
||||||
def _convert_object_instance_to_document(
|
def _convert_object_instance_to_document(
|
||||||
self, object_dict: dict[str, Any]
|
self, sf_object: SalesforceObject
|
||||||
) -> Document:
|
) -> Document:
|
||||||
|
object_dict = sf_object.data
|
||||||
salesforce_id = object_dict["Id"]
|
salesforce_id = object_dict["Id"]
|
||||||
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
|
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
|
||||||
base_url = f"https://{self.sf_client.sf_instance}"
|
base_url = f"https://{self.sf_client.sf_instance}"
|
||||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||||
extracted_primary_owners = [
|
|
||||||
BasicExpertInfo(
|
sections = [extract_section(sf_object, base_url)]
|
||||||
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
|
for id in get_child_ids(sf_object.id):
|
||||||
)
|
if not (child_object := get_record(id)):
|
||||||
]
|
continue
|
||||||
|
sections.append(extract_section(child_object, base_url))
|
||||||
|
|
||||||
doc = Document(
|
doc = Document(
|
||||||
id=onyx_salesforce_id,
|
id=onyx_salesforce_id,
|
||||||
sections=extract_sections(object_dict, base_url),
|
sections=sections,
|
||||||
source=DocumentSource.SALESFORCE,
|
source=DocumentSource.SALESFORCE,
|
||||||
semantic_identifier=extracted_semantic_identifier,
|
semantic_identifier=extracted_semantic_identifier,
|
||||||
doc_updated_at=extracted_doc_updated_at,
|
doc_updated_at=extracted_doc_updated_at,
|
||||||
primary_owners=extracted_primary_owners,
|
primary_owners=self._extract_primary_owners(sf_object),
|
||||||
metadata={},
|
metadata={},
|
||||||
)
|
)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def _is_valid_child_object(self, child_relationship: dict) -> bool:
|
|
||||||
if not child_relationship["childSObject"]:
|
|
||||||
return False
|
|
||||||
if not child_relationship["relationshipName"]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
sf_type = child_relationship["childSObject"]
|
|
||||||
object_description = self._get_sf_type_object_json(sf_type)
|
|
||||||
if not object_description["queryable"]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
|
||||||
result = self.sf_client.query(query)
|
|
||||||
if result["totalSize"] == 0:
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if child_relationship["field"]:
|
|
||||||
if child_relationship["field"] == "RelatedToId":
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
|
|
||||||
logger.debug(f"Fetching children for SF type: {sf_type}")
|
|
||||||
object_description = self._get_sf_type_object_json(sf_type)
|
|
||||||
|
|
||||||
children_objects: list[dict] = []
|
|
||||||
for child_relationship in object_description["childRelationships"]:
|
|
||||||
if self._is_valid_child_object(child_relationship):
|
|
||||||
children_objects.append(
|
|
||||||
{
|
|
||||||
"relationship_name": child_relationship["relationshipName"],
|
|
||||||
"object_type": child_relationship["childSObject"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return children_objects
|
|
||||||
|
|
||||||
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
|
|
||||||
object_description = self._get_sf_type_object_json(sf_type)
|
|
||||||
|
|
||||||
fields = [
|
|
||||||
field.get("name")
|
|
||||||
for field in object_description["fields"]
|
|
||||||
if field.get("type", "base64") != "base64"
|
|
||||||
]
|
|
||||||
|
|
||||||
return fields
|
|
||||||
|
|
||||||
def _get_parent_object_ids(
|
|
||||||
self, parent_sf_type: str, time_filter_query: str
|
|
||||||
) -> list[str]:
|
|
||||||
"""Fetch all IDs for a given parent object type."""
|
|
||||||
logger.debug(f"Fetching IDs for parent type: {parent_sf_type}")
|
|
||||||
query = f"SELECT Id FROM {parent_sf_type}{time_filter_query}"
|
|
||||||
query_result = self.sf_client.query_all(query)
|
|
||||||
ids = [record["Id"] for record in query_result["records"]]
|
|
||||||
logger.debug(f"Found {len(ids)} IDs for parent type: {parent_sf_type}")
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def _process_id_batch(
|
|
||||||
self,
|
|
||||||
id_batch: list[str],
|
|
||||||
queries: list[str],
|
|
||||||
) -> dict[str, dict[str, Any]]:
|
|
||||||
"""Process a batch of IDs using the given queries."""
|
|
||||||
# Initialize results dictionary for this batch
|
|
||||||
logger.debug(f"Processing batch of {len(id_batch)} IDs")
|
|
||||||
query_results: dict[str, dict[str, Any]] = {}
|
|
||||||
|
|
||||||
# For each query, fetch and combine results for the batch
|
|
||||||
for query in queries:
|
|
||||||
id_filter = f" WHERE Id IN {tuple(id_batch)}"
|
|
||||||
batch_query = query + id_filter
|
|
||||||
logger.debug(f"Executing query with length: {len(batch_query)}")
|
|
||||||
query_result = self.sf_client.query_all(batch_query)
|
|
||||||
logger.debug(f"Retrieved {len(query_result['records'])} records for query")
|
|
||||||
|
|
||||||
for record_dict in query_result["records"]:
|
|
||||||
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
|
|
||||||
|
|
||||||
# Convert results to documents
|
|
||||||
return query_results
|
|
||||||
|
|
||||||
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
|
|
||||||
"""
|
|
||||||
parent_sf_type is a string that represents the Salesforce object type.
|
|
||||||
This function generates queries that will fetch:
|
|
||||||
- all the fields of the parent object type
|
|
||||||
- all the fields of the child objects of the parent object type
|
|
||||||
"""
|
|
||||||
logger.debug(f"Generating queries for parent type: {parent_sf_type}")
|
|
||||||
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
|
|
||||||
logger.debug(f"Found {len(parent_fields)} fields for parent type")
|
|
||||||
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
|
|
||||||
logger.debug(f"Found {len(child_sf_types)} child types")
|
|
||||||
|
|
||||||
query = f"SELECT {', '.join(parent_fields)}"
|
|
||||||
for child_object_dict in child_sf_types:
|
|
||||||
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
|
|
||||||
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
|
|
||||||
|
|
||||||
if len(query_addition) + len(query) > _MAX_QUERY_LENGTH:
|
|
||||||
query += f"\n FROM {parent_sf_type}"
|
|
||||||
yield query
|
|
||||||
query = "SELECT Id" + query_addition
|
|
||||||
else:
|
|
||||||
query += query_addition
|
|
||||||
|
|
||||||
query += f"\n FROM {parent_sf_type}"
|
|
||||||
|
|
||||||
yield query
|
|
||||||
|
|
||||||
def _batch_retrieval(
|
|
||||||
self,
|
|
||||||
id_batches: list[list[str]],
|
|
||||||
queries: list[str],
|
|
||||||
) -> GenerateDocumentsOutput:
|
|
||||||
doc_batch: list[Document] = []
|
|
||||||
# For each batch of IDs, perform all queries and convert to documents
|
|
||||||
# so they can be yielded in batches
|
|
||||||
for id_batch in id_batches:
|
|
||||||
query_results = self._process_id_batch(id_batch, queries)
|
|
||||||
for doc in query_results.values():
|
|
||||||
doc_batch.append(self._convert_object_instance_to_document(doc))
|
|
||||||
if len(doc_batch) >= self.batch_size:
|
|
||||||
yield doc_batch
|
|
||||||
doc_batch = []
|
|
||||||
|
|
||||||
yield doc_batch
|
|
||||||
|
|
||||||
def _fetch_from_salesforce(
|
def _fetch_from_salesforce(
|
||||||
self,
|
self,
|
||||||
start: SecondsSinceUnixEpoch | None = None,
|
start: SecondsSinceUnixEpoch | None = None,
|
||||||
end: SecondsSinceUnixEpoch | None = None,
|
end: SecondsSinceUnixEpoch | None = None,
|
||||||
) -> GenerateDocumentsOutput:
|
) -> GenerateDocumentsOutput:
|
||||||
logger.debug(f"Starting Salesforce fetch from {start} to {end}")
|
init_db()
|
||||||
time_filter_query = _build_time_filter_for_salesforce(start, end)
|
all_object_types: set[str] = set(self.parent_object_list)
|
||||||
|
|
||||||
|
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
|
||||||
|
logger.debug(f"Parent object types: {self.parent_object_list}")
|
||||||
|
|
||||||
|
# This takes like 20 seconds
|
||||||
for parent_object_type in self.parent_object_list:
|
for parent_object_type in self.parent_object_list:
|
||||||
logger.info(f"Processing parent object type: {parent_object_type}")
|
child_types = get_all_children_of_sf_type(
|
||||||
|
self.sf_client, parent_object_type
|
||||||
|
)
|
||||||
|
all_object_types.update(child_types)
|
||||||
|
logger.debug(
|
||||||
|
f"Found {len(child_types)} child types for {parent_object_type}"
|
||||||
|
)
|
||||||
|
|
||||||
all_ids = self._get_parent_object_ids(parent_object_type, time_filter_query)
|
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
|
||||||
logger.info(f"Found {len(all_ids)} IDs for {parent_object_type}")
|
logger.debug(f"All object types: {all_object_types}")
|
||||||
id_batches = batch_list(all_ids, _MAX_ID_BATCH_SIZE)
|
|
||||||
|
|
||||||
# Generate all queries we'll need
|
# checkpoint - we've found all object types, now time to fetch the data
|
||||||
queries = list(self._generate_query_per_parent_type(parent_object_type))
|
logger.info("Starting to fetch CSVs for all object types")
|
||||||
logger.info(f"Generated {len(queries)} queries for {parent_object_type}")
|
# This takes like 30 minutes first time and <2 minutes for updates
|
||||||
yield from self._batch_retrieval(id_batches, queries)
|
object_type_to_csv_path = fetch_all_csvs_in_parallel(
|
||||||
|
sf_client=self.sf_client,
|
||||||
|
object_types=all_object_types,
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_ids: set[str] = set()
|
||||||
|
# This takes like 10 seconds
|
||||||
|
# This is for testing the rest of the functionality if data has
|
||||||
|
# already been fetched and put in sqlite
|
||||||
|
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
|
||||||
|
# for object_type in self.parent_object_list:
|
||||||
|
# updated_ids.update(list(find_ids_by_type(object_type)))
|
||||||
|
|
||||||
|
# This takes 10-70 minutes first time (idk why the range is so big)
|
||||||
|
total_types = len(object_type_to_csv_path)
|
||||||
|
logger.info(f"Starting to process {total_types} object types")
|
||||||
|
|
||||||
|
for i, (object_type, csv_paths) in enumerate(
|
||||||
|
object_type_to_csv_path.items(), 1
|
||||||
|
):
|
||||||
|
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
|
||||||
|
# If path is None, it means it failed to fetch the csv
|
||||||
|
if csv_paths is None:
|
||||||
|
continue
|
||||||
|
# Go through each csv path and use it to update the db
|
||||||
|
for csv_path in csv_paths:
|
||||||
|
logger.debug(f"Updating {object_type} with {csv_path}")
|
||||||
|
new_ids = update_sf_db_with_csv(
|
||||||
|
object_type=object_type,
|
||||||
|
csv_download_path=csv_path,
|
||||||
|
)
|
||||||
|
updated_ids.update(new_ids)
|
||||||
|
logger.debug(
|
||||||
|
f"Added {len(new_ids)} new/updated records for {object_type}"
|
||||||
|
)
|
||||||
|
# Remove the csv file after it has been used
|
||||||
|
# to successfully update the db
|
||||||
|
os.remove(csv_path)
|
||||||
|
|
||||||
|
logger.info(f"Found {len(updated_ids)} total updated records")
|
||||||
|
logger.info(
|
||||||
|
f"Starting to process parent objects of types: {self.parent_object_list}"
|
||||||
|
)
|
||||||
|
|
||||||
|
docs_to_yield: list[Document] = []
|
||||||
|
docs_processed = 0
|
||||||
|
# Takes 15-20 seconds per batch
|
||||||
|
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
|
||||||
|
updated_ids=list(updated_ids),
|
||||||
|
parent_types=self.parent_object_list,
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
|
||||||
|
)
|
||||||
|
for parent_id in parent_id_batch:
|
||||||
|
if not (parent_object := get_record(parent_id, parent_type)):
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
docs_to_yield.append(
|
||||||
|
self._convert_object_instance_to_document(parent_object)
|
||||||
|
)
|
||||||
|
docs_processed += 1
|
||||||
|
|
||||||
|
if len(docs_to_yield) >= self.batch_size:
|
||||||
|
yield docs_to_yield
|
||||||
|
docs_to_yield = []
|
||||||
|
|
||||||
|
yield docs_to_yield
|
||||||
|
|
||||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||||
return self._fetch_from_salesforce()
|
return self._fetch_from_salesforce()
|
||||||
@ -305,9 +235,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
connector = SalesforceConnector(
|
import time
|
||||||
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
|
|
||||||
)
|
connector = SalesforceConnector(requested_objects=["Account"])
|
||||||
|
|
||||||
connector.load_credentials(
|
connector.load_credentials(
|
||||||
{
|
{
|
||||||
@ -316,5 +246,20 @@ if __name__ == "__main__":
|
|||||||
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
|
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
document_batches = connector.load_from_state()
|
start_time = time.time()
|
||||||
print(next(document_batches))
|
doc_count = 0
|
||||||
|
section_count = 0
|
||||||
|
text_count = 0
|
||||||
|
for doc_batch in connector.load_from_state():
|
||||||
|
doc_count += len(doc_batch)
|
||||||
|
print(f"doc_count: {doc_count}")
|
||||||
|
for doc in doc_batch:
|
||||||
|
section_count += len(doc.sections)
|
||||||
|
for section in doc.sections:
|
||||||
|
text_count += len(section.text)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
print(f"Doc count: {doc_count}")
|
||||||
|
print(f"Section count: {section_count}")
|
||||||
|
print(f"Text count: {text_count}")
|
||||||
|
print(f"Time taken: {end_time - start_time}")
|
||||||
|
@ -2,6 +2,7 @@ import re
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from onyx.connectors.models import Section
|
from onyx.connectors.models import Section
|
||||||
|
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||||
|
|
||||||
# All of these types of keys are handled by specific fields in the doc
|
# All of these types of keys are handled by specific fields in the doc
|
||||||
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
|
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
|
||||||
@ -102,6 +103,13 @@ def _extract_dict_text(raw_dict: dict) -> str:
|
|||||||
return natural_language_for_dict
|
return natural_language_for_dict
|
||||||
|
|
||||||
|
|
||||||
|
def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
|
||||||
|
return Section(
|
||||||
|
text=_extract_dict_text(salesforce_object.data),
|
||||||
|
link=f"{base_url}/{salesforce_object.id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _field_value_is_child_object(field_value: dict) -> bool:
|
def _field_value_is_child_object(field_value: dict) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks if the field value is a child object.
|
Checks if the field value is a child object.
|
||||||
@ -115,7 +123,7 @@ def _field_value_is_child_object(field_value: dict) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
|
def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
|
||||||
"""
|
"""
|
||||||
This goes through the salesforce_object and extracts the top level fields as a Section.
|
This goes through the salesforce_object and extracts the top level fields as a Section.
|
||||||
It also goes through the child objects and extracts them as Sections.
|
It also goes through the child objects and extracts them as Sections.
|
||||||
|
210
backend/onyx/connectors/salesforce/salesforce_calls.py
Normal file
210
backend/onyx/connectors/salesforce/salesforce_calls.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pytz import UTC
|
||||||
|
from simple_salesforce import Salesforce
|
||||||
|
from simple_salesforce import SFType
|
||||||
|
from simple_salesforce.bulk2 import SFBulk2Handler
|
||||||
|
from simple_salesforce.bulk2 import SFBulk2Type
|
||||||
|
|
||||||
|
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
|
||||||
|
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||||
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_time_filter_for_salesforce(
|
||||||
|
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||||
|
) -> str:
|
||||||
|
if start is None or end is None:
|
||||||
|
return ""
|
||||||
|
start_datetime = datetime.fromtimestamp(start, UTC)
|
||||||
|
end_datetime = datetime.fromtimestamp(end, UTC)
|
||||||
|
return (
|
||||||
|
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
|
||||||
|
f"AND LastModifiedDate < {end_datetime.isoformat()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
|
||||||
|
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
|
||||||
|
return sf_object.describe()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_child_object(
|
||||||
|
sf_client: Salesforce, child_relationship: dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
if not child_relationship["childSObject"]:
|
||||||
|
return False
|
||||||
|
if not child_relationship["relationshipName"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
sf_type = child_relationship["childSObject"]
|
||||||
|
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||||
|
if not object_description["queryable"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if child_relationship["field"]:
|
||||||
|
if child_relationship["field"] == "RelatedToId":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]:
|
||||||
|
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||||
|
|
||||||
|
child_object_types = set()
|
||||||
|
for child_relationship in object_description["childRelationships"]:
|
||||||
|
if _is_valid_child_object(sf_client, child_relationship):
|
||||||
|
logger.debug(
|
||||||
|
f"Found valid child object {child_relationship['childSObject']}"
|
||||||
|
)
|
||||||
|
child_object_types.add(child_relationship["childSObject"])
|
||||||
|
return child_object_types
|
||||||
|
|
||||||
|
|
||||||
|
def _get_all_queryable_fields_of_sf_type(
|
||||||
|
sf_client: Salesforce,
|
||||||
|
sf_type: str,
|
||||||
|
) -> list[str]:
|
||||||
|
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||||
|
fields: list[dict[str, Any]] = object_description["fields"]
|
||||||
|
valid_fields: set[str] = set()
|
||||||
|
compound_field_names: set[str] = set()
|
||||||
|
for field in fields:
|
||||||
|
if compound_field_name := field.get("compoundFieldName"):
|
||||||
|
compound_field_names.add(compound_field_name)
|
||||||
|
if field.get("type", "base64") == "base64":
|
||||||
|
continue
|
||||||
|
if field_name := field.get("name"):
|
||||||
|
valid_fields.add(field_name)
|
||||||
|
|
||||||
|
return list(valid_fields - compound_field_names)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
Send a small query to check if the object type is empty so we don't
|
||||||
|
perform extra bulk queries
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
||||||
|
result = sf_client.query(query)
|
||||||
|
if result["totalSize"] == 0:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
if "OPERATION_TOO_LARGE" not in str(e):
|
||||||
|
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
|
||||||
|
# Check if the csv already exists
|
||||||
|
if os.path.exists(get_object_type_path(sf_type)):
|
||||||
|
existing_csvs = [
|
||||||
|
os.path.join(get_object_type_path(sf_type), f)
|
||||||
|
for f in os.listdir(get_object_type_path(sf_type))
|
||||||
|
if f.endswith(".csv")
|
||||||
|
]
|
||||||
|
# If the csv already exists, return the path
|
||||||
|
# This is likely due to a previous run that failed
|
||||||
|
# after downloading the csv but before the data was
|
||||||
|
# written to the db
|
||||||
|
if existing_csvs:
|
||||||
|
return existing_csvs
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
|
||||||
|
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
|
||||||
|
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
def _bulk_retrieve_from_salesforce(
|
||||||
|
sf_client: Salesforce,
|
||||||
|
sf_type: str,
|
||||||
|
time_filter: str,
|
||||||
|
) -> tuple[str, list[str] | None]:
|
||||||
|
if not _check_if_object_type_is_empty(sf_client, sf_type):
|
||||||
|
return sf_type, None
|
||||||
|
|
||||||
|
if existing_csvs := _check_for_existing_csvs(sf_type):
|
||||||
|
return sf_type, existing_csvs
|
||||||
|
|
||||||
|
query = _build_bulk_query(sf_client, sf_type, time_filter)
|
||||||
|
|
||||||
|
bulk_2_handler = SFBulk2Handler(
|
||||||
|
session_id=sf_client.session_id,
|
||||||
|
bulk2_url=sf_client.bulk2_url,
|
||||||
|
proxies=sf_client.proxies,
|
||||||
|
session=sf_client.session,
|
||||||
|
)
|
||||||
|
bulk_2_type = SFBulk2Type(
|
||||||
|
object_name=sf_type,
|
||||||
|
bulk2_url=bulk_2_handler.bulk2_url,
|
||||||
|
headers=bulk_2_handler.headers,
|
||||||
|
session=bulk_2_handler.session,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Downloading {sf_type}")
|
||||||
|
logger.info(f"Query: {query}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# This downloads the file to a file in the target path with a random name
|
||||||
|
results = bulk_2_type.download(
|
||||||
|
query=query,
|
||||||
|
path=get_object_type_path(sf_type),
|
||||||
|
max_records=1000000,
|
||||||
|
)
|
||||||
|
all_download_paths = [result["file"] for result in results]
|
||||||
|
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
|
||||||
|
return sf_type, all_download_paths
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
|
||||||
|
return sf_type, None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_all_csvs_in_parallel(
|
||||||
|
sf_client: Salesforce,
|
||||||
|
object_types: set[str],
|
||||||
|
start: SecondsSinceUnixEpoch | None,
|
||||||
|
end: SecondsSinceUnixEpoch | None,
|
||||||
|
) -> dict[str, list[str] | None]:
|
||||||
|
"""
|
||||||
|
Fetches all the csvs in parallel for the given object types
|
||||||
|
Returns a dict of (sf_type, full_download_path)
|
||||||
|
"""
|
||||||
|
time_filter = _build_time_filter_for_salesforce(start, end)
|
||||||
|
time_filter_for_each_object_type = {}
|
||||||
|
# We do this outside of the thread pool executor because this requires
|
||||||
|
# a database connection and we don't want to block the thread pool
|
||||||
|
# executor from running
|
||||||
|
for sf_type in object_types:
|
||||||
|
"""Only add time filter if there is at least one object of the type
|
||||||
|
in the database. We aren't worried about partially completed object update runs
|
||||||
|
because this occurs after we check for existing csvs which covers this case"""
|
||||||
|
if has_at_least_one_object_of_type(sf_type):
|
||||||
|
time_filter_for_each_object_type[sf_type] = time_filter
|
||||||
|
else:
|
||||||
|
time_filter_for_each_object_type[sf_type] = ""
|
||||||
|
|
||||||
|
# Run the bulk retrieve in parallel
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
results = executor.map(
|
||||||
|
lambda object_type: _bulk_retrieve_from_salesforce(
|
||||||
|
sf_client=sf_client,
|
||||||
|
sf_type=object_type,
|
||||||
|
time_filter=time_filter_for_each_object_type[object_type],
|
||||||
|
),
|
||||||
|
object_types,
|
||||||
|
)
|
||||||
|
return dict(results)
|
@ -0,0 +1,209 @@
|
|||||||
|
import csv
|
||||||
|
import shelve
|
||||||
|
|
||||||
|
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
|
||||||
|
get_child_to_parent_shelf_path,
|
||||||
|
)
|
||||||
|
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path
|
||||||
|
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path
|
||||||
|
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
|
||||||
|
get_parent_to_child_shelf_path,
|
||||||
|
)
|
||||||
|
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||||
|
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||||
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _update_relationship_shelves(
|
||||||
|
child_id: str,
|
||||||
|
parent_ids: set[str],
|
||||||
|
) -> None:
|
||||||
|
"""Update the relationship shelf when a record is updated."""
|
||||||
|
try:
|
||||||
|
# Convert child_id to string once
|
||||||
|
str_child_id = str(child_id)
|
||||||
|
|
||||||
|
# First update child to parent mapping
|
||||||
|
with shelve.open(
|
||||||
|
get_child_to_parent_shelf_path(),
|
||||||
|
flag="c",
|
||||||
|
protocol=None,
|
||||||
|
writeback=True,
|
||||||
|
) as child_to_parent_db:
|
||||||
|
old_parent_ids = set(child_to_parent_db.get(str_child_id, []))
|
||||||
|
child_to_parent_db[str_child_id] = list(parent_ids)
|
||||||
|
|
||||||
|
# Calculate differences outside the next context manager
|
||||||
|
parent_ids_to_remove = old_parent_ids - parent_ids
|
||||||
|
parent_ids_to_add = parent_ids - old_parent_ids
|
||||||
|
|
||||||
|
# Only sync once at the end
|
||||||
|
child_to_parent_db.sync()
|
||||||
|
|
||||||
|
# Then update parent to child mapping in a single transaction
|
||||||
|
if not parent_ids_to_remove and not parent_ids_to_add:
|
||||||
|
return
|
||||||
|
with shelve.open(
|
||||||
|
get_parent_to_child_shelf_path(),
|
||||||
|
flag="c",
|
||||||
|
protocol=None,
|
||||||
|
writeback=True,
|
||||||
|
) as parent_to_child_db:
|
||||||
|
# Process all removals first
|
||||||
|
for parent_id in parent_ids_to_remove:
|
||||||
|
str_parent_id = str(parent_id)
|
||||||
|
existing_children = set(parent_to_child_db.get(str_parent_id, []))
|
||||||
|
if str_child_id in existing_children:
|
||||||
|
existing_children.remove(str_child_id)
|
||||||
|
parent_to_child_db[str_parent_id] = list(existing_children)
|
||||||
|
|
||||||
|
# Then process all additions
|
||||||
|
for parent_id in parent_ids_to_add:
|
||||||
|
str_parent_id = str(parent_id)
|
||||||
|
existing_children = set(parent_to_child_db.get(str_parent_id, []))
|
||||||
|
existing_children.add(str_child_id)
|
||||||
|
parent_to_child_db[str_parent_id] = list(existing_children)
|
||||||
|
|
||||||
|
# Single sync at the end
|
||||||
|
parent_to_child_db.sync()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating relationship shelves: {e}")
|
||||||
|
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_child_ids(parent_id: str) -> set[str]:
|
||||||
|
"""Get all child IDs for a given parent ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent_id: The ID of the parent object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of child object IDs
|
||||||
|
"""
|
||||||
|
with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db:
|
||||||
|
return set(parent_to_child_db.get(parent_id, []))
|
||||||
|
|
||||||
|
|
||||||
|
def update_sf_db_with_csv(
|
||||||
|
object_type: str,
|
||||||
|
csv_download_path: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Update the SF DB with a CSV file using shelve storage."""
|
||||||
|
updated_ids = []
|
||||||
|
shelf_path = get_object_shelf_path(object_type)
|
||||||
|
|
||||||
|
# First read the CSV to get all the data
|
||||||
|
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
id = row["Id"]
|
||||||
|
parent_ids = set()
|
||||||
|
field_to_remove: set[str] = set()
|
||||||
|
# Update relationship shelves for any parent references
|
||||||
|
for field, value in row.items():
|
||||||
|
if validate_salesforce_id(value) and field != "Id":
|
||||||
|
parent_ids.add(value)
|
||||||
|
field_to_remove.add(field)
|
||||||
|
if not value:
|
||||||
|
field_to_remove.add(field)
|
||||||
|
_update_relationship_shelves(id, parent_ids)
|
||||||
|
for field in field_to_remove:
|
||||||
|
# We use this to extract the Primary Owner later
|
||||||
|
if field != "LastModifiedById":
|
||||||
|
del row[field]
|
||||||
|
|
||||||
|
# Update the main object shelf
|
||||||
|
with shelve.open(shelf_path) as object_type_db:
|
||||||
|
object_type_db[id] = row
|
||||||
|
# Update the ID-to-type mapping shelf
|
||||||
|
with shelve.open(get_id_type_shelf_path()) as id_type_db:
|
||||||
|
id_type_db[id] = object_type
|
||||||
|
|
||||||
|
updated_ids.append(id)
|
||||||
|
|
||||||
|
# os.remove(csv_download_path)
|
||||||
|
return updated_ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_type_from_id(object_id: str) -> str | None:
|
||||||
|
"""Get the type of an object from its ID."""
|
||||||
|
# Look up the object type from the ID-to-type mapping
|
||||||
|
with shelve.open(get_id_type_shelf_path()) as id_type_db:
|
||||||
|
if object_id not in id_type_db:
|
||||||
|
logger.warning(f"Object ID {object_id} not found in ID-to-type mapping")
|
||||||
|
return None
|
||||||
|
return id_type_db[object_id]
|
||||||
|
|
||||||
|
|
||||||
|
def get_record(
|
||||||
|
object_id: str, object_type: str | None = None
|
||||||
|
) -> SalesforceObject | None:
|
||||||
|
"""
|
||||||
|
Retrieve the record and return it as a SalesforceObject.
|
||||||
|
The object type will be looked up from the ID-to-type mapping shelf.
|
||||||
|
"""
|
||||||
|
if object_type is None:
|
||||||
|
if not (object_type := get_type_from_id(object_id)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
shelf_path = get_object_shelf_path(object_type)
|
||||||
|
with shelve.open(shelf_path) as db:
|
||||||
|
if object_id not in db:
|
||||||
|
logger.warning(f"Object ID {object_id} not found in {shelf_path}")
|
||||||
|
return None
|
||||||
|
data = db[object_id]
|
||||||
|
return SalesforceObject(
|
||||||
|
id=object_id,
|
||||||
|
type=object_type,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def find_ids_by_type(object_type: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Find all object IDs for rows of the specified type.
|
||||||
|
"""
|
||||||
|
shelf_path = get_object_shelf_path(object_type)
|
||||||
|
try:
|
||||||
|
with shelve.open(shelf_path) as db:
|
||||||
|
return list(db.keys())
|
||||||
|
except FileNotFoundError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def get_affected_parent_ids_by_type(
|
||||||
|
updated_ids: set[str], parent_types: list[str]
|
||||||
|
) -> dict[str, set[str]]:
|
||||||
|
"""Get IDs of objects that are of the specified parent types and are either in the updated_ids
|
||||||
|
or have children in the updated_ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
updated_ids: List of IDs that were updated
|
||||||
|
parent_types: List of object types to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of IDs that match the criteria
|
||||||
|
"""
|
||||||
|
affected_ids_by_type: dict[str, set[str]] = {}
|
||||||
|
|
||||||
|
# Check each updated ID
|
||||||
|
for updated_id in updated_ids:
|
||||||
|
# Add the ID itself if it's of a parent type
|
||||||
|
updated_type = get_type_from_id(updated_id)
|
||||||
|
if updated_type in parent_types:
|
||||||
|
affected_ids_by_type.setdefault(updated_type, set()).add(updated_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get parents of this ID and add them if they're of a parent type
|
||||||
|
with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db:
|
||||||
|
parent_ids = child_to_parent_db.get(updated_id, [])
|
||||||
|
for parent_id in parent_ids:
|
||||||
|
parent_type = get_type_from_id(parent_id)
|
||||||
|
if parent_type in parent_types:
|
||||||
|
affected_ids_by_type.setdefault(parent_type, set()).add(parent_id)
|
||||||
|
|
||||||
|
return affected_ids_by_type
|
@ -0,0 +1,29 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||||
|
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_object_shelf_path(object_type: str) -> str:
|
||||||
|
"""Get the path to the shelf file for a specific object type."""
|
||||||
|
base_path = get_object_type_path(object_type)
|
||||||
|
os.makedirs(base_path, exist_ok=True)
|
||||||
|
return os.path.join(base_path, "data.shelf")
|
||||||
|
|
||||||
|
|
||||||
|
def get_id_type_shelf_path() -> str:
|
||||||
|
"""Get the path to the ID-to-type mapping shelf."""
|
||||||
|
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||||
|
return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g")
|
||||||
|
|
||||||
|
|
||||||
|
def get_parent_to_child_shelf_path() -> str:
|
||||||
|
"""Get the path to the parent-to-child mapping shelf."""
|
||||||
|
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||||
|
return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g")
|
||||||
|
|
||||||
|
|
||||||
|
def get_child_to_parent_shelf_path() -> str:
|
||||||
|
"""Get the path to the child-to-parent mapping shelf."""
|
||||||
|
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||||
|
return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g")
|
@ -0,0 +1,737 @@
|
|||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
|
||||||
|
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||||
|
get_affected_parent_ids_by_type,
|
||||||
|
)
|
||||||
|
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
|
||||||
|
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
|
||||||
|
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||||
|
update_sf_db_with_csv,
|
||||||
|
)
|
||||||
|
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||||
|
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||||
|
|
||||||
|
_VALID_SALESFORCE_IDS = [
|
||||||
|
"001bm00000fd9Z3AAI",
|
||||||
|
"001bm00000fdYTdAAM",
|
||||||
|
"001bm00000fdYTeAAM",
|
||||||
|
"001bm00000fdYTfAAM",
|
||||||
|
"001bm00000fdYTgAAM",
|
||||||
|
"001bm00000fdYThAAM",
|
||||||
|
"001bm00000fdYTiAAM",
|
||||||
|
"001bm00000fdYTjAAM",
|
||||||
|
"001bm00000fdYTkAAM",
|
||||||
|
"001bm00000fdYTlAAM",
|
||||||
|
"001bm00000fdYTmAAM",
|
||||||
|
"001bm00000fdYTnAAM",
|
||||||
|
"001bm00000fdYToAAM",
|
||||||
|
"500bm00000XoOxtAAF",
|
||||||
|
"500bm00000XoOxuAAF",
|
||||||
|
"500bm00000XoOxvAAF",
|
||||||
|
"500bm00000XoOxwAAF",
|
||||||
|
"500bm00000XoOxxAAF",
|
||||||
|
"500bm00000XoOxyAAF",
|
||||||
|
"500bm00000XoOxzAAF",
|
||||||
|
"500bm00000XoOy0AAF",
|
||||||
|
"500bm00000XoOy1AAF",
|
||||||
|
"500bm00000XoOy2AAF",
|
||||||
|
"500bm00000XoOy3AAF",
|
||||||
|
"500bm00000XoOy4AAF",
|
||||||
|
"500bm00000XoOy5AAF",
|
||||||
|
"500bm00000XoOy6AAF",
|
||||||
|
"500bm00000XoOy7AAF",
|
||||||
|
"500bm00000XoOy8AAF",
|
||||||
|
"500bm00000XoOy9AAF",
|
||||||
|
"500bm00000XoOyAAAV",
|
||||||
|
"500bm00000XoOyBAAV",
|
||||||
|
"500bm00000XoOyCAAV",
|
||||||
|
"500bm00000XoOyDAAV",
|
||||||
|
"500bm00000XoOyEAAV",
|
||||||
|
"500bm00000XoOyFAAV",
|
||||||
|
"500bm00000XoOyGAAV",
|
||||||
|
"500bm00000XoOyHAAV",
|
||||||
|
"500bm00000XoOyIAAV",
|
||||||
|
"003bm00000EjHCjAAN",
|
||||||
|
"003bm00000EjHCkAAN",
|
||||||
|
"003bm00000EjHClAAN",
|
||||||
|
"003bm00000EjHCmAAN",
|
||||||
|
"003bm00000EjHCnAAN",
|
||||||
|
"003bm00000EjHCoAAN",
|
||||||
|
"003bm00000EjHCpAAN",
|
||||||
|
"003bm00000EjHCqAAN",
|
||||||
|
"003bm00000EjHCrAAN",
|
||||||
|
"003bm00000EjHCsAAN",
|
||||||
|
"003bm00000EjHCtAAN",
|
||||||
|
"003bm00000EjHCuAAN",
|
||||||
|
"003bm00000EjHCvAAN",
|
||||||
|
"003bm00000EjHCwAAN",
|
||||||
|
"003bm00000EjHCxAAN",
|
||||||
|
"003bm00000EjHCyAAN",
|
||||||
|
"003bm00000EjHCzAAN",
|
||||||
|
"003bm00000EjHD0AAN",
|
||||||
|
"003bm00000EjHD1AAN",
|
||||||
|
"003bm00000EjHD2AAN",
|
||||||
|
"550bm00000EXc2tAAD",
|
||||||
|
"006bm000006kyDpAAI",
|
||||||
|
"006bm000006kyDqAAI",
|
||||||
|
"006bm000006kyDrAAI",
|
||||||
|
"006bm000006kyDsAAI",
|
||||||
|
"006bm000006kyDtAAI",
|
||||||
|
"006bm000006kyDuAAI",
|
||||||
|
"006bm000006kyDvAAI",
|
||||||
|
"006bm000006kyDwAAI",
|
||||||
|
"006bm000006kyDxAAI",
|
||||||
|
"006bm000006kyDyAAI",
|
||||||
|
"006bm000006kyDzAAI",
|
||||||
|
"006bm000006kyE0AAI",
|
||||||
|
"006bm000006kyE1AAI",
|
||||||
|
"006bm000006kyE2AAI",
|
||||||
|
"006bm000006kyE3AAI",
|
||||||
|
"006bm000006kyE4AAI",
|
||||||
|
"006bm000006kyE5AAI",
|
||||||
|
"006bm000006kyE6AAI",
|
||||||
|
"006bm000006kyE7AAI",
|
||||||
|
"006bm000006kyE8AAI",
|
||||||
|
"006bm000006kyE9AAI",
|
||||||
|
"006bm000006kyEAAAY",
|
||||||
|
"006bm000006kyEBAAY",
|
||||||
|
"006bm000006kyECAAY",
|
||||||
|
"006bm000006kyEDAAY",
|
||||||
|
"006bm000006kyEEAAY",
|
||||||
|
"006bm000006kyEFAAY",
|
||||||
|
"006bm000006kyEGAAY",
|
||||||
|
"006bm000006kyEHAAY",
|
||||||
|
"006bm000006kyEIAAY",
|
||||||
|
"006bm000006kyEJAAY",
|
||||||
|
"005bm000009zy0TAAQ",
|
||||||
|
"005bm000009zy25AAA",
|
||||||
|
"005bm000009zy26AAA",
|
||||||
|
"005bm000009zy28AAA",
|
||||||
|
"005bm000009zy29AAA",
|
||||||
|
"005bm000009zy2AAAQ",
|
||||||
|
"005bm000009zy2BAAQ",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def clear_sf_db() -> None:
|
||||||
|
"""
|
||||||
|
Clears the SF DB by deleting all files in the data directory.
|
||||||
|
"""
|
||||||
|
shutil.rmtree(BASE_DATA_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
def create_csv_file(
|
||||||
|
object_type: str, records: list[dict], filename: str = "test_data.csv"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Creates a CSV file for the given object type and records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_type: The Salesforce object type (e.g. "Account", "Contact")
|
||||||
|
records: List of dictionaries containing the record data
|
||||||
|
filename: Name of the CSV file to create (default: test_data.csv)
|
||||||
|
"""
|
||||||
|
if not records:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all unique fields from records
|
||||||
|
fields: set[str] = set()
|
||||||
|
for record in records:
|
||||||
|
fields.update(record.keys())
|
||||||
|
fields = set(sorted(list(fields))) # Sort for consistent order
|
||||||
|
|
||||||
|
# Create CSV file
|
||||||
|
csv_path = os.path.join(get_object_type_path(object_type), filename)
|
||||||
|
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fields)
|
||||||
|
writer.writeheader()
|
||||||
|
for record in records:
|
||||||
|
writer.writerow(record)
|
||||||
|
|
||||||
|
# Update the database with the CSV
|
||||||
|
update_sf_db_with_csv(object_type, csv_path)
|
||||||
|
|
||||||
|
|
||||||
|
def create_csv_with_example_data() -> None:
|
||||||
|
"""
|
||||||
|
Creates CSV files with example data, organized by object type.
|
||||||
|
"""
|
||||||
|
example_data: dict[str, list[dict]] = {
|
||||||
|
"Account": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Name": "Acme Inc.",
|
||||||
|
"BillingCity": "New York",
|
||||||
|
"Industry": "Technology",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[1],
|
||||||
|
"Name": "Globex Corp",
|
||||||
|
"BillingCity": "Los Angeles",
|
||||||
|
"Industry": "Manufacturing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[2],
|
||||||
|
"Name": "Initech",
|
||||||
|
"BillingCity": "Austin",
|
||||||
|
"Industry": "Software",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[3],
|
||||||
|
"Name": "TechCorp Solutions",
|
||||||
|
"BillingCity": "San Francisco",
|
||||||
|
"Industry": "Software",
|
||||||
|
"AnnualRevenue": 5000000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[4],
|
||||||
|
"Name": "BioMed Research",
|
||||||
|
"BillingCity": "Boston",
|
||||||
|
"Industry": "Healthcare",
|
||||||
|
"AnnualRevenue": 12000000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[5],
|
||||||
|
"Name": "Green Energy Co",
|
||||||
|
"BillingCity": "Portland",
|
||||||
|
"Industry": "Energy",
|
||||||
|
"AnnualRevenue": 8000000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[6],
|
||||||
|
"Name": "DataFlow Analytics",
|
||||||
|
"BillingCity": "Seattle",
|
||||||
|
"Industry": "Technology",
|
||||||
|
"AnnualRevenue": 3000000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[7],
|
||||||
|
"Name": "Cloud Nine Services",
|
||||||
|
"BillingCity": "Denver",
|
||||||
|
"Industry": "Cloud Computing",
|
||||||
|
"AnnualRevenue": 7000000,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"Contact": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[40],
|
||||||
|
"FirstName": "John",
|
||||||
|
"LastName": "Doe",
|
||||||
|
"Email": "john.doe@acme.com",
|
||||||
|
"Title": "CEO",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[41],
|
||||||
|
"FirstName": "Jane",
|
||||||
|
"LastName": "Smith",
|
||||||
|
"Email": "jane.smith@acme.com",
|
||||||
|
"Title": "CTO",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[42],
|
||||||
|
"FirstName": "Bob",
|
||||||
|
"LastName": "Johnson",
|
||||||
|
"Email": "bob.j@globex.com",
|
||||||
|
"Title": "Sales Director",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[43],
|
||||||
|
"FirstName": "Sarah",
|
||||||
|
"LastName": "Chen",
|
||||||
|
"Email": "sarah.chen@techcorp.com",
|
||||||
|
"Title": "Product Manager",
|
||||||
|
"Phone": "415-555-0101",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[44],
|
||||||
|
"FirstName": "Michael",
|
||||||
|
"LastName": "Rodriguez",
|
||||||
|
"Email": "m.rodriguez@biomed.com",
|
||||||
|
"Title": "Research Director",
|
||||||
|
"Phone": "617-555-0202",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[45],
|
||||||
|
"FirstName": "Emily",
|
||||||
|
"LastName": "Green",
|
||||||
|
"Email": "emily.g@greenenergy.com",
|
||||||
|
"Title": "Sustainability Lead",
|
||||||
|
"Phone": "503-555-0303",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[46],
|
||||||
|
"FirstName": "David",
|
||||||
|
"LastName": "Kim",
|
||||||
|
"Email": "david.kim@dataflow.com",
|
||||||
|
"Title": "Data Scientist",
|
||||||
|
"Phone": "206-555-0404",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[47],
|
||||||
|
"FirstName": "Rachel",
|
||||||
|
"LastName": "Taylor",
|
||||||
|
"Email": "r.taylor@cloudnine.com",
|
||||||
|
"Title": "Cloud Architect",
|
||||||
|
"Phone": "303-555-0505",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"Opportunity": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[62],
|
||||||
|
"Name": "Acme Server Upgrade",
|
||||||
|
"Amount": 50000,
|
||||||
|
"Stage": "Prospecting",
|
||||||
|
"CloseDate": "2024-06-30",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[63],
|
||||||
|
"Name": "Globex Manufacturing Line",
|
||||||
|
"Amount": 150000,
|
||||||
|
"Stage": "Negotiation",
|
||||||
|
"CloseDate": "2024-03-15",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[64],
|
||||||
|
"Name": "Initech Software License",
|
||||||
|
"Amount": 75000,
|
||||||
|
"Stage": "Closed Won",
|
||||||
|
"CloseDate": "2024-01-30",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[65],
|
||||||
|
"Name": "TechCorp AI Implementation",
|
||||||
|
"Amount": 250000,
|
||||||
|
"Stage": "Needs Analysis",
|
||||||
|
"CloseDate": "2024-08-15",
|
||||||
|
"Probability": 60,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[66],
|
||||||
|
"Name": "BioMed Lab Equipment",
|
||||||
|
"Amount": 500000,
|
||||||
|
"Stage": "Value Proposition",
|
||||||
|
"CloseDate": "2024-09-30",
|
||||||
|
"Probability": 75,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[67],
|
||||||
|
"Name": "Green Energy Solar Project",
|
||||||
|
"Amount": 750000,
|
||||||
|
"Stage": "Proposal",
|
||||||
|
"CloseDate": "2024-07-15",
|
||||||
|
"Probability": 80,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[68],
|
||||||
|
"Name": "DataFlow Analytics Platform",
|
||||||
|
"Amount": 180000,
|
||||||
|
"Stage": "Negotiation",
|
||||||
|
"CloseDate": "2024-05-30",
|
||||||
|
"Probability": 90,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[69],
|
||||||
|
"Name": "Cloud Nine Infrastructure",
|
||||||
|
"Amount": 300000,
|
||||||
|
"Stage": "Qualification",
|
||||||
|
"CloseDate": "2024-10-15",
|
||||||
|
"Probability": 40,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create CSV files for each object type
|
||||||
|
for object_type, records in example_data.items():
|
||||||
|
create_csv_file(object_type, records)
|
||||||
|
|
||||||
|
|
||||||
|
def test_query() -> None:
|
||||||
|
"""
|
||||||
|
Tests querying functionality by verifying:
|
||||||
|
1. All expected Account IDs are found
|
||||||
|
2. Each Account's data matches what was inserted
|
||||||
|
"""
|
||||||
|
# Expected test data for verification
|
||||||
|
expected_accounts: dict[str, dict[str, str | int]] = {
|
||||||
|
_VALID_SALESFORCE_IDS[0]: {
|
||||||
|
"Name": "Acme Inc.",
|
||||||
|
"BillingCity": "New York",
|
||||||
|
"Industry": "Technology",
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[1]: {
|
||||||
|
"Name": "Globex Corp",
|
||||||
|
"BillingCity": "Los Angeles",
|
||||||
|
"Industry": "Manufacturing",
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[2]: {
|
||||||
|
"Name": "Initech",
|
||||||
|
"BillingCity": "Austin",
|
||||||
|
"Industry": "Software",
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[3]: {
|
||||||
|
"Name": "TechCorp Solutions",
|
||||||
|
"BillingCity": "San Francisco",
|
||||||
|
"Industry": "Software",
|
||||||
|
"AnnualRevenue": 5000000,
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[4]: {
|
||||||
|
"Name": "BioMed Research",
|
||||||
|
"BillingCity": "Boston",
|
||||||
|
"Industry": "Healthcare",
|
||||||
|
"AnnualRevenue": 12000000,
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[5]: {
|
||||||
|
"Name": "Green Energy Co",
|
||||||
|
"BillingCity": "Portland",
|
||||||
|
"Industry": "Energy",
|
||||||
|
"AnnualRevenue": 8000000,
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[6]: {
|
||||||
|
"Name": "DataFlow Analytics",
|
||||||
|
"BillingCity": "Seattle",
|
||||||
|
"Industry": "Technology",
|
||||||
|
"AnnualRevenue": 3000000,
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[7]: {
|
||||||
|
"Name": "Cloud Nine Services",
|
||||||
|
"BillingCity": "Denver",
|
||||||
|
"Industry": "Cloud Computing",
|
||||||
|
"AnnualRevenue": 7000000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get all Account IDs
|
||||||
|
account_ids = find_ids_by_type("Account")
|
||||||
|
|
||||||
|
# Verify we found all expected accounts
|
||||||
|
assert len(account_ids) == len(
|
||||||
|
expected_accounts
|
||||||
|
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
|
||||||
|
assert set(account_ids) == set(
|
||||||
|
expected_accounts.keys()
|
||||||
|
), "Found account IDs don't match expected IDs"
|
||||||
|
|
||||||
|
# Verify each account's data
|
||||||
|
for acc_id in account_ids:
|
||||||
|
combined = get_record(acc_id)
|
||||||
|
assert combined is not None, f"Could not find account {acc_id}"
|
||||||
|
|
||||||
|
expected = expected_accounts[acc_id]
|
||||||
|
|
||||||
|
# Verify account data matches
|
||||||
|
for key, value in expected.items():
|
||||||
|
value = str(value)
|
||||||
|
assert (
|
||||||
|
combined.data[key] == value
|
||||||
|
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
|
||||||
|
|
||||||
|
print("All query tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_upsert() -> None:
|
||||||
|
"""
|
||||||
|
Tests upsert functionality by:
|
||||||
|
1. Updating an existing account
|
||||||
|
2. Creating a new account
|
||||||
|
3. Verifying both operations were successful
|
||||||
|
"""
|
||||||
|
# Create CSV for updating an existing account and adding a new one
|
||||||
|
update_data: list[dict[str, str | int]] = [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Name": "Acme Inc. Updated",
|
||||||
|
"BillingCity": "New York",
|
||||||
|
"Industry": "Technology",
|
||||||
|
"Description": "Updated company info",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[2],
|
||||||
|
"Name": "New Company Inc.",
|
||||||
|
"BillingCity": "Miami",
|
||||||
|
"Industry": "Finance",
|
||||||
|
"AnnualRevenue": 1000000,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
create_csv_file("Account", update_data, "update_data.csv")
|
||||||
|
|
||||||
|
# Verify the update worked
|
||||||
|
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
|
||||||
|
assert updated_record is not None, "Updated record not found"
|
||||||
|
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
|
||||||
|
assert (
|
||||||
|
updated_record.data["Description"] == "Updated company info"
|
||||||
|
), "Description not added"
|
||||||
|
|
||||||
|
# Verify the new record was created
|
||||||
|
new_record = get_record(_VALID_SALESFORCE_IDS[2])
|
||||||
|
assert new_record is not None, "New record not found"
|
||||||
|
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
|
||||||
|
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
|
||||||
|
|
||||||
|
print("All upsert tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_relationships() -> None:
|
||||||
|
"""
|
||||||
|
Tests relationship shelf updates and queries by:
|
||||||
|
1. Creating test data with relationships
|
||||||
|
2. Verifying the relationships are correctly stored
|
||||||
|
3. Testing relationship queries
|
||||||
|
"""
|
||||||
|
# Create test data for each object type
|
||||||
|
test_data: dict[str, list[dict[str, str | int]]] = {
|
||||||
|
"Case": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[13],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Subject": "Test Case 1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[14],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Subject": "Test Case 2",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"Contact": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[48],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"FirstName": "Test",
|
||||||
|
"LastName": "Contact",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"Opportunity": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[62],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Name": "Test Opportunity",
|
||||||
|
"Amount": 100000,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create and update CSV files for each object type
|
||||||
|
for object_type, records in test_data.items():
|
||||||
|
create_csv_file(object_type, records, "relationship_test.csv")
|
||||||
|
|
||||||
|
# Test relationship queries
|
||||||
|
# All these objects should be children of Acme Inc.
|
||||||
|
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||||
|
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
|
||||||
|
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
|
||||||
|
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
|
||||||
|
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[62] in child_ids
|
||||||
|
), "Opportunity not found in relationship"
|
||||||
|
|
||||||
|
# Test querying relationships for a different account (should be empty)
|
||||||
|
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||||
|
assert (
|
||||||
|
len(other_account_children) == 0
|
||||||
|
), "Expected no children for different account"
|
||||||
|
|
||||||
|
print("All relationship tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_account_with_children() -> None:
|
||||||
|
"""
|
||||||
|
Tests querying all accounts and retrieving their child objects.
|
||||||
|
This test verifies that:
|
||||||
|
1. All accounts can be retrieved
|
||||||
|
2. Child objects are correctly linked
|
||||||
|
3. Child object data is complete and accurate
|
||||||
|
"""
|
||||||
|
# First get all account IDs
|
||||||
|
account_ids = find_ids_by_type("Account")
|
||||||
|
assert len(account_ids) > 0, "No accounts found"
|
||||||
|
|
||||||
|
# For each account, get its children and verify the data
|
||||||
|
for account_id in account_ids:
|
||||||
|
account = get_record(account_id)
|
||||||
|
assert account is not None, f"Could not find account {account_id}"
|
||||||
|
|
||||||
|
# Get all child objects
|
||||||
|
child_ids = get_child_ids(account_id)
|
||||||
|
|
||||||
|
# For Acme Inc., verify specific relationships
|
||||||
|
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
|
||||||
|
assert (
|
||||||
|
len(child_ids) == 4
|
||||||
|
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
|
||||||
|
|
||||||
|
# Get all child records
|
||||||
|
child_records = []
|
||||||
|
for child_id in child_ids:
|
||||||
|
child_record = get_record(child_id)
|
||||||
|
if child_record is not None:
|
||||||
|
child_records.append(child_record)
|
||||||
|
# Verify Cases
|
||||||
|
cases = [r for r in child_records if r.type == "Case"]
|
||||||
|
assert (
|
||||||
|
len(cases) == 2
|
||||||
|
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
|
||||||
|
case_subjects = {case.data["Subject"] for case in cases}
|
||||||
|
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
|
||||||
|
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
|
||||||
|
|
||||||
|
# Verify Contacts
|
||||||
|
contacts = [r for r in child_records if r.type == "Contact"]
|
||||||
|
assert (
|
||||||
|
len(contacts) == 1
|
||||||
|
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
|
||||||
|
contact = contacts[0]
|
||||||
|
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
|
||||||
|
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
|
||||||
|
|
||||||
|
# Verify Opportunities
|
||||||
|
opportunities = [r for r in child_records if r.type == "Opportunity"]
|
||||||
|
assert (
|
||||||
|
len(opportunities) == 1
|
||||||
|
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
|
||||||
|
opportunity = opportunities[0]
|
||||||
|
assert (
|
||||||
|
opportunity.data["Name"] == "Test Opportunity"
|
||||||
|
), "Opportunity name mismatch"
|
||||||
|
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
|
||||||
|
|
||||||
|
print("All account with children tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_relationship_updates() -> None:
|
||||||
|
"""
|
||||||
|
Tests that relationships are properly updated when a child object's parent reference changes.
|
||||||
|
This test verifies:
|
||||||
|
1. Initial relationship is created correctly
|
||||||
|
2. When parent reference is updated, old relationship is removed
|
||||||
|
3. New relationship is created correctly
|
||||||
|
"""
|
||||||
|
# Create initial test data - Contact linked to Acme Inc.
|
||||||
|
initial_contact = [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[40],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"FirstName": "Test",
|
||||||
|
"LastName": "Contact",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
create_csv_file("Contact", initial_contact, "initial_contact.csv")
|
||||||
|
|
||||||
|
# Verify initial relationship
|
||||||
|
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[40] in acme_children
|
||||||
|
), "Initial relationship not created"
|
||||||
|
|
||||||
|
# Update contact to be linked to Globex Corp instead
|
||||||
|
updated_contact = [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[40],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[1],
|
||||||
|
"FirstName": "Test",
|
||||||
|
"LastName": "Contact",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
create_csv_file("Contact", updated_contact, "updated_contact.csv")
|
||||||
|
|
||||||
|
# Verify old relationship is removed
|
||||||
|
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[40] not in acme_children
|
||||||
|
), "Old relationship not removed"
|
||||||
|
|
||||||
|
# Verify new relationship is created
|
||||||
|
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||||
|
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
|
||||||
|
|
||||||
|
print("All relationship update tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_affected_parent_ids() -> None:
|
||||||
|
"""
|
||||||
|
Tests get_affected_parent_ids functionality by verifying:
|
||||||
|
1. IDs that are directly in the parent_types list are included
|
||||||
|
2. IDs that have children in the updated_ids list are included
|
||||||
|
3. IDs that are neither of the above are not included
|
||||||
|
"""
|
||||||
|
# Create test data with relationships
|
||||||
|
test_data = {
|
||||||
|
"Account": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Name": "Parent Account 1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[1],
|
||||||
|
"Name": "Parent Account 2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[2],
|
||||||
|
"Name": "Not Affected Account",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"Contact": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[40],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"FirstName": "Child",
|
||||||
|
"LastName": "Contact",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create and update CSV files for test data
|
||||||
|
for object_type, records in test_data.items():
|
||||||
|
create_csv_file(object_type, records)
|
||||||
|
|
||||||
|
# Test Case 1: Account directly in updated_ids and parent_types
|
||||||
|
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
|
||||||
|
parent_types = ["Account"]
|
||||||
|
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||||
|
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||||
|
|
||||||
|
# Test Case 2: Account with child in updated_ids
|
||||||
|
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||||
|
parent_types = ["Account"]
|
||||||
|
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[0] in affected_ids
|
||||||
|
), "Parent of updated child not included"
|
||||||
|
|
||||||
|
# Test Case 3: Both direct and indirect affects
|
||||||
|
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
|
||||||
|
parent_types = ["Account"]
|
||||||
|
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||||
|
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
|
||||||
|
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
|
||||||
|
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[2] not in affected_ids
|
||||||
|
), "Unaffected ID incorrectly included"
|
||||||
|
|
||||||
|
# Test Case 4: No matches
|
||||||
|
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||||
|
parent_types = ["Opportunity"] # Wrong type
|
||||||
|
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||||
|
assert len(affected_ids) == 0, "Should return empty list when no matches"
|
||||||
|
|
||||||
|
print("All get_affected_parent_ids tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def main_build() -> None:
|
||||||
|
clear_sf_db()
|
||||||
|
create_csv_with_example_data()
|
||||||
|
test_query()
|
||||||
|
test_upsert()
|
||||||
|
test_relationships()
|
||||||
|
test_account_with_children()
|
||||||
|
test_relationship_updates()
|
||||||
|
test_get_affected_parent_ids()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main_build()
|
386
backend/onyx/connectors/salesforce/sqlite_functions.py
Normal file
386
backend/onyx/connectors/salesforce/sqlite_functions.py
Normal file
@ -0,0 +1,386 @@
|
|||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||||
|
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||||
|
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||||
|
from onyx.utils.logger import setup_logger
|
||||||
|
from shared_configs.utils import batch_list
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_db_connection(
|
||||||
|
isolation_level: str | None = None,
|
||||||
|
) -> Iterator[sqlite3.Connection]:
|
||||||
|
"""Get a database connection with proper isolation level and error handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
isolation_level: SQLite isolation level. None = default "DEFERRED",
|
||||||
|
can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
|
||||||
|
"""
|
||||||
|
# 60 second timeout for locks
|
||||||
|
conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0)
|
||||||
|
|
||||||
|
if isolation_level is not None:
|
||||||
|
conn.isolation_level = isolation_level
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def init_db() -> None:
|
||||||
|
"""Initialize the SQLite database with required tables if they don't exist."""
|
||||||
|
if os.path.exists(get_sqlite_db_path()):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create database directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
|
||||||
|
|
||||||
|
with get_db_connection("EXCLUSIVE") as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Enable WAL mode for better concurrent access and write performance
|
||||||
|
cursor.execute("PRAGMA journal_mode=WAL")
|
||||||
|
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
cursor.execute("PRAGMA temp_store=MEMORY")
|
||||||
|
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
|
||||||
|
|
||||||
|
# Main table for storing Salesforce objects
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS salesforce_objects (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
object_type TEXT NOT NULL,
|
||||||
|
data TEXT NOT NULL, -- JSON serialized data
|
||||||
|
last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management
|
||||||
|
) WITHOUT ROWID -- Optimize for primary key lookups
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Table for parent-child relationships with covering index
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS relationships (
|
||||||
|
child_id TEXT NOT NULL,
|
||||||
|
parent_id TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (child_id, parent_id)
|
||||||
|
) WITHOUT ROWID -- Optimize for primary key lookups
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# New table for caching parent-child relationships with object types
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS relationship_types (
|
||||||
|
child_id TEXT NOT NULL,
|
||||||
|
parent_id TEXT NOT NULL,
|
||||||
|
parent_type TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (child_id, parent_id, parent_type)
|
||||||
|
) WITHOUT ROWID
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Always recreate indexes to ensure they exist
|
||||||
|
cursor.execute("DROP INDEX IF EXISTS idx_object_type")
|
||||||
|
cursor.execute("DROP INDEX IF EXISTS idx_parent_id")
|
||||||
|
cursor.execute("DROP INDEX IF EXISTS idx_child_parent")
|
||||||
|
cursor.execute("DROP INDEX IF EXISTS idx_object_type_id")
|
||||||
|
cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup")
|
||||||
|
|
||||||
|
# Create covering indexes for common queries
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX idx_object_type
|
||||||
|
ON salesforce_objects(object_type, id)
|
||||||
|
WHERE object_type IS NOT NULL
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX idx_parent_id
|
||||||
|
ON relationships(parent_id, child_id)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX idx_child_parent
|
||||||
|
ON relationships(child_id)
|
||||||
|
WHERE child_id IS NOT NULL
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# New composite index for fast parent type lookups
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX idx_relationship_types_lookup
|
||||||
|
ON relationship_types(parent_type, child_id, parent_id)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Analyze tables to help query planner
|
||||||
|
cursor.execute("ANALYZE relationships")
|
||||||
|
cursor.execute("ANALYZE salesforce_objects")
|
||||||
|
cursor.execute("ANALYZE relationship_types")
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def _update_relationship_tables(
|
||||||
|
conn: sqlite3.Connection, child_id: str, parent_ids: set[str]
|
||||||
|
) -> None:
|
||||||
|
"""Update the relationship tables when a record is updated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conn: The database connection to use (must be in a transaction)
|
||||||
|
child_id: The ID of the child record
|
||||||
|
parent_ids: Set of parent IDs to link to
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Get existing parent IDs
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,)
|
||||||
|
)
|
||||||
|
old_parent_ids = {row[0] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
# Calculate differences
|
||||||
|
parent_ids_to_remove = old_parent_ids - parent_ids
|
||||||
|
parent_ids_to_add = parent_ids - old_parent_ids
|
||||||
|
|
||||||
|
# Remove old relationships
|
||||||
|
if parent_ids_to_remove:
|
||||||
|
cursor.executemany(
|
||||||
|
"DELETE FROM relationships WHERE child_id = ? AND parent_id = ?",
|
||||||
|
[(child_id, pid) for pid in parent_ids_to_remove],
|
||||||
|
)
|
||||||
|
# Also remove from relationship_types
|
||||||
|
cursor.executemany(
|
||||||
|
"DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?",
|
||||||
|
[(child_id, pid) for pid in parent_ids_to_remove],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add new relationships
|
||||||
|
if parent_ids_to_add:
|
||||||
|
# First add to relationships table
|
||||||
|
cursor.executemany(
|
||||||
|
"INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)",
|
||||||
|
[(child_id, pid) for pid in parent_ids_to_add],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then get the types of the parent objects and add to relationship_types
|
||||||
|
for parent_id in parent_ids_to_add:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT object_type FROM salesforce_objects WHERE id = ?",
|
||||||
|
(parent_id,),
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
parent_type = result[0]
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO relationship_types (child_id, parent_id, parent_type)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""",
|
||||||
|
(child_id, parent_id, parent_type),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating relationship tables: {e}")
|
||||||
|
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]:
|
||||||
|
"""Update the SF DB with a CSV file using SQLite storage."""
|
||||||
|
updated_ids = []
|
||||||
|
|
||||||
|
# Use IMMEDIATE to get a write lock at the start of the transaction
|
||||||
|
with get_db_connection("IMMEDIATE") as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
if "Id" not in row:
|
||||||
|
logger.warning(
|
||||||
|
f"Row {row} does not have an Id field in {csv_download_path}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
id = row["Id"]
|
||||||
|
parent_ids = set()
|
||||||
|
field_to_remove: set[str] = set()
|
||||||
|
|
||||||
|
# Process relationships and clean data
|
||||||
|
for field, value in row.items():
|
||||||
|
if validate_salesforce_id(value) and field != "Id":
|
||||||
|
parent_ids.add(value)
|
||||||
|
field_to_remove.add(field)
|
||||||
|
if not value:
|
||||||
|
field_to_remove.add(field)
|
||||||
|
|
||||||
|
# Remove unwanted fields
|
||||||
|
for field in field_to_remove:
|
||||||
|
if field != "LastModifiedById":
|
||||||
|
del row[field]
|
||||||
|
|
||||||
|
# Update main object data
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO salesforce_objects (id, object_type, data)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""",
|
||||||
|
(id, object_type, json.dumps(row)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update relationships using the same connection
|
||||||
|
_update_relationship_tables(conn, id, parent_ids)
|
||||||
|
updated_ids.append(id)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
return updated_ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_child_ids(parent_id: str) -> set[str]:
|
||||||
|
"""Get all child IDs for a given parent ID."""
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Force index usage with INDEXED BY
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?",
|
||||||
|
(parent_id,),
|
||||||
|
)
|
||||||
|
child_ids = {row[0] for row in cursor.fetchall()}
|
||||||
|
return child_ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_type_from_id(object_id: str) -> str | None:
|
||||||
|
"""Get the type of an object from its ID."""
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if not result:
|
||||||
|
logger.warning(f"Object ID {object_id} not found")
|
||||||
|
return None
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_record(
|
||||||
|
object_id: str, object_type: str | None = None
|
||||||
|
) -> SalesforceObject | None:
|
||||||
|
"""Retrieve the record and return it as a SalesforceObject."""
|
||||||
|
if object_type is None:
|
||||||
|
object_type = get_type_from_id(object_id)
|
||||||
|
if not object_type:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if not result:
|
||||||
|
logger.warning(f"Object ID {object_id} not found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = json.loads(result[0])
|
||||||
|
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||||
|
|
||||||
|
|
||||||
|
def find_ids_by_type(object_type: str) -> list[str]:
|
||||||
|
"""Find all object IDs for rows of the specified type."""
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,)
|
||||||
|
)
|
||||||
|
return [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
|
||||||
|
def get_affected_parent_ids_by_type(
|
||||||
|
updated_ids: list[str],
|
||||||
|
parent_types: list[str],
|
||||||
|
batch_size: int = 500,
|
||||||
|
) -> Iterator[tuple[str, set[str]]]:
|
||||||
|
"""Get IDs of objects that are of the specified parent types and are either in the
|
||||||
|
updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids).
|
||||||
|
"""
|
||||||
|
# SQLite typically has a limit of 999 variables
|
||||||
|
updated_ids_batches = batch_list(updated_ids, batch_size)
|
||||||
|
updated_parent_ids: set[str] = set()
|
||||||
|
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
for batch_ids in updated_ids_batches:
|
||||||
|
id_placeholders = ",".join(["?" for _ in batch_ids])
|
||||||
|
|
||||||
|
for parent_type in parent_types:
|
||||||
|
affected_ids: set[str] = set()
|
||||||
|
|
||||||
|
# Get directly updated objects of parent types - using index on object_type
|
||||||
|
cursor.execute(
|
||||||
|
f"""
|
||||||
|
SELECT id FROM salesforce_objects
|
||||||
|
WHERE id IN ({id_placeholders})
|
||||||
|
AND object_type = ?
|
||||||
|
""",
|
||||||
|
batch_ids + [parent_type],
|
||||||
|
)
|
||||||
|
affected_ids.update(row[0] for row in cursor.fetchall())
|
||||||
|
|
||||||
|
# Get parent objects of updated objects - using optimized relationship_types table
|
||||||
|
cursor.execute(
|
||||||
|
f"""
|
||||||
|
SELECT DISTINCT parent_id
|
||||||
|
FROM relationship_types
|
||||||
|
INDEXED BY idx_relationship_types_lookup
|
||||||
|
WHERE parent_type = ?
|
||||||
|
AND child_id IN ({id_placeholders})
|
||||||
|
""",
|
||||||
|
[parent_type] + batch_ids,
|
||||||
|
)
|
||||||
|
affected_ids.update(row[0] for row in cursor.fetchall())
|
||||||
|
|
||||||
|
# Remove any parent IDs that have already been processed
|
||||||
|
new_affected_ids = affected_ids - updated_parent_ids
|
||||||
|
# Add the new affected IDs to the set of updated parent IDs
|
||||||
|
updated_parent_ids.update(new_affected_ids)
|
||||||
|
|
||||||
|
if new_affected_ids:
|
||||||
|
yield parent_type, new_affected_ids
|
||||||
|
|
||||||
|
|
||||||
|
def has_at_least_one_object_of_type(object_type: str) -> bool:
|
||||||
|
"""Check if there is at least one object of the specified type in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_type: The Salesforce object type to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if at least one object exists, False otherwise
|
||||||
|
"""
|
||||||
|
with get_db_connection() as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?",
|
||||||
|
(object_type,),
|
||||||
|
)
|
||||||
|
count = cursor.fetchone()[0]
|
||||||
|
return count > 0
|
72
backend/onyx/connectors/salesforce/utils.py
Normal file
72
backend/onyx/connectors/salesforce/utils.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SalesforceObject:
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
data: dict[str, Any]
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"ID": self.id,
|
||||||
|
"Type": self.type,
|
||||||
|
"Data": self.data,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject":
|
||||||
|
return cls(
|
||||||
|
id=data["Id"],
|
||||||
|
type=data["Type"],
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# This defines the base path for all data files relative to this file
|
||||||
|
# AKA BE CAREFUL WHEN MOVING THIS FILE
|
||||||
|
BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
|
||||||
|
|
||||||
|
|
||||||
|
def get_sqlite_db_path() -> str:
|
||||||
|
"""Get the path to the sqlite db file."""
|
||||||
|
return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite")
|
||||||
|
|
||||||
|
|
||||||
|
def get_object_type_path(object_type: str) -> str:
|
||||||
|
"""Get the directory path for a specific object type."""
|
||||||
|
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||||
|
os.makedirs(type_dir, exist_ok=True)
|
||||||
|
return type_dir
|
||||||
|
|
||||||
|
|
||||||
|
_CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
|
||||||
|
_LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)}
|
||||||
|
|
||||||
|
|
||||||
|
def validate_salesforce_id(salesforce_id: str) -> bool:
|
||||||
|
"""Validate the checksum portion of an 18-character Salesforce ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
salesforce_id: An 18-character Salesforce ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the checksum is valid, False otherwise
|
||||||
|
"""
|
||||||
|
if len(salesforce_id) != 18:
|
||||||
|
return False
|
||||||
|
|
||||||
|
chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]]
|
||||||
|
|
||||||
|
checksum = salesforce_id[15:18]
|
||||||
|
calculated_checksum = ""
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
result_string = "".join(
|
||||||
|
"1" if char.isupper() else "0" for char in reversed(chunk)
|
||||||
|
)
|
||||||
|
calculated_checksum += _LOOKUP[result_string]
|
||||||
|
|
||||||
|
return checksum == calculated_checksum
|
@ -0,0 +1,746 @@
|
|||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import find_ids_by_type
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||||
|
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
||||||
|
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||||
|
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||||
|
|
||||||
|
_VALID_SALESFORCE_IDS = [
|
||||||
|
"001bm00000fd9Z3AAI",
|
||||||
|
"001bm00000fdYTdAAM",
|
||||||
|
"001bm00000fdYTeAAM",
|
||||||
|
"001bm00000fdYTfAAM",
|
||||||
|
"001bm00000fdYTgAAM",
|
||||||
|
"001bm00000fdYThAAM",
|
||||||
|
"001bm00000fdYTiAAM",
|
||||||
|
"001bm00000fdYTjAAM",
|
||||||
|
"001bm00000fdYTkAAM",
|
||||||
|
"001bm00000fdYTlAAM",
|
||||||
|
"001bm00000fdYTmAAM",
|
||||||
|
"001bm00000fdYTnAAM",
|
||||||
|
"001bm00000fdYToAAM",
|
||||||
|
"500bm00000XoOxtAAF",
|
||||||
|
"500bm00000XoOxuAAF",
|
||||||
|
"500bm00000XoOxvAAF",
|
||||||
|
"500bm00000XoOxwAAF",
|
||||||
|
"500bm00000XoOxxAAF",
|
||||||
|
"500bm00000XoOxyAAF",
|
||||||
|
"500bm00000XoOxzAAF",
|
||||||
|
"500bm00000XoOy0AAF",
|
||||||
|
"500bm00000XoOy1AAF",
|
||||||
|
"500bm00000XoOy2AAF",
|
||||||
|
"500bm00000XoOy3AAF",
|
||||||
|
"500bm00000XoOy4AAF",
|
||||||
|
"500bm00000XoOy5AAF",
|
||||||
|
"500bm00000XoOy6AAF",
|
||||||
|
"500bm00000XoOy7AAF",
|
||||||
|
"500bm00000XoOy8AAF",
|
||||||
|
"500bm00000XoOy9AAF",
|
||||||
|
"500bm00000XoOyAAAV",
|
||||||
|
"500bm00000XoOyBAAV",
|
||||||
|
"500bm00000XoOyCAAV",
|
||||||
|
"500bm00000XoOyDAAV",
|
||||||
|
"500bm00000XoOyEAAV",
|
||||||
|
"500bm00000XoOyFAAV",
|
||||||
|
"500bm00000XoOyGAAV",
|
||||||
|
"500bm00000XoOyHAAV",
|
||||||
|
"500bm00000XoOyIAAV",
|
||||||
|
"003bm00000EjHCjAAN",
|
||||||
|
"003bm00000EjHCkAAN",
|
||||||
|
"003bm00000EjHClAAN",
|
||||||
|
"003bm00000EjHCmAAN",
|
||||||
|
"003bm00000EjHCnAAN",
|
||||||
|
"003bm00000EjHCoAAN",
|
||||||
|
"003bm00000EjHCpAAN",
|
||||||
|
"003bm00000EjHCqAAN",
|
||||||
|
"003bm00000EjHCrAAN",
|
||||||
|
"003bm00000EjHCsAAN",
|
||||||
|
"003bm00000EjHCtAAN",
|
||||||
|
"003bm00000EjHCuAAN",
|
||||||
|
"003bm00000EjHCvAAN",
|
||||||
|
"003bm00000EjHCwAAN",
|
||||||
|
"003bm00000EjHCxAAN",
|
||||||
|
"003bm00000EjHCyAAN",
|
||||||
|
"003bm00000EjHCzAAN",
|
||||||
|
"003bm00000EjHD0AAN",
|
||||||
|
"003bm00000EjHD1AAN",
|
||||||
|
"003bm00000EjHD2AAN",
|
||||||
|
"550bm00000EXc2tAAD",
|
||||||
|
"006bm000006kyDpAAI",
|
||||||
|
"006bm000006kyDqAAI",
|
||||||
|
"006bm000006kyDrAAI",
|
||||||
|
"006bm000006kyDsAAI",
|
||||||
|
"006bm000006kyDtAAI",
|
||||||
|
"006bm000006kyDuAAI",
|
||||||
|
"006bm000006kyDvAAI",
|
||||||
|
"006bm000006kyDwAAI",
|
||||||
|
"006bm000006kyDxAAI",
|
||||||
|
"006bm000006kyDyAAI",
|
||||||
|
"006bm000006kyDzAAI",
|
||||||
|
"006bm000006kyE0AAI",
|
||||||
|
"006bm000006kyE1AAI",
|
||||||
|
"006bm000006kyE2AAI",
|
||||||
|
"006bm000006kyE3AAI",
|
||||||
|
"006bm000006kyE4AAI",
|
||||||
|
"006bm000006kyE5AAI",
|
||||||
|
"006bm000006kyE6AAI",
|
||||||
|
"006bm000006kyE7AAI",
|
||||||
|
"006bm000006kyE8AAI",
|
||||||
|
"006bm000006kyE9AAI",
|
||||||
|
"006bm000006kyEAAAY",
|
||||||
|
"006bm000006kyEBAAY",
|
||||||
|
"006bm000006kyECAAY",
|
||||||
|
"006bm000006kyEDAAY",
|
||||||
|
"006bm000006kyEEAAY",
|
||||||
|
"006bm000006kyEFAAY",
|
||||||
|
"006bm000006kyEGAAY",
|
||||||
|
"006bm000006kyEHAAY",
|
||||||
|
"006bm000006kyEIAAY",
|
||||||
|
"006bm000006kyEJAAY",
|
||||||
|
"005bm000009zy0TAAQ",
|
||||||
|
"005bm000009zy25AAA",
|
||||||
|
"005bm000009zy26AAA",
|
||||||
|
"005bm000009zy28AAA",
|
||||||
|
"005bm000009zy29AAA",
|
||||||
|
"005bm000009zy2AAAQ",
|
||||||
|
"005bm000009zy2BAAQ",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_sf_db() -> None:
|
||||||
|
"""
|
||||||
|
Clears the SF DB by deleting all files in the data directory.
|
||||||
|
"""
|
||||||
|
shutil.rmtree(BASE_DATA_PATH, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_csv_file(
|
||||||
|
object_type: str, records: list[dict], filename: str = "test_data.csv"
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Creates a CSV file for the given object type and records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_type: The Salesforce object type (e.g. "Account", "Contact")
|
||||||
|
records: List of dictionaries containing the record data
|
||||||
|
filename: Name of the CSV file to create (default: test_data.csv)
|
||||||
|
"""
|
||||||
|
if not records:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all unique fields from records
|
||||||
|
fields: set[str] = set()
|
||||||
|
for record in records:
|
||||||
|
fields.update(record.keys())
|
||||||
|
fields = set(sorted(list(fields))) # Sort for consistent order
|
||||||
|
|
||||||
|
# Create CSV file
|
||||||
|
csv_path = os.path.join(get_object_type_path(object_type), filename)
|
||||||
|
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fields)
|
||||||
|
writer.writeheader()
|
||||||
|
for record in records:
|
||||||
|
writer.writerow(record)
|
||||||
|
|
||||||
|
# Update the database with the CSV
|
||||||
|
update_sf_db_with_csv(object_type, csv_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_csv_with_example_data() -> None:
|
||||||
|
"""
|
||||||
|
Creates CSV files with example data, organized by object type.
|
||||||
|
"""
|
||||||
|
example_data: dict[str, list[dict]] = {
|
||||||
|
"Account": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Name": "Acme Inc.",
|
||||||
|
"BillingCity": "New York",
|
||||||
|
"Industry": "Technology",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[1],
|
||||||
|
"Name": "Globex Corp",
|
||||||
|
"BillingCity": "Los Angeles",
|
||||||
|
"Industry": "Manufacturing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[2],
|
||||||
|
"Name": "Initech",
|
||||||
|
"BillingCity": "Austin",
|
||||||
|
"Industry": "Software",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[3],
|
||||||
|
"Name": "TechCorp Solutions",
|
||||||
|
"BillingCity": "San Francisco",
|
||||||
|
"Industry": "Software",
|
||||||
|
"AnnualRevenue": 5000000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[4],
|
||||||
|
"Name": "BioMed Research",
|
||||||
|
"BillingCity": "Boston",
|
||||||
|
"Industry": "Healthcare",
|
||||||
|
"AnnualRevenue": 12000000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[5],
|
||||||
|
"Name": "Green Energy Co",
|
||||||
|
"BillingCity": "Portland",
|
||||||
|
"Industry": "Energy",
|
||||||
|
"AnnualRevenue": 8000000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[6],
|
||||||
|
"Name": "DataFlow Analytics",
|
||||||
|
"BillingCity": "Seattle",
|
||||||
|
"Industry": "Technology",
|
||||||
|
"AnnualRevenue": 3000000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[7],
|
||||||
|
"Name": "Cloud Nine Services",
|
||||||
|
"BillingCity": "Denver",
|
||||||
|
"Industry": "Cloud Computing",
|
||||||
|
"AnnualRevenue": 7000000,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"Contact": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[40],
|
||||||
|
"FirstName": "John",
|
||||||
|
"LastName": "Doe",
|
||||||
|
"Email": "john.doe@acme.com",
|
||||||
|
"Title": "CEO",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[41],
|
||||||
|
"FirstName": "Jane",
|
||||||
|
"LastName": "Smith",
|
||||||
|
"Email": "jane.smith@acme.com",
|
||||||
|
"Title": "CTO",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[42],
|
||||||
|
"FirstName": "Bob",
|
||||||
|
"LastName": "Johnson",
|
||||||
|
"Email": "bob.j@globex.com",
|
||||||
|
"Title": "Sales Director",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[43],
|
||||||
|
"FirstName": "Sarah",
|
||||||
|
"LastName": "Chen",
|
||||||
|
"Email": "sarah.chen@techcorp.com",
|
||||||
|
"Title": "Product Manager",
|
||||||
|
"Phone": "415-555-0101",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[44],
|
||||||
|
"FirstName": "Michael",
|
||||||
|
"LastName": "Rodriguez",
|
||||||
|
"Email": "m.rodriguez@biomed.com",
|
||||||
|
"Title": "Research Director",
|
||||||
|
"Phone": "617-555-0202",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[45],
|
||||||
|
"FirstName": "Emily",
|
||||||
|
"LastName": "Green",
|
||||||
|
"Email": "emily.g@greenenergy.com",
|
||||||
|
"Title": "Sustainability Lead",
|
||||||
|
"Phone": "503-555-0303",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[46],
|
||||||
|
"FirstName": "David",
|
||||||
|
"LastName": "Kim",
|
||||||
|
"Email": "david.kim@dataflow.com",
|
||||||
|
"Title": "Data Scientist",
|
||||||
|
"Phone": "206-555-0404",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[47],
|
||||||
|
"FirstName": "Rachel",
|
||||||
|
"LastName": "Taylor",
|
||||||
|
"Email": "r.taylor@cloudnine.com",
|
||||||
|
"Title": "Cloud Architect",
|
||||||
|
"Phone": "303-555-0505",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"Opportunity": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[62],
|
||||||
|
"Name": "Acme Server Upgrade",
|
||||||
|
"Amount": 50000,
|
||||||
|
"Stage": "Prospecting",
|
||||||
|
"CloseDate": "2024-06-30",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[63],
|
||||||
|
"Name": "Globex Manufacturing Line",
|
||||||
|
"Amount": 150000,
|
||||||
|
"Stage": "Negotiation",
|
||||||
|
"CloseDate": "2024-03-15",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[64],
|
||||||
|
"Name": "Initech Software License",
|
||||||
|
"Amount": 75000,
|
||||||
|
"Stage": "Closed Won",
|
||||||
|
"CloseDate": "2024-01-30",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[65],
|
||||||
|
"Name": "TechCorp AI Implementation",
|
||||||
|
"Amount": 250000,
|
||||||
|
"Stage": "Needs Analysis",
|
||||||
|
"CloseDate": "2024-08-15",
|
||||||
|
"Probability": 60,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[66],
|
||||||
|
"Name": "BioMed Lab Equipment",
|
||||||
|
"Amount": 500000,
|
||||||
|
"Stage": "Value Proposition",
|
||||||
|
"CloseDate": "2024-09-30",
|
||||||
|
"Probability": 75,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[67],
|
||||||
|
"Name": "Green Energy Solar Project",
|
||||||
|
"Amount": 750000,
|
||||||
|
"Stage": "Proposal",
|
||||||
|
"CloseDate": "2024-07-15",
|
||||||
|
"Probability": 80,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[68],
|
||||||
|
"Name": "DataFlow Analytics Platform",
|
||||||
|
"Amount": 180000,
|
||||||
|
"Stage": "Negotiation",
|
||||||
|
"CloseDate": "2024-05-30",
|
||||||
|
"Probability": 90,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[69],
|
||||||
|
"Name": "Cloud Nine Infrastructure",
|
||||||
|
"Amount": 300000,
|
||||||
|
"Stage": "Qualification",
|
||||||
|
"CloseDate": "2024-10-15",
|
||||||
|
"Probability": 40,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create CSV files for each object type
|
||||||
|
for object_type, records in example_data.items():
|
||||||
|
_create_csv_file(object_type, records)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_query() -> None:
|
||||||
|
"""
|
||||||
|
Tests querying functionality by verifying:
|
||||||
|
1. All expected Account IDs are found
|
||||||
|
2. Each Account's data matches what was inserted
|
||||||
|
"""
|
||||||
|
# Expected test data for verification
|
||||||
|
expected_accounts: dict[str, dict[str, str | int]] = {
|
||||||
|
_VALID_SALESFORCE_IDS[0]: {
|
||||||
|
"Name": "Acme Inc.",
|
||||||
|
"BillingCity": "New York",
|
||||||
|
"Industry": "Technology",
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[1]: {
|
||||||
|
"Name": "Globex Corp",
|
||||||
|
"BillingCity": "Los Angeles",
|
||||||
|
"Industry": "Manufacturing",
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[2]: {
|
||||||
|
"Name": "Initech",
|
||||||
|
"BillingCity": "Austin",
|
||||||
|
"Industry": "Software",
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[3]: {
|
||||||
|
"Name": "TechCorp Solutions",
|
||||||
|
"BillingCity": "San Francisco",
|
||||||
|
"Industry": "Software",
|
||||||
|
"AnnualRevenue": 5000000,
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[4]: {
|
||||||
|
"Name": "BioMed Research",
|
||||||
|
"BillingCity": "Boston",
|
||||||
|
"Industry": "Healthcare",
|
||||||
|
"AnnualRevenue": 12000000,
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[5]: {
|
||||||
|
"Name": "Green Energy Co",
|
||||||
|
"BillingCity": "Portland",
|
||||||
|
"Industry": "Energy",
|
||||||
|
"AnnualRevenue": 8000000,
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[6]: {
|
||||||
|
"Name": "DataFlow Analytics",
|
||||||
|
"BillingCity": "Seattle",
|
||||||
|
"Industry": "Technology",
|
||||||
|
"AnnualRevenue": 3000000,
|
||||||
|
},
|
||||||
|
_VALID_SALESFORCE_IDS[7]: {
|
||||||
|
"Name": "Cloud Nine Services",
|
||||||
|
"BillingCity": "Denver",
|
||||||
|
"Industry": "Cloud Computing",
|
||||||
|
"AnnualRevenue": 7000000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get all Account IDs
|
||||||
|
account_ids = find_ids_by_type("Account")
|
||||||
|
|
||||||
|
# Verify we found all expected accounts
|
||||||
|
assert len(account_ids) == len(
|
||||||
|
expected_accounts
|
||||||
|
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
|
||||||
|
assert set(account_ids) == set(
|
||||||
|
expected_accounts.keys()
|
||||||
|
), "Found account IDs don't match expected IDs"
|
||||||
|
|
||||||
|
# Verify each account's data
|
||||||
|
for acc_id in account_ids:
|
||||||
|
combined = get_record(acc_id)
|
||||||
|
assert combined is not None, f"Could not find account {acc_id}"
|
||||||
|
|
||||||
|
expected = expected_accounts[acc_id]
|
||||||
|
|
||||||
|
# Verify account data matches
|
||||||
|
for key, value in expected.items():
|
||||||
|
value = str(value)
|
||||||
|
assert (
|
||||||
|
combined.data[key] == value
|
||||||
|
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
|
||||||
|
|
||||||
|
print("All query tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def _test_upsert() -> None:
|
||||||
|
"""
|
||||||
|
Tests upsert functionality by:
|
||||||
|
1. Updating an existing account
|
||||||
|
2. Creating a new account
|
||||||
|
3. Verifying both operations were successful
|
||||||
|
"""
|
||||||
|
# Create CSV for updating an existing account and adding a new one
|
||||||
|
update_data: list[dict[str, str | int]] = [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Name": "Acme Inc. Updated",
|
||||||
|
"BillingCity": "New York",
|
||||||
|
"Industry": "Technology",
|
||||||
|
"Description": "Updated company info",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[2],
|
||||||
|
"Name": "New Company Inc.",
|
||||||
|
"BillingCity": "Miami",
|
||||||
|
"Industry": "Finance",
|
||||||
|
"AnnualRevenue": 1000000,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
_create_csv_file("Account", update_data, "update_data.csv")
|
||||||
|
|
||||||
|
# Verify the update worked
|
||||||
|
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
|
||||||
|
assert updated_record is not None, "Updated record not found"
|
||||||
|
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
|
||||||
|
assert (
|
||||||
|
updated_record.data["Description"] == "Updated company info"
|
||||||
|
), "Description not added"
|
||||||
|
|
||||||
|
# Verify the new record was created
|
||||||
|
new_record = get_record(_VALID_SALESFORCE_IDS[2])
|
||||||
|
assert new_record is not None, "New record not found"
|
||||||
|
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
|
||||||
|
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
|
||||||
|
|
||||||
|
print("All upsert tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def _test_relationships() -> None:
|
||||||
|
"""
|
||||||
|
Tests relationship shelf updates and queries by:
|
||||||
|
1. Creating test data with relationships
|
||||||
|
2. Verifying the relationships are correctly stored
|
||||||
|
3. Testing relationship queries
|
||||||
|
"""
|
||||||
|
# Create test data for each object type
|
||||||
|
test_data: dict[str, list[dict[str, str | int]]] = {
|
||||||
|
"Case": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[13],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Subject": "Test Case 1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[14],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Subject": "Test Case 2",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"Contact": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[48],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"FirstName": "Test",
|
||||||
|
"LastName": "Contact",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"Opportunity": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[62],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Name": "Test Opportunity",
|
||||||
|
"Amount": 100000,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create and update CSV files for each object type
|
||||||
|
for object_type, records in test_data.items():
|
||||||
|
_create_csv_file(object_type, records, "relationship_test.csv")
|
||||||
|
|
||||||
|
# Test relationship queries
|
||||||
|
# All these objects should be children of Acme Inc.
|
||||||
|
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||||
|
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
|
||||||
|
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
|
||||||
|
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
|
||||||
|
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[62] in child_ids
|
||||||
|
), "Opportunity not found in relationship"
|
||||||
|
|
||||||
|
# Test querying relationships for a different account (should be empty)
|
||||||
|
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||||
|
assert (
|
||||||
|
len(other_account_children) == 0
|
||||||
|
), "Expected no children for different account"
|
||||||
|
|
||||||
|
print("All relationship tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def _test_account_with_children() -> None:
|
||||||
|
"""
|
||||||
|
Tests querying all accounts and retrieving their child objects.
|
||||||
|
This test verifies that:
|
||||||
|
1. All accounts can be retrieved
|
||||||
|
2. Child objects are correctly linked
|
||||||
|
3. Child object data is complete and accurate
|
||||||
|
"""
|
||||||
|
# First get all account IDs
|
||||||
|
account_ids = find_ids_by_type("Account")
|
||||||
|
assert len(account_ids) > 0, "No accounts found"
|
||||||
|
|
||||||
|
# For each account, get its children and verify the data
|
||||||
|
for account_id in account_ids:
|
||||||
|
account = get_record(account_id)
|
||||||
|
assert account is not None, f"Could not find account {account_id}"
|
||||||
|
|
||||||
|
# Get all child objects
|
||||||
|
child_ids = get_child_ids(account_id)
|
||||||
|
|
||||||
|
# For Acme Inc., verify specific relationships
|
||||||
|
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
|
||||||
|
assert (
|
||||||
|
len(child_ids) == 4
|
||||||
|
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
|
||||||
|
|
||||||
|
# Get all child records
|
||||||
|
child_records = []
|
||||||
|
for child_id in child_ids:
|
||||||
|
child_record = get_record(child_id)
|
||||||
|
if child_record is not None:
|
||||||
|
child_records.append(child_record)
|
||||||
|
# Verify Cases
|
||||||
|
cases = [r for r in child_records if r.type == "Case"]
|
||||||
|
assert (
|
||||||
|
len(cases) == 2
|
||||||
|
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
|
||||||
|
case_subjects = {case.data["Subject"] for case in cases}
|
||||||
|
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
|
||||||
|
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
|
||||||
|
|
||||||
|
# Verify Contacts
|
||||||
|
contacts = [r for r in child_records if r.type == "Contact"]
|
||||||
|
assert (
|
||||||
|
len(contacts) == 1
|
||||||
|
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
|
||||||
|
contact = contacts[0]
|
||||||
|
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
|
||||||
|
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
|
||||||
|
|
||||||
|
# Verify Opportunities
|
||||||
|
opportunities = [r for r in child_records if r.type == "Opportunity"]
|
||||||
|
assert (
|
||||||
|
len(opportunities) == 1
|
||||||
|
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
|
||||||
|
opportunity = opportunities[0]
|
||||||
|
assert (
|
||||||
|
opportunity.data["Name"] == "Test Opportunity"
|
||||||
|
), "Opportunity name mismatch"
|
||||||
|
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
|
||||||
|
|
||||||
|
print("All account with children tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def _test_relationship_updates() -> None:
|
||||||
|
"""
|
||||||
|
Tests that relationships are properly updated when a child object's parent reference changes.
|
||||||
|
This test verifies:
|
||||||
|
1. Initial relationship is created correctly
|
||||||
|
2. When parent reference is updated, old relationship is removed
|
||||||
|
3. New relationship is created correctly
|
||||||
|
"""
|
||||||
|
# Create initial test data - Contact linked to Acme Inc.
|
||||||
|
initial_contact = [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[40],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"FirstName": "Test",
|
||||||
|
"LastName": "Contact",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
_create_csv_file("Contact", initial_contact, "initial_contact.csv")
|
||||||
|
|
||||||
|
# Verify initial relationship
|
||||||
|
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[40] in acme_children
|
||||||
|
), "Initial relationship not created"
|
||||||
|
|
||||||
|
# Update contact to be linked to Globex Corp instead
|
||||||
|
updated_contact = [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[40],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[1],
|
||||||
|
"FirstName": "Test",
|
||||||
|
"LastName": "Contact",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
_create_csv_file("Contact", updated_contact, "updated_contact.csv")
|
||||||
|
|
||||||
|
# Verify old relationship is removed
|
||||||
|
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[40] not in acme_children
|
||||||
|
), "Old relationship not removed"
|
||||||
|
|
||||||
|
# Verify new relationship is created
|
||||||
|
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||||
|
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
|
||||||
|
|
||||||
|
print("All relationship update tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def _test_get_affected_parent_ids() -> None:
|
||||||
|
"""
|
||||||
|
Tests get_affected_parent_ids functionality by verifying:
|
||||||
|
1. IDs that are directly in the parent_types list are included
|
||||||
|
2. IDs that have children in the updated_ids list are included
|
||||||
|
3. IDs that are neither of the above are not included
|
||||||
|
"""
|
||||||
|
# Create test data with relationships
|
||||||
|
test_data = {
|
||||||
|
"Account": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"Name": "Parent Account 1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[1],
|
||||||
|
"Name": "Parent Account 2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[2],
|
||||||
|
"Name": "Not Affected Account",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"Contact": [
|
||||||
|
{
|
||||||
|
"Id": _VALID_SALESFORCE_IDS[40],
|
||||||
|
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||||
|
"FirstName": "Child",
|
||||||
|
"LastName": "Contact",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create and update CSV files for test data
|
||||||
|
for object_type, records in test_data.items():
|
||||||
|
_create_csv_file(object_type, records)
|
||||||
|
|
||||||
|
# Test Case 1: Account directly in updated_ids and parent_types
|
||||||
|
updated_ids = [_VALID_SALESFORCE_IDS[1]] # Parent Account 2
|
||||||
|
parent_types = ["Account"]
|
||||||
|
affected_ids_by_type = dict(
|
||||||
|
get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||||
|
)
|
||||||
|
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[1] in affected_ids_by_type["Account"]
|
||||||
|
), "Direct parent ID not included"
|
||||||
|
|
||||||
|
# Test Case 2: Account with child in updated_ids
|
||||||
|
updated_ids = [_VALID_SALESFORCE_IDS[40]] # Child Contact
|
||||||
|
parent_types = ["Account"]
|
||||||
|
affected_ids_by_type = dict(
|
||||||
|
get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||||
|
)
|
||||||
|
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[0] in affected_ids_by_type["Account"]
|
||||||
|
), "Parent of updated child not included"
|
||||||
|
|
||||||
|
# Test Case 3: Both direct and indirect affects
|
||||||
|
updated_ids = [_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]] # Both cases
|
||||||
|
parent_types = ["Account"]
|
||||||
|
affected_ids_by_type = dict(
|
||||||
|
get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||||
|
)
|
||||||
|
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
|
||||||
|
affected_ids = affected_ids_by_type["Account"]
|
||||||
|
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
|
||||||
|
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
|
||||||
|
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||||
|
assert (
|
||||||
|
_VALID_SALESFORCE_IDS[2] not in affected_ids
|
||||||
|
), "Unaffected ID incorrectly included"
|
||||||
|
|
||||||
|
# Test Case 4: No matches
|
||||||
|
updated_ids = [_VALID_SALESFORCE_IDS[40]] # Child Contact
|
||||||
|
parent_types = ["Opportunity"] # Wrong type
|
||||||
|
affected_ids_by_type = dict(
|
||||||
|
get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||||
|
)
|
||||||
|
assert len(affected_ids_by_type) == 0, "Should return empty dict when no matches"
|
||||||
|
|
||||||
|
print("All get_affected_parent_ids tests passed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_salesforce_sqlite() -> None:
|
||||||
|
_clear_sf_db()
|
||||||
|
init_db()
|
||||||
|
_create_csv_with_example_data()
|
||||||
|
_test_query()
|
||||||
|
_test_upsert()
|
||||||
|
_test_relationships()
|
||||||
|
_test_account_with_children()
|
||||||
|
_test_relationship_updates()
|
||||||
|
_test_get_affected_parent_ids()
|
||||||
|
_clear_sf_db()
|
Loading…
x
Reference in New Issue
Block a user