mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 04:37:09 +02:00
Added Permission Syncing for Salesforce (#3551)
* Added Permission Syncing for Salesforce * cleanup * updated connector doc conversion * finished salesforce permission syncing * fixed connector to batch Salesforce queries * tests! * k * Added error handling and check for ee and sync type for postprocessing * comments * minor touchups * tested to work! * done * my pie * lil cleanup * minor comment
This commit is contained in:
@@ -3,6 +3,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_documents
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
@@ -10,6 +14,7 @@ from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_gro
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -52,9 +57,20 @@ def _get_access_for_documents(
|
||||
)
|
||||
doc_id_map = {doc.id: doc for doc in documents}
|
||||
|
||||
# Get all sources in one batch
|
||||
doc_id_to_source_map = get_document_sources(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
access_map = {}
|
||||
for document_id, non_ee_access in non_ee_access_dict.items():
|
||||
document = doc_id_map[document_id]
|
||||
source = doc_id_to_source_map.get(document_id)
|
||||
is_only_censored = (
|
||||
source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
and source not in DOC_PERMISSIONS_FUNC_MAP
|
||||
)
|
||||
|
||||
ext_u_emails = (
|
||||
set(document.external_user_emails)
|
||||
@@ -70,7 +86,11 @@ def _get_access_for_documents(
|
||||
|
||||
# If the document is determined to be "public" externally (through a SYNC connector)
|
||||
# then it's given the same access level as if it were marked public within Onyx
|
||||
is_public_anywhere = document.is_public or non_ee_access.is_public
|
||||
# If its censored, then it's public anywhere during the search and then permissions are
|
||||
# applied after the search
|
||||
is_public_anywhere = (
|
||||
document.is_public or non_ee_access.is_public or is_only_censored
|
||||
)
|
||||
|
||||
# To avoid collisions of group namings between connectors, they need to be prefixed
|
||||
access_map[document_id] = DocumentAccess(
|
||||
|
@@ -10,6 +10,7 @@ from onyx.access.utils import prefix_group_w_source
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -106,3 +107,21 @@ def fetch_external_groups_for_user(
|
||||
User__ExternalUserGroupId.user_id == user_id
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_external_groups_for_user_email_and_group_ids(
|
||||
db_session: Session,
|
||||
user_email: str,
|
||||
group_ids: list[str],
|
||||
) -> list[User__ExternalUserGroupId]:
|
||||
user = get_user_by_email(db_session=db_session, email=user_email)
|
||||
if user is None:
|
||||
return []
|
||||
user_id = user.id
|
||||
user_ext_groups = db_session.scalars(
|
||||
select(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.user_id == user_id,
|
||||
User__ExternalUserGroupId.external_user_group_id.in_(group_ids),
|
||||
)
|
||||
).all()
|
||||
return list(user_ext_groups)
|
||||
|
84
backend/ee/onyx/external_permissions/post_query_censoring.py
Normal file
84
backend/ee/onyx/external_permissions/post_query_censoring.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.pipeline import InferenceChunk
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: dict[
|
||||
DocumentSource,
|
||||
# list of chunks to be censored and the user email. returns censored chunks
|
||||
Callable[[list[InferenceChunk], str], list[InferenceChunk]],
|
||||
] = {
|
||||
DocumentSource.SALESFORCE: censor_salesforce_chunks,
|
||||
}
|
||||
|
||||
|
||||
def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
|
||||
"""
|
||||
Returns the set of sources that have censoring enabled.
|
||||
This is based on if the access_type is set to sync and the connector
|
||||
source is included in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION.
|
||||
|
||||
NOTE: This means if there is a source has a single cc_pair that is sync,
|
||||
all chunks for that source will be censored, even if the connector that
|
||||
indexed that chunk is not sync. This was done to avoid getting the cc_pair
|
||||
for every single chunk.
|
||||
"""
|
||||
with get_session_context_manager() as db_session:
|
||||
enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session)
|
||||
return {
|
||||
cc_pair.connector.source
|
||||
for cc_pair in enabled_sync_connectors
|
||||
if cc_pair.connector.source in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
}
|
||||
|
||||
|
||||
# NOTE: This is only called if ee is enabled.
|
||||
def _post_query_chunk_censoring(
|
||||
chunks: list[InferenceChunk],
|
||||
user: User | None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
This function checks all chunks to see if they need to be sent to a censoring
|
||||
function. If they do, it sends them to the censoring function and returns the
|
||||
censored chunks. If they don't, it returns the original chunks.
|
||||
"""
|
||||
if user is None:
|
||||
# if user is None, permissions are not enforced
|
||||
return chunks
|
||||
|
||||
chunks_to_keep = []
|
||||
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
|
||||
|
||||
sources_to_censor = _get_all_censoring_enabled_sources()
|
||||
for chunk in chunks:
|
||||
# Separate out chunks that require permission post-processing by source
|
||||
if chunk.source_type in sources_to_censor:
|
||||
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
|
||||
else:
|
||||
chunks_to_keep.append(chunk)
|
||||
|
||||
# For each source, filter out the chunks using the permission
|
||||
# check function for that source
|
||||
# TODO: Use a threadpool/multiprocessing to process the sources in parallel
|
||||
for source, chunks_for_source in chunks_to_process.items():
|
||||
censor_chunks_for_source = DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION[source]
|
||||
try:
|
||||
censored_chunks = censor_chunks_for_source(chunks_for_source, user.email)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to censor chunks for source {source} so throwing out all"
|
||||
f" chunks for this source and continuing: {e}"
|
||||
)
|
||||
continue
|
||||
chunks_to_keep.extend(censored_chunks)
|
||||
|
||||
return chunks_to_keep
|
@@ -0,0 +1,226 @@
|
||||
import time
|
||||
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids
|
||||
from ee.onyx.external_permissions.salesforce.utils import (
|
||||
get_any_salesforce_client_for_doc_id,
|
||||
)
|
||||
from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id
|
||||
from ee.onyx.external_permissions.salesforce.utils import (
|
||||
get_salesforce_user_id_from_email,
|
||||
)
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Types
|
||||
ChunkKey = tuple[str, int] # (doc_id, chunk_id)
|
||||
ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end
|
||||
|
||||
|
||||
# NOTE: Used for testing timing
|
||||
def _get_dummy_object_access_map(
|
||||
object_ids: set[str], user_email: str, chunks: list[InferenceChunk]
|
||||
) -> dict[str, bool]:
|
||||
time.sleep(0.15)
|
||||
# return {object_id: True for object_id in object_ids}
|
||||
import random
|
||||
|
||||
return {object_id: random.choice([True, False]) for object_id in object_ids}
|
||||
|
||||
|
||||
def _get_objects_access_for_user_email_from_salesforce(
|
||||
object_ids: set[str],
|
||||
user_email: str,
|
||||
chunks: list[InferenceChunk],
|
||||
) -> dict[str, bool] | None:
|
||||
"""
|
||||
This function wraps the salesforce call as we may want to change how this
|
||||
is done in the future. (E.g. replace it with the above function)
|
||||
"""
|
||||
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
|
||||
# but subsequent queries for this source are essentially instant
|
||||
first_doc_id = chunks[0].document_id
|
||||
with get_session_context_manager() as db_session:
|
||||
salesforce_client = get_any_salesforce_client_for_doc_id(
|
||||
db_session, first_doc_id
|
||||
)
|
||||
|
||||
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
|
||||
# but subsequent queries by the same user are essentially instant
|
||||
start_time = time.time()
|
||||
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
|
||||
)
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
# This is the only query that is not cached in the function
|
||||
# so it takes 0.1-0.2 seconds total
|
||||
object_id_to_access = get_objects_access_for_user_id(
|
||||
salesforce_client, user_id, list(object_ids)
|
||||
)
|
||||
return object_id_to_access
|
||||
|
||||
|
||||
def _extract_salesforce_object_id_from_url(url: str) -> str:
|
||||
return url.split("/")[-1]
|
||||
|
||||
|
||||
def _get_object_ranges_for_chunk(
|
||||
chunk: InferenceChunk,
|
||||
) -> dict[str, list[ContentRange]]:
|
||||
"""
|
||||
Given a chunk, return a dictionary of salesforce object ids and the content ranges
|
||||
for that object id in the current chunk
|
||||
"""
|
||||
if chunk.source_links is None:
|
||||
return {}
|
||||
|
||||
object_ranges: dict[str, list[ContentRange]] = {}
|
||||
end_index = None
|
||||
descending_source_links = sorted(
|
||||
chunk.source_links.items(), key=lambda x: x[0], reverse=True
|
||||
)
|
||||
for start_index, url in descending_source_links:
|
||||
object_id = _extract_salesforce_object_id_from_url(url)
|
||||
if object_id not in object_ranges:
|
||||
object_ranges[object_id] = []
|
||||
object_ranges[object_id].append((start_index, end_index))
|
||||
end_index = start_index
|
||||
return object_ranges
|
||||
|
||||
|
||||
def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk:
|
||||
"""
|
||||
Create a copy of the unfiltered chunk where potentially sensitive content is removed
|
||||
to be added later if the user has access to each of the sub-objects
|
||||
"""
|
||||
empty_censored_chunk = InferenceChunk(
|
||||
**uncensored_chunk.model_dump(),
|
||||
)
|
||||
empty_censored_chunk.content = ""
|
||||
empty_censored_chunk.blurb = ""
|
||||
empty_censored_chunk.source_links = {}
|
||||
return empty_censored_chunk
|
||||
|
||||
|
||||
def _update_censored_chunk(
|
||||
censored_chunk: InferenceChunk,
|
||||
uncensored_chunk: InferenceChunk,
|
||||
content_range: ContentRange,
|
||||
) -> InferenceChunk:
|
||||
"""
|
||||
Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges
|
||||
"""
|
||||
start_index, end_index = content_range
|
||||
|
||||
# Update the content of the filtered chunk
|
||||
permitted_content = uncensored_chunk.content[start_index:end_index]
|
||||
permitted_section_start_index = len(censored_chunk.content)
|
||||
censored_chunk.content = permitted_content + censored_chunk.content
|
||||
|
||||
# Update the source links of the filtered chunk
|
||||
if uncensored_chunk.source_links is not None:
|
||||
if censored_chunk.source_links is None:
|
||||
censored_chunk.source_links = {}
|
||||
link_content = uncensored_chunk.source_links[start_index]
|
||||
censored_chunk.source_links[permitted_section_start_index] = link_content
|
||||
|
||||
# Update the blurb of the filtered chunk
|
||||
censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE]
|
||||
|
||||
return censored_chunk
|
||||
|
||||
|
||||
# TODO: Generalize this to other sources
|
||||
def censor_salesforce_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
user_email: str,
|
||||
# This is so we can provide a mock access map for testing
|
||||
access_map: dict[str, bool] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
# object_id -> list[((doc_id, chunk_id), (start_index, end_index))]
|
||||
object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {}
|
||||
|
||||
# (doc_id, chunk_id) -> chunk
|
||||
uncensored_chunks: dict[ChunkKey, InferenceChunk] = {}
|
||||
|
||||
# keep track of all object ids that we have seen to make it easier to get
|
||||
# the access for these object ids
|
||||
object_ids: set[str] = set()
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_key = (chunk.document_id, chunk.chunk_id)
|
||||
# create a dictionary to quickly look up the unfiltered chunk
|
||||
uncensored_chunks[chunk_key] = chunk
|
||||
|
||||
# for each chunk, get a dictionary of object ids and the content ranges
|
||||
# for that object id in the current chunk
|
||||
object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk)
|
||||
for object_id, ranges in object_ranges_for_chunk.items():
|
||||
object_ids.add(object_id)
|
||||
for start_index, end_index in ranges:
|
||||
object_to_content_map.setdefault(object_id, []).append(
|
||||
(chunk_key, (start_index, end_index))
|
||||
)
|
||||
|
||||
# This is so we can provide a mock access map for testing
|
||||
if access_map is None:
|
||||
access_map = _get_objects_access_for_user_email_from_salesforce(
|
||||
object_ids=object_ids,
|
||||
user_email=user_email,
|
||||
chunks=chunks,
|
||||
)
|
||||
if access_map is None:
|
||||
# If the user is not found in Salesforce, access_map will be None
|
||||
# so we should just return an empty list because no chunks will be
|
||||
# censored
|
||||
return []
|
||||
|
||||
censored_chunks: dict[ChunkKey, InferenceChunk] = {}
|
||||
for object_id, content_list in object_to_content_map.items():
|
||||
# if the user does not have access to the object, or the object is not in the
|
||||
# access_map, do not include its content in the filtered chunks
|
||||
if not access_map.get(object_id, False):
|
||||
continue
|
||||
|
||||
# if we got this far, the user has access to the object so we can create or update
|
||||
# the filtered chunk(s) for this object
|
||||
# NOTE: we only create a censored chunk if the user has access to some
|
||||
# part of the chunk
|
||||
for chunk_key, content_range in content_list:
|
||||
if chunk_key not in censored_chunks:
|
||||
censored_chunks[chunk_key] = _create_empty_censored_chunk(
|
||||
uncensored_chunks[chunk_key]
|
||||
)
|
||||
|
||||
uncensored_chunk = uncensored_chunks[chunk_key]
|
||||
censored_chunk = _update_censored_chunk(
|
||||
censored_chunk=censored_chunks[chunk_key],
|
||||
uncensored_chunk=uncensored_chunk,
|
||||
content_range=content_range,
|
||||
)
|
||||
censored_chunks[chunk_key] = censored_chunk
|
||||
|
||||
return list(censored_chunks.values())
|
||||
|
||||
|
||||
# NOTE: This is not used anywhere.
|
||||
def _get_objects_access_for_user_email(
|
||||
object_ids: set[str], user_email: str
|
||||
) -> dict[str, bool]:
|
||||
with get_session_context_manager() as db_session:
|
||||
external_groups = fetch_external_groups_for_user_email_and_group_ids(
|
||||
db_session=db_session,
|
||||
user_email=user_email,
|
||||
# Maybe make a function that adds a salesforce prefix to the group ids
|
||||
group_ids=list(object_ids),
|
||||
)
|
||||
external_group_ids = {group.external_user_group_id for group in external_groups}
|
||||
return {group_id: group_id in external_group_ids for group_id in object_ids}
|
174
backend/ee/onyx/external_permissions/salesforce/utils.py
Normal file
174
backend/ee/onyx/external_permissions/salesforce/utils.py
Normal file
@@ -0,0 +1,174 @@
|
||||
from simple_salesforce import Salesforce
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import get_cc_pairs_for_document
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ANY_SALESFORCE_CLIENT: Salesforce | None = None
|
||||
|
||||
|
||||
def get_any_salesforce_client_for_doc_id(
|
||||
db_session: Session, doc_id: str
|
||||
) -> Salesforce:
|
||||
"""
|
||||
We create a salesforce client for the first cc_pair for the first doc_id where
|
||||
salesforce censoring is enabled. After that we just cache and reuse the same
|
||||
client for all queries.
|
||||
|
||||
We do this to reduce the number of postgres queries we make at query time.
|
||||
|
||||
This may be problematic if they are using multiple cc_pairs for salesforce.
|
||||
E.g. there are 2 different credential sets for 2 different salesforce cc_pairs
|
||||
but only one has the permissions to access the permissions needed for the query.
|
||||
"""
|
||||
global _ANY_SALESFORCE_CLIENT
|
||||
if _ANY_SALESFORCE_CLIENT is None:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
credential_json = first_cc_pair.credential.credential_json
|
||||
_ANY_SALESFORCE_CLIENT = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
security_token=credential_json["sf_security_token"],
|
||||
)
|
||||
return _ANY_SALESFORCE_CLIENT
|
||||
|
||||
|
||||
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
|
||||
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
|
||||
result = sf_client.query(query)
|
||||
if len(result["records"]) == 0:
|
||||
return None
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
|
||||
# This contains only the user_ids that we have found in Salesforce.
|
||||
# If we don't know their user_id, we don't store anything in the cache.
|
||||
_CACHED_SF_EMAIL_TO_ID_MAP: dict[str, str] = {}
|
||||
|
||||
|
||||
def get_salesforce_user_id_from_email(
|
||||
sf_client: Salesforce,
|
||||
user_email: str,
|
||||
) -> str | None:
|
||||
"""
|
||||
We cache this so we don't have to query Salesforce for every query and salesforce
|
||||
user IDs never change.
|
||||
Memory usage is fine because we just store 2 small strings per user.
|
||||
|
||||
If the email is not in the cache, we check the local salesforce database for the info.
|
||||
If the user is not found in the local salesforce database, we query Salesforce.
|
||||
Whatever we get back from Salesforce is added to the database.
|
||||
If no user_id is found, we add a NULL_ID_STRING to the database for that email so
|
||||
we don't query Salesforce again (which is slow) but we still check the local salesforce
|
||||
database every query until a user id is found. This is acceptable because the query time
|
||||
is quite fast.
|
||||
If a user_id is created in Salesforce, it will be added to the local salesforce database
|
||||
next time the connector is run. Then that value will be found in this function and cached.
|
||||
|
||||
NOTE: First time this runs, it may be slow if it hasn't already been updated in the local
|
||||
salesforce database. (Around 0.1-0.3 seconds)
|
||||
If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds).
|
||||
"""
|
||||
global _CACHED_SF_EMAIL_TO_ID_MAP
|
||||
if user_email in _CACHED_SF_EMAIL_TO_ID_MAP:
|
||||
if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None:
|
||||
return _CACHED_SF_EMAIL_TO_ID_MAP[user_email]
|
||||
|
||||
db_exists = True
|
||||
try:
|
||||
# Check if the user is already in the database
|
||||
user_id = get_user_id_by_email(user_email)
|
||||
except Exception:
|
||||
init_db()
|
||||
try:
|
||||
user_id = get_user_id_by_email(user_email)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if user is in database: {e}")
|
||||
user_id = None
|
||||
db_exists = False
|
||||
|
||||
# If no entry is found in the database (indicated by user_id being None)...
|
||||
if user_id is None:
|
||||
# ...query Salesforce and store the result in the database
|
||||
user_id = _query_salesforce_user_id(sf_client, user_email)
|
||||
if db_exists:
|
||||
update_email_to_id_table(user_email, user_id)
|
||||
return user_id
|
||||
elif user_id is None:
|
||||
return None
|
||||
elif user_id == NULL_ID_STRING:
|
||||
return None
|
||||
# If the found user_id is real, cache it
|
||||
_CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id
|
||||
return user_id
|
||||
|
||||
|
||||
_MAX_RECORD_IDS_PER_QUERY = 200
|
||||
|
||||
|
||||
def get_objects_access_for_user_id(
|
||||
salesforce_client: Salesforce,
|
||||
user_id: str,
|
||||
record_ids: list[str],
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Salesforce has a limit of 200 record ids per query. So we just truncate
|
||||
the list of record ids to 200. We only ever retrieve 50 chunks at a time
|
||||
so this should be fine (unlikely that we retrieve all 50 chunks contain
|
||||
4 unique objects).
|
||||
If we decide this isn't acceptable we can use multiple queries but they
|
||||
should be in parallel so query time doesn't get too long.
|
||||
"""
|
||||
truncated_record_ids = record_ids[:_MAX_RECORD_IDS_PER_QUERY]
|
||||
record_ids_str = "'" + "','".join(truncated_record_ids) + "'"
|
||||
access_query = f"""
|
||||
SELECT RecordId, HasReadAccess
|
||||
FROM UserRecordAccess
|
||||
WHERE RecordId IN ({record_ids_str})
|
||||
AND UserId = '{user_id}'
|
||||
"""
|
||||
result = salesforce_client.query_all(access_query)
|
||||
return {record["RecordId"]: record["HasReadAccess"] for record in result["records"]}
|
||||
|
||||
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP: dict[int, Salesforce] = {}
|
||||
_DOC_ID_TO_CC_PAIR_ID_MAP: dict[str, int] = {}
|
||||
|
||||
|
||||
# NOTE: This is not used anywhere.
|
||||
def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Salesforce:
|
||||
"""
|
||||
Uses a document id to get the cc_pair that indexed that document and uses the credentials
|
||||
for that cc_pair to create a Salesforce client.
|
||||
Problems:
|
||||
- There may be multiple cc_pairs for a document, and we don't know which one to use.
|
||||
- right now we just use the first one
|
||||
- Building a new Salesforce client for each document is slow.
|
||||
- Memory usage could be an issue as we build these dictionaries.
|
||||
"""
|
||||
if doc_id not in _DOC_ID_TO_CC_PAIR_ID_MAP:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
_DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] = first_cc_pair.id
|
||||
|
||||
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
|
||||
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"CC pair {cc_pair_id} not found")
|
||||
credential_json = cc_pair.credential.credential_json
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
security_token=credential_json["sf_security_token"],
|
||||
)
|
||||
|
||||
return _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id]
|
@@ -8,6 +8,9 @@ from ee.onyx.external_permissions.confluence.group_sync import confluence_group_
|
||||
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.onyx.external_permissions.post_query_censoring import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -71,4 +74,7 @@ EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
return (
|
||||
source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
or source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION
|
||||
)
|
||||
|
@@ -16,6 +16,9 @@ from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.onyx.external_permissions.sync_params import (
|
||||
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
|
||||
)
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -286,6 +289,8 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION:
|
||||
return None
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
@@ -4,34 +4,29 @@ from typing import Any
|
||||
from simple_salesforce import Salesforce
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.salesforce.doc_conversion import extract_section
|
||||
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
|
||||
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
_ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
@@ -65,46 +60,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
return self._sf_client
|
||||
|
||||
def _extract_primary_owners(
|
||||
self, sf_object: SalesforceObject
|
||||
) -> list[BasicExpertInfo] | None:
|
||||
object_dict = sf_object.data
|
||||
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
|
||||
return None
|
||||
if not (last_modified_by := get_record(last_modified_by_id)):
|
||||
return None
|
||||
if not (last_modified_by_name := last_modified_by.data.get("Name")):
|
||||
return None
|
||||
primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)]
|
||||
return primary_owners
|
||||
|
||||
def _convert_object_instance_to_document(
|
||||
self, sf_object: SalesforceObject
|
||||
) -> Document:
|
||||
object_dict = sf_object.data
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
|
||||
base_url = f"https://{self.sf_client.sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
|
||||
sections = [extract_section(sf_object, base_url)]
|
||||
for id in get_child_ids(sf_object.id):
|
||||
if not (child_object := get_record(id)):
|
||||
continue
|
||||
sections.append(extract_section(child_object, base_url))
|
||||
|
||||
doc = Document(
|
||||
id=onyx_salesforce_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
doc_updated_at=extracted_doc_updated_at,
|
||||
primary_owners=self._extract_primary_owners(sf_object),
|
||||
metadata={},
|
||||
)
|
||||
return doc
|
||||
|
||||
def _fetch_from_salesforce(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -126,6 +81,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
f"Found {len(child_types)} child types for {parent_object_type}"
|
||||
)
|
||||
|
||||
# Always want to make sure user is grabbed for permissioning purposes
|
||||
all_object_types.add("User")
|
||||
|
||||
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
|
||||
logger.debug(f"All object types: {all_object_types}")
|
||||
|
||||
@@ -169,9 +127,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
logger.debug(
|
||||
f"Added {len(new_ids)} new/updated records for {object_type}"
|
||||
)
|
||||
# Remove the csv file after it has been used
|
||||
# to successfully update the db
|
||||
os.remove(csv_path)
|
||||
|
||||
logger.info(f"Found {len(updated_ids)} total updated records")
|
||||
logger.info(
|
||||
@@ -196,7 +151,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
continue
|
||||
|
||||
docs_to_yield.append(
|
||||
self._convert_object_instance_to_document(parent_object)
|
||||
convert_sf_object_to_doc(
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
)
|
||||
docs_processed += 1
|
||||
|
||||
@@ -225,7 +183,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
query_result = self.sf_client.query_all(query)
|
||||
doc_metadata_list.extend(
|
||||
SlimDocument(
|
||||
id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
perm_sync_data={},
|
||||
)
|
||||
for instance_dict in query_result["records"]
|
||||
|
@@ -1,8 +1,18 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
# All of these types of keys are handled by specific fields in the doc
|
||||
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
|
||||
@@ -103,54 +113,72 @@ def _extract_dict_text(raw_dict: dict) -> str:
|
||||
return natural_language_for_dict
|
||||
|
||||
|
||||
def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
|
||||
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
|
||||
return Section(
|
||||
text=_extract_dict_text(salesforce_object.data),
|
||||
link=f"{base_url}/{salesforce_object.id}",
|
||||
)
|
||||
|
||||
|
||||
def _field_value_is_child_object(field_value: dict) -> bool:
|
||||
"""
|
||||
Checks if the field value is a child object.
|
||||
"""
|
||||
return (
|
||||
isinstance(field_value, OrderedDict)
|
||||
and "records" in field_value.keys()
|
||||
and isinstance(field_value["records"], list)
|
||||
and len(field_value["records"]) > 0
|
||||
and "Id" in field_value["records"][0].keys()
|
||||
def _extract_primary_owners(
|
||||
sf_object: SalesforceObject,
|
||||
) -> list[BasicExpertInfo] | None:
|
||||
object_dict = sf_object.data
|
||||
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
|
||||
logger.warning(f"No LastModifiedById found for {sf_object.id}")
|
||||
return None
|
||||
if not (last_modified_by := get_record(last_modified_by_id)):
|
||||
logger.warning(f"No LastModifiedBy found for {last_modified_by_id}")
|
||||
return None
|
||||
|
||||
user_data = last_modified_by.data
|
||||
expert_info = BasicExpertInfo(
|
||||
first_name=user_data.get("FirstName"),
|
||||
last_name=user_data.get("LastName"),
|
||||
email=user_data.get("Email"),
|
||||
display_name=user_data.get("Name"),
|
||||
)
|
||||
|
||||
# Check if all fields are None
|
||||
if all(
|
||||
value is None
|
||||
for value in [
|
||||
expert_info.first_name,
|
||||
expert_info.last_name,
|
||||
expert_info.email,
|
||||
expert_info.display_name,
|
||||
]
|
||||
):
|
||||
logger.warning(f"No identifying information found for user {user_data}")
|
||||
return None
|
||||
|
||||
def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
|
||||
"""
|
||||
This goes through the salesforce_object and extracts the top level fields as a Section.
|
||||
It also goes through the child objects and extracts them as Sections.
|
||||
"""
|
||||
top_level_dict = {}
|
||||
return [expert_info]
|
||||
|
||||
child_object_sections = []
|
||||
for field_name, field_value in salesforce_object.items():
|
||||
# If the field value is not a child object, add it to the top level dict
|
||||
# to turn into text for the top level section
|
||||
if not _field_value_is_child_object(field_value):
|
||||
top_level_dict[field_name] = field_value
|
||||
|
||||
def convert_sf_object_to_doc(
|
||||
sf_object: SalesforceObject,
|
||||
sf_instance: str,
|
||||
) -> Document:
|
||||
object_dict = sf_object.data
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
base_url = f"https://{sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
|
||||
sections = [_extract_section(sf_object, base_url)]
|
||||
for id in get_child_ids(sf_object.id):
|
||||
if not (child_object := get_record(id)):
|
||||
continue
|
||||
sections.append(_extract_section(child_object, base_url))
|
||||
|
||||
# If the field value is a child object, extract the child objects and add them as sections
|
||||
for record in field_value["records"]:
|
||||
child_object_id = record["Id"]
|
||||
child_object_sections.append(
|
||||
Section(
|
||||
text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}",
|
||||
link=f"{base_url}/{child_object_id}",
|
||||
)
|
||||
)
|
||||
|
||||
top_level_id = salesforce_object["Id"]
|
||||
top_level_section = Section(
|
||||
text=_extract_dict_text(top_level_dict),
|
||||
link=f"{base_url}/{top_level_id}",
|
||||
doc = Document(
|
||||
id=onyx_salesforce_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
doc_updated_at=extracted_doc_updated_at,
|
||||
primary_owners=_extract_primary_owners(sf_object),
|
||||
metadata={},
|
||||
)
|
||||
return [top_level_section, *child_object_sections]
|
||||
return doc
|
||||
|
@@ -77,25 +77,28 @@ def _get_all_queryable_fields_of_sf_type(
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
fields: list[dict[str, Any]] = object_description["fields"]
|
||||
valid_fields: set[str] = set()
|
||||
compound_field_names: set[str] = set()
|
||||
field_names_to_remove: set[str] = set()
|
||||
for field in fields:
|
||||
if compound_field_name := field.get("compoundFieldName"):
|
||||
compound_field_names.add(compound_field_name)
|
||||
# We do want to get name fields even if they are compound
|
||||
if not field.get("nameField"):
|
||||
field_names_to_remove.add(compound_field_name)
|
||||
if field.get("type", "base64") == "base64":
|
||||
continue
|
||||
if field_name := field.get("name"):
|
||||
valid_fields.add(field_name)
|
||||
|
||||
return list(valid_fields - compound_field_names)
|
||||
return list(valid_fields - field_names_to_remove)
|
||||
|
||||
|
||||
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
|
||||
def _check_if_object_type_is_empty(
|
||||
sf_client: Salesforce, sf_type: str, time_filter: str
|
||||
) -> bool:
|
||||
"""
|
||||
Send a small query to check if the object type is empty so we don't
|
||||
perform extra bulk queries
|
||||
Use the rest api to check to make sure the query will result in a non-empty response
|
||||
"""
|
||||
try:
|
||||
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
||||
query = f"SELECT Count() FROM {sf_type}{time_filter} LIMIT 1"
|
||||
result = sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
@@ -134,7 +137,7 @@ def _bulk_retrieve_from_salesforce(
|
||||
sf_type: str,
|
||||
time_filter: str,
|
||||
) -> tuple[str, list[str] | None]:
|
||||
if not _check_if_object_type_is_empty(sf_client, sf_type):
|
||||
if not _check_if_object_type_is_empty(sf_client, sf_type, time_filter):
|
||||
return sf_type, None
|
||||
|
||||
if existing_csvs := _check_for_existing_csvs(sf_type):
|
||||
|
@@ -40,20 +40,20 @@ def get_db_connection(
|
||||
|
||||
def init_db() -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist."""
|
||||
if os.path.exists(get_sqlite_db_path()):
|
||||
return
|
||||
|
||||
# Create database directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
|
||||
|
||||
with get_db_connection("EXCLUSIVE") as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Enable WAL mode for better concurrent access and write performance
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA temp_store=MEMORY")
|
||||
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
|
||||
db_exists = os.path.exists(get_sqlite_db_path())
|
||||
|
||||
if not db_exists:
|
||||
# Enable WAL mode for better concurrent access and write performance
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA temp_store=MEMORY")
|
||||
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
|
||||
|
||||
# Main table for storing Salesforce objects
|
||||
cursor.execute(
|
||||
@@ -90,49 +90,69 @@ def init_db() -> None:
|
||||
"""
|
||||
)
|
||||
|
||||
# Always recreate indexes to ensure they exist
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_object_type")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_parent_id")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_child_parent")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_object_type_id")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup")
|
||||
|
||||
# Create covering indexes for common queries
|
||||
# Create a table for User email to ID mapping if it doesn't exist
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS user_email_map (
|
||||
email TEXT PRIMARY KEY,
|
||||
user_id TEXT, -- Nullable to allow for users without IDs
|
||||
FOREIGN KEY (user_id) REFERENCES salesforce_objects(id)
|
||||
) WITHOUT ROWID
|
||||
"""
|
||||
)
|
||||
|
||||
# Create indexes if they don't exist (SQLite ignores IF NOT EXISTS for indexes)
|
||||
def create_index_if_not_exists(index_name: str, create_statement: str) -> None:
|
||||
cursor.execute(
|
||||
f"SELECT name FROM sqlite_master WHERE type='index' AND name='{index_name}'"
|
||||
)
|
||||
if not cursor.fetchone():
|
||||
cursor.execute(create_statement)
|
||||
|
||||
create_index_if_not_exists(
|
||||
"idx_object_type",
|
||||
"""
|
||||
CREATE INDEX idx_object_type
|
||||
ON salesforce_objects(object_type, id)
|
||||
WHERE object_type IS NOT NULL
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
create_index_if_not_exists(
|
||||
"idx_parent_id",
|
||||
"""
|
||||
CREATE INDEX idx_parent_id
|
||||
ON relationships(parent_id, child_id)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
create_index_if_not_exists(
|
||||
"idx_child_parent",
|
||||
"""
|
||||
CREATE INDEX idx_child_parent
|
||||
ON relationships(child_id)
|
||||
WHERE child_id IS NOT NULL
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
# New composite index for fast parent type lookups
|
||||
cursor.execute(
|
||||
create_index_if_not_exists(
|
||||
"idx_relationship_types_lookup",
|
||||
"""
|
||||
CREATE INDEX idx_relationship_types_lookup
|
||||
ON relationship_types(parent_type, child_id, parent_id)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
# Analyze tables to help query planner
|
||||
cursor.execute("ANALYZE relationships")
|
||||
cursor.execute("ANALYZE salesforce_objects")
|
||||
cursor.execute("ANALYZE relationship_types")
|
||||
cursor.execute("ANALYZE user_email_map")
|
||||
|
||||
# If database already existed but user_email_map needs to be populated
|
||||
cursor.execute("SELECT COUNT(*) FROM user_email_map")
|
||||
if cursor.fetchone()[0] == 0:
|
||||
_update_user_email_map(conn)
|
||||
|
||||
conn.commit()
|
||||
|
||||
@@ -203,7 +223,27 @@ def _update_relationship_tables(
|
||||
raise
|
||||
|
||||
|
||||
def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]:
|
||||
def _update_user_email_map(conn: sqlite3.Connection) -> None:
|
||||
"""Update the user_email_map table with current User objects.
|
||||
Called internally by update_sf_db_with_csv when User objects are updated.
|
||||
"""
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO user_email_map (email, user_id)
|
||||
SELECT json_extract(data, '$.Email'), id
|
||||
FROM salesforce_objects
|
||||
WHERE object_type = 'User'
|
||||
AND json_extract(data, '$.Email') IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def update_sf_db_with_csv(
|
||||
object_type: str,
|
||||
csv_download_path: str,
|
||||
delete_csv_after_use: bool = True,
|
||||
) -> list[str]:
|
||||
"""Update the SF DB with a CSV file using SQLite storage."""
|
||||
updated_ids = []
|
||||
|
||||
@@ -249,8 +289,17 @@ def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]
|
||||
_update_relationship_tables(conn, id, parent_ids)
|
||||
updated_ids.append(id)
|
||||
|
||||
# If we're updating User objects, update the email map
|
||||
if object_type == "User":
|
||||
_update_user_email_map(conn)
|
||||
|
||||
conn.commit()
|
||||
|
||||
if delete_csv_after_use:
|
||||
# Remove the csv file after it has been used
|
||||
# to successfully update the db
|
||||
os.remove(csv_download_path)
|
||||
|
||||
return updated_ids
|
||||
|
||||
|
||||
@@ -329,6 +378,9 @@ def get_affected_parent_ids_by_type(
|
||||
cursor = conn.cursor()
|
||||
|
||||
for batch_ids in updated_ids_batches:
|
||||
batch_ids = list(set(batch_ids) - updated_parent_ids)
|
||||
if not batch_ids:
|
||||
continue
|
||||
id_placeholders = ",".join(["?" for _ in batch_ids])
|
||||
|
||||
for parent_type in parent_types:
|
||||
@@ -384,3 +436,40 @@ def has_at_least_one_object_of_type(object_type: str) -> bool:
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
|
||||
# NULL_ID_STRING is used to indicate that the user ID was queried but not found
|
||||
# As opposed to None because it has yet to be queried at all
|
||||
NULL_ID_STRING = "N/A"
|
||||
|
||||
|
||||
def get_user_id_by_email(email: str) -> str | None:
|
||||
"""Get the Salesforce User ID for a given email address.
|
||||
|
||||
Args:
|
||||
email: The email address to look up
|
||||
|
||||
Returns:
|
||||
A tuple of (was_found, user_id):
|
||||
- was_found: True if the email exists in the table, False if not found
|
||||
- user_id: The Salesforce User ID if exists, None otherwise
|
||||
"""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT user_id FROM user_email_map WHERE email = ?", (email,))
|
||||
result = cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return result[0]
|
||||
|
||||
|
||||
def update_email_to_id_table(email: str, id: str | None) -> None:
|
||||
"""Update the email to ID map table with a new email and ID."""
|
||||
id_to_use = id or NULL_ID_STRING
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR REPLACE INTO user_email_map (email, user_id) VALUES (?, ?)",
|
||||
(email, id_to_use),
|
||||
)
|
||||
conn.commit()
|
||||
|
@@ -37,6 +37,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -163,6 +164,17 @@ class SearchPipeline:
|
||||
# These chunks are ordered, deduped, and contain no large chunks
|
||||
retrieved_chunks = self._get_chunks()
|
||||
|
||||
# If ee is enabled, censor the chunk sections based on user access
|
||||
# Otherwise, return the retrieved chunks
|
||||
censored_chunks = fetch_ee_implementation_or_noop(
|
||||
"onyx.external_permissions.post_query_censoring",
|
||||
"_post_query_chunk_censoring",
|
||||
retrieved_chunks,
|
||||
)(
|
||||
chunks=retrieved_chunks,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
above = self.search_query.chunks_above
|
||||
below = self.search_query.chunks_below
|
||||
|
||||
@@ -175,7 +187,7 @@ class SearchPipeline:
|
||||
seen_document_ids = set()
|
||||
|
||||
# This preserves the ordering since the chunks are retrieved in score order
|
||||
for chunk in retrieved_chunks:
|
||||
for chunk in censored_chunks:
|
||||
if chunk.document_id not in seen_document_ids:
|
||||
seen_document_ids.add(chunk.document_id)
|
||||
chunk_requests.append(
|
||||
@@ -225,7 +237,7 @@ class SearchPipeline:
|
||||
# This maintains the original chunks ordering. Note, we cannot simply sort by score here
|
||||
# as reranking flow may wipe the scores for a lot of the chunks.
|
||||
doc_chunk_ranges_map = defaultdict(list)
|
||||
for chunk in retrieved_chunks:
|
||||
for chunk in censored_chunks:
|
||||
# The list of ranges for each document is ordered by score
|
||||
doc_chunk_ranges_map[chunk.document_id].append(
|
||||
ChunkRange(
|
||||
@@ -274,11 +286,11 @@ class SearchPipeline:
|
||||
|
||||
# In case of failed parallel calls to Vespa, at least we should have the initial retrieved chunks
|
||||
doc_chunk_ind_to_chunk.update(
|
||||
{(chunk.document_id, chunk.chunk_id): chunk for chunk in retrieved_chunks}
|
||||
{(chunk.document_id, chunk.chunk_id): chunk for chunk in censored_chunks}
|
||||
)
|
||||
|
||||
# Build the surroundings for all of the initial retrieved chunks
|
||||
for chunk in retrieved_chunks:
|
||||
for chunk in censored_chunks:
|
||||
start_ind = max(0, chunk.chunk_id - above)
|
||||
end_ind = chunk.chunk_id + below
|
||||
|
||||
|
@@ -20,10 +20,12 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import null
|
||||
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Document as DbDocument
|
||||
@@ -626,6 +628,60 @@ def get_document(
|
||||
return doc
|
||||
|
||||
|
||||
def get_cc_pairs_for_document(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = (
|
||||
select(ConnectorCredentialPair)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.where(DocumentByConnectorCredentialPair.id == document_id)
|
||||
)
|
||||
return list(db_session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
def get_document_sources(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> dict[str, DocumentSource]:
|
||||
"""Gets the sources for a list of document IDs.
|
||||
Returns a dictionary mapping document ID to its source.
|
||||
If a document has multiple sources (multiple CC pairs), returns the first one found.
|
||||
"""
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
Connector.source,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
.distinct()
|
||||
)
|
||||
|
||||
results = db_session.execute(stmt).all()
|
||||
return {doc_id: source for doc_id, source in results}
|
||||
|
||||
|
||||
def fetch_chunk_counts_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
|
@@ -23,11 +23,13 @@ class ChunkEmbedding(BaseModel):
|
||||
|
||||
class BaseChunk(BaseModel):
|
||||
chunk_id: int
|
||||
blurb: str # The first sentence(s) of the first Section of the chunk
|
||||
# The first sentence(s) of the first Section of the chunk
|
||||
blurb: str
|
||||
content: str
|
||||
# Holds the link and the offsets into the raw Chunk text
|
||||
source_links: dict[int, str] | None
|
||||
section_continuation: bool # True if this Chunk's start is not at the start of a Section
|
||||
# True if this Chunk's start is not at the start of a Section
|
||||
section_continuation: bool
|
||||
|
||||
|
||||
class DocAwareChunk(BaseChunk):
|
||||
|
@@ -0,0 +1,196 @@
|
||||
from datetime import datetime
|
||||
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
)
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
|
||||
|
||||
def create_test_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int,
|
||||
content: str,
|
||||
source_links: dict[int, str] | None,
|
||||
) -> InferenceChunk:
|
||||
return InferenceChunk(
|
||||
document_id=doc_id,
|
||||
chunk_id=chunk_id,
|
||||
blurb=content[:BLURB_SIZE],
|
||||
content=content,
|
||||
source_links=source_links,
|
||||
section_continuation=False,
|
||||
source_type=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="test_chunk",
|
||||
title="Test Chunk",
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=None,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
def test_validate_salesforce_access_single_object() -> None:
|
||||
"""Test filtering when chunk has a single Salesforce object reference"""
|
||||
section = "This is a test document about a Salesforce object."
|
||||
test_content = section
|
||||
test_chunk = create_test_chunk(
|
||||
doc_id="doc1",
|
||||
chunk_id=1,
|
||||
content=test_content,
|
||||
source_links={0: "https://salesforce.com/object1"},
|
||||
)
|
||||
|
||||
# Test when user has access
|
||||
filtered_chunks = censor_salesforce_chunks(
|
||||
chunks=[test_chunk],
|
||||
user_email="test@example.com",
|
||||
access_map={"object1": True},
|
||||
)
|
||||
assert len(filtered_chunks) == 1
|
||||
assert filtered_chunks[0].content == test_content
|
||||
|
||||
# Test when user doesn't have access
|
||||
filtered_chunks = censor_salesforce_chunks(
|
||||
chunks=[test_chunk],
|
||||
user_email="test@example.com",
|
||||
access_map={"object1": False},
|
||||
)
|
||||
assert len(filtered_chunks) == 0
|
||||
|
||||
|
||||
def test_validate_salesforce_access_multiple_objects() -> None:
|
||||
"""Test filtering when chunk has multiple Salesforce object references"""
|
||||
section1 = "First part about object1. "
|
||||
section2 = "Second part about object2. "
|
||||
section3 = "Third part about object3."
|
||||
|
||||
test_content = section1 + section2 + section3
|
||||
section1_end = len(section1)
|
||||
section2_end = section1_end + len(section2)
|
||||
|
||||
test_chunk = create_test_chunk(
|
||||
doc_id="doc1",
|
||||
chunk_id=1,
|
||||
content=test_content,
|
||||
source_links={
|
||||
0: "https://salesforce.com/object1",
|
||||
section1_end: "https://salesforce.com/object2",
|
||||
section2_end: "https://salesforce.com/object3",
|
||||
},
|
||||
)
|
||||
|
||||
# Test when user has access to all objects
|
||||
filtered_chunks = censor_salesforce_chunks(
|
||||
chunks=[test_chunk],
|
||||
user_email="test@example.com",
|
||||
access_map={
|
||||
"object1": True,
|
||||
"object2": True,
|
||||
"object3": True,
|
||||
},
|
||||
)
|
||||
assert len(filtered_chunks) == 1
|
||||
assert filtered_chunks[0].content == test_content
|
||||
|
||||
# Test when user has access to some objects
|
||||
filtered_chunks = censor_salesforce_chunks(
|
||||
chunks=[test_chunk],
|
||||
user_email="test@example.com",
|
||||
access_map={
|
||||
"object1": True,
|
||||
"object2": False,
|
||||
"object3": True,
|
||||
},
|
||||
)
|
||||
assert len(filtered_chunks) == 1
|
||||
assert section1 in filtered_chunks[0].content
|
||||
assert section2 not in filtered_chunks[0].content
|
||||
assert section3 in filtered_chunks[0].content
|
||||
|
||||
# Test when user has no access
|
||||
filtered_chunks = censor_salesforce_chunks(
|
||||
chunks=[test_chunk],
|
||||
user_email="test@example.com",
|
||||
access_map={
|
||||
"object1": False,
|
||||
"object2": False,
|
||||
"object3": False,
|
||||
},
|
||||
)
|
||||
assert len(filtered_chunks) == 0
|
||||
|
||||
|
||||
def test_validate_salesforce_access_multiple_chunks() -> None:
|
||||
"""Test filtering when there are multiple chunks with different access patterns"""
|
||||
section1 = "Content about object1"
|
||||
section2 = "Content about object2"
|
||||
|
||||
chunk1 = create_test_chunk(
|
||||
doc_id="doc1",
|
||||
chunk_id=1,
|
||||
content=section1,
|
||||
source_links={0: "https://salesforce.com/object1"},
|
||||
)
|
||||
chunk2 = create_test_chunk(
|
||||
doc_id="doc1",
|
||||
chunk_id=2,
|
||||
content=section2,
|
||||
source_links={0: "https://salesforce.com/object2"},
|
||||
)
|
||||
|
||||
# Test mixed access
|
||||
filtered_chunks = censor_salesforce_chunks(
|
||||
chunks=[chunk1, chunk2],
|
||||
user_email="test@example.com",
|
||||
access_map={
|
||||
"object1": True,
|
||||
"object2": False,
|
||||
},
|
||||
)
|
||||
assert len(filtered_chunks) == 1
|
||||
assert filtered_chunks[0].chunk_id == 1
|
||||
assert section1 in filtered_chunks[0].content
|
||||
|
||||
|
||||
def test_validate_salesforce_access_no_source_links() -> None:
|
||||
"""Test handling of chunks with no source links"""
|
||||
section = "Content with no source links"
|
||||
test_chunk = create_test_chunk(
|
||||
doc_id="doc1",
|
||||
chunk_id=1,
|
||||
content=section,
|
||||
source_links=None,
|
||||
)
|
||||
|
||||
filtered_chunks = censor_salesforce_chunks(
|
||||
chunks=[test_chunk],
|
||||
user_email="test@example.com",
|
||||
access_map={},
|
||||
)
|
||||
assert len(filtered_chunks) == 0
|
||||
|
||||
|
||||
def test_validate_salesforce_access_blurb_update() -> None:
|
||||
"""Test that blurbs are properly updated based on permitted content"""
|
||||
section = "First part about object1. "
|
||||
long_content = section * 20 # Make it longer than BLURB_SIZE
|
||||
test_chunk = create_test_chunk(
|
||||
doc_id="doc1",
|
||||
chunk_id=1,
|
||||
content=long_content,
|
||||
source_links={0: "https://salesforce.com/object1"},
|
||||
)
|
||||
|
||||
filtered_chunks = censor_salesforce_chunks(
|
||||
chunks=[test_chunk],
|
||||
user_email="test@example.com",
|
||||
access_map={"object1": True},
|
||||
)
|
||||
assert len(filtered_chunks) == 1
|
||||
assert len(filtered_chunks[0].blurb) <= BLURB_SIZE
|
||||
assert filtered_chunks[0].blurb.startswith(section)
|
@@ -15,4 +15,5 @@ export const autoSyncConfigBySource: Record<
|
||||
google_drive: {},
|
||||
gmail: {},
|
||||
slack: {},
|
||||
salesforce: {},
|
||||
};
|
||||
|
@@ -343,6 +343,7 @@ export const validAutoSyncSources = [
|
||||
ValidSources.GoogleDrive,
|
||||
ValidSources.Gmail,
|
||||
ValidSources.Slack,
|
||||
ValidSources.Salesforce,
|
||||
] as const;
|
||||
|
||||
// Create a type from the array elements
|
||||
|
Reference in New Issue
Block a user