Slack Role Override (#755)

This commit is contained in:
Yuhong Sun
2023-11-22 17:47:18 -08:00
committed by GitHub
parent 35c3511daa
commit bdfb894507
8 changed files with 86 additions and 3 deletions

View File

@@ -105,6 +105,17 @@ def handle_message(
send_to: list[str] | None = None send_to: list[str] | None = None
respond_tag_only = False respond_tag_only = False
respond_team_member_list = None respond_team_member_list = None
bypass_acl = False
if (
channel_config
and channel_config.persona
and channel_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
if channel_config and channel_config.channel_config: if channel_config and channel_config.channel_config:
channel_conf = channel_config.channel_config channel_conf = channel_config.channel_config
if not bipass_filters and "answer_filters" in channel_conf: if not bipass_filters and "answer_filters" in channel_conf:
@@ -172,6 +183,7 @@ def handle_message(
answer_generation_timeout=answer_generation_timeout, answer_generation_timeout=answer_generation_timeout,
real_time_flow=False, real_time_flow=False,
enable_reflexion=reflexion, enable_reflexion=reflexion,
bypass_acl=bypass_acl,
) )
if not answer.error_msg: if not answer.error_msg:
return answer return answer

View File

@@ -396,3 +396,33 @@ def get_or_create_document_set_by_name(
db_session.commit() db_session.commit()
return new_doc_set return new_doc_set
def check_document_sets_are_public(
db_session: Session,
document_set_ids: list[int],
) -> bool:
connector_credential_pair_ids = (
db_session.query(
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
)
.filter(
DocumentSet__ConnectorCredentialPair.document_set_id.in_(document_set_ids)
)
.subquery()
)
not_public_exists = (
db_session.query(ConnectorCredentialPair.id)
.filter(
ConnectorCredentialPair.id.in_(
connector_credential_pair_ids # type:ignore
),
ConnectorCredentialPair.is_public.is_(False),
)
.limit(1)
.first()
is not None
)
return not not_public_exists

View File

@@ -50,6 +50,7 @@ def answer_qa_query(
answer_generation_timeout: int = QA_TIMEOUT, answer_generation_timeout: int = QA_TIMEOUT,
real_time_flow: bool = True, real_time_flow: bool = True,
enable_reflexion: bool = False, enable_reflexion: bool = False,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None, | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
@@ -91,6 +92,7 @@ def answer_qa_query(
user=user, user=user,
db_session=db_session, db_session=db_session,
document_index=get_default_document_index(), document_index=get_default_document_index(),
bypass_acl=bypass_acl,
retrieval_metrics_callback=retrieval_metrics_callback, retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback, rerank_metrics_callback=rerank_metrics_callback,
) )

View File

@@ -339,7 +339,10 @@ def _build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else "" filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval # CAREFUL touching this one, currently there is no second ACL double-check post retrieval
filter_str += _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list) if filters.access_control_list is not None:
filter_str += _build_or_filters(
ACCESS_CONTROL_LIST, filters.access_control_list
)
source_strs = ( source_strs = (
[s.value for s in filters.source_type] if filters.source_type else None [s.value for s in filters.source_type] if filters.source_type else None

View File

@@ -39,7 +39,7 @@ class BaseFilters(BaseModel):
class IndexFilters(BaseFilters): class IndexFilters(BaseFilters):
access_control_list: list[str] access_control_list: list[str] | None
class ChunkMetric(BaseModel): class ChunkMetric(BaseModel):

View File

@@ -545,6 +545,7 @@ def danswer_search_generator(
db_session: Session, db_session: Session,
document_index: DocumentIndex, document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None, | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
@@ -561,7 +562,9 @@ def danswer_search_generator(
db_session=db_session, db_session=db_session,
) )
user_acl_filters = build_access_filters_for_user(user, db_session) user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters( final_filters = IndexFilters(
source_type=question.filters.source_type, source_type=question.filters.source_type,
document_set=question.filters.document_set, document_set=question.filters.document_set,
@@ -609,6 +612,7 @@ def danswer_search(
db_session: Session, db_session: Session,
document_index: DocumentIndex, document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None, | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
@@ -624,6 +628,7 @@ def danswer_search(
db_session=db_session, db_session=db_session,
document_index=document_index, document_index=document_index,
skip_llm_chunk_filter=skip_llm_chunk_filter, skip_llm_chunk_filter=skip_llm_chunk_filter,
bypass_acl=bypass_acl,
retrieval_metrics_callback=retrieval_metrics_callback, retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback, rerank_metrics_callback=rerank_metrics_callback,
) )

View File

@@ -5,12 +5,15 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.db.document_set import check_document_sets_are_public
from danswer.db.document_set import fetch_document_sets from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import insert_document_set from danswer.db.document_set import insert_document_set
from danswer.db.document_set import mark_document_set_as_to_be_deleted from danswer.db.document_set import mark_document_set_as_to_be_deleted
from danswer.db.document_set import update_document_set from danswer.db.document_set import update_document_set
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.models import User from danswer.db.models import User
from danswer.server.models import CheckDocSetPublicRequest
from danswer.server.models import CheckDocSetPublicResponse
from danswer.server.models import ConnectorCredentialPairDescriptor from danswer.server.models import ConnectorCredentialPairDescriptor
from danswer.server.models import ConnectorSnapshot from danswer.server.models import ConnectorSnapshot
from danswer.server.models import CredentialSnapshot from danswer.server.models import CredentialSnapshot
@@ -82,6 +85,7 @@ def list_document_sets(
id=document_set_db_model.id, id=document_set_db_model.id,
name=document_set_db_model.name, name=document_set_db_model.name,
description=document_set_db_model.description, description=document_set_db_model.description,
contains_non_public=any([not cc_pair.is_public for cc_pair in cc_pairs]),
cc_pair_descriptors=[ cc_pair_descriptors=[
ConnectorCredentialPairDescriptor( ConnectorCredentialPairDescriptor(
id=cc_pair.id, id=cc_pair.id,
@@ -99,3 +103,15 @@ def list_document_sets(
) )
for document_set_db_model, cc_pairs in document_set_info for document_set_db_model, cc_pairs in document_set_info
] ]
@router.get("/document-set-public")
def document_set_public(
check_public_request: CheckDocSetPublicRequest,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> CheckDocSetPublicResponse:
is_public = check_document_sets_are_public(
document_set_ids=check_public_request.document_set_ids, db_session=db_session
)
return CheckDocSetPublicResponse(is_public=is_public)

View File

@@ -461,12 +461,21 @@ class DocumentSetUpdateRequest(BaseModel):
cc_pair_ids: list[int] cc_pair_ids: list[int]
class CheckDocSetPublicRequest(BaseModel):
document_set_ids: list[int]
class CheckDocSetPublicResponse(BaseModel):
is_public: bool
class DocumentSet(BaseModel): class DocumentSet(BaseModel):
id: int id: int
name: str name: str
description: str description: str
cc_pair_descriptors: list[ConnectorCredentialPairDescriptor] cc_pair_descriptors: list[ConnectorCredentialPairDescriptor]
is_up_to_date: bool is_up_to_date: bool
contains_non_public: bool
@classmethod @classmethod
def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet": def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet":
@@ -474,6 +483,12 @@ class DocumentSet(BaseModel):
id=document_set_model.id, id=document_set_model.id,
name=document_set_model.name, name=document_set_model.name,
description=document_set_model.description, description=document_set_model.description,
contains_non_public=any(
[
not cc_pair.is_public
for cc_pair in document_set_model.connector_credential_pairs
]
),
cc_pair_descriptors=[ cc_pair_descriptors=[
ConnectorCredentialPairDescriptor( ConnectorCredentialPairDescriptor(
id=cc_pair.id, id=cc_pair.id,