Slack Role Override (#755)

This commit is contained in:
Yuhong Sun 2023-11-22 17:47:18 -08:00 committed by GitHub
parent 35c3511daa
commit bdfb894507
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 86 additions and 3 deletions

View File

@ -105,6 +105,17 @@ def handle_message(
send_to: list[str] | None = None
respond_tag_only = False
respond_team_member_list = None
bypass_acl = False
if (
channel_config
and channel_config.persona
and channel_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
if channel_config and channel_config.channel_config:
channel_conf = channel_config.channel_config
if not bipass_filters and "answer_filters" in channel_conf:
@ -172,6 +183,7 @@ def handle_message(
answer_generation_timeout=answer_generation_timeout,
real_time_flow=False,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
)
if not answer.error_msg:
return answer

View File

@ -396,3 +396,33 @@ def get_or_create_document_set_by_name(
db_session.commit()
return new_doc_set
def check_document_sets_are_public(
db_session: Session,
document_set_ids: list[int],
) -> bool:
connector_credential_pair_ids = (
db_session.query(
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
)
.filter(
DocumentSet__ConnectorCredentialPair.document_set_id.in_(document_set_ids)
)
.subquery()
)
not_public_exists = (
db_session.query(ConnectorCredentialPair.id)
.filter(
ConnectorCredentialPair.id.in_(
connector_credential_pair_ids # type:ignore
),
ConnectorCredentialPair.is_public.is_(False),
)
.limit(1)
.first()
is not None
)
return not not_public_exists

View File

@ -50,6 +50,7 @@ def answer_qa_query(
answer_generation_timeout: int = QA_TIMEOUT,
real_time_flow: bool = True,
enable_reflexion: bool = False,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
@ -91,6 +92,7 @@ def answer_qa_query(
user=user,
db_session=db_session,
document_index=get_default_document_index(),
bypass_acl=bypass_acl,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)

View File

@ -339,7 +339,10 @@ def _build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
filter_str += _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list)
if filters.access_control_list is not None:
filter_str += _build_or_filters(
ACCESS_CONTROL_LIST, filters.access_control_list
)
source_strs = (
[s.value for s in filters.source_type] if filters.source_type else None

View File

@ -39,7 +39,7 @@ class BaseFilters(BaseModel):
class IndexFilters(BaseFilters):
access_control_list: list[str]
access_control_list: list[str] | None
class ChunkMetric(BaseModel):

View File

@ -545,6 +545,7 @@ def danswer_search_generator(
db_session: Session,
document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
@ -561,7 +562,9 @@ def danswer_search_generator(
db_session=db_session,
)
user_acl_filters = build_access_filters_for_user(user, db_session)
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters(
source_type=question.filters.source_type,
document_set=question.filters.document_set,
@ -609,6 +612,7 @@ def danswer_search(
db_session: Session,
document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
@ -624,6 +628,7 @@ def danswer_search(
db_session=db_session,
document_index=document_index,
skip_llm_chunk_filter=skip_llm_chunk_filter,
bypass_acl=bypass_acl,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)

View File

@ -5,12 +5,15 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.document_set import check_document_sets_are_public
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import insert_document_set
from danswer.db.document_set import mark_document_set_as_to_be_deleted
from danswer.db.document_set import update_document_set
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.server.models import CheckDocSetPublicRequest
from danswer.server.models import CheckDocSetPublicResponse
from danswer.server.models import ConnectorCredentialPairDescriptor
from danswer.server.models import ConnectorSnapshot
from danswer.server.models import CredentialSnapshot
@ -82,6 +85,7 @@ def list_document_sets(
id=document_set_db_model.id,
name=document_set_db_model.name,
description=document_set_db_model.description,
contains_non_public=any([not cc_pair.is_public for cc_pair in cc_pairs]),
cc_pair_descriptors=[
ConnectorCredentialPairDescriptor(
id=cc_pair.id,
@ -99,3 +103,15 @@ def list_document_sets(
)
for document_set_db_model, cc_pairs in document_set_info
]
@router.get("/document-set-public")
def document_set_public(
check_public_request: CheckDocSetPublicRequest,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> CheckDocSetPublicResponse:
is_public = check_document_sets_are_public(
document_set_ids=check_public_request.document_set_ids, db_session=db_session
)
return CheckDocSetPublicResponse(is_public=is_public)

View File

@ -461,12 +461,21 @@ class DocumentSetUpdateRequest(BaseModel):
cc_pair_ids: list[int]
class CheckDocSetPublicRequest(BaseModel):
document_set_ids: list[int]
class CheckDocSetPublicResponse(BaseModel):
is_public: bool
class DocumentSet(BaseModel):
id: int
name: str
description: str
cc_pair_descriptors: list[ConnectorCredentialPairDescriptor]
is_up_to_date: bool
contains_non_public: bool
@classmethod
def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet":
@ -474,6 +483,12 @@ class DocumentSet(BaseModel):
id=document_set_model.id,
name=document_set_model.name,
description=document_set_model.description,
contains_non_public=any(
[
not cc_pair.is_public
for cc_pair in document_set_model.connector_credential_pairs
]
),
cc_pair_descriptors=[
ConnectorCredentialPairDescriptor(
id=cc_pair.id,