Added Permission Syncing for Salesforce (#3551)

* Added Permission Syncing for Salesforce

* cleanup

* updated connector doc conversion

* finished salesforce permission syncing

* fixed connector to batch Salesforce queries

* tests!

* k

* Added error handling and check for ee and sync type for postprocessing

* comments

* minor touchups

* tested to work!

* done

* my pie

* lil cleanup

* minor comment
This commit is contained in:
hagen-danswer
2025-01-06 16:37:03 -08:00
committed by GitHub
parent 71c2559ea9
commit e329b63b89
17 changed files with 1012 additions and 132 deletions

View File

@@ -4,34 +4,29 @@ from typing import Any
from simple_salesforce import Salesforce
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.salesforce.doc_conversion import extract_section
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_ID_PREFIX = "SALESFORCE_"
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
@@ -65,46 +60,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
raise ConnectorMissingCredentialError("Salesforce")
return self._sf_client
def _extract_primary_owners(
self, sf_object: SalesforceObject
) -> list[BasicExpertInfo] | None:
object_dict = sf_object.data
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
return None
if not (last_modified_by := get_record(last_modified_by_id)):
return None
if not (last_modified_by_name := last_modified_by.data.get("Name")):
return None
primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)]
return primary_owners
def _convert_object_instance_to_document(
self, sf_object: SalesforceObject
) -> Document:
object_dict = sf_object.data
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
base_url = f"https://{self.sf_client.sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
sections = [extract_section(sf_object, base_url)]
for id in get_child_ids(sf_object.id):
if not (child_object := get_record(id)):
continue
sections.append(extract_section(child_object, base_url))
doc = Document(
id=onyx_salesforce_id,
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=self._extract_primary_owners(sf_object),
metadata={},
)
return doc
def _fetch_from_salesforce(
self,
start: SecondsSinceUnixEpoch | None = None,
@@ -126,6 +81,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
f"Found {len(child_types)} child types for {parent_object_type}"
)
# Always want to make sure user is grabbed for permissioning purposes
all_object_types.add("User")
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
logger.debug(f"All object types: {all_object_types}")
@@ -169,9 +127,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
logger.debug(
f"Added {len(new_ids)} new/updated records for {object_type}"
)
# Remove the csv file after it has been used
# to successfully update the db
os.remove(csv_path)
logger.info(f"Found {len(updated_ids)} total updated records")
logger.info(
@@ -196,7 +151,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
continue
docs_to_yield.append(
self._convert_object_instance_to_document(parent_object)
convert_sf_object_to_doc(
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
)
docs_processed += 1
@@ -225,7 +183,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
query_result = self.sf_client.query_all(query)
doc_metadata_list.extend(
SlimDocument(
id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}",
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
perm_sync_data={},
)
for instance_dict in query_result["records"]

View File

@@ -1,8 +1,18 @@
import re
from collections import OrderedDict
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
logger = setup_logger()
ID_PREFIX = "SALESFORCE_"
# All of these types of keys are handled by specific fields in the doc
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
@@ -103,54 +113,72 @@ def _extract_dict_text(raw_dict: dict) -> str:
return natural_language_for_dict
def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
return Section(
text=_extract_dict_text(salesforce_object.data),
link=f"{base_url}/{salesforce_object.id}",
)
def _field_value_is_child_object(field_value: dict) -> bool:
"""
Checks if the field value is a child object.
"""
return (
isinstance(field_value, OrderedDict)
and "records" in field_value.keys()
and isinstance(field_value["records"], list)
and len(field_value["records"]) > 0
and "Id" in field_value["records"][0].keys()
def _extract_primary_owners(
sf_object: SalesforceObject,
) -> list[BasicExpertInfo] | None:
object_dict = sf_object.data
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
logger.warning(f"No LastModifiedById found for {sf_object.id}")
return None
if not (last_modified_by := get_record(last_modified_by_id)):
logger.warning(f"No LastModifiedBy found for {last_modified_by_id}")
return None
user_data = last_modified_by.data
expert_info = BasicExpertInfo(
first_name=user_data.get("FirstName"),
last_name=user_data.get("LastName"),
email=user_data.get("Email"),
display_name=user_data.get("Name"),
)
# Check if all fields are None
if all(
value is None
for value in [
expert_info.first_name,
expert_info.last_name,
expert_info.email,
expert_info.display_name,
]
):
logger.warning(f"No identifying information found for user {user_data}")
return None
def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
"""
This goes through the salesforce_object and extracts the top level fields as a Section.
It also goes through the child objects and extracts them as Sections.
"""
top_level_dict = {}
return [expert_info]
child_object_sections = []
for field_name, field_value in salesforce_object.items():
# If the field value is not a child object, add it to the top level dict
# to turn into text for the top level section
if not _field_value_is_child_object(field_value):
top_level_dict[field_name] = field_value
def convert_sf_object_to_doc(
sf_object: SalesforceObject,
sf_instance: str,
) -> Document:
object_dict = sf_object.data
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
base_url = f"https://{sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
sections = [_extract_section(sf_object, base_url)]
for id in get_child_ids(sf_object.id):
if not (child_object := get_record(id)):
continue
sections.append(_extract_section(child_object, base_url))
# If the field value is a child object, extract the child objects and add them as sections
for record in field_value["records"]:
child_object_id = record["Id"]
child_object_sections.append(
Section(
text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}",
link=f"{base_url}/{child_object_id}",
)
)
top_level_id = salesforce_object["Id"]
top_level_section = Section(
text=_extract_dict_text(top_level_dict),
link=f"{base_url}/{top_level_id}",
doc = Document(
id=onyx_salesforce_id,
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=_extract_primary_owners(sf_object),
metadata={},
)
return [top_level_section, *child_object_sections]
return doc

View File

@@ -77,25 +77,28 @@ def _get_all_queryable_fields_of_sf_type(
object_description = _get_sf_type_object_json(sf_client, sf_type)
fields: list[dict[str, Any]] = object_description["fields"]
valid_fields: set[str] = set()
compound_field_names: set[str] = set()
field_names_to_remove: set[str] = set()
for field in fields:
if compound_field_name := field.get("compoundFieldName"):
compound_field_names.add(compound_field_name)
# We do want to get name fields even if they are compound
if not field.get("nameField"):
field_names_to_remove.add(compound_field_name)
if field.get("type", "base64") == "base64":
continue
if field_name := field.get("name"):
valid_fields.add(field_name)
return list(valid_fields - compound_field_names)
return list(valid_fields - field_names_to_remove)
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
def _check_if_object_type_is_empty(
sf_client: Salesforce, sf_type: str, time_filter: str
) -> bool:
"""
Send a small query to check if the object type is empty so we don't
perform extra bulk queries
Use the rest api to check to make sure the query will result in a non-empty response
"""
try:
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
query = f"SELECT Count() FROM {sf_type}{time_filter} LIMIT 1"
result = sf_client.query(query)
if result["totalSize"] == 0:
return False
@@ -134,7 +137,7 @@ def _bulk_retrieve_from_salesforce(
sf_type: str,
time_filter: str,
) -> tuple[str, list[str] | None]:
if not _check_if_object_type_is_empty(sf_client, sf_type):
if not _check_if_object_type_is_empty(sf_client, sf_type, time_filter):
return sf_type, None
if existing_csvs := _check_for_existing_csvs(sf_type):

View File

@@ -40,20 +40,20 @@ def get_db_connection(
def init_db() -> None:
"""Initialize the SQLite database with required tables if they don't exist."""
if os.path.exists(get_sqlite_db_path()):
return
# Create database directory if it doesn't exist
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
with get_db_connection("EXCLUSIVE") as conn:
cursor = conn.cursor()
# Enable WAL mode for better concurrent access and write performance
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA temp_store=MEMORY")
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
db_exists = os.path.exists(get_sqlite_db_path())
if not db_exists:
# Enable WAL mode for better concurrent access and write performance
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA temp_store=MEMORY")
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
# Main table for storing Salesforce objects
cursor.execute(
@@ -90,49 +90,69 @@ def init_db() -> None:
"""
)
# Always recreate indexes to ensure they exist
cursor.execute("DROP INDEX IF EXISTS idx_object_type")
cursor.execute("DROP INDEX IF EXISTS idx_parent_id")
cursor.execute("DROP INDEX IF EXISTS idx_child_parent")
cursor.execute("DROP INDEX IF EXISTS idx_object_type_id")
cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup")
# Create covering indexes for common queries
# Create a table for User email to ID mapping if it doesn't exist
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_email_map (
email TEXT PRIMARY KEY,
user_id TEXT, -- Nullable to allow for users without IDs
FOREIGN KEY (user_id) REFERENCES salesforce_objects(id)
) WITHOUT ROWID
"""
)
# Create indexes if they don't exist (SQLite ignores IF NOT EXISTS for indexes)
def create_index_if_not_exists(index_name: str, create_statement: str) -> None:
cursor.execute(
f"SELECT name FROM sqlite_master WHERE type='index' AND name='{index_name}'"
)
if not cursor.fetchone():
cursor.execute(create_statement)
create_index_if_not_exists(
"idx_object_type",
"""
CREATE INDEX idx_object_type
ON salesforce_objects(object_type, id)
WHERE object_type IS NOT NULL
"""
""",
)
cursor.execute(
create_index_if_not_exists(
"idx_parent_id",
"""
CREATE INDEX idx_parent_id
ON relationships(parent_id, child_id)
"""
""",
)
cursor.execute(
create_index_if_not_exists(
"idx_child_parent",
"""
CREATE INDEX idx_child_parent
ON relationships(child_id)
WHERE child_id IS NOT NULL
"""
""",
)
# New composite index for fast parent type lookups
cursor.execute(
create_index_if_not_exists(
"idx_relationship_types_lookup",
"""
CREATE INDEX idx_relationship_types_lookup
ON relationship_types(parent_type, child_id, parent_id)
"""
""",
)
# Analyze tables to help query planner
cursor.execute("ANALYZE relationships")
cursor.execute("ANALYZE salesforce_objects")
cursor.execute("ANALYZE relationship_types")
cursor.execute("ANALYZE user_email_map")
# If database already existed but user_email_map needs to be populated
cursor.execute("SELECT COUNT(*) FROM user_email_map")
if cursor.fetchone()[0] == 0:
_update_user_email_map(conn)
conn.commit()
@@ -203,7 +223,27 @@ def _update_relationship_tables(
raise
def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]:
def _update_user_email_map(conn: sqlite3.Connection) -> None:
"""Update the user_email_map table with current User objects.
Called internally by update_sf_db_with_csv when User objects are updated.
"""
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO user_email_map (email, user_id)
SELECT json_extract(data, '$.Email'), id
FROM salesforce_objects
WHERE object_type = 'User'
AND json_extract(data, '$.Email') IS NOT NULL
"""
)
def update_sf_db_with_csv(
object_type: str,
csv_download_path: str,
delete_csv_after_use: bool = True,
) -> list[str]:
"""Update the SF DB with a CSV file using SQLite storage."""
updated_ids = []
@@ -249,8 +289,17 @@ def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]
_update_relationship_tables(conn, id, parent_ids)
updated_ids.append(id)
# If we're updating User objects, update the email map
if object_type == "User":
_update_user_email_map(conn)
conn.commit()
if delete_csv_after_use:
# Remove the csv file after it has been used
# to successfully update the db
os.remove(csv_download_path)
return updated_ids
@@ -329,6 +378,9 @@ def get_affected_parent_ids_by_type(
cursor = conn.cursor()
for batch_ids in updated_ids_batches:
batch_ids = list(set(batch_ids) - updated_parent_ids)
if not batch_ids:
continue
id_placeholders = ",".join(["?" for _ in batch_ids])
for parent_type in parent_types:
@@ -384,3 +436,40 @@ def has_at_least_one_object_of_type(object_type: str) -> bool:
)
count = cursor.fetchone()[0]
return count > 0
# NULL_ID_STRING is used to indicate that the user ID was queried but not found
# As opposed to None because it has yet to be queried at all
NULL_ID_STRING = "N/A"
def get_user_id_by_email(email: str) -> str | None:
"""Get the Salesforce User ID for a given email address.
Args:
email: The email address to look up
Returns:
A tuple of (was_found, user_id):
- was_found: True if the email exists in the table, False if not found
- user_id: The Salesforce User ID if exists, None otherwise
"""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT user_id FROM user_email_map WHERE email = ?", (email,))
result = cursor.fetchone()
if result is None:
return None
return result[0]
def update_email_to_id_table(email: str, id: str | None) -> None:
"""Update the email to ID map table with a new email and ID."""
id_to_use = id or NULL_ID_STRING
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO user_email_map (email, user_id) VALUES (?, ?)",
(email, id_to_use),
)
conn.commit()