Improve Salesforce connector

This commit is contained in:
hagen-danswer
2024-12-29 13:19:14 -08:00
committed by Chris Weaver
parent 9bffeb65af
commit d14ef431a7
5 changed files with 304 additions and 148 deletions

View File

@@ -26,6 +26,10 @@ env:
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
# Salesforce
SF_USERNAME: ${{ secrets.SF_USERNAME }}
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
jobs:
connectors-check:

View File

@@ -1,7 +1,7 @@
import os
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from datetime import UTC
from typing import Any
from simple_salesforce import Salesforce
@@ -19,23 +19,36 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.salesforce.utils import extract_dict_text
from onyx.connectors.salesforce.doc_conversion import extract_sections
from onyx.utils.logger import setup_logger
# TODO: this connector does not work well at large scales
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
ID_PREFIX = "SALESFORCE_"
from shared_configs.utils import batch_list
logger = setup_logger()
# max query length is 20,000 characters, leave 5000 characters for slop
_MAX_QUERY_LENGTH = 10000
# There are 22 extra characters per ID so 200 * 22 = 4400 characters which is
# still well under the max query length
_MAX_ID_BATCH_SIZE = 200
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_ID_PREFIX = "SALESFORCE_"
def _build_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
if start is None or end is None:
return ""
start_datetime = datetime.fromtimestamp(start, UTC)
end_datetime = datetime.fromtimestamp(end, UTC)
return (
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
f"AND LastModifiedDate < {end_datetime.isoformat()}"
)
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
@@ -44,33 +57,34 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
requested_objects: list[str] = [],
) -> None:
self.batch_size = batch_size
self.sf_client: Salesforce | None = None
self._sf_client: Salesforce | None = None
self.parent_object_list = (
[obj.capitalize() for obj in requested_objects]
if requested_objects
else DEFAULT_PARENT_OBJECT_TYPES
else _DEFAULT_PARENT_OBJECT_TYPES
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.sf_client = Salesforce(
self._sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
security_token=credentials["sf_security_token"],
)
return None
def _get_sf_type_object_json(self, type_name: str) -> Any:
if self.sf_client is None:
@property
def sf_client(self) -> Salesforce:
if self._sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
return self._sf_client
def _get_sf_type_object_json(self, type_name: str) -> Any:
sf_object = SFType(
type_name, self.sf_client.session_id, self.sf_client.sf_instance
)
return sf_object.describe()
def _get_name_from_id(self, id: str) -> str:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
try:
user_object_info = self.sf_client.query(
f"SELECT Name FROM User WHERE Id = '{id}'"
@@ -84,14 +98,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def _convert_object_instance_to_document(
self, object_dict: dict[str, Any]
) -> Document:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
base_url = f"https://{self.sf_client.sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_object_text = extract_dict_text(object_dict)
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
extracted_primary_owners = [
BasicExpertInfo(
@@ -101,7 +111,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
doc = Document(
id=onyx_salesforce_id,
sections=[Section(link=extracted_link, text=extracted_object_text)],
sections=extract_sections(object_dict, base_url),
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
@@ -111,9 +121,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
return doc
def _is_valid_child_object(self, child_relationship: dict) -> bool:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
if not child_relationship["childSObject"]:
return False
if not child_relationship["relationshipName"]:
@@ -142,9 +149,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
return True
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
logger.debug(f"Fetching children for SF type: {sf_type}")
object_description = self._get_sf_type_object_json(sf_type)
children_objects: list[dict] = []
@@ -159,9 +164,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
return children_objects
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_description = self._get_sf_type_object_json(sf_type)
fields = [
@@ -172,23 +174,60 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
return fields
def _get_parent_object_ids(
self, parent_sf_type: str, time_filter_query: str
) -> list[str]:
"""Fetch all IDs for a given parent object type."""
logger.debug(f"Fetching IDs for parent type: {parent_sf_type}")
query = f"SELECT Id FROM {parent_sf_type}{time_filter_query}"
query_result = self.sf_client.query_all(query)
ids = [record["Id"] for record in query_result["records"]]
logger.debug(f"Found {len(ids)} IDs for parent type: {parent_sf_type}")
return ids
def _process_id_batch(
self,
id_batch: list[str],
queries: list[str],
) -> dict[str, dict[str, Any]]:
"""Process a batch of IDs using the given queries."""
# Initialize results dictionary for this batch
logger.debug(f"Processing batch of {len(id_batch)} IDs")
query_results: dict[str, dict[str, Any]] = {}
# For each query, fetch and combine results for the batch
for query in queries:
id_filter = f" WHERE Id IN {tuple(id_batch)}"
batch_query = query + id_filter
logger.debug(f"Executing query with length: {len(batch_query)}")
query_result = self.sf_client.query_all(batch_query)
logger.debug(f"Retrieved {len(query_result['records'])} records for query")
for record_dict in query_result["records"]:
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
# Convert results to documents
return query_results
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
"""
This function takes in an object_type and generates query(s) designed to grab
information associated to objects of that type.
It does that by getting all the fields of the parent object type.
Then it gets all the child objects of that object type and all the fields of
those children as well.
parent_sf_type is a string that represents the Salesforce object type.
This function generates queries that will fetch:
- all the fields of the parent object type
- all the fields of the child objects of the parent object type
"""
logger.debug(f"Generating queries for parent type: {parent_sf_type}")
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
logger.debug(f"Found {len(parent_fields)} fields for parent type")
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
logger.debug(f"Found {len(child_sf_types)} child types")
query = f"SELECT {', '.join(parent_fields)}"
for child_object_dict in child_sf_types:
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
if len(query_addition) + len(query) > _MAX_QUERY_LENGTH:
query += f"\n FROM {parent_sf_type}"
yield query
query = "SELECT Id" + query_addition
@@ -199,45 +238,41 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
yield query
def _batch_retrieval(
self,
id_batches: list[list[str]],
queries: list[str],
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
# For each batch of IDs, perform all queries and convert to documents
# so they can be yielded in batches
for id_batch in id_batches:
query_results = self._process_id_batch(id_batch, queries)
for doc in query_results.values():
doc_batch.append(self._convert_object_instance_to_document(doc))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
def _fetch_from_salesforce(
self,
start: datetime | None = None,
end: datetime | None = None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
logger.debug(f"Starting Salesforce fetch from {start} to {end}")
time_filter_query = _build_time_filter_for_salesforce(start, end)
doc_batch: list[Document] = []
for parent_object_type in self.parent_object_list:
logger.debug(f"Processing: {parent_object_type}")
query_results: dict = {}
for query in self._generate_query_per_parent_type(parent_object_type):
if start is not None and end is not None:
if start and start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
all_ids = self._get_parent_object_ids(parent_object_type, time_filter_query)
id_batches = batch_list(all_ids, _MAX_ID_BATCH_SIZE)
query_result = self.sf_client.query_all(query)
for record_dict in query_result["records"]:
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
logger.info(
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
)
for combined_object_dict in query_results.values():
doc_batch.append(
self._convert_object_instance_to_document(combined_object_dict)
)
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
# Generate all queries we'll need
queries = list(self._generate_query_per_parent_type(parent_object_type))
yield from self._batch_retrieval(id_batches, queries)
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_salesforce()
@@ -245,26 +280,20 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
return self._fetch_from_salesforce(start=start, end=end)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
query_result = self.sf_client.query_all(query)
doc_metadata_list.extend(
SlimDocument(
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}",
perm_sync_data={},
)
for instance_dict in query_result["records"]

View File

@@ -0,0 +1,147 @@
import re
from collections import OrderedDict
from onyx.connectors.models import Section
# All of these types of keys are handled by specific fields in the doc
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
_SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
def _clean_salesforce_dict(data: dict | list) -> dict | list:
"""Clean and transform Salesforce API response data by recursively:
1. Extracting records from the response if present
2. Merging attributes into the main dictionary
3. Filtering out keys matching certain patterns (Id, Date, stamp, url)
4. Removing '__c' suffix from custom field names
5. Removing None values and empty containers
Args:
data: A dictionary or list from Salesforce API response
Returns:
Cleaned dictionary or list with transformed keys and filtered values
"""
if isinstance(data, dict):
if "records" in data.keys():
data = data["records"]
if isinstance(data, dict):
if "attributes" in data.keys():
if isinstance(data["attributes"], dict):
data.update(data.pop("attributes"))
if isinstance(data, dict):
filtered_dict = {}
for key, value in data.items():
if not re.search(_SF_JSON_FILTER, key, re.IGNORECASE):
# remove the custom object indicator for display
if "__c" in key:
key = key[:-3]
if isinstance(value, (dict, list)):
filtered_value = _clean_salesforce_dict(value)
# Only add non-empty dictionaries or lists
if filtered_value:
filtered_dict[key] = filtered_value
elif value is not None:
filtered_dict[key] = value
return filtered_dict
elif isinstance(data, list):
filtered_list = []
for item in data:
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
# Only add non-empty dictionaries or lists
if filtered_item:
filtered_list.append(filtered_item)
elif item is not None:
filtered_list.append(filtered_item)
return filtered_list
else:
return data
def _json_to_natural_language(data: dict | list, indent: int = 0) -> str:
"""Convert a nested dictionary or list into a human-readable string format.
Recursively traverses the data structure and formats it with:
- Key-value pairs on separate lines
- Nested structures indented for readability
- Lists and dictionaries handled with appropriate formatting
Args:
data: The dictionary or list to convert
indent: Number of spaces to indent (default: 0)
Returns:
A formatted string representation of the data structure
"""
result = []
indent_str = " " * indent
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (dict, list)):
result.append(f"{indent_str}{key}:")
result.append(_json_to_natural_language(value, indent + 2))
else:
result.append(f"{indent_str}{key}: {value}")
elif isinstance(data, list):
for item in data:
result.append(_json_to_natural_language(item, indent + 2))
return "\n".join(result)
def _extract_dict_text(raw_dict: dict) -> str:
"""Extract text from a Salesforce API response dictionary by:
1. Cleaning the dictionary
2. Converting the cleaned dictionary to natural language
"""
processed_dict = _clean_salesforce_dict(raw_dict)
natural_language_for_dict = _json_to_natural_language(processed_dict)
return natural_language_for_dict
def _field_value_is_child_object(field_value: dict) -> bool:
"""
Checks if the field value is a child object.
"""
return (
isinstance(field_value, OrderedDict)
and "records" in field_value.keys()
and isinstance(field_value["records"], list)
and "Id" in field_value["records"][0].keys()
)
def extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
"""
This goes through the salesforce_object and extracts the top level fields as a Section.
It also goes through the child objects and extracts them as Sections.
"""
top_level_dict = {}
child_object_sections = []
for field_name, field_value in salesforce_object.items():
# If the field value is not a child object, add it to the top level dict
# to turn into text for the top level section
if not _field_value_is_child_object(field_value):
top_level_dict[field_name] = field_value
continue
# If the field value is a child object, extract the child objects and add them as sections
for record in field_value["records"]:
child_object_id = record["Id"]
child_object_sections.append(
Section(
text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}",
link=f"{base_url}/{child_object_id}",
)
)
top_level_id = salesforce_object["Id"]
top_level_section = Section(
text=_extract_dict_text(top_level_dict),
link=f"{base_url}/{top_level_id}",
)
return [top_level_section, *child_object_sections]

View File

@@ -1,66 +0,0 @@
import re
from typing import Union
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
if isinstance(data, dict):
if "records" in data.keys():
data = data["records"]
if isinstance(data, dict):
if "attributes" in data.keys():
if isinstance(data["attributes"], dict):
data.update(data.pop("attributes"))
if isinstance(data, dict):
filtered_dict = {}
for key, value in data.items():
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
if "__c" in key: # remove the custom object indicator for display
key = key[:-3]
if isinstance(value, (dict, list)):
filtered_value = _clean_salesforce_dict(value)
if filtered_value: # Only add non-empty dictionaries or lists
filtered_dict[key] = filtered_value
elif value is not None:
filtered_dict[key] = value
return filtered_dict
elif isinstance(data, list):
filtered_list = []
for item in data:
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
if filtered_item: # Only add non-empty dictionaries or lists
filtered_list.append(filtered_item)
elif item is not None:
filtered_list.append(filtered_item)
return filtered_list
else:
return data
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
result = []
indent_str = " " * indent
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (dict, list)):
result.append(f"{indent_str}{key}:")
result.append(_json_to_natural_language(value, indent + 2))
else:
result.append(f"{indent_str}{key}: {value}")
elif isinstance(data, list):
for item in data:
result.append(_json_to_natural_language(item, indent))
else:
result.append(f"{indent_str}{data}")
return "\n".join(result)
def extract_dict_text(raw_dict: dict) -> str:
processed_dict = _clean_salesforce_dict(raw_dict)
natural_language_dict = _json_to_natural_language(processed_dict)
return natural_language_dict

View File

@@ -0,0 +1,42 @@
import os
import pytest
from onyx.connectors.salesforce.connector import SalesforceConnector
@pytest.fixture
def salesforce_connector() -> SalesforceConnector:
connector = SalesforceConnector(
requested_objects=["Account", "Contact", "Opportunity"],
)
connector.load_credentials(
{
"sf_username": os.environ["SF_USERNAME"],
"sf_password": os.environ["SF_PASSWORD"],
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
}
)
return connector
# TODO: make the credentials not expire
@pytest.mark.xfail(
reason=(
"Credentials change over time, so this test will fail if run when "
"the credentials expire."
)
)
def test_salesforce_connector_slim(salesforce_connector: SalesforceConnector) -> None:
# Get all doc IDs from the full connector
all_full_doc_ids = set()
for doc_batch in salesforce_connector.load_from_state():
all_full_doc_ids.update([doc.id for doc in doc_batch])
# Get all doc IDs from the slim connector
all_slim_doc_ids = set()
for slim_doc_batch in salesforce_connector.retrieve_all_slim_documents():
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
# The set of full doc IDs should be always be a subset of the slim doc IDs
assert all_full_doc_ids.issubset(all_slim_doc_ids)