mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-25 15:30:59 +02:00
* use callback in slim doc functions * more callbacks --------- Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
226 lines
8.5 KiB
Python
226 lines
8.5 KiB
Python
import os
|
|
from typing import Any
|
|
|
|
from simple_salesforce import Salesforce
|
|
|
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
|
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
|
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
|
from onyx.connectors.interfaces import LoadConnector
|
|
from onyx.connectors.interfaces import PollConnector
|
|
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
|
from onyx.connectors.interfaces import SlimConnector
|
|
from onyx.connectors.models import ConnectorMissingCredentialError
|
|
from onyx.connectors.models import Document
|
|
from onyx.connectors.models import SlimDocument
|
|
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
|
|
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
|
|
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
|
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
|
|
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
|
|
from onyx.connectors.salesforce.sqlite_functions import get_record
|
|
from onyx.connectors.salesforce.sqlite_functions import init_db
|
|
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
|
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
|
|
|
|
|
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
|
def __init__(
|
|
self,
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
requested_objects: list[str] = [],
|
|
) -> None:
|
|
self.batch_size = batch_size
|
|
self._sf_client: Salesforce | None = None
|
|
self.parent_object_list = (
|
|
[obj.capitalize() for obj in requested_objects]
|
|
if requested_objects
|
|
else _DEFAULT_PARENT_OBJECT_TYPES
|
|
)
|
|
|
|
def load_credentials(
|
|
self,
|
|
credentials: dict[str, Any],
|
|
) -> dict[str, Any] | None:
|
|
self._sf_client = Salesforce(
|
|
username=credentials["sf_username"],
|
|
password=credentials["sf_password"],
|
|
security_token=credentials["sf_security_token"],
|
|
)
|
|
return None
|
|
|
|
@property
|
|
def sf_client(self) -> Salesforce:
|
|
if self._sf_client is None:
|
|
raise ConnectorMissingCredentialError("Salesforce")
|
|
return self._sf_client
|
|
|
|
def _fetch_from_salesforce(
|
|
self,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
) -> GenerateDocumentsOutput:
|
|
init_db()
|
|
all_object_types: set[str] = set(self.parent_object_list)
|
|
|
|
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
|
|
logger.debug(f"Parent object types: {self.parent_object_list}")
|
|
|
|
# This takes like 20 seconds
|
|
for parent_object_type in self.parent_object_list:
|
|
child_types = get_all_children_of_sf_type(
|
|
self.sf_client, parent_object_type
|
|
)
|
|
all_object_types.update(child_types)
|
|
logger.debug(
|
|
f"Found {len(child_types)} child types for {parent_object_type}"
|
|
)
|
|
|
|
# Always want to make sure user is grabbed for permissioning purposes
|
|
all_object_types.add("User")
|
|
|
|
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
|
|
logger.debug(f"All object types: {all_object_types}")
|
|
|
|
# checkpoint - we've found all object types, now time to fetch the data
|
|
logger.info("Starting to fetch CSVs for all object types")
|
|
# This takes like 30 minutes first time and <2 minutes for updates
|
|
object_type_to_csv_path = fetch_all_csvs_in_parallel(
|
|
sf_client=self.sf_client,
|
|
object_types=all_object_types,
|
|
start=start,
|
|
end=end,
|
|
)
|
|
|
|
updated_ids: set[str] = set()
|
|
# This takes like 10 seconds
|
|
# This is for testing the rest of the functionality if data has
|
|
# already been fetched and put in sqlite
|
|
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
|
|
# for object_type in self.parent_object_list:
|
|
# updated_ids.update(list(find_ids_by_type(object_type)))
|
|
|
|
# This takes 10-70 minutes first time (idk why the range is so big)
|
|
total_types = len(object_type_to_csv_path)
|
|
logger.info(f"Starting to process {total_types} object types")
|
|
|
|
for i, (object_type, csv_paths) in enumerate(
|
|
object_type_to_csv_path.items(), 1
|
|
):
|
|
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
|
|
# If path is None, it means it failed to fetch the csv
|
|
if csv_paths is None:
|
|
continue
|
|
# Go through each csv path and use it to update the db
|
|
for csv_path in csv_paths:
|
|
logger.debug(f"Updating {object_type} with {csv_path}")
|
|
new_ids = update_sf_db_with_csv(
|
|
object_type=object_type,
|
|
csv_download_path=csv_path,
|
|
)
|
|
updated_ids.update(new_ids)
|
|
logger.debug(
|
|
f"Added {len(new_ids)} new/updated records for {object_type}"
|
|
)
|
|
|
|
logger.info(f"Found {len(updated_ids)} total updated records")
|
|
logger.info(
|
|
f"Starting to process parent objects of types: {self.parent_object_list}"
|
|
)
|
|
|
|
docs_to_yield: list[Document] = []
|
|
docs_processed = 0
|
|
# Takes 15-20 seconds per batch
|
|
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
|
|
updated_ids=list(updated_ids),
|
|
parent_types=self.parent_object_list,
|
|
):
|
|
logger.info(
|
|
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
|
|
)
|
|
for parent_id in parent_id_batch:
|
|
if not (parent_object := get_record(parent_id, parent_type)):
|
|
logger.warning(
|
|
f"Failed to get parent object {parent_id} for {parent_type}"
|
|
)
|
|
continue
|
|
|
|
docs_to_yield.append(
|
|
convert_sf_object_to_doc(
|
|
sf_object=parent_object,
|
|
sf_instance=self.sf_client.sf_instance,
|
|
)
|
|
)
|
|
docs_processed += 1
|
|
|
|
if len(docs_to_yield) >= self.batch_size:
|
|
yield docs_to_yield
|
|
docs_to_yield = []
|
|
|
|
yield docs_to_yield
|
|
|
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
|
return self._fetch_from_salesforce()
|
|
|
|
def poll_source(
|
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
) -> GenerateDocumentsOutput:
|
|
return self._fetch_from_salesforce(start=start, end=end)
|
|
|
|
def retrieve_all_slim_documents(
|
|
self,
|
|
start: SecondsSinceUnixEpoch | None = None,
|
|
end: SecondsSinceUnixEpoch | None = None,
|
|
callback: IndexingHeartbeatInterface | None = None,
|
|
) -> GenerateSlimDocumentOutput:
|
|
doc_metadata_list: list[SlimDocument] = []
|
|
for parent_object_type in self.parent_object_list:
|
|
query = f"SELECT Id FROM {parent_object_type}"
|
|
query_result = self.sf_client.query_all(query)
|
|
doc_metadata_list.extend(
|
|
SlimDocument(
|
|
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
|
|
perm_sync_data={},
|
|
)
|
|
for instance_dict in query_result["records"]
|
|
)
|
|
|
|
yield doc_metadata_list
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import time
|
|
|
|
connector = SalesforceConnector(requested_objects=["Account"])
|
|
|
|
connector.load_credentials(
|
|
{
|
|
"sf_username": os.environ["SF_USERNAME"],
|
|
"sf_password": os.environ["SF_PASSWORD"],
|
|
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
|
|
}
|
|
)
|
|
start_time = time.time()
|
|
doc_count = 0
|
|
section_count = 0
|
|
text_count = 0
|
|
for doc_batch in connector.load_from_state():
|
|
doc_count += len(doc_batch)
|
|
print(f"doc_count: {doc_count}")
|
|
for doc in doc_batch:
|
|
section_count += len(doc.sections)
|
|
for section in doc.sections:
|
|
text_count += len(section.text)
|
|
end_time = time.time()
|
|
|
|
print(f"Doc count: {doc_count}")
|
|
print(f"Section count: {section_count}")
|
|
print(f"Text count: {text_count}")
|
|
print(f"Time taken: {end_time - start_time}")
|