diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index b19ea8d9f..cf9fb348d 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -105,6 +105,17 @@ def handle_message( send_to: list[str] | None = None respond_tag_only = False respond_team_member_list = None + + bypass_acl = False + if ( + channel_config + and channel_config.persona + and channel_config.persona.document_sets + ): + # For Slack channels, use the full document set, admin will be warned when configuring it + # with non-public document sets + bypass_acl = True + if channel_config and channel_config.channel_config: channel_conf = channel_config.channel_config if not bipass_filters and "answer_filters" in channel_conf: @@ -172,6 +183,7 @@ def handle_message( answer_generation_timeout=answer_generation_timeout, real_time_flow=False, enable_reflexion=reflexion, + bypass_acl=bypass_acl, ) if not answer.error_msg: return answer diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index 328abb268..26dd2b3f1 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -396,3 +396,33 @@ def get_or_create_document_set_by_name( db_session.commit() return new_doc_set + + +def check_document_sets_are_public( + db_session: Session, + document_set_ids: list[int], +) -> bool: + connector_credential_pair_ids = ( + db_session.query( + DocumentSet__ConnectorCredentialPair.connector_credential_pair_id + ) + .filter( + DocumentSet__ConnectorCredentialPair.document_set_id.in_(document_set_ids) + ) + .subquery() + ) + + not_public_exists = ( + db_session.query(ConnectorCredentialPair.id) + .filter( + ConnectorCredentialPair.id.in_( + connector_credential_pair_ids # type:ignore + ), + ConnectorCredentialPair.is_public.is_(False), + ) + .limit(1) + .first() + is not None + ) + + return not not_public_exists diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index c94d3225c..30deb66b3 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -50,6 +50,7 @@ def answer_qa_query( answer_generation_timeout: int = QA_TIMEOUT, real_time_flow: bool = True, enable_reflexion: bool = False, + bypass_acl: bool = False, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, @@ -91,6 +92,7 @@ def answer_qa_query( user=user, db_session=db_session, document_index=get_default_document_index(), + bypass_acl=bypass_acl, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, ) diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 441a20599..e95bbd564 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -339,7 +339,10 @@ def _build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> filter_str = f"!({HIDDEN}=true) and " if not include_hidden else "" # CAREFUL touching this one, currently there is no second ACL double-check post retrieval - filter_str += _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list) + if filters.access_control_list is not None: + filter_str += _build_or_filters( + ACCESS_CONTROL_LIST, filters.access_control_list + ) source_strs = ( [s.value for s in filters.source_type] if filters.source_type else None diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index ef0d4c2cb..06258ea43 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -39,7 +39,7 @@ class BaseFilters(BaseModel): class IndexFilters(BaseFilters): - access_control_list: list[str] + access_control_list: list[str] | None class ChunkMetric(BaseModel): diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py index 2a6097b54..08fc06e00 100644 --- a/backend/danswer/search/search_runner.py +++ b/backend/danswer/search/search_runner.py @@ -545,6 +545,7 @@ def danswer_search_generator( db_session: Session, document_index: DocumentIndex, skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, + bypass_acl: bool = False, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, @@ -561,7 +562,9 @@ def danswer_search_generator( db_session=db_session, ) - user_acl_filters = build_access_filters_for_user(user, db_session) + user_acl_filters = ( + None if bypass_acl else build_access_filters_for_user(user, db_session) + ) final_filters = IndexFilters( source_type=question.filters.source_type, document_set=question.filters.document_set, @@ -609,6 +612,7 @@ def danswer_search( db_session: Session, document_index: DocumentIndex, skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, + bypass_acl: bool = False, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, @@ -624,6 +628,7 @@ def danswer_search( db_session=db_session, document_index=document_index, skip_llm_chunk_filter=skip_llm_chunk_filter, + bypass_acl=bypass_acl, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, ) diff --git a/backend/danswer/server/document_set.py b/backend/danswer/server/document_set.py index 013069409..79a666e8e 100644 --- a/backend/danswer/server/document_set.py +++ b/backend/danswer/server/document_set.py @@ -5,12 +5,15 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.auth.users import current_user +from danswer.db.document_set import check_document_sets_are_public from danswer.db.document_set import fetch_document_sets from danswer.db.document_set import insert_document_set from danswer.db.document_set import mark_document_set_as_to_be_deleted from danswer.db.document_set import update_document_set from danswer.db.engine import get_session from danswer.db.models import User +from danswer.server.models import CheckDocSetPublicRequest +from danswer.server.models import CheckDocSetPublicResponse from danswer.server.models import ConnectorCredentialPairDescriptor from danswer.server.models import ConnectorSnapshot from danswer.server.models import CredentialSnapshot @@ -82,6 +85,7 @@ def list_document_sets( id=document_set_db_model.id, name=document_set_db_model.name, description=document_set_db_model.description, + contains_non_public=any([not cc_pair.is_public for cc_pair in cc_pairs]), cc_pair_descriptors=[ ConnectorCredentialPairDescriptor( id=cc_pair.id, @@ -99,3 +103,15 @@ def list_document_sets( ) for document_set_db_model, cc_pairs in document_set_info ] + + +@router.get("/document-set-public") +def document_set_public( + check_public_request: CheckDocSetPublicRequest, + _: User = Depends(current_user), + db_session: Session = Depends(get_session), +) -> CheckDocSetPublicResponse: + is_public = check_document_sets_are_public( + document_set_ids=check_public_request.document_set_ids, db_session=db_session + ) + return CheckDocSetPublicResponse(is_public=is_public) diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 57cb7386b..343633f2c 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -461,12 +461,21 @@ class DocumentSetUpdateRequest(BaseModel): cc_pair_ids: list[int] +class CheckDocSetPublicRequest(BaseModel): + document_set_ids: list[int] + + +class CheckDocSetPublicResponse(BaseModel): + is_public: bool + + class DocumentSet(BaseModel): id: int name: str description: str cc_pair_descriptors: list[ConnectorCredentialPairDescriptor] is_up_to_date: bool + contains_non_public: bool @classmethod def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet": @@ -474,6 +483,12 @@ class DocumentSet(BaseModel): id=document_set_model.id, name=document_set_model.name, description=document_set_model.description, + contains_non_public=any( + [ + not cc_pair.is_public + for cc_pair in document_set_model.connector_credential_pairs + ] + ), cc_pair_descriptors=[ ConnectorCredentialPairDescriptor( id=cc_pair.id,