mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-31 18:21:15 +02:00
211 lines
7.3 KiB
Python
211 lines
7.3 KiB
Python
import os
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from pytz import UTC
|
|
from simple_salesforce import Salesforce
|
|
from simple_salesforce import SFType
|
|
from simple_salesforce.bulk2 import SFBulk2Handler
|
|
from simple_salesforce.bulk2 import SFBulk2Type
|
|
|
|
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
|
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
|
|
from onyx.connectors.salesforce.utils import get_object_type_path
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def _build_time_filter_for_salesforce(
|
|
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
|
) -> str:
|
|
if start is None or end is None:
|
|
return ""
|
|
start_datetime = datetime.fromtimestamp(start, UTC)
|
|
end_datetime = datetime.fromtimestamp(end, UTC)
|
|
return (
|
|
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
|
|
f"AND LastModifiedDate < {end_datetime.isoformat()}"
|
|
)
|
|
|
|
|
|
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
|
|
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
|
|
return sf_object.describe()
|
|
|
|
|
|
def _is_valid_child_object(
|
|
sf_client: Salesforce, child_relationship: dict[str, Any]
|
|
) -> bool:
|
|
if not child_relationship["childSObject"]:
|
|
return False
|
|
if not child_relationship["relationshipName"]:
|
|
return False
|
|
|
|
sf_type = child_relationship["childSObject"]
|
|
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
|
if not object_description["queryable"]:
|
|
return False
|
|
|
|
if child_relationship["field"]:
|
|
if child_relationship["field"] == "RelatedToId":
|
|
return False
|
|
else:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]:
|
|
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
|
|
|
child_object_types = set()
|
|
for child_relationship in object_description["childRelationships"]:
|
|
if _is_valid_child_object(sf_client, child_relationship):
|
|
logger.debug(
|
|
f"Found valid child object {child_relationship['childSObject']}"
|
|
)
|
|
child_object_types.add(child_relationship["childSObject"])
|
|
return child_object_types
|
|
|
|
|
|
def _get_all_queryable_fields_of_sf_type(
|
|
sf_client: Salesforce,
|
|
sf_type: str,
|
|
) -> list[str]:
|
|
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
|
fields: list[dict[str, Any]] = object_description["fields"]
|
|
valid_fields: set[str] = set()
|
|
compound_field_names: set[str] = set()
|
|
for field in fields:
|
|
if compound_field_name := field.get("compoundFieldName"):
|
|
compound_field_names.add(compound_field_name)
|
|
if field.get("type", "base64") == "base64":
|
|
continue
|
|
if field_name := field.get("name"):
|
|
valid_fields.add(field_name)
|
|
|
|
return list(valid_fields - compound_field_names)
|
|
|
|
|
|
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
|
|
"""
|
|
Send a small query to check if the object type is empty so we don't
|
|
perform extra bulk queries
|
|
"""
|
|
try:
|
|
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
|
result = sf_client.query(query)
|
|
if result["totalSize"] == 0:
|
|
return False
|
|
except Exception as e:
|
|
if "OPERATION_TOO_LARGE" not in str(e):
|
|
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
|
return False
|
|
return True
|
|
|
|
|
|
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
|
|
# Check if the csv already exists
|
|
if os.path.exists(get_object_type_path(sf_type)):
|
|
existing_csvs = [
|
|
os.path.join(get_object_type_path(sf_type), f)
|
|
for f in os.listdir(get_object_type_path(sf_type))
|
|
if f.endswith(".csv")
|
|
]
|
|
# If the csv already exists, return the path
|
|
# This is likely due to a previous run that failed
|
|
# after downloading the csv but before the data was
|
|
# written to the db
|
|
if existing_csvs:
|
|
return existing_csvs
|
|
return None
|
|
|
|
|
|
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
|
|
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
|
|
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
|
|
return query
|
|
|
|
|
|
def _bulk_retrieve_from_salesforce(
|
|
sf_client: Salesforce,
|
|
sf_type: str,
|
|
time_filter: str,
|
|
) -> tuple[str, list[str] | None]:
|
|
if not _check_if_object_type_is_empty(sf_client, sf_type):
|
|
return sf_type, None
|
|
|
|
if existing_csvs := _check_for_existing_csvs(sf_type):
|
|
return sf_type, existing_csvs
|
|
|
|
query = _build_bulk_query(sf_client, sf_type, time_filter)
|
|
|
|
bulk_2_handler = SFBulk2Handler(
|
|
session_id=sf_client.session_id,
|
|
bulk2_url=sf_client.bulk2_url,
|
|
proxies=sf_client.proxies,
|
|
session=sf_client.session,
|
|
)
|
|
bulk_2_type = SFBulk2Type(
|
|
object_name=sf_type,
|
|
bulk2_url=bulk_2_handler.bulk2_url,
|
|
headers=bulk_2_handler.headers,
|
|
session=bulk_2_handler.session,
|
|
)
|
|
|
|
logger.info(f"Downloading {sf_type}")
|
|
logger.info(f"Query: {query}")
|
|
|
|
try:
|
|
# This downloads the file to a file in the target path with a random name
|
|
results = bulk_2_type.download(
|
|
query=query,
|
|
path=get_object_type_path(sf_type),
|
|
max_records=1000000,
|
|
)
|
|
all_download_paths = [result["file"] for result in results]
|
|
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
|
|
return sf_type, all_download_paths
|
|
except Exception as e:
|
|
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
|
|
return sf_type, None
|
|
|
|
|
|
def fetch_all_csvs_in_parallel(
|
|
sf_client: Salesforce,
|
|
object_types: set[str],
|
|
start: SecondsSinceUnixEpoch | None,
|
|
end: SecondsSinceUnixEpoch | None,
|
|
) -> dict[str, list[str] | None]:
|
|
"""
|
|
Fetches all the csvs in parallel for the given object types
|
|
Returns a dict of (sf_type, full_download_path)
|
|
"""
|
|
time_filter = _build_time_filter_for_salesforce(start, end)
|
|
time_filter_for_each_object_type = {}
|
|
# We do this outside of the thread pool executor because this requires
|
|
# a database connection and we don't want to block the thread pool
|
|
# executor from running
|
|
for sf_type in object_types:
|
|
"""Only add time filter if there is at least one object of the type
|
|
in the database. We aren't worried about partially completed object update runs
|
|
because this occurs after we check for existing csvs which covers this case"""
|
|
if has_at_least_one_object_of_type(sf_type):
|
|
time_filter_for_each_object_type[sf_type] = time_filter
|
|
else:
|
|
time_filter_for_each_object_type[sf_type] = ""
|
|
|
|
# Run the bulk retrieve in parallel
|
|
with ThreadPoolExecutor() as executor:
|
|
results = executor.map(
|
|
lambda object_type: _bulk_retrieve_from_salesforce(
|
|
sf_client=sf_client,
|
|
sf_type=object_type,
|
|
time_filter=time_filter_for_each_object_type[object_type],
|
|
),
|
|
object_types,
|
|
)
|
|
return dict(results)
|