Added Permission Syncing for Salesforce (#3551)

* Added Permission Syncing for Salesforce

* cleanup

* updated connector doc conversion

* finished salesforce permission syncing

* fixed connector to batch Salesforce queries

* tests!

* k

* Added error handling and check for ee and sync type for postprocessing

* comments

* minor touchups

* tested to work!

* done

* my pie

* lil cleanup

* minor comment
This commit is contained in:
hagen-danswer
2025-01-06 16:37:03 -08:00
committed by GitHub
parent 71c2559ea9
commit e329b63b89
17 changed files with 1012 additions and 132 deletions

View File

@@ -3,6 +3,10 @@ from sqlalchemy.orm import Session
from ee.onyx.db.external_perm import fetch_external_groups_for_user
from ee.onyx.db.user_group import fetch_user_groups_for_documents
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from onyx.access.access import (
_get_access_for_documents as get_access_for_documents_without_groups,
)
@@ -10,6 +14,7 @@ from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_gro
from onyx.access.models import DocumentAccess
from onyx.access.utils import prefix_external_group
from onyx.access.utils import prefix_user_group
from onyx.db.document import get_document_sources
from onyx.db.document import get_documents_by_ids
from onyx.db.models import User
@@ -52,9 +57,20 @@ def _get_access_for_documents(
)
doc_id_map = {doc.id: doc for doc in documents}
# Get all sources in one batch
doc_id_to_source_map = get_document_sources(
db_session=db_session,
document_ids=document_ids,
)
access_map = {}
for document_id, non_ee_access in non_ee_access_dict.items():
document = doc_id_map[document_id]
source = doc_id_to_source_map.get(document_id)
is_only_censored = (
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
and source not in DOC_PERMISSIONS_FUNC_MAP
)
ext_u_emails = (
set(document.external_user_emails)
@@ -70,7 +86,11 @@ def _get_access_for_documents(
# If the document is determined to be "public" externally (through a SYNC connector)
# then it's given the same access level as if it were marked public within Onyx
is_public_anywhere = document.is_public or non_ee_access.is_public
# If its censored, then it's public anywhere during the search and then permissions are
# applied after the search
is_public_anywhere = (
document.is_public or non_ee_access.is_public or is_only_censored
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess(

View File

@@ -10,6 +10,7 @@ from onyx.access.utils import prefix_group_w_source
from onyx.configs.constants import DocumentSource
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -106,3 +107,21 @@ def fetch_external_groups_for_user(
User__ExternalUserGroupId.user_id == user_id
)
).all()
def fetch_external_groups_for_user_email_and_group_ids(
db_session: Session,
user_email: str,
group_ids: list[str],
) -> list[User__ExternalUserGroupId]:
user = get_user_by_email(db_session=db_session, email=user_email)
if user is None:
return []
user_id = user.id
user_ext_groups = db_session.scalars(
select(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.user_id == user_id,
User__ExternalUserGroupId.external_user_group_id.in_(group_ids),
)
).all()
return list(user_ext_groups)

View File

@@ -0,0 +1,84 @@
from collections.abc import Callable
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from onyx.configs.constants import DocumentSource
from onyx.context.search.pipeline import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
DocumentSource,
# list of chunks to be censored and the user email. returns censored chunks
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
] = {
DocumentSource.SALESFORCE: censor_salesforce_chunks,
}
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
"""
Returns the set of sources that have censoring enabled.
This is based on if the access_type is set to sync and the connector
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
NOTE: This means if there is a source has a single cc_pair that is sync,
all chunks for that source will be censored, even if the connector that
indexed that chunk is not sync. This was done to avoid getting the cc_pair
for every single chunk.
"""
with get_session_context_manager() as db_session:
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
return {
cc_pair.connector.source
for cc_pair in enabled_sync_connectors
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
}
# NOTE: This is only called if ee is enabled.
def _post_query_chunk_censoring(
chunks: list[InferenceChunk],
user: User | None,
) -> list[InferenceChunk]:
"""
This function checks all chunks to see if they need to be sent to a censoring
function. If they do, it sends them to the censoring function and returns the
censored chunks. If they don't, it returns the original chunks.
"""
if user is None:
# if user is None, permissions are not enforced
return chunks
chunks_to_keep = []
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
for chunk in chunks:
# Separate out chunks that require permission post-processing by source
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
chunks_to_keep.append(chunk)
# For each source, filter out the chunks using the permission
# check function for that source
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
for source, chunks_for_source in chunks_to_process.items():
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
try:
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
except Exception as e:
logger.exception(
f"Failed to censor chunks for source {source} so throwing out all"
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
return chunks_to_keep

View File

@@ -0,0 +1,226 @@
import time
from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids
from ee.onyx.external_permissions.salesforce.utils import (
get_any_salesforce_client_for_doc_id,
)
from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id
from ee.onyx.external_permissions.salesforce.utils import (
get_salesforce_user_id_from_email,
)
from onyx.configs.app_configs import BLURB_SIZE
from onyx.context.search.models import InferenceChunk
from onyx.db.engine import get_session_context_manager
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Types
ChunkKey = tuple[str, int] # (doc_id, chunk_id)
ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end
# NOTE: Used for testing timing
def _get_dummy_object_access_map(
object_ids: set[str], user_email: str, chunks: list[InferenceChunk]
) -> dict[str, bool]:
time.sleep(0.15)
# return {object_id: True for object_id in object_ids}
import random
return {object_id: random.choice([True, False]) for object_id in object_ids}
def _get_objects_access_for_user_email_from_salesforce(
object_ids: set[str],
user_email: str,
chunks: list[InferenceChunk],
) -> dict[str, bool] | None:
"""
This function wraps the salesforce call as we may want to change how this
is done in the future. (E.g. replace it with the above function)
"""
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries for this source are essentially instant
first_doc_id = chunks[0].document_id
with get_session_context_manager() as db_session:
salesforce_client = get_any_salesforce_client_for_doc_id(
db_session, first_doc_id
)
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries by the same user are essentially instant
start_time = time.time()
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
end_time = time.time()
logger.info(
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
)
if user_id is None:
return None
# This is the only query that is not cached in the function
# so it takes 0.1-0.2 seconds total
object_id_to_access = get_objects_access_for_user_id(
salesforce_client, user_id, list(object_ids)
)
return object_id_to_access
def _extract_salesforce_object_id_from_url(url: str) -> str:
return url.split("/")[-1]
def _get_object_ranges_for_chunk(
chunk: InferenceChunk,
) -> dict[str, list[ContentRange]]:
"""
Given a chunk, return a dictionary of salesforce object ids and the content ranges
for that object id in the current chunk
"""
if chunk.source_links is None:
return {}
object_ranges: dict[str, list[ContentRange]] = {}
end_index = None
descending_source_links = sorted(
chunk.source_links.items(), key=lambda x: x[0], reverse=True
)
for start_index, url in descending_source_links:
object_id = _extract_salesforce_object_id_from_url(url)
if object_id not in object_ranges:
object_ranges[object_id] = []
object_ranges[object_id].append((start_index, end_index))
end_index = start_index
return object_ranges
def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk:
"""
Create a copy of the unfiltered chunk where potentially sensitive content is removed
to be added later if the user has access to each of the sub-objects
"""
empty_censored_chunk = InferenceChunk(
**uncensored_chunk.model_dump(),
)
empty_censored_chunk.content = ""
empty_censored_chunk.blurb = ""
empty_censored_chunk.source_links = {}
return empty_censored_chunk
def _update_censored_chunk(
censored_chunk: InferenceChunk,
uncensored_chunk: InferenceChunk,
content_range: ContentRange,
) -> InferenceChunk:
"""
Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges
"""
start_index, end_index = content_range
# Update the content of the filtered chunk
permitted_content = uncensored_chunk.content[start_index:end_index]
permitted_section_start_index = len(censored_chunk.content)
censored_chunk.content = permitted_content + censored_chunk.content
# Update the source links of the filtered chunk
if uncensored_chunk.source_links is not None:
if censored_chunk.source_links is None:
censored_chunk.source_links = {}
link_content = uncensored_chunk.source_links[start_index]
censored_chunk.source_links[permitted_section_start_index] = link_content
# Update the blurb of the filtered chunk
censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE]
return censored_chunk
# TODO: Generalize this to other sources
def censor_salesforce_chunks(
chunks: list[InferenceChunk],
user_email: str,
# This is so we can provide a mock access map for testing
access_map: dict[str, bool] | None = None,
) -> list[InferenceChunk]:
# object_id -> list[((doc_id, chunk_id), (start_index, end_index))]
object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {}
# (doc_id, chunk_id) -> chunk
uncensored_chunks: dict[ChunkKey, InferenceChunk] = {}
# keep track of all object ids that we have seen to make it easier to get
# the access for these object ids
object_ids: set[str] = set()
for chunk in chunks:
chunk_key = (chunk.document_id, chunk.chunk_id)
# create a dictionary to quickly look up the unfiltered chunk
uncensored_chunks[chunk_key] = chunk
# for each chunk, get a dictionary of object ids and the content ranges
# for that object id in the current chunk
object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk)
for object_id, ranges in object_ranges_for_chunk.items():
object_ids.add(object_id)
for start_index, end_index in ranges:
object_to_content_map.setdefault(object_id, []).append(
(chunk_key, (start_index, end_index))
)
# This is so we can provide a mock access map for testing
if access_map is None:
access_map = _get_objects_access_for_user_email_from_salesforce(
object_ids=object_ids,
user_email=user_email,
chunks=chunks,
)
if access_map is None:
# If the user is not found in Salesforce, access_map will be None
# so we should just return an empty list because no chunks will be
# censored
return []
censored_chunks: dict[ChunkKey, InferenceChunk] = {}
for object_id, content_list in object_to_content_map.items():
# if the user does not have access to the object, or the object is not in the
# access_map, do not include its content in the filtered chunks
if not access_map.get(object_id, False):
continue
# if we got this far, the user has access to the object so we can create or update
# the filtered chunk(s) for this object
# NOTE: we only create a censored chunk if the user has access to some
# part of the chunk
for chunk_key, content_range in content_list:
if chunk_key not in censored_chunks:
censored_chunks[chunk_key] = _create_empty_censored_chunk(
uncensored_chunks[chunk_key]
)
uncensored_chunk = uncensored_chunks[chunk_key]
censored_chunk = _update_censored_chunk(
censored_chunk=censored_chunks[chunk_key],
uncensored_chunk=uncensored_chunk,
content_range=content_range,
)
censored_chunks[chunk_key] = censored_chunk
return list(censored_chunks.values())
# NOTE: This is not used anywhere.
def _get_objects_access_for_user_email(
object_ids: set[str], user_email: str
) -> dict[str, bool]:
with get_session_context_manager() as db_session:
external_groups = fetch_external_groups_for_user_email_and_group_ids(
db_session=db_session,
user_email=user_email,
# Maybe make a function that adds a salesforce prefix to the group ids
group_ids=list(object_ids),
)
external_group_ids = {group.external_user_group_id for group in external_groups}
return {group_id: group_id in external_group_ids for group_id in object_ids}

View File

@@ -0,0 +1,174 @@
from simple_salesforce import Salesforce
from sqlalchemy.orm import Session
from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING
from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_cc_pairs_for_document
from onyx.utils.logger import setup_logger
logger = setup_logger()
_ANY_SALESFORCE_CLIENT: Salesforce | None = None
def get_any_salesforce_client_for_doc_id(
db_session: Session, doc_id: str
) -> Salesforce:
"""
We create a salesforce client for the first cc_pair for the first doc_id where
salesforce censoring is enabled. After that we just cache and reuse the same
client for all queries.
We do this to reduce the number of postgres queries we make at query time.
This may be problematic if they are using multiple cc_pairs for salesforce.
E.g. there are 2 different credential sets for 2 different salesforce cc_pairs
but only one has the permissions to access the permissions needed for the query.
"""
global _ANY_SALESFORCE_CLIENT
if _ANY_SALESFORCE_CLIENT is None:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
credential_json = first_cc_pair.credential.credential_json
_ANY_SALESFORCE_CLIENT = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
security_token=credential_json["sf_security_token"],
)
return _ANY_SALESFORCE_CLIENT
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
result = sf_client.query(query)
if len(result["records"]) == 0:
return None
return result["records"][0]["Id"]
# This contains only the user_ids that we have found in Salesforce.
# If we don't know their user_id, we don't store anything in the cache.
_CACHED_SF_EMAIL_TO_ID_MAP: dict[str, str] = {}
def get_salesforce_user_id_from_email(
sf_client: Salesforce,
user_email: str,
) -> str | None:
"""
We cache this so we don't have to query Salesforce for every query and salesforce
user IDs never change.
Memory usage is fine because we just store 2 small strings per user.
If the email is not in the cache, we check the local salesforce database for the info.
If the user is not found in the local salesforce database, we query Salesforce.
Whatever we get back from Salesforce is added to the database.
If no user_id is found, we add a NULL_ID_STRING to the database for that email so
we don't query Salesforce again (which is slow) but we still check the local salesforce
database every query until a user id is found. This is acceptable because the query time
is quite fast.
If a user_id is created in Salesforce, it will be added to the local salesforce database
next time the connector is run. Then that value will be found in this function and cached.
NOTE: First time this runs, it may be slow if it hasn't already been updated in the local
salesforce database. (Around 0.1-0.3 seconds)
If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds).
"""
global _CACHED_SF_EMAIL_TO_ID_MAP
if user_email in _CACHED_SF_EMAIL_TO_ID_MAP:
if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None:
return _CACHED_SF_EMAIL_TO_ID_MAP[user_email]
db_exists = True
try:
# Check if the user is already in the database
user_id = get_user_id_by_email(user_email)
except Exception:
init_db()
try:
user_id = get_user_id_by_email(user_email)
except Exception as e:
logger.error(f"Error checking if user is in database: {e}")
user_id = None
db_exists = False
# If no entry is found in the database (indicated by user_id being None)...
if user_id is None:
# ...query Salesforce and store the result in the database
user_id = _query_salesforce_user_id(sf_client, user_email)
if db_exists:
update_email_to_id_table(user_email, user_id)
return user_id
elif user_id is None:
return None
elif user_id == NULL_ID_STRING:
return None
# If the found user_id is real, cache it
_CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id
return user_id
_MAX_RECORD_IDS_PER_QUERY = 200
def get_objects_access_for_user_id(
salesforce_client: Salesforce,
user_id: str,
record_ids: list[str],
) -> dict[str, bool]:
"""
Salesforce has a limit of 200 record ids per query. So we just truncate
the list of record ids to 200. We only ever retrieve 50 chunks at a time
so this should be fine (unlikely that we retrieve all 50 chunks contain
4 unique objects).
If we decide this isn't acceptable we can use multiple queries but they
should be in parallel so query time doesn't get too long.
"""
truncated_record_ids = record_ids[:_MAX_RECORD_IDS_PER_QUERY]
record_ids_str = "'" + "','".join(truncated_record_ids) + "'"
access_query = f"""
SELECT RecordId, HasReadAccess
FROM UserRecordAccess
WHERE RecordId IN ({record_ids_str})
AND UserId = '{user_id}'
"""
result = salesforce_client.query_all(access_query)
return {record["RecordId"]: record["HasReadAccess"] for record in result["records"]}
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP: dict[int, Salesforce] = {}
_DOC_ID_TO_CC_PAIR_ID_MAP: dict[str, int] = {}
# NOTE: This is not used anywhere.
def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Salesforce:
"""
Uses a document id to get the cc_pair that indexed that document and uses the credentials
for that cc_pair to create a Salesforce client.
Problems:
- There may be multiple cc_pairs for a document, and we don't know which one to use.
- right now we just use the first one
- Building a new Salesforce client for each document is slow.
- Memory usage could be an issue as we build these dictionaries.
"""
if doc_id not in _DOC_ID_TO_CC_PAIR_ID_MAP:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
_DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] = first_cc_pair.id
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = cc_pair.credential.credential_json
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
security_token=credential_json["sf_security_token"],
)
return _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id]

View File

@@ -8,6 +8,9 @@ from ee.onyx.external_permissions.confluence.group_sync import confluence_group_
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
@@ -71,4 +74,7 @@ EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
return source_type in DOC_PERMISSIONS_FUNC_MAP
return (
source_type in DOC_PERMISSIONS_FUNC_MAP
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
)