mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 21:33:56 +02:00
Added Permission Syncing for Salesforce (#3551)
* Added Permission Syncing for Salesforce * cleanup * updated connector doc conversion * finished salesforce permission syncing * fixed connector to batch Salesforce queries * tests! * k * Added error handling and check for ee and sync type for postprocessing * comments * minor touchups * tested to work! * done * my pie * lil cleanup * minor comment
This commit is contained in:
@@ -3,6 +3,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_documents
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
@@ -10,6 +14,7 @@ from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_gro
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -52,9 +57,20 @@ def _get_access_for_documents(
|
||||
)
|
||||
doc_id_map = {doc.id: doc for doc in documents}
|
||||
|
||||
# Get all sources in one batch
|
||||
doc_id_to_source_map = get_document_sources(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
access_map = {}
|
||||
for document_id, non_ee_access in non_ee_access_dict.items():
|
||||
document = doc_id_map[document_id]
|
||||
source = doc_id_to_source_map.get(document_id)
|
||||
is_only_censored = (
|
||||
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
and source not in DOC_PERMISSIONS_FUNC_MAP
|
||||
)
|
||||
|
||||
ext_u_emails = (
|
||||
set(document.external_user_emails)
|
||||
@@ -70,7 +86,11 @@ def _get_access_for_documents(
|
||||
|
||||
# If the document is determined to be "public" externally (through a SYNC connector)
|
||||
# then it's given the same access level as if it were marked public within Onyx
|
||||
is_public_anywhere = document.is_public or non_ee_access.is_public
|
||||
# If its censored, then it's public anywhere during the search and then permissions are
|
||||
# applied after the search
|
||||
is_public_anywhere = (
|
||||
document.is_public or non_ee_access.is_public or is_only_censored
|
||||
)
|
||||
|
||||
# To avoid collisions of group namings between connectors, they need to be prefixed
|
||||
access_map[document_id] = DocumentAccess(
|
||||
|
@@ -10,6 +10,7 @@ from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -106,3 +107,21 @@ def fetch_external_groups_for_user(
|
||||
User__ExternalUserGroupId.user_id == user_id
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_external_groups_for_user_email_and_group_ids(
|
||||
db_session: Session,
|
||||
user_email: str,
|
||||
group_ids: list[str],
|
||||
) -> list[User__ExternalUserGroupId]:
|
||||
user = get_user_by_email(db_session=db_session, email=user_email)
|
||||
if user is None:
|
||||
return []
|
||||
user_id = user.id
|
||||
user_ext_groups = db_session.scalars(
|
||||
select(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.user_id == user_id,
|
||||
User__ExternalUserGroupId.external_user_group_id.in_(group_ids),
|
||||
)
|
||||
).all()
|
||||
return list(user_ext_groups)
|
||||
|
84
backend/ee/onyx/external_permissions/post_query_censoring.py
Normal file
84
backend/ee/onyx/external_permissions/post_query_censoring.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.pipeline import InferenceChunk
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
|
||||
DocumentSource,
|
||||
# list of chunks to be censored and the user email. returns censored chunks
|
||||
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
|
||||
] = {
|
||||
DocumentSource.SALESFORCE: censor_salesforce_chunks,
|
||||
}
|
||||
|
||||
|
||||
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
|
||||
"""
|
||||
Returns the set of sources that have censoring enabled.
|
||||
This is based on if the access_type is set to sync and the connector
|
||||
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
|
||||
|
||||
NOTE: This means if there is a source has a single cc_pair that is sync,
|
||||
all chunks for that source will be censored, even if the connector that
|
||||
indexed that chunk is not sync. This was done to avoid getting the cc_pair
|
||||
for every single chunk.
|
||||
"""
|
||||
with get_session_context_manager() as db_session:
|
||||
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
|
||||
return {
|
||||
cc_pair.connector.source
|
||||
for cc_pair in enabled_sync_connectors
|
||||
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
}
|
||||
|
||||
|
||||
# NOTE: This is only called if ee is enabled.
|
||||
def _post_query_chunk_censoring(
|
||||
chunks: list[InferenceChunk],
|
||||
user: User | None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
This function checks all chunks to see if they need to be sent to a censoring
|
||||
function. If they do, it sends them to the censoring function and returns the
|
||||
censored chunks. If they don't, it returns the original chunks.
|
||||
"""
|
||||
if user is None:
|
||||
# if user is None, permissions are not enforced
|
||||
return chunks
|
||||
|
||||
chunks_to_keep = []
|
||||
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
|
||||
|
||||
sources_to_censor = _get_all_censoring_enabled_sources()
|
||||
for chunk in chunks:
|
||||
# Separate out chunks that require permission post-processing by source
|
||||
if chunk.source_type in sources_to_censor:
|
||||
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
|
||||
else:
|
||||
chunks_to_keep.append(chunk)
|
||||
|
||||
# For each source, filter out the chunks using the permission
|
||||
# check function for that source
|
||||
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
|
||||
for source, chunks_for_source in chunks_to_process.items():
|
||||
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
|
||||
try:
|
||||
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to censor chunks for source {source} so throwing out all"
|
||||
f" chunks for this source and continuing: {e}"
|
||||
)
|
||||
continue
|
||||
chunks_to_keep.extend(censored_chunks)
|
||||
|
||||
return chunks_to_keep
|
@@ -0,0 +1,226 @@
|
||||
import time
|
||||
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids
|
||||
from ee.onyx.external_permissions.salesforce.utils import (
|
||||
get_any_salesforce_client_for_doc_id,
|
||||
)
|
||||
from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id
|
||||
from ee.onyx.external_permissions.salesforce.utils import (
|
||||
get_salesforce_user_id_from_email,
|
||||
)
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Types
|
||||
ChunkKey = tuple[str, int] # (doc_id, chunk_id)
|
||||
ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end
|
||||
|
||||
|
||||
# NOTE: Used for testing timing
|
||||
def _get_dummy_object_access_map(
|
||||
object_ids: set[str], user_email: str, chunks: list[InferenceChunk]
|
||||
) -> dict[str, bool]:
|
||||
time.sleep(0.15)
|
||||
# return {object_id: True for object_id in object_ids}
|
||||
import random
|
||||
|
||||
return {object_id: random.choice([True, False]) for object_id in object_ids}
|
||||
|
||||
|
||||
def _get_objects_access_for_user_email_from_salesforce(
|
||||
object_ids: set[str],
|
||||
user_email: str,
|
||||
chunks: list[InferenceChunk],
|
||||
) -> dict[str, bool] | None:
|
||||
"""
|
||||
This function wraps the salesforce call as we may want to change how this
|
||||
is done in the future. (E.g. replace it with the above function)
|
||||
"""
|
||||
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
|
||||
# but subsequent queries for this source are essentially instant
|
||||
first_doc_id = chunks[0].document_id
|
||||
with get_session_context_manager() as db_session:
|
||||
salesforce_client = get_any_salesforce_client_for_doc_id(
|
||||
db_session, first_doc_id
|
||||
)
|
||||
|
||||
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
|
||||
# but subsequent queries by the same user are essentially instant
|
||||
start_time = time.time()
|
||||
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
|
||||
)
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
# This is the only query that is not cached in the function
|
||||
# so it takes 0.1-0.2 seconds total
|
||||
object_id_to_access = get_objects_access_for_user_id(
|
||||
salesforce_client, user_id, list(object_ids)
|
||||
)
|
||||
return object_id_to_access
|
||||
|
||||
|
||||
def _extract_salesforce_object_id_from_url(url: str) -> str:
|
||||
return url.split("/")[-1]
|
||||
|
||||
|
||||
def _get_object_ranges_for_chunk(
|
||||
chunk: InferenceChunk,
|
||||
) -> dict[str, list[ContentRange]]:
|
||||
"""
|
||||
Given a chunk, return a dictionary of salesforce object ids and the content ranges
|
||||
for that object id in the current chunk
|
||||
"""
|
||||
if chunk.source_links is None:
|
||||
return {}
|
||||
|
||||
object_ranges: dict[str, list[ContentRange]] = {}
|
||||
end_index = None
|
||||
descending_source_links = sorted(
|
||||
chunk.source_links.items(), key=lambda x: x[0], reverse=True
|
||||
)
|
||||
for start_index, url in descending_source_links:
|
||||
object_id = _extract_salesforce_object_id_from_url(url)
|
||||
if object_id not in object_ranges:
|
||||
object_ranges[object_id] = []
|
||||
object_ranges[object_id].append((start_index, end_index))
|
||||
end_index = start_index
|
||||
return object_ranges
|
||||
|
||||
|
||||
def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk:
|
||||
"""
|
||||
Create a copy of the unfiltered chunk where potentially sensitive content is removed
|
||||
to be added later if the user has access to each of the sub-objects
|
||||
"""
|
||||
empty_censored_chunk = InferenceChunk(
|
||||
**uncensored_chunk.model_dump(),
|
||||
)
|
||||
empty_censored_chunk.content = ""
|
||||
empty_censored_chunk.blurb = ""
|
||||
empty_censored_chunk.source_links = {}
|
||||
return empty_censored_chunk
|
||||
|
||||
|
||||
def _update_censored_chunk(
|
||||
censored_chunk: InferenceChunk,
|
||||
uncensored_chunk: InferenceChunk,
|
||||
content_range: ContentRange,
|
||||
) -> InferenceChunk:
|
||||
"""
|
||||
Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges
|
||||
"""
|
||||
start_index, end_index = content_range
|
||||
|
||||
# Update the content of the filtered chunk
|
||||
permitted_content = uncensored_chunk.content[start_index:end_index]
|
||||
permitted_section_start_index = len(censored_chunk.content)
|
||||
censored_chunk.content = permitted_content + censored_chunk.content
|
||||
|
||||
# Update the source links of the filtered chunk
|
||||
if uncensored_chunk.source_links is not None:
|
||||
if censored_chunk.source_links is None:
|
||||
censored_chunk.source_links = {}
|
||||
link_content = uncensored_chunk.source_links[start_index]
|
||||
censored_chunk.source_links[permitted_section_start_index] = link_content
|
||||
|
||||
# Update the blurb of the filtered chunk
|
||||
censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE]
|
||||
|
||||
return censored_chunk
|
||||
|
||||
|
||||
# TODO: Generalize this to other sources
|
||||
def censor_salesforce_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
user_email: str,
|
||||
# This is so we can provide a mock access map for testing
|
||||
access_map: dict[str, bool] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
# object_id -> list[((doc_id, chunk_id), (start_index, end_index))]
|
||||
object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {}
|
||||
|
||||
# (doc_id, chunk_id) -> chunk
|
||||
uncensored_chunks: dict[ChunkKey, InferenceChunk] = {}
|
||||
|
||||
# keep track of all object ids that we have seen to make it easier to get
|
||||
# the access for these object ids
|
||||
object_ids: set[str] = set()
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_key = (chunk.document_id, chunk.chunk_id)
|
||||
# create a dictionary to quickly look up the unfiltered chunk
|
||||
uncensored_chunks[chunk_key] = chunk
|
||||
|
||||
# for each chunk, get a dictionary of object ids and the content ranges
|
||||
# for that object id in the current chunk
|
||||
object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk)
|
||||
for object_id, ranges in object_ranges_for_chunk.items():
|
||||
object_ids.add(object_id)
|
||||
for start_index, end_index in ranges:
|
||||
object_to_content_map.setdefault(object_id, []).append(
|
||||
(chunk_key, (start_index, end_index))
|
||||
)
|
||||
|
||||
# This is so we can provide a mock access map for testing
|
||||
if access_map is None:
|
||||
access_map = _get_objects_access_for_user_email_from_salesforce(
|
||||
object_ids=object_ids,
|
||||
user_email=user_email,
|
||||
chunks=chunks,
|
||||
)
|
||||
if access_map is None:
|
||||
# If the user is not found in Salesforce, access_map will be None
|
||||
# so we should just return an empty list because no chunks will be
|
||||
# censored
|
||||
return []
|
||||
|
||||
censored_chunks: dict[ChunkKey, InferenceChunk] = {}
|
||||
for object_id, content_list in object_to_content_map.items():
|
||||
# if the user does not have access to the object, or the object is not in the
|
||||
# access_map, do not include its content in the filtered chunks
|
||||
if not access_map.get(object_id, False):
|
||||
continue
|
||||
|
||||
# if we got this far, the user has access to the object so we can create or update
|
||||
# the filtered chunk(s) for this object
|
||||
# NOTE: we only create a censored chunk if the user has access to some
|
||||
# part of the chunk
|
||||
for chunk_key, content_range in content_list:
|
||||
if chunk_key not in censored_chunks:
|
||||
censored_chunks[chunk_key] = _create_empty_censored_chunk(
|
||||
uncensored_chunks[chunk_key]
|
||||
)
|
||||
|
||||
uncensored_chunk = uncensored_chunks[chunk_key]
|
||||
censored_chunk = _update_censored_chunk(
|
||||
censored_chunk=censored_chunks[chunk_key],
|
||||
uncensored_chunk=uncensored_chunk,
|
||||
content_range=content_range,
|
||||
)
|
||||
censored_chunks[chunk_key] = censored_chunk
|
||||
|
||||
return list(censored_chunks.values())
|
||||
|
||||
|
||||
# NOTE: This is not used anywhere.
|
||||
def _get_objects_access_for_user_email(
|
||||
object_ids: set[str], user_email: str
|
||||
) -> dict[str, bool]:
|
||||
with get_session_context_manager() as db_session:
|
||||
external_groups = fetch_external_groups_for_user_email_and_group_ids(
|
||||
db_session=db_session,
|
||||
user_email=user_email,
|
||||
# Maybe make a function that adds a salesforce prefix to the group ids
|
||||
group_ids=list(object_ids),
|
||||
)
|
||||
external_group_ids = {group.external_user_group_id for group in external_groups}
|
||||
return {group_id: group_id in external_group_ids for group_id in object_ids}
|
174
backend/ee/onyx/external_permissions/salesforce/utils.py
Normal file
174
backend/ee/onyx/external_permissions/salesforce/utils.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from simple_salesforce import Salesforce
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import get_cc_pairs_for_document
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ANY_SALESFORCE_CLIENT: Salesforce | None = None
|
||||
|
||||
|
||||
def get_any_salesforce_client_for_doc_id(
|
||||
db_session: Session, doc_id: str
|
||||
) -> Salesforce:
|
||||
"""
|
||||
We create a salesforce client for the first cc_pair for the first doc_id where
|
||||
salesforce censoring is enabled. After that we just cache and reuse the same
|
||||
client for all queries.
|
||||
|
||||
We do this to reduce the number of postgres queries we make at query time.
|
||||
|
||||
This may be problematic if they are using multiple cc_pairs for salesforce.
|
||||
E.g. there are 2 different credential sets for 2 different salesforce cc_pairs
|
||||
but only one has the permissions to access the permissions needed for the query.
|
||||
"""
|
||||
global _ANY_SALESFORCE_CLIENT
|
||||
if _ANY_SALESFORCE_CLIENT is None:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
credential_json = first_cc_pair.credential.credential_json
|
||||
_ANY_SALESFORCE_CLIENT = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
security_token=credential_json["sf_security_token"],
|
||||
)
|
||||
return _ANY_SALESFORCE_CLIENT
|
||||
|
||||
|
||||
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
|
||||
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
|
||||
result = sf_client.query(query)
|
||||
if len(result["records"]) == 0:
|
||||
return None
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
|
||||
# This contains only the user_ids that we have found in Salesforce.
|
||||
# If we don't know their user_id, we don't store anything in the cache.
|
||||
_CACHED_SF_EMAIL_TO_ID_MAP: dict[str, str] = {}
|
||||
|
||||
|
||||
def get_salesforce_user_id_from_email(
|
||||
sf_client: Salesforce,
|
||||
user_email: str,
|
||||
) -> str | None:
|
||||
"""
|
||||
We cache this so we don't have to query Salesforce for every query and salesforce
|
||||
user IDs never change.
|
||||
Memory usage is fine because we just store 2 small strings per user.
|
||||
|
||||
If the email is not in the cache, we check the local salesforce database for the info.
|
||||
If the user is not found in the local salesforce database, we query Salesforce.
|
||||
Whatever we get back from Salesforce is added to the database.
|
||||
If no user_id is found, we add a NULL_ID_STRING to the database for that email so
|
||||
we don't query Salesforce again (which is slow) but we still check the local salesforce
|
||||
database every query until a user id is found. This is acceptable because the query time
|
||||
is quite fast.
|
||||
If a user_id is created in Salesforce, it will be added to the local salesforce database
|
||||
next time the connector is run. Then that value will be found in this function and cached.
|
||||
|
||||
NOTE: First time this runs, it may be slow if it hasn't already been updated in the local
|
||||
salesforce database. (Around 0.1-0.3 seconds)
|
||||
If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds).
|
||||
"""
|
||||
global _CACHED_SF_EMAIL_TO_ID_MAP
|
||||
if user_email in _CACHED_SF_EMAIL_TO_ID_MAP:
|
||||
if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None:
|
||||
return _CACHED_SF_EMAIL_TO_ID_MAP[user_email]
|
||||
|
||||
db_exists = True
|
||||
try:
|
||||
# Check if the user is already in the database
|
||||
user_id = get_user_id_by_email(user_email)
|
||||
except Exception:
|
||||
init_db()
|
||||
try:
|
||||
user_id = get_user_id_by_email(user_email)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if user is in database: {e}")
|
||||
user_id = None
|
||||
db_exists = False
|
||||
|
||||
# If no entry is found in the database (indicated by user_id being None)...
|
||||
if user_id is None:
|
||||
# ...query Salesforce and store the result in the database
|
||||
user_id = _query_salesforce_user_id(sf_client, user_email)
|
||||
if db_exists:
|
||||
update_email_to_id_table(user_email, user_id)
|
||||
return user_id
|
||||
elif user_id is None:
|
||||
return None
|
||||
elif user_id == NULL_ID_STRING:
|
||||
return None
|
||||
# If the found user_id is real, cache it
|
||||
_CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id
|
||||
return user_id
|
||||
|
||||
|
||||
_MAX_RECORD_IDS_PER_QUERY = 200
|
||||
|
||||
|
||||
def get_objects_access_for_user_id(
|
||||
salesforce_client: Salesforce,
|
||||
user_id: str,
|
||||
record_ids: list[str],
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Salesforce has a limit of 200 record ids per query. So we just truncate
|
||||
the list of record ids to 200. We only ever retrieve 50 chunks at a time
|
||||
so this should be fine (unlikely that we retrieve all 50 chunks contain
|
||||
4 unique objects).
|
||||
If we decide this isn't acceptable we can use multiple queries but they
|
||||
should be in parallel so query time doesn't get too long.
|
||||
"""
|
||||
truncated_record_ids = record_ids[:_MAX_RECORD_IDS_PER_QUERY]
|
||||
record_ids_str = "'" + "','".join(truncated_record_ids) + "'"
|
||||
access_query = f"""
|
||||
SELECT RecordId, HasReadAccess
|
||||
FROM UserRecordAccess
|
||||
WHERE RecordId IN ({record_ids_str})
|
||||
AND UserId = '{user_id}'
|
||||
"""
|
||||
result = salesforce_client.query_all(access_query)
|
||||
return {record["RecordId"]: record["HasReadAccess"] for record in result["records"]}
|
||||
|
||||
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP: dict[int, Salesforce] = {}
|
||||
_DOC_ID_TO_CC_PAIR_ID_MAP: dict[str, int] = {}
|
||||
|
||||
|
||||
# NOTE: This is not used anywhere.
|
||||
def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Salesforce:
|
||||
"""
|
||||
Uses a document id to get the cc_pair that indexed that document and uses the credentials
|
||||
for that cc_pair to create a Salesforce client.
|
||||
Problems:
|
||||
- There may be multiple cc_pairs for a document, and we don't know which one to use.
|
||||
- right now we just use the first one
|
||||
- Building a new Salesforce client for each document is slow.
|
||||
- Memory usage could be an issue as we build these dictionaries.
|
||||
"""
|
||||
if doc_id not in _DOC_ID_TO_CC_PAIR_ID_MAP:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
_DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] = first_cc_pair.id
|
||||
|
||||
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
|
||||
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"CC pair {cc_pair_id} not found")
|
||||
credential_json = cc_pair.credential.credential_json
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
security_token=credential_json["sf_security_token"],
|
||||
)
|
||||
|
||||
return _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id]
|
@@ -8,6 +8,9 @@ from ee.onyx.external_permissions.confluence.group_sync import confluence_group_
|
||||
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -71,4 +74,7 @@ EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
return (
|
||||
source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
)
|
||||
|
Reference in New Issue
Block a user