Added Permission Syncing for Salesforce (#3551)

* Added Permission Syncing for Salesforce

* cleanup

* updated connector doc conversion

* finished salesforce permission syncing

* fixed connector to batch Salesforce queries

* tests!

* k

* Added error handling and check for ee and sync type for postprocessing

* comments

* minor touchups

* tested to work!

* done

* my pie

* lil cleanup

* minor comment
This commit is contained in:
hagen-danswer
2025-01-06 16:37:03 -08:00
committed by GitHub
parent 71c2559ea9
commit e329b63b89
17 changed files with 1012 additions and 132 deletions

View File

@@ -3,6 +3,10 @@ from sqlalchemy.orm import Session
from ee.onyx.db.external_perm import fetch_external_groups_for_user
from ee.onyx.db.user_group import fetch_user_groups_for_documents
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from onyx.access.access import (
_get_access_for_documents as get_access_for_documents_without_groups,
)
@@ -10,6 +14,7 @@ from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_gro
from onyx.access.models import DocumentAccess
from onyx.access.utils import prefix_external_group
from onyx.access.utils import prefix_user_group
from onyx.db.document import get_document_sources
from onyx.db.document import get_documents_by_ids
from onyx.db.models import User
@@ -52,9 +57,20 @@ def _get_access_for_documents(
)
doc_id_map = {doc.id: doc for doc in documents}
# Get all sources in one batch
doc_id_to_source_map = get_document_sources(
db_session=db_session,
document_ids=document_ids,
)
access_map = {}
for document_id, non_ee_access in non_ee_access_dict.items():
document = doc_id_map[document_id]
source = doc_id_to_source_map.get(document_id)
is_only_censored = (
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
and source not in DOC_PERMISSIONS_FUNC_MAP
)
ext_u_emails = (
set(document.external_user_emails)
@@ -70,7 +86,11 @@ def _get_access_for_documents(
# If the document is determined to be "public" externally (through a SYNC connector)
# then it's given the same access level as if it were marked public within Onyx
is_public_anywhere = document.is_public or non_ee_access.is_public
# If its censored, then it's public anywhere during the search and then permissions are
# applied after the search
is_public_anywhere = (
document.is_public or non_ee_access.is_public or is_only_censored
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess(

View File

@@ -10,6 +10,7 @@ from onyx.access.utils import prefix_group_w_source
from onyx.configs.constants import DocumentSource
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -106,3 +107,21 @@ def fetch_external_groups_for_user(
User__ExternalUserGroupId.user_id == user_id
)
).all()
def fetch_external_groups_for_user_email_and_group_ids(
db_session: Session,
user_email: str,
group_ids: list[str],
) -> list[User__ExternalUserGroupId]:
user = get_user_by_email(db_session=db_session, email=user_email)
if user is None:
return []
user_id = user.id
user_ext_groups = db_session.scalars(
select(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.user_id == user_id,
User__ExternalUserGroupId.external_user_group_id.in_(group_ids),
)
).all()
return list(user_ext_groups)

View File

@@ -0,0 +1,84 @@
from collections.abc import Callable
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from onyx.configs.constants import DocumentSource
from onyx.context.search.pipeline import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
DocumentSource,
# list of chunks to be censored and the user email. returns censored chunks
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
] = {
DocumentSource.SALESFORCE: censor_salesforce_chunks,
}
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
"""
Returns the set of sources that have censoring enabled.
This is based on if the access_type is set to sync and the connector
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
NOTE: This means if there is a source has a single cc_pair that is sync,
all chunks for that source will be censored, even if the connector that
indexed that chunk is not sync. This was done to avoid getting the cc_pair
for every single chunk.
"""
with get_session_context_manager() as db_session:
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
return {
cc_pair.connector.source
for cc_pair in enabled_sync_connectors
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
}
# NOTE: This is only called if ee is enabled.
def _post_query_chunk_censoring(
chunks: list[InferenceChunk],
user: User | None,
) -> list[InferenceChunk]:
"""
This function checks all chunks to see if they need to be sent to a censoring
function. If they do, it sends them to the censoring function and returns the
censored chunks. If they don't, it returns the original chunks.
"""
if user is None:
# if user is None, permissions are not enforced
return chunks
chunks_to_keep = []
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
for chunk in chunks:
# Separate out chunks that require permission post-processing by source
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
chunks_to_keep.append(chunk)
# For each source, filter out the chunks using the permission
# check function for that source
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
for source, chunks_for_source in chunks_to_process.items():
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
try:
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
except Exception as e:
logger.exception(
f"Failed to censor chunks for source {source} so throwing out all"
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
return chunks_to_keep

View File

@@ -0,0 +1,226 @@
import time
from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids
from ee.onyx.external_permissions.salesforce.utils import (
get_any_salesforce_client_for_doc_id,
)
from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id
from ee.onyx.external_permissions.salesforce.utils import (
get_salesforce_user_id_from_email,
)
from onyx.configs.app_configs import BLURB_SIZE
from onyx.context.search.models import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Types
ChunkKey = tuple[str, int] # (doc_id, chunk_id)
ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end
# NOTE: Used for testing timing
def _get_dummy_object_access_map(
object_ids: set[str], user_email: str, chunks: list[InferenceChunk]
) -> dict[str, bool]:
time.sleep(0.15)
# return {object_id: True for object_id in object_ids}
import random
return {object_id: random.choice([True, False]) for object_id in object_ids}
def _get_objects_access_for_user_email_from_salesforce(
object_ids: set[str],
user_email: str,
chunks: list[InferenceChunk],
) -> dict[str, bool] | None:
"""
This function wraps the salesforce call as we may want to change how this
is done in the future. (E.g. replace it with the above function)
"""
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries for this source are essentially instant
first_doc_id = chunks[0].document_id
with get_session_context_manager() as db_session:
salesforce_client = get_any_salesforce_client_for_doc_id(
db_session, first_doc_id
)
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries by the same user are essentially instant
start_time = time.time()
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
end_time = time.time()
logger.info(
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
)
if user_id is None:
return None
# This is the only query that is not cached in the function
# so it takes 0.1-0.2 seconds total
object_id_to_access = get_objects_access_for_user_id(
salesforce_client, user_id, list(object_ids)
)
return object_id_to_access
def _extract_salesforce_object_id_from_url(url: str) -> str:
return url.split("/")[-1]
def _get_object_ranges_for_chunk(
chunk: InferenceChunk,
) -> dict[str, list[ContentRange]]:
"""
Given a chunk, return a dictionary of salesforce object ids and the content ranges
for that object id in the current chunk
"""
if chunk.source_links is None:
return {}
object_ranges: dict[str, list[ContentRange]] = {}
end_index = None
descending_source_links = sorted(
chunk.source_links.items(), key=lambda x: x[0], reverse=True
)
for start_index, url in descending_source_links:
object_id = _extract_salesforce_object_id_from_url(url)
if object_id not in object_ranges:
object_ranges[object_id] = []
object_ranges[object_id].append((start_index, end_index))
end_index = start_index
return object_ranges
def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk:
"""
Create a copy of the unfiltered chunk where potentially sensitive content is removed
to be added later if the user has access to each of the sub-objects
"""
empty_censored_chunk = InferenceChunk(
**uncensored_chunk.model_dump(),
)
empty_censored_chunk.content = ""
empty_censored_chunk.blurb = ""
empty_censored_chunk.source_links = {}
return empty_censored_chunk
def _update_censored_chunk(
censored_chunk: InferenceChunk,
uncensored_chunk: InferenceChunk,
content_range: ContentRange,
) -> InferenceChunk:
"""
Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges
"""
start_index, end_index = content_range
# Update the content of the filtered chunk
permitted_content = uncensored_chunk.content[start_index:end_index]
permitted_section_start_index = len(censored_chunk.content)
censored_chunk.content = permitted_content + censored_chunk.content
# Update the source links of the filtered chunk
if uncensored_chunk.source_links is not None:
if censored_chunk.source_links is None:
censored_chunk.source_links = {}
link_content = uncensored_chunk.source_links[start_index]
censored_chunk.source_links[permitted_section_start_index] = link_content
# Update the blurb of the filtered chunk
censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE]
return censored_chunk
# TODO: Generalize this to other sources
def censor_salesforce_chunks(
chunks: list[InferenceChunk],
user_email: str,
# This is so we can provide a mock access map for testing
access_map: dict[str, bool] | None = None,
) -> list[InferenceChunk]:
# object_id -> list[((doc_id, chunk_id), (start_index, end_index))]
object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {}
# (doc_id, chunk_id) -> chunk
uncensored_chunks: dict[ChunkKey, InferenceChunk] = {}
# keep track of all object ids that we have seen to make it easier to get
# the access for these object ids
object_ids: set[str] = set()
for chunk in chunks:
chunk_key = (chunk.document_id, chunk.chunk_id)
# create a dictionary to quickly look up the unfiltered chunk
uncensored_chunks[chunk_key] = chunk
# for each chunk, get a dictionary of object ids and the content ranges
# for that object id in the current chunk
object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk)
for object_id, ranges in object_ranges_for_chunk.items():
object_ids.add(object_id)
for start_index, end_index in ranges:
object_to_content_map.setdefault(object_id, []).append(
(chunk_key, (start_index, end_index))
)
# This is so we can provide a mock access map for testing
if access_map is None:
access_map = _get_objects_access_for_user_email_from_salesforce(
object_ids=object_ids,
user_email=user_email,
chunks=chunks,
)
if access_map is None:
# If the user is not found in Salesforce, access_map will be None
# so we should just return an empty list because no chunks will be
# censored
return []
censored_chunks: dict[ChunkKey, InferenceChunk] = {}
for object_id, content_list in object_to_content_map.items():
# if the user does not have access to the object, or the object is not in the
# access_map, do not include its content in the filtered chunks
if not access_map.get(object_id, False):
continue
# if we got this far, the user has access to the object so we can create or update
# the filtered chunk(s) for this object
# NOTE: we only create a censored chunk if the user has access to some
# part of the chunk
for chunk_key, content_range in content_list:
if chunk_key not in censored_chunks:
censored_chunks[chunk_key] = _create_empty_censored_chunk(
uncensored_chunks[chunk_key]
)
uncensored_chunk = uncensored_chunks[chunk_key]
censored_chunk = _update_censored_chunk(
censored_chunk=censored_chunks[chunk_key],
uncensored_chunk=uncensored_chunk,
content_range=content_range,
)
censored_chunks[chunk_key] = censored_chunk
return list(censored_chunks.values())
# NOTE: This is not used anywhere.
def _get_objects_access_for_user_email(
object_ids: set[str], user_email: str
) -> dict[str, bool]:
with get_session_context_manager() as db_session:
external_groups = fetch_external_groups_for_user_email_and_group_ids(
db_session=db_session,
user_email=user_email,
# Maybe make a function that adds a salesforce prefix to the group ids
group_ids=list(object_ids),
)
external_group_ids = {group.external_user_group_id for group in external_groups}
return {group_id: group_id in external_group_ids for group_id in object_ids}

View File

@@ -0,0 +1,174 @@
from simple_salesforce import Salesforce
from sqlalchemy.orm import Session
from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING
from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_cc_pairs_for_document
from onyx.utils.logger import setup_logger
logger = setup_logger()
_ANY_SALESFORCE_CLIENT: Salesforce | None = None
def get_any_salesforce_client_for_doc_id(
db_session: Session, doc_id: str
) -> Salesforce:
"""
We create a salesforce client for the first cc_pair for the first doc_id where
salesforce censoring is enabled. After that we just cache and reuse the same
client for all queries.
We do this to reduce the number of postgres queries we make at query time.
This may be problematic if they are using multiple cc_pairs for salesforce.
E.g. there are 2 different credential sets for 2 different salesforce cc_pairs
but only one has the permissions to access the permissions needed for the query.
"""
global _ANY_SALESFORCE_CLIENT
if _ANY_SALESFORCE_CLIENT is None:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
credential_json = first_cc_pair.credential.credential_json
_ANY_SALESFORCE_CLIENT = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
security_token=credential_json["sf_security_token"],
)
return _ANY_SALESFORCE_CLIENT
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
result = sf_client.query(query)
if len(result["records"]) == 0:
return None
return result["records"][0]["Id"]
# This contains only the user_ids that we have found in Salesforce.
# If we don't know their user_id, we don't store anything in the cache.
_CACHED_SF_EMAIL_TO_ID_MAP: dict[str, str] = {}
def get_salesforce_user_id_from_email(
sf_client: Salesforce,
user_email: str,
) -> str | None:
"""
We cache this so we don't have to query Salesforce for every query and salesforce
user IDs never change.
Memory usage is fine because we just store 2 small strings per user.
If the email is not in the cache, we check the local salesforce database for the info.
If the user is not found in the local salesforce database, we query Salesforce.
Whatever we get back from Salesforce is added to the database.
If no user_id is found, we add a NULL_ID_STRING to the database for that email so
we don't query Salesforce again (which is slow) but we still check the local salesforce
database every query until a user id is found. This is acceptable because the query time
is quite fast.
If a user_id is created in Salesforce, it will be added to the local salesforce database
next time the connector is run. Then that value will be found in this function and cached.
NOTE: First time this runs, it may be slow if it hasn't already been updated in the local
salesforce database. (Around 0.1-0.3 seconds)
If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds).
"""
global _CACHED_SF_EMAIL_TO_ID_MAP
if user_email in _CACHED_SF_EMAIL_TO_ID_MAP:
if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None:
return _CACHED_SF_EMAIL_TO_ID_MAP[user_email]
db_exists = True
try:
# Check if the user is already in the database
user_id = get_user_id_by_email(user_email)
except Exception:
init_db()
try:
user_id = get_user_id_by_email(user_email)
except Exception as e:
logger.error(f"Error checking if user is in database: {e}")
user_id = None
db_exists = False
# If no entry is found in the database (indicated by user_id being None)...
if user_id is None:
# ...query Salesforce and store the result in the database
user_id = _query_salesforce_user_id(sf_client, user_email)
if db_exists:
update_email_to_id_table(user_email, user_id)
return user_id
elif user_id is None:
return None
elif user_id == NULL_ID_STRING:
return None
# If the found user_id is real, cache it
_CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id
return user_id
_MAX_RECORD_IDS_PER_QUERY = 200
def get_objects_access_for_user_id(
salesforce_client: Salesforce,
user_id: str,
record_ids: list[str],
) -> dict[str, bool]:
"""
Salesforce has a limit of 200 record ids per query. So we just truncate
the list of record ids to 200. We only ever retrieve 50 chunks at a time
so this should be fine (unlikely that we retrieve all 50 chunks contain
4 unique objects).
If we decide this isn't acceptable we can use multiple queries but they
should be in parallel so query time doesn't get too long.
"""
truncated_record_ids = record_ids[:_MAX_RECORD_IDS_PER_QUERY]
record_ids_str = "'" + "','".join(truncated_record_ids) + "'"
access_query = f"""
SELECT RecordId, HasReadAccess
FROM UserRecordAccess
WHERE RecordId IN ({record_ids_str})
AND UserId = '{user_id}'
"""
result = salesforce_client.query_all(access_query)
return {record["RecordId"]: record["HasReadAccess"] for record in result["records"]}
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP: dict[int, Salesforce] = {}
_DOC_ID_TO_CC_PAIR_ID_MAP: dict[str, int] = {}
# NOTE: This is not used anywhere.
def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Salesforce:
"""
Uses a document id to get the cc_pair that indexed that document and uses the credentials
for that cc_pair to create a Salesforce client.
Problems:
- There may be multiple cc_pairs for a document, and we don't know which one to use.
- right now we just use the first one
- Building a new Salesforce client for each document is slow.
- Memory usage could be an issue as we build these dictionaries.
"""
if doc_id not in _DOC_ID_TO_CC_PAIR_ID_MAP:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
_DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] = first_cc_pair.id
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = cc_pair.credential.credential_json
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
security_token=credential_json["sf_security_token"],
)
return _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id]

