mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-28 21:05:17 +02:00
Slack Role Override (#755)
This commit is contained in:
@@ -105,6 +105,17 @@ def handle_message(
|
|||||||
send_to: list[str] | None = None
|
send_to: list[str] | None = None
|
||||||
respond_tag_only = False
|
respond_tag_only = False
|
||||||
respond_team_member_list = None
|
respond_team_member_list = None
|
||||||
|
|
||||||
|
bypass_acl = False
|
||||||
|
if (
|
||||||
|
channel_config
|
||||||
|
and channel_config.persona
|
||||||
|
and channel_config.persona.document_sets
|
||||||
|
):
|
||||||
|
# For Slack channels, use the full document set, admin will be warned when configuring it
|
||||||
|
# with non-public document sets
|
||||||
|
bypass_acl = True
|
||||||
|
|
||||||
if channel_config and channel_config.channel_config:
|
if channel_config and channel_config.channel_config:
|
||||||
channel_conf = channel_config.channel_config
|
channel_conf = channel_config.channel_config
|
||||||
if not bipass_filters and "answer_filters" in channel_conf:
|
if not bipass_filters and "answer_filters" in channel_conf:
|
||||||
@@ -172,6 +183,7 @@ def handle_message(
|
|||||||
answer_generation_timeout=answer_generation_timeout,
|
answer_generation_timeout=answer_generation_timeout,
|
||||||
real_time_flow=False,
|
real_time_flow=False,
|
||||||
enable_reflexion=reflexion,
|
enable_reflexion=reflexion,
|
||||||
|
bypass_acl=bypass_acl,
|
||||||
)
|
)
|
||||||
if not answer.error_msg:
|
if not answer.error_msg:
|
||||||
return answer
|
return answer
|
||||||
|
@@ -396,3 +396,33 @@ def get_or_create_document_set_by_name(
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
return new_doc_set
|
return new_doc_set
|
||||||
|
|
||||||
|
|
||||||
|
def check_document_sets_are_public(
|
||||||
|
db_session: Session,
|
||||||
|
document_set_ids: list[int],
|
||||||
|
) -> bool:
|
||||||
|
connector_credential_pair_ids = (
|
||||||
|
db_session.query(
|
||||||
|
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
DocumentSet__ConnectorCredentialPair.document_set_id.in_(document_set_ids)
|
||||||
|
)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
not_public_exists = (
|
||||||
|
db_session.query(ConnectorCredentialPair.id)
|
||||||
|
.filter(
|
||||||
|
ConnectorCredentialPair.id.in_(
|
||||||
|
connector_credential_pair_ids # type:ignore
|
||||||
|
),
|
||||||
|
ConnectorCredentialPair.is_public.is_(False),
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
.first()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
return not not_public_exists
|
||||||
|
@@ -50,6 +50,7 @@ def answer_qa_query(
|
|||||||
answer_generation_timeout: int = QA_TIMEOUT,
|
answer_generation_timeout: int = QA_TIMEOUT,
|
||||||
real_time_flow: bool = True,
|
real_time_flow: bool = True,
|
||||||
enable_reflexion: bool = False,
|
enable_reflexion: bool = False,
|
||||||
|
bypass_acl: bool = False,
|
||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||||
| None = None,
|
| None = None,
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
@@ -91,6 +92,7 @@ def answer_qa_query(
|
|||||||
user=user,
|
user=user,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
document_index=get_default_document_index(),
|
document_index=get_default_document_index(),
|
||||||
|
bypass_acl=bypass_acl,
|
||||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||||
rerank_metrics_callback=rerank_metrics_callback,
|
rerank_metrics_callback=rerank_metrics_callback,
|
||||||
)
|
)
|
||||||
|
@@ -339,7 +339,10 @@ def _build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
|
|||||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||||
|
|
||||||
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
|
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
|
||||||
filter_str += _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list)
|
if filters.access_control_list is not None:
|
||||||
|
filter_str += _build_or_filters(
|
||||||
|
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||||
|
)
|
||||||
|
|
||||||
source_strs = (
|
source_strs = (
|
||||||
[s.value for s in filters.source_type] if filters.source_type else None
|
[s.value for s in filters.source_type] if filters.source_type else None
|
||||||
|
@@ -39,7 +39,7 @@ class BaseFilters(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class IndexFilters(BaseFilters):
|
class IndexFilters(BaseFilters):
|
||||||
access_control_list: list[str]
|
access_control_list: list[str] | None
|
||||||
|
|
||||||
|
|
||||||
class ChunkMetric(BaseModel):
|
class ChunkMetric(BaseModel):
|
||||||
|
@@ -545,6 +545,7 @@ def danswer_search_generator(
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
document_index: DocumentIndex,
|
document_index: DocumentIndex,
|
||||||
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||||
|
bypass_acl: bool = False,
|
||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||||
| None = None,
|
| None = None,
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
@@ -561,7 +562,9 @@ def danswer_search_generator(
|
|||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
user_acl_filters = (
|
||||||
|
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||||
|
)
|
||||||
final_filters = IndexFilters(
|
final_filters = IndexFilters(
|
||||||
source_type=question.filters.source_type,
|
source_type=question.filters.source_type,
|
||||||
document_set=question.filters.document_set,
|
document_set=question.filters.document_set,
|
||||||
@@ -609,6 +612,7 @@ def danswer_search(
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
document_index: DocumentIndex,
|
document_index: DocumentIndex,
|
||||||
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||||
|
bypass_acl: bool = False,
|
||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||||
| None = None,
|
| None = None,
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
@@ -624,6 +628,7 @@ def danswer_search(
|
|||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
document_index=document_index,
|
document_index=document_index,
|
||||||
skip_llm_chunk_filter=skip_llm_chunk_filter,
|
skip_llm_chunk_filter=skip_llm_chunk_filter,
|
||||||
|
bypass_acl=bypass_acl,
|
||||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||||
rerank_metrics_callback=rerank_metrics_callback,
|
rerank_metrics_callback=rerank_metrics_callback,
|
||||||
)
|
)
|
||||||
|
@@ -5,12 +5,15 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
|
from danswer.db.document_set import check_document_sets_are_public
|
||||||
from danswer.db.document_set import fetch_document_sets
|
from danswer.db.document_set import fetch_document_sets
|
||||||
from danswer.db.document_set import insert_document_set
|
from danswer.db.document_set import insert_document_set
|
||||||
from danswer.db.document_set import mark_document_set_as_to_be_deleted
|
from danswer.db.document_set import mark_document_set_as_to_be_deleted
|
||||||
from danswer.db.document_set import update_document_set
|
from danswer.db.document_set import update_document_set
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
|
from danswer.server.models import CheckDocSetPublicRequest
|
||||||
|
from danswer.server.models import CheckDocSetPublicResponse
|
||||||
from danswer.server.models import ConnectorCredentialPairDescriptor
|
from danswer.server.models import ConnectorCredentialPairDescriptor
|
||||||
from danswer.server.models import ConnectorSnapshot
|
from danswer.server.models import ConnectorSnapshot
|
||||||
from danswer.server.models import CredentialSnapshot
|
from danswer.server.models import CredentialSnapshot
|
||||||
@@ -82,6 +85,7 @@ def list_document_sets(
|
|||||||
id=document_set_db_model.id,
|
id=document_set_db_model.id,
|
||||||
name=document_set_db_model.name,
|
name=document_set_db_model.name,
|
||||||
description=document_set_db_model.description,
|
description=document_set_db_model.description,
|
||||||
|
contains_non_public=any([not cc_pair.is_public for cc_pair in cc_pairs]),
|
||||||
cc_pair_descriptors=[
|
cc_pair_descriptors=[
|
||||||
ConnectorCredentialPairDescriptor(
|
ConnectorCredentialPairDescriptor(
|
||||||
id=cc_pair.id,
|
id=cc_pair.id,
|
||||||
@@ -99,3 +103,15 @@ def list_document_sets(
|
|||||||
)
|
)
|
||||||
for document_set_db_model, cc_pairs in document_set_info
|
for document_set_db_model, cc_pairs in document_set_info
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/document-set-public")
|
||||||
|
def document_set_public(
|
||||||
|
check_public_request: CheckDocSetPublicRequest,
|
||||||
|
_: User = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> CheckDocSetPublicResponse:
|
||||||
|
is_public = check_document_sets_are_public(
|
||||||
|
document_set_ids=check_public_request.document_set_ids, db_session=db_session
|
||||||
|
)
|
||||||
|
return CheckDocSetPublicResponse(is_public=is_public)
|
||||||
|
@@ -461,12 +461,21 @@ class DocumentSetUpdateRequest(BaseModel):
|
|||||||
cc_pair_ids: list[int]
|
cc_pair_ids: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
class CheckDocSetPublicRequest(BaseModel):
|
||||||
|
document_set_ids: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
class CheckDocSetPublicResponse(BaseModel):
|
||||||
|
is_public: bool
|
||||||
|
|
||||||
|
|
||||||
class DocumentSet(BaseModel):
|
class DocumentSet(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
cc_pair_descriptors: list[ConnectorCredentialPairDescriptor]
|
cc_pair_descriptors: list[ConnectorCredentialPairDescriptor]
|
||||||
is_up_to_date: bool
|
is_up_to_date: bool
|
||||||
|
contains_non_public: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet":
|
def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet":
|
||||||
@@ -474,6 +483,12 @@ class DocumentSet(BaseModel):
|
|||||||
id=document_set_model.id,
|
id=document_set_model.id,
|
||||||
name=document_set_model.name,
|
name=document_set_model.name,
|
||||||
description=document_set_model.description,
|
description=document_set_model.description,
|
||||||
|
contains_non_public=any(
|
||||||
|
[
|
||||||
|
not cc_pair.is_public
|
||||||
|
for cc_pair in document_set_model.connector_credential_pairs
|
||||||
|
]
|
||||||
|
),
|
||||||
cc_pair_descriptors=[
|
cc_pair_descriptors=[
|
||||||
ConnectorCredentialPairDescriptor(
|
ConnectorCredentialPairDescriptor(
|
||||||
id=cc_pair.id,
|
id=cc_pair.id,
|
||||||
|
Reference in New Issue
Block a user