View File

@@ -8,6 +8,9 @@ from ee.onyx.external_permissions.confluence.group_sync import confluence_group_
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
@@ -71,4 +74,7 @@ EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
return source_type in DOC_PERMISSIONS_FUNC_MAP
return (
source_type in DOC_PERMISSIONS_FUNC_MAP
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
)

View File

@@ -16,6 +16,9 @@ from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from ee.onyx.external_permissions.sync_params import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
@@ -286,6 +289,8 @@ def connector_permission_sync_generator_task(
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION:
return None
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)

View File

@@ -4,34 +4,29 @@ from typing import Any
from simple_salesforce import Salesforce
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.salesforce.doc_conversion import extract_section
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_ID_PREFIX = "SALESFORCE_"
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
@@ -65,46 +60,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
raise ConnectorMissingCredentialError("Salesforce")
return self._sf_client
def _extract_primary_owners(
self, sf_object: SalesforceObject
) -> list[BasicExpertInfo] | None:
object_dict = sf_object.data
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
return None
if not (last_modified_by := get_record(last_modified_by_id)):
return None
if not (last_modified_by_name := last_modified_by.data.get("Name")):
return None
primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)]
return primary_owners
def _convert_object_instance_to_document(
self, sf_object: SalesforceObject
) -> Document:
object_dict = sf_object.data
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
base_url = f"https://{self.sf_client.sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
sections = [extract_section(sf_object, base_url)]
for id in get_child_ids(sf_object.id):
if not (child_object := get_record(id)):
continue
sections.append(extract_section(child_object, base_url))
doc = Document(
id=onyx_salesforce_id,
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=self._extract_primary_owners(sf_object),
metadata={},
)
return doc
def _fetch_from_salesforce(
self,
start: SecondsSinceUnixEpoch | None = None,
@@ -126,6 +81,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
f"Found {len(child_types)} child types for {parent_object_type}"
)
# Always want to make sure user is grabbed for permissioning purposes
all_object_types.add("User")
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
logger.debug(f"All object types: {all_object_types}")
@@ -169,9 +127,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
logger.debug(
f"Added {len(new_ids)} new/updated records for {object_type}"
)
# Remove the csv file after it has been used
# to successfully update the db
os.remove(csv_path)
logger.info(f"Found {len(updated_ids)} total updated records")
logger.info(
@@ -196,7 +151,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
continue
docs_to_yield.append(
self._convert_object_instance_to_document(parent_object)
convert_sf_object_to_doc(
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
)
docs_processed += 1
@@ -225,7 +183,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
query_result = self.sf_client.query_all(query)
doc_metadata_list.extend(
SlimDocument(
id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}",
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
perm_sync_data={},
)
for instance_dict in query_result["records"]

View File

@@ -1,8 +1,18 @@
import re
from collections import OrderedDict
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
logger = setup_logger()
ID_PREFIX = "SALESFORCE_"
# All of these types of keys are handled by specific fields in the doc
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
@@ -103,54 +113,72 @@ def _extract_dict_text(raw_dict: dict) -> str:
return natural_language_for_dict
def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
return Section(
text=_extract_dict_text(salesforce_object.data),
link=f"{base_url}/{salesforce_object.id}",
)
def _field_value_is_child_object(field_value: dict) -> bool:
"""
Checks if the field value is a child object.
"""
return (
isinstance(field_value, OrderedDict)
and "records" in field_value.keys()
and isinstance(field_value["records"], list)
and len(field_value["records"]) > 0
and "Id" in field_value["records"][0].keys()
def _extract_primary_owners(
sf_object: SalesforceObject,
) -> list[BasicExpertInfo] | None:
object_dict = sf_object.data
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
logger.warning(f"No LastModifiedById found for {sf_object.id}")
return None
if not (last_modified_by := get_record(last_modified_by_id)):
logger.warning(f"No LastModifiedBy found for {last_modified_by_id}")
return None
user_data = last_modified_by.data
expert_info = BasicExpertInfo(
first_name=user_data.get("FirstName"),
last_name=user_data.get("LastName"),
email=user_data.get("Email"),
display_name=user_data.get("Name"),
)
# Check if all fields are None
if all(
value is None
for value in [
expert_info.first_name,
expert_info.last_name,
expert_info.email,
expert_info.display_name,
]
):
logger.warning(f"No identifying information found for user {user_data}")
return None
def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
"""
This goes through the salesforce_object and extracts the top level fields as a Section.
It also goes through the child objects and extracts them as Sections.
"""
top_level_dict = {}
return [expert_info]
child_object_sections = []
for field_name, field_value in salesforce_object.items():
# If the field value is not a child object, add it to the top level dict
# to turn into text for the top level section
if not _field_value_is_child_object(field_value):
top_level_dict[field_name] = field_value
def convert_sf_object_to_doc(
sf_object: SalesforceObject,
sf_instance: str,
) -> Document:
object_dict = sf_object.data
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
base_url = f"https://{sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
sections = [_extract_section(sf_object, base_url)]
for id in get_child_ids(sf_object.id):
if not (child_object := get_record(id)):
continue
sections.append(_extract_section(child_object, base_url))
# If the field value is a child object, extract the child objects and add them as sections
for record in field_value["records"]:
child_object_id = record["Id"]
child_object_sections.append(
Section(
text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}",
link=f"{base_url}/{child_object_id}",
)
)
top_level_id = salesforce_object["Id"]
top_level_section = Section(
text=_extract_dict_text(top_level_dict),
link=f"{base_url}/{top_level_id}",
doc = Document(
id=onyx_salesforce_id,
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=_extract_primary_owners(sf_object),
metadata={},
)
return [top_level_section, *child_object_sections]
return doc

View File

@@ -77,25 +77,28 @@ def _get_all_queryable_fields_of_sf_type(
object_description = _get_sf_type_object_json(sf_client, sf_type)
fields: list[dict[str, Any]] = object_description["fields"]
valid_fields: set[str] = set()
compound_field_names: set[str] = set()
field_names_to_remove: set[str] = set()
for field in fields:
if compound_field_name := field.get("compoundFieldName"):
compound_field_names.add(compound_field_name)
# We do want to get name fields even if they are compound
if not field.get("nameField"):
field_names_to_remove.add(compound_field_name)
if field.get("type", "base64") == "base64":
continue
if field_name := field.get("name"):
valid_fields.add(field_name)
return list(valid_fields - compound_field_names)
return list(valid_fields - field_names_to_remove)
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
def _check_if_object_type_is_empty(
sf_client: Salesforce, sf_type: str, time_filter: str
) -> bool:
"""
Send a small query to check if the object type is empty so we don't
perform extra bulk queries
Use the rest api to check to make sure the query will result in a non-empty response
"""
try:
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
query = f"SELECT Count() FROM {sf_type}{time_filter} LIMIT 1"
result = sf_client.query(query)
if result["totalSize"] == 0:
return False
@@ -134,7 +137,7 @@ def _bulk_retrieve_from_salesforce(
sf_type: str,
time_filter: str,
) -> tuple[str, list[str] | None]:
if not _check_if_object_type_is_empty(sf_client, sf_type):
if not _check_if_object_type_is_empty(sf_client, sf_type, time_filter):
return sf_type, None
if existing_csvs := _check_for_existing_csvs(sf_type):

View File

@@ -40,20 +40,20 @@ def get_db_connection(
def init_db() -> None:
"""Initialize the SQLite database with required tables if they don't exist."""
if os.path.exists(get_sqlite_db_path()):
return
# Create database directory if it doesn't exist
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
with get_db_connection("EXCLUSIVE") as conn:
cursor = conn.cursor()
# Enable WAL mode for better concurrent access and write performance
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA temp_store=MEMORY")
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
db_exists = os.path.exists(get_sqlite_db_path())
if not db_exists:
# Enable WAL mode for better concurrent access and write performance
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA temp_store=MEMORY")
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
# Main table for storing Salesforce objects
cursor.execute(
@@ -90,49 +90,69 @@ def init_db() -> None:
"""
)
# Always recreate indexes to ensure they exist
cursor.execute("DROP INDEX IF EXISTS idx_object_type")
cursor.execute("DROP INDEX IF EXISTS idx_parent_id")
cursor.execute("DROP INDEX IF EXISTS idx_child_parent")
cursor.execute("DROP INDEX IF EXISTS idx_object_type_id")
cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup")
# Create covering indexes for common queries
# Create a table for User email to ID mapping if it doesn't exist
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS user_email_map (
email TEXT PRIMARY KEY,
user_id TEXT, -- Nullable to allow for users without IDs
FOREIGN KEY (user_id) REFERENCES salesforce_objects(id)
) WITHOUT ROWID
"""
)
# Create indexes if they don't exist (SQLite ignores IF NOT EXISTS for indexes)
def create_index_if_not_exists(index_name: str, create_statement: str) -> None:
cursor.execute(
f"SELECT name FROM sqlite_master WHERE type='index' AND name='{index_name}'"
)
if not cursor.fetchone():
cursor.execute(create_statement)
create_index_if_not_exists(
"idx_object_type",
"""
CREATE INDEX idx_object_type
ON salesforce_objects(object_type, id)
WHERE object_type IS NOT NULL
"""
""",
)
cursor.execute(
create_index_if_not_exists(
"idx_parent_id",
"""
CREATE INDEX idx_parent_id
ON relationships(parent_id, child_id)
"""
""",
)
cursor.execute(
create_index_if_not_exists(
"idx_child_parent",
"""
CREATE INDEX idx_child_parent
ON relationships(child_id)
WHERE child_id IS NOT NULL
"""
""",
)
# New composite index for fast parent type lookups
cursor.execute(
create_index_if_not_exists(
"idx_relationship_types_lookup",
"""
CREATE INDEX idx_relationship_types_lookup
ON relationship_types(parent_type, child_id, parent_id)
"""
""",
)
# Analyze tables to help query planner
cursor.execute("ANALYZE relationships")
cursor.execute("ANALYZE salesforce_objects")
cursor.execute("ANALYZE relationship_types")
cursor.execute("ANALYZE user_email_map")
# If database already existed but user_email_map needs to be populated
cursor.execute("SELECT COUNT(*) FROM user_email_map")
if cursor.fetchone()[0] == 0:
_update_user_email_map(conn)
conn.commit()
@@ -203,7 +223,27 @@ def _update_relationship_tables(
raise
def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]:
def _update_user_email_map(conn: sqlite3.Connection) -> None:
"""Update the user_email_map table with current User objects.
Called internally by update_sf_db_with_csv when User objects are updated.
"""
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO user_email_map (email, user_id)
SELECT json_extract(data, '$.Email'), id
FROM salesforce_objects
WHERE object_type = 'User'
AND json_extract(data, '$.Email') IS NOT NULL
"""
)
def update_sf_db_with_csv(
object_type: str,
csv_download_path: str,
delete_csv_after_use: bool = True,
) -> list[str]:
"""Update the SF DB with a CSV file using SQLite storage."""
updated_ids = []
@@ -249,8 +289,17 @@ def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]
_update_relationship_tables(conn, id, parent_ids)
updated_ids.append(id)
# If we're updating User objects, update the email map
if object_type == "User":
_update_user_email_map(conn)
conn.commit()
if delete_csv_after_use:
# Remove the csv file after it has been used
# to successfully update the db
os.remove(csv_download_path)
return updated_ids
@@ -329,6 +378,9 @@ def get_affected_parent_ids_by_type(
cursor = conn.cursor()
for batch_ids in updated_ids_batches:
batch_ids = list(set(batch_ids) - updated_parent_ids)
if not batch_ids:
continue
id_placeholders = ",".join(["?" for _ in batch_ids])
for parent_type in parent_types:
@@ -384,3 +436,40 @@ def has_at_least_one_object_of_type(object_type: str) -> bool:
)
count = cursor.fetchone()[0]
return count > 0
# NULL_ID_STRING is used to indicate that the user ID was queried but not found
# As opposed to None because it has yet to be queried at all
NULL_ID_STRING = "N/A"
def get_user_id_by_email(email: str) -> str | None:
"""Get the Salesforce User ID for a given email address.
Args:
email: The email address to look up
Returns:
A tuple of (was_found, user_id):
- was_found: True if the email exists in the table, False if not found
- user_id: The Salesforce User ID if exists, None otherwise
"""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT user_id FROM user_email_map WHERE email = ?", (email,))
result = cursor.fetchone()
if result is None:
return None
return result[0]
def update_email_to_id_table(email: str, id: str | None) -> None:
"""Update the email to ID map table with a new email and ID."""
id_to_use = id or NULL_ID_STRING
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO user_email_map (email, user_id) VALUES (?, ?)",
(email, id_to_use),
)
conn.commit()

View File

@@ -37,6 +37,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.timing import log_function_time
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
@@ -163,6 +164,17 @@ class SearchPipeline:
# These chunks are ordered, deduped, and contain no large chunks
retrieved_chunks = self._get_chunks()
# If ee is enabled, censor the chunk sections based on user access
# Otherwise, return the retrieved chunks
censored_chunks = fetch_ee_implementation_or_noop(
"onyx.external_permissions.post_query_censoring",
"_post_query_chunk_censoring",
retrieved_chunks,
)(
chunks=retrieved_chunks,
user=self.user,
)
above = self.search_query.chunks_above
below = self.search_query.chunks_below
@@ -175,7 +187,7 @@ class SearchPipeline:
seen_document_ids = set()
# This preserves the ordering since the chunks are retrieved in score order
for chunk in retrieved_chunks:
for chunk in censored_chunks:
if chunk.document_id not in seen_document_ids:
seen_document_ids.add(chunk.document_id)
chunk_requests.append(
@@ -225,7 +237,7 @@ class SearchPipeline:
# This maintains the original chunks ordering. Note, we cannot simply sort by score here
# as reranking flow may wipe the scores for a lot of the chunks.
doc_chunk_ranges_map = defaultdict(list)
for chunk in retrieved_chunks:
for chunk in censored_chunks:
# The list of ranges for each document is ordered by score
doc_chunk_ranges_map[chunk.document_id].append(
ChunkRange(
@@ -274,11 +286,11 @@ class SearchPipeline:
# In case of failed parallel calls to Vespa, at least we should have the initial retrieved chunks
doc_chunk_ind_to_chunk.update(
{(chunk.document_id, chunk.chunk_id): chunk for chunk in retrieved_chunks}
{(chunk.document_id, chunk.chunk_id): chunk for chunk in censored_chunks}
)
# Build the surroundings for all of the initial retrieved chunks
for chunk in retrieved_chunks:
for chunk in censored_chunks:
start_ind = max(0, chunk.chunk_id - above)
end_ind = chunk.chunk_id + below

View File

@@ -20,10 +20,12 @@ from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Document as DbDocument
@@ -626,6 +628,60 @@ def get_document(
return doc
def get_cc_pairs_for_document(
db_session: Session,
document_id: str,
) -> list[ConnectorCredentialPair]:
stmt = (
select(ConnectorCredentialPair)
.join(
DocumentByConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.where(DocumentByConnectorCredentialPair.id == document_id)
)
return list(db_session.execute(stmt).scalars().all())
def get_document_sources(
db_session: Session,
document_ids: list[str],
) -> dict[str, DocumentSource]:
"""Gets the sources for a list of document IDs.
Returns a dictionary mapping document ID to its source.
If a document has multiple sources (multiple CC pairs), returns the first one found.
"""
stmt = (
select(
DocumentByConnectorCredentialPair.id,
Connector.source,
)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
.distinct()
)
results = db_session.execute(stmt).all()
return {doc_id: source for doc_id, source in results}
def fetch_chunk_counts_for_documents(
document_ids: list[str],
db_session: Session,

View File

@@ -23,11 +23,13 @@ class ChunkEmbedding(BaseModel):
class BaseChunk(BaseModel):
chunk_id: int
blurb: str # The first sentence(s) of the first Section of the chunk
# The first sentence(s) of the first Section of the chunk
blurb: str
content: str
# Holds the link and the offsets into the raw Chunk text
source_links: dict[int, str] | None
section_continuation: bool # True if this Chunk's start is not at the start of a Section
# True if this Chunk's start is not at the start of a Section
section_continuation: bool
class DocAwareChunk(BaseChunk):

View File

@@ -0,0 +1,196 @@
from datetime import datetime
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
def create_test_chunk(
doc_id: str,
chunk_id: int,
content: str,
source_links: dict[int, str] | None,
) -> InferenceChunk:
return InferenceChunk(
document_id=doc_id,
chunk_id=chunk_id,
blurb=content[:BLURB_SIZE],
content=content,
source_links=source_links,
section_continuation=False,
source_type=DocumentSource.SALESFORCE,
semantic_identifier="test_chunk",
title="Test Chunk",
boost=1,
recency_bias=1.0,
score=None,
hidden=False,
metadata={},
match_highlights=[],
updated_at=datetime.now(),
)
def test_validate_salesforce_access_single_object() -> None:
"""Test filtering when chunk has a single Salesforce object reference"""
section = "This is a test document about a Salesforce object."
test_content = section
test_chunk = create_test_chunk(
doc_id="doc1",
chunk_id=1,
content=test_content,
source_links={0: "https://salesforce.com/object1"},
)
# Test when user has access
filtered_chunks = censor_salesforce_chunks(
chunks=[test_chunk],
user_email="test@example.com",
access_map={"object1": True},
)
assert len(filtered_chunks) == 1
assert filtered_chunks[0].content == test_content
# Test when user doesn't have access
filtered_chunks = censor_salesforce_chunks(
chunks=[test_chunk],
user_email="test@example.com",
access_map={"object1": False},
)
assert len(filtered_chunks) == 0
def test_validate_salesforce_access_multiple_objects() -> None:
"""Test filtering when chunk has multiple Salesforce object references"""
section1 = "First part about object1. "
section2 = "Second part about object2. "
section3 = "Third part about object3."
test_content = section1 + section2 + section3
section1_end = len(section1)
section2_end = section1_end + len(section2)
test_chunk = create_test_chunk(
doc_id="doc1",
chunk_id=1,
content=test_content,
source_links={
0: "https://salesforce.com/object1",
section1_end: "https://salesforce.com/object2",
section2_end: "https://salesforce.com/object3",
},
)
# Test when user has access to all objects
filtered_chunks = censor_salesforce_chunks(
chunks=[test_chunk],
user_email="test@example.com",
access_map={
"object1": True,
"object2": True,
"object3": True,
},
)
assert len(filtered_chunks) == 1
assert filtered_chunks[0].content == test_content
# Test when user has access to some objects
filtered_chunks = censor_salesforce_chunks(
chunks=[test_chunk],
user_email="test@example.com",
access_map={
"object1": True,
"object2": False,
"object3": True,
},
)
assert len(filtered_chunks) == 1
assert section1 in filtered_chunks[0].content
assert section2 not in filtered_chunks[0].content
assert section3 in filtered_chunks[0].content
# Test when user has no access
filtered_chunks = censor_salesforce_chunks(
chunks=[test_chunk],
user_email="test@example.com",
access_map={
"object1": False,
"object2": False,
"object3": False,
},
)
assert len(filtered_chunks) == 0
def test_validate_salesforce_access_multiple_chunks() -> None:
"""Test filtering when there are multiple chunks with different access patterns"""
section1 = "Content about object1"
section2 = "Content about object2"
chunk1 = create_test_chunk(
doc_id="doc1",
chunk_id=1,
content=section1,
source_links={0: "https://salesforce.com/object1"},
)
chunk2 = create_test_chunk(
doc_id="doc1",
chunk_id=2,
content=section2,
source_links={0: "https://salesforce.com/object2"},
)
# Test mixed access
filtered_chunks = censor_salesforce_chunks(
chunks=[chunk1, chunk2],
user_email="test@example.com",
access_map={
"object1": True,
"object2": False,
},
)
assert len(filtered_chunks) == 1
assert filtered_chunks[0].chunk_id == 1
assert section1 in filtered_chunks[0].content
def test_validate_salesforce_access_no_source_links() -> None:
"""Test handling of chunks with no source links"""
section = "Content with no source links"
test_chunk = create_test_chunk(
doc_id="doc1",
chunk_id=1,
content=section,
source_links=None,
)
filtered_chunks = censor_salesforce_chunks(
chunks=[test_chunk],
user_email="test@example.com",
access_map={},
)
assert len(filtered_chunks) == 0
def test_validate_salesforce_access_blurb_update() -> None:
"""Test that blurbs are properly updated based on permitted content"""
section = "First part about object1. "
long_content = section * 20 # Make it longer than BLURB_SIZE
test_chunk = create_test_chunk(
doc_id="doc1",
chunk_id=1,
content=long_content,
source_links={0: "https://salesforce.com/object1"},
)
filtered_chunks = censor_salesforce_chunks(
chunks=[test_chunk],
user_email="test@example.com",
access_map={"object1": True},
)
assert len(filtered_chunks) == 1
assert len(filtered_chunks[0].blurb) <= BLURB_SIZE
assert filtered_chunks[0].blurb.startswith(section)

View File

@@ -15,4 +15,5 @@ export const autoSyncConfigBySource: Record<
google_drive: {},
gmail: {},
slack: {},
salesforce: {},
};

View File

@@ -343,6 +343,7 @@ export const validAutoSyncSources = [
ValidSources.GoogleDrive,
ValidSources.Gmail,
ValidSources.Slack,
ValidSources.Salesforce,
] as const;
// Create a type from the array elements