Added permission syncing (#2340)

* Added permission syncing on the backend

* Rewored to work with celery

alembic fix

fixed test

* frontend changes

* got groups working

* added comments and fixed public docs

* fixed merge issues

* frontend complete!

* frontend cleanup and mypy fixes

* refactored connector access_type selection

* mypy fixes

* minor refactor and frontend improvements

* get to fetch

* renames and comments

* minor change to var names

* got curator stuff working

* addressed pablo's comments

* refactored user_external_group to reference users table

* implemented polling

* small refactor

* fixed a whoopsies on the frontend

* added scripts to seed dummy docs and test query times

* fixed frontend build issue

* alembic fix

* handled is_public overlap

* yuhong feedback

* added more checks for sync

* black

* mypy

* fixed circular import

* todos

* alembic fix

* alembic
This commit is contained in:
hagen-danswer 2024-09-19 15:07:36 -07:00 committed by GitHub
parent ef104e9a82
commit 2274cab554
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
79 changed files with 2192 additions and 1049 deletions

View File

@ -1,7 +1,7 @@
"""Add last synced and last modified to document table
Revision ID: 52a219fb5233
Revises: f17bf3b0d9f1
Revises: f7e58d357687
Create Date: 2024-08-28 17:40:46.077470
"""

View File

@ -0,0 +1,162 @@
"""Add Permission Syncing
Revision ID: 61ff3651add4
Revises: 1b8206b29c5d
Create Date: 2024-09-05 13:57:11.770413
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "61ff3651add4"
down_revision = "1b8206b29c5d"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Admin user who set up connectors will lose access to the docs temporarily
# only way currently to give back access is to rerun from beginning
op.add_column(
"connector_credential_pair",
sa.Column(
"access_type",
sa.String(),
nullable=True,
),
)
op.execute(
"UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true"
)
op.execute(
"UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false"
)
op.alter_column("connector_credential_pair", "access_type", nullable=False)
op.add_column(
"connector_credential_pair",
sa.Column(
"auto_sync_options",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True),
)
op.drop_column("connector_credential_pair", "is_public")
op.add_column(
"document",
sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True),
)
op.add_column(
"document",
sa.Column(
"external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True
),
)
op.add_column(
"document",
sa.Column("is_public", sa.Boolean(), nullable=True),
)
op.create_table(
"user__external_user_group_id",
sa.Column(
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
),
sa.Column("external_user_group_id", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("user_id"),
)
op.drop_column("external_permission", "user_id")
op.drop_column("email_to_external_user_cache", "user_id")
op.drop_table("permission_sync_run")
op.drop_table("external_permission")
op.drop_table("email_to_external_user_cache")
def downgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column("is_public", sa.BOOLEAN(), nullable=True),
)
op.execute(
"UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')"
)
op.alter_column("connector_credential_pair", "is_public", nullable=False)
op.drop_column("connector_credential_pair", "auto_sync_options")
op.drop_column("connector_credential_pair", "access_type")
op.drop_column("connector_credential_pair", "last_time_perm_sync")
op.drop_column("document", "external_user_emails")
op.drop_column("document", "external_user_group_ids")
op.drop_column("document", "is_public")
op.drop_table("user__external_user_group_id")
# Drop the enum type at the end of the downgrade
op.create_table(
"permission_sync_run",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("update_type", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
sa.Column(
"status",
sa.String(),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["cc_pair_id"],
["connector_credential_pair.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"external_permission",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("external_permission_group", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"email_to_external_user_cache",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_user_id", sa.String(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)

View File

@ -1,7 +1,7 @@
"""standard answer match_regex flag
Revision ID: efb35676026c
Revises: 52a219fb5233
Revises: 0ebb1d516877
Create Date: 2024-09-11 13:55:46.101149
"""

View File

@ -1,7 +1,7 @@
"""add has_web_login column to user
Revision ID: f7e58d357687
Revises: bceb1e139447
Revises: ba98eba0f66a
Create Date: 2024-09-07 20:20:54.522620
"""

View File

@ -1,7 +1,7 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_user_email
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_access_info_for_document
from danswer.db.document import get_access_info_for_documents
@ -18,10 +18,13 @@ def _get_access_for_document(
document_id=document_id,
)
if not info:
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
return DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=info[2] if info else False,
)
def get_access_for_document(
@ -34,6 +37,16 @@ def get_access_for_document(
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
def get_null_document_access() -> DocumentAccess:
return DocumentAccess(
user_emails=set(),
user_groups=set(),
is_public=False,
external_user_emails=set(),
external_user_group_ids=set(),
)
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,
@ -42,13 +55,27 @@ def _get_access_for_documents(
db_session=db_session,
document_ids=document_ids,
)
return {
document_id: DocumentAccess.build(
user_ids=user_ids, user_groups=[], is_public=is_public
doc_access = {
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
# MIT version will wipe all groups and external groups on update
user_groups=set(),
is_public=is_public,
external_user_emails=set(),
external_user_group_ids=set(),
)
for document_id, user_ids, is_public in document_access_info
for document_id, user_emails, is_public in document_access_info
}
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.
for doc_id in document_ids:
if doc_id not in doc_access:
doc_access[doc_id] = get_null_document_access()
return doc_access
def get_access_for_documents(
document_ids: list[str],
@ -70,7 +97,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
matches one entry in the returned set.
"""
if user:
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
return {PUBLIC_DOC_PAT}

View File

@ -1,30 +1,72 @@
from dataclasses import dataclass
from uuid import UUID
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_external_group
from danswer.access.utils import prefix_user_email
from danswer.access.utils import prefix_user_group
from danswer.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class DocumentAccess:
user_ids: set[str] # stringified UUIDs
user_groups: set[str] # names of user groups associated with this document
class ExternalAccess:
# Emails of external users with access to the doc externally
external_user_emails: set[str]
# Names or external IDs of groups with access to the doc
external_user_group_ids: set[str]
# Whether the document is public in the external system or Danswer
is_public: bool
def to_acl(self) -> list[str]:
return (
[prefix_user(user_id) for user_id in self.user_ids]
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
def to_acl(self) -> set[str]:
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
cls,
user_emails: list[str | None],
user_groups: list[str],
external_user_emails: list[str],
external_user_group_ids: list[str],
is_public: bool,
) -> "DocumentAccess":
return cls(
user_ids={str(user_id) for user_id in user_ids if user_id},
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_groups=set(user_groups),
is_public=is_public,
)

View File

@ -1,10 +1,24 @@
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
from danswer.configs.constants import DocumentSource
def prefix_user_email(user_email: str) -> str:
"""Prefixes a user email to eliminate collision with group names.
This applies to both a Danswer user and an External user, this is to make the query time
more efficient"""
return f"user_email:{user_email}"
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user IDs.
"""Prefixes a user group name to eliminate collision with user emails.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"
def prefix_external_group(ext_group_name: str) -> str:
"""Prefixes an external group name to eliminate collision with user emails / Danswer groups."""
return f"external_group:{ext_group_name}"
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
"""External groups may collide across sources, every source needs its own prefix."""
return f"{source.value.upper()}_{ext_group_name}"

View File

@ -128,11 +128,11 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
return
runnable_connector = instantiate_connector(
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
db_session,
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.PRUNE,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(

View File

@ -56,11 +56,11 @@ def _get_connector_runner(
try:
runnable_connector = instantiate_connector(
attempt.connector_credential_pair.connector.source,
task,
attempt.connector_credential_pair.connector.connector_specific_config,
attempt.connector_credential_pair.credential,
db_session,
db_session=db_session,
source=attempt.connector_credential_pair.connector.source,
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")

View File

@ -124,11 +124,11 @@ def identify_connector_class(
def instantiate_connector(
db_session: Session,
source: DocumentSource,
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
db_session: Session,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
connector = connector_class(**connector_specific_config)

View File

@ -6,7 +6,6 @@ from datetime import timezone
from enum import Enum
from itertools import chain
from typing import Any
from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
@ -21,19 +20,13 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds_for_authorized_user,
)
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds_for_service_account,
)
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
@ -407,42 +400,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds: OAuthCredentials | ServiceAccountCredentials | None = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
)
creds = get_google_drive_creds_for_authorized_user(
token_json_str=access_token_json_str
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = creds.to_json() if creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
if DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
creds = get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
if creds is None:
raise PermissionError(
"Unable to access Google Drive - unknown credential structure."
)
creds, new_creds_dict = get_google_drive_creds(credentials)
self.creds = creds
return new_creds_dict
@ -509,6 +467,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
file["modifiedTime"]
).astimezone(timezone.utc),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
)
except Exception as e:

View File

@ -10,11 +10,13 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import BASE_SCOPES
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
@ -22,7 +24,8 @@ from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.constants import SCOPES
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.dynamic_configs.factory import get_dynamic_config_store
@ -34,15 +37,25 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def build_gdrive_scopes() -> list[str]:
base_scopes: list[str] = BASE_SCOPES
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
if ENTERPRISE_EDITION_ENABLED:
return base_scopes + permissions_scopes + groups_scopes
return base_scopes + permissions_scopes
def _build_frontend_google_drive_redirect() -> str:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
def get_google_drive_creds_for_authorized_user(
token_json_str: str,
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
if creds.valid:
return creds
@ -59,18 +72,67 @@ def get_google_drive_creds_for_authorized_user(
return None
def get_google_drive_creds_for_service_account(
service_account_key_json_str: str,
def _get_google_drive_creds_for_service_account(
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=SCOPES
service_account_key, scopes=scopes
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def get_google_drive_creds(
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_drive_creds_for_authorized_user(
token_json_str=access_token_json_str, scopes=scopes
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_creds = _get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str,
scopes=scopes,
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
service_creds = (
service_creds.with_subject(delegated_user_email)
if service_creds
else None
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
oauth_creds or service_creds
)
if creds is None:
raise PermissionError(
"Unable to access Google Drive - unknown credential structure."
)
return creds, new_creds_dict
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
@ -84,7 +146,7 @@ def get_auth_url(credential_id: int) -> str:
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=SCOPES,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
@ -107,7 +169,7 @@ def update_credential_access_tokens(
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=SCOPES,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
flow.fetch_token(code=auth_code)

View File

@ -1,7 +1,7 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
SCOPES = [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
]
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]

View File

@ -113,6 +113,9 @@ class DocumentBase(BaseModel):
# The default title is semantic_identifier though unless otherwise specified
title: str | None = None
from_ingestion_api: bool = False
# Anything else that may be useful that is specific to this particular connector type that other
# parts of the code may need. If you're unsure, this can be left as None
additional_info: Any = None
def get_title_for_document_index(
self,

View File

@ -211,7 +211,7 @@ def handle_message(
with Session(get_sqlalchemy_engine()) as db_session:
if message_info.email:
add_non_web_user_if_not_exists(message_info.email, db_session)
add_non_web_user_if_not_exists(db_session, message_info.email)
# first check if we need to respond with a standard answer
used_standard_answer = handle_standard_answers(

View File

@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_connector_by_id
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
@ -24,6 +25,10 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_function_map import (
check_if_valid_sync_source,
)
logger = setup_logger()
@ -74,7 +79,7 @@ def _add_user_filters(
.correlate(ConnectorCredentialPair)
)
else:
where_clause |= ConnectorCredentialPair.is_public == True # noqa: E712
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
return stmt.where(where_clause)
@ -94,8 +99,7 @@ def get_connector_credential_pairs(
) # noqa
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
results = db_session.scalars(stmt)
return list(results.all())
return list(db_session.scalars(stmt).all())
def add_deletion_failure_message(
@ -309,9 +313,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
association = ConnectorCredentialPair(
connector_id=0,
credential_id=0,
access_type=AccessType.PUBLIC,
name="DefaultCCPair",
status=ConnectorCredentialPairStatus.ACTIVE,
is_public=True,
)
db_session.add(association)
db_session.commit()
@ -336,8 +340,9 @@ def add_credential_to_connector(
connector_id: int,
credential_id: int,
cc_pair_name: str | None,
is_public: bool,
access_type: AccessType,
groups: list[int] | None,
auto_sync_options: dict | None = None,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
@ -345,6 +350,13 @@ def add_credential_to_connector(
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
if access_type == AccessType.SYNC:
if not check_if_valid_sync_source(connector.source):
raise HTTPException(
status_code=400,
detail=f"Connector of type {connector.source} does not support SYNC access type",
)
if credential is None:
error_msg = (
f"Credential {credential_id} does not exist or does not belong to user"
@ -375,12 +387,13 @@ def add_credential_to_connector(
credential_id=credential_id,
name=cc_pair_name,
status=ConnectorCredentialPairStatus.ACTIVE,
is_public=is_public,
access_type=access_type,
auto_sync_options=auto_sync_options,
)
db_session.add(association)
db_session.flush() # make sure the association has an id
if groups:
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
@ -423,6 +436,10 @@ def remove_credential_from_connector(
)
if association is not None:
delete_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
)
db_session.delete(association)
db_session.commit()
return StatusResponse(

View File

@ -4,7 +4,6 @@ from collections.abc import Generator
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
@ -17,14 +16,17 @@ from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.util import TransactionalContext
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import null
from danswer.configs.constants import DEFAULT_BOOST
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import User
from danswer.db.tag import delete_document_tags_for_documents__no_commit
from danswer.db.utils import model_to_dict
from danswer.document_index.interfaces import DocumentMetadata
@ -186,16 +188,14 @@ def get_document_counts_for_cc_pairs(
def get_access_info_for_document(
db_session: Session,
document_id: str,
) -> tuple[str, list[UUID | None], bool] | None:
) -> tuple[str, list[str | None], bool] | None:
"""Gets access info for a single document by calling the get_access_info_for_documents function
and passing a list with a single document ID.
Args:
db_session (Session): The database session to use.
document_id (str): The document ID to fetch access info for.
Returns:
Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs,
Optional[Tuple[str, List[str | None], bool]]: A tuple containing the document ID, a list of user emails,
and a boolean indicating if the document is globally public, or None if no results are found.
"""
results = get_access_info_for_documents(db_session, [document_id])
@ -208,19 +208,27 @@ def get_access_info_for_document(
def get_access_info_for_documents(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, list[UUID | None], bool]]:
) -> Sequence[tuple[str, list[str | None], bool]]:
"""Gets back all relevant access info for the given documents. This includes
the user_ids for cc pairs that the document is associated with + whether any
of the associated cc pairs are intending to make the document globally public.
Returns the list where each element contains:
- Document ID (which is also the ID of the DocumentByConnectorCredentialPair)
- List of emails of Danswer users with direct access to the doc (includes a "None" element if
the connector was set up by an admin when auth was off
- bool for whether the document is public (the document later can also be marked public by
automatic permission sync step)
"""
stmt = select(
DocumentByConnectorCredentialPair.id,
func.array_agg(func.coalesce(User.email, null())).label("user_emails"),
func.bool_or(ConnectorCredentialPair.access_type == AccessType.PUBLIC).label(
"public_doc"
),
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
stmt = (
select(
DocumentByConnectorCredentialPair.id,
func.array_agg(Credential.user_id).label("user_ids"),
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
)
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
.join(
stmt.join(
Credential,
DocumentByConnectorCredentialPair.credential_id == Credential.id,
)
@ -233,6 +241,13 @@ def get_access_info_for_documents(
== ConnectorCredentialPair.credential_id,
),
)
.outerjoin(
User,
and_(
Credential.user_id == User.id,
ConnectorCredentialPair.access_type != AccessType.SYNC,
),
)
# don't include CC pairs that are being deleted
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
@ -278,9 +293,19 @@ def upsert_documents(
for doc in seen_documents.values()
]
)
# for now, there are no columns to update. If more metadata is added, then this
# needs to change to an `on_conflict_do_update`
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
on_conflict_stmt = insert_stmt.on_conflict_do_update(
index_elements=["id"], # Conflict target
set_={
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
"boost": insert_stmt.excluded.boost,
"hidden": insert_stmt.excluded.hidden,
"semantic_id": insert_stmt.excluded.semantic_id,
"link": insert_stmt.excluded.link,
"primary_owners": insert_stmt.excluded.primary_owners,
"secondary_owners": insert_stmt.excluded.secondary_owners,
},
)
db_session.execute(on_conflict_stmt)
db_session.commit()

View File

@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
@ -180,7 +181,7 @@ def _check_if_cc_pairs_are_owned_by_groups(
ids=missing_cc_pair_ids,
)
for cc_pair in cc_pairs:
if not cc_pair.is_public:
if cc_pair.access_type != AccessType.PUBLIC:
raise ValueError(
f"Connector Credential Pair with ID: '{cc_pair.id}'"
" is not owned by the specified groups"
@ -704,7 +705,7 @@ def check_document_sets_are_public(
ConnectorCredentialPair.id.in_(
connector_credential_pair_ids # type:ignore
),
ConnectorCredentialPair.is_public.is_(False),
ConnectorCredentialPair.access_type != AccessType.PUBLIC,
)
.limit(1)
.first()

View File

@ -51,3 +51,9 @@ class ConnectorCredentialPairStatus(str, PyEnum):
def is_active(self) -> bool:
return self == ConnectorCredentialPairStatus.ACTIVE
class AccessType(str, PyEnum):
PUBLIC = "public"
PRIVATE = "private"
SYNC = "sync"

View File

@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.db.chat import get_chat_message
from danswer.db.enums import AccessType
from danswer.db.models import ChatMessageFeedback
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document as DbDocument
@ -94,7 +95,7 @@ def _add_user_filters(
.correlate(CCPair)
)
else:
where_clause |= CCPair.is_public == True # noqa: E712
where_clause |= CCPair.access_type == AccessType.PUBLIC
return stmt.where(where_clause)

View File

@ -39,6 +39,7 @@ from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.db.enums import AccessType
from danswer.configs.constants import NotificationType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.constants import TokenRateLimitScope
@ -388,10 +389,20 @@ class ConnectorCredentialPair(Base):
# controls whether the documents indexed by this CC pair are visible to all
# or if they are only visible to those with that are given explicit access
# (e.g. via owning the credential or being a part of a group that is given access)
is_public: Mapped[bool] = mapped_column(
Boolean,
default=True,
nullable=False,
access_type: Mapped[AccessType] = mapped_column(
Enum(AccessType, native_enum=False), nullable=False
)
# special info needed for the auto-sync feature. The exact structure depends on the
# source type (defined in the connector's `source` field)
# E.g. for google_drive perm sync:
# {"customer_id": "123567", "company_domain": "@danswer.ai"}
auto_sync_options: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Time finished, not used for calculating backend jobs which uses time started (created)
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
@ -422,6 +433,7 @@ class ConnectorCredentialPair(Base):
class Document(Base):
__tablename__ = "document"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
# this should correspond to the ID of the document
# (as is passed around in Danswer)
@ -465,7 +477,18 @@ class Document(Base):
secondary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# TODO if more sensitive data is added here for display, make sure to add user/group permission
# Permission sync columns
# Email addresses are saved at the document level for externally synced permissions
# This is becuase the normal flow of assigning permissions is through the cc_pair
# doesn't apply here
external_user_emails: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# These group ids have been prefixed by the source type
external_user_group_ids: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
@ -1674,95 +1697,18 @@ class StandardAnswer(Base):
"""Tables related to Permission Sync"""
class PermissionSyncStatus(str, PyEnum):
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
class PermissionSyncJobType(str, PyEnum):
USER_LEVEL = "user_level"
GROUP_LEVEL = "group_level"
class PermissionSyncRun(Base):
"""Represents one run of a permission sync job. For some given cc_pair, it is either sync-ing
the users or it is sync-ing the groups"""
__tablename__ = "permission_sync_run"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
# Not strictly needed but makes it easy to use without fetching from cc_pair
source_type: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
# Currently all sync jobs are handled as a group permission sync or a user permission sync
update_type: Mapped[PermissionSyncJobType] = mapped_column(
Enum(PermissionSyncJobType)
)
cc_pair_id: Mapped[int | None] = mapped_column(
ForeignKey("connector_credential_pair.id"), nullable=True
)
status: Mapped[PermissionSyncStatus] = mapped_column(Enum(PermissionSyncStatus))
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
cc_pair: Mapped[ConnectorCredentialPair] = relationship("ConnectorCredentialPair")
class ExternalPermission(Base):
class User__ExternalUserGroupId(Base):
"""Maps user info both internal and external to the name of the external group
This maps the user to all of their external groups so that the external group name can be
attached to the ACL list matching during query time. User level permissions can be handled by
directly adding the Danswer user to the doc ACL list"""
__tablename__ = "external_permission"
__tablename__ = "user__external_user_group_id"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
# Email is needed because we want to keep track of users not in Danswer to simplify process
# when the user joins
user_email: Mapped[str] = mapped_column(String)
source_type: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
external_permission_group: Mapped[str] = mapped_column(String)
user = relationship("User")
class EmailToExternalUserCache(Base):
"""A way to map users IDs in the external tool to a user in Danswer or at least an email for
when the user joins. Used as a cache for when fetching external groups which have their own
user ids, this can easily be mapped back to users already known in Danswer without needing
to call external APIs to get the user emails.
This way when groups are updated in the external tool and we need to update the mapping of
internal users to the groups, we can sync the internal users to the external groups they are
part of using this.
Ie. User Chris is part of groups alpha, beta, and we can update this if Chris is no longer
part of alpha in some external tool.
"""
__tablename__ = "email_to_external_user_cache"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_user_id: Mapped[str] = mapped_column(String)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
# Email is needed because we want to keep track of users not in Danswer to simplify process
# when the user joins
user_email: Mapped[str] = mapped_column(String)
source_type: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
user = relationship("User")
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
# These group ids have been prefixed by the source type
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
class UsageReport(Base):

View File

@ -22,6 +22,17 @@ def list_users(
return db_session.scalars(stmt).unique().all()
def get_users_by_emails(
db_session: Session, emails: list[str]
) -> tuple[list[User], list[str]]:
# Use distinct to avoid duplicates
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
found_users_emails = [user.email for user in found_users]
missing_user_emails = [email for email in emails if email not in found_users_emails]
return found_users, missing_user_emails
def get_user_by_email(email: str, db_session: Session) -> User | None:
user = db_session.query(User).filter(User.email == email).first() # type: ignore
@ -34,20 +45,50 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
return user
def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user
def _generate_non_web_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
user = User(
return User(
email=email,
hashed_password=hashed_pass,
has_web_login=False,
role=UserRole.BASIC,
)
def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user
user = _generate_non_web_user(email=email)
db_session.add(user)
db_session.commit()
return user
def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user
user = _generate_non_web_user(email=email)
db_session.add(user)
db_session.flush() # generate id
return user
def batch_add_non_web_user_if_not_exists__no_commit(
db_session: Session, emails: list[str]
) -> list[User]:
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
new_users: list[User] = []
for email in missing_user_emails:
new_users.append(_generate_non_web_user(email=email))
db_session.add_all(new_users)
db_session.flush() # generate ids
return found_users + new_users

View File

@ -265,7 +265,13 @@ def index_doc_batch(
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
no_access = DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
no_access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
)
ctx = index_doc_batch_prepare(
document_batch=document_batch,

View File

@ -18,6 +18,7 @@ from danswer.db.connector_credential_pair import (
)
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
@ -201,7 +202,7 @@ def associate_credential_to_connector(
db_session=db_session,
user=user,
target_group_ids=metadata.groups,
object_is_public=metadata.is_public,
object_is_public=metadata.access_type == AccessType.PUBLIC,
)
try:
@ -211,7 +212,8 @@ def associate_credential_to_connector(
connector_id=connector_id,
credential_id=credential_id,
cc_pair_name=metadata.name,
is_public=True if metadata.is_public is None else metadata.is_public,
access_type=metadata.access_type,
auto_sync_options=metadata.auto_sync_options,
groups=metadata.groups,
)

View File

@ -64,6 +64,7 @@ from danswer.db.credentials import fetch_credential_by_id
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
@ -559,7 +560,7 @@ def get_connector_indexing_status(
cc_pair_status=cc_pair.status,
connector=ConnectorSnapshot.from_connector_db_model(connector),
credential=CredentialSnapshot.from_credential_db_model(credential),
public_doc=cc_pair.is_public,
access_type=cc_pair.access_type,
owner=credential.user.email if credential.user else "",
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
last_finished_status=(
@ -668,12 +669,15 @@ def create_connector_with_mock_credential(
credential = create_credential(
mock_credential, user=user, db_session=db_session
)
access_type = (
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
)
response = add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=cast(int, connector_response.id), # will aways be an int
credential_id=credential.id,
is_public=connector_data.is_public or False,
access_type=access_type,
cc_pair_name=connector_data.name,
groups=connector_data.groups,
)

View File

@ -10,6 +10,7 @@ from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import DocumentErrorSummary
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
@ -218,7 +219,7 @@ class CCPairFullInfo(BaseModel):
number_of_index_attempts: int
last_index_attempt_status: IndexingStatus | None
latest_deletion_attempt: DeletionAttemptSnapshot | None
is_public: bool
access_type: AccessType
is_editable_for_current_user: bool
deletion_failure_message: str | None
@ -261,7 +262,7 @@ class CCPairFullInfo(BaseModel):
number_of_index_attempts=number_of_index_attempts,
last_index_attempt_status=last_indexing_status,
latest_deletion_attempt=latest_deletion_attempt,
is_public=cc_pair_model.is_public,
access_type=cc_pair_model.access_type,
is_editable_for_current_user=is_editable_for_current_user,
deletion_failure_message=cc_pair_model.deletion_failure_message,
)
@ -288,7 +289,7 @@ class ConnectorIndexingStatus(BaseModel):
credential: CredentialSnapshot
owner: str
groups: list[int]
public_doc: bool
access_type: AccessType
last_finished_status: IndexingStatus | None
last_status: IndexingStatus | None
last_success: datetime | None
@ -306,7 +307,8 @@ class ConnectorCredentialPairIdentifier(BaseModel):
class ConnectorCredentialPairMetadata(BaseModel):
name: str | None = None
is_public: bool | None = None
access_type: AccessType
auto_sync_options: dict[str, Any] | None = None
groups: list[int] = Field(default_factory=list)

View File

@ -47,7 +47,6 @@ class DocumentSet(BaseModel):
description: str
cc_pair_descriptors: list[ConnectorCredentialPairDescriptor]
is_up_to_date: bool
contains_non_public: bool
is_public: bool
# For Private Document Sets, who should be able to access these
users: list[UUID]
@ -59,12 +58,6 @@ class DocumentSet(BaseModel):
id=document_set_model.id,
name=document_set_model.name,
description=document_set_model.description,
contains_non_public=any(
[
not cc_pair.is_public
for cc_pair in document_set_model.connector_credential_pairs
]
),
cc_pair_descriptors=[
ConnectorCredentialPairDescriptor(
id=cc_pair.id,

View File

@ -18,6 +18,9 @@ class TestEmbeddingRequest(BaseModel):
api_url: str | None = None
model_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
class CloudEmbeddingProvider(BaseModel):
provider_type: EmbeddingProvider

View File

@ -48,6 +48,7 @@ from danswer.server.models import InvitedUserSnapshot
from danswer.server.models import MinimalUserSnapshot
from danswer.utils.logger import setup_logger
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
from ee.danswer.db.user_group import remove_curator_status__no_commit
logger = setup_logger()
@ -243,6 +244,11 @@ async def delete_user(
for oauth_account in user_to_delete.oauth_accounts:
db_session.delete(oauth_account)
delete_user__ext_group_for_user__no_commit(
db_session=db_session,
user_id=user_to_delete.id,
)
db_session.query(DocumentSet__User).filter(
DocumentSet__User.user_id == user_to_delete.id
).delete()

View File

@ -5,8 +5,11 @@ from danswer.access.access import (
)
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_external_group
from danswer.access.utils import prefix_user_group
from danswer.db.document import get_documents_by_ids
from danswer.db.models import User
from ee.danswer.db.external_perm import fetch_external_groups_for_user
from ee.danswer.db.user_group import fetch_user_groups_for_documents
from ee.danswer.db.user_group import fetch_user_groups_for_user
@ -17,7 +20,13 @@ def _get_access_for_document(
) -> DocumentAccess:
id_to_access = _get_access_for_documents([document_id], db_session)
if len(id_to_access) == 0:
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
return DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
)
return next(iter(id_to_access.values()))
@ -30,22 +39,48 @@ def _get_access_for_documents(
document_ids=document_ids,
db_session=db_session,
)
user_group_info = {
user_group_info: dict[str, list[str]] = {
document_id: group_names
for document_id, group_names in fetch_user_groups_for_documents(
db_session=db_session,
document_ids=document_ids,
)
}
documents = get_documents_by_ids(
db_session=db_session,
document_ids=document_ids,
)
doc_id_map = {doc.id: doc for doc in documents}
return {
document_id: DocumentAccess(
user_ids=non_ee_access.user_ids,
user_groups=user_group_info.get(document_id, []), # type: ignore
is_public=non_ee_access.is_public,
access_map = {}
for document_id, non_ee_access in non_ee_access_dict.items():
document = doc_id_map[document_id]
ext_u_emails = (
set(document.external_user_emails)
if document.external_user_emails
else set()
)
for document_id, non_ee_access in non_ee_access_dict.items()
}
ext_u_groups = (
set(document.external_user_group_ids)
if document.external_user_group_ids
else set()
)
# If the document is determined to be "public" externally (through a SYNC connector)
# then it's given the same access level as if it were marked public within Danswer
is_public_anywhere = document.is_public or non_ee_access.is_public
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess(
user_emails=non_ee_access.user_emails,
user_groups=set(user_group_info.get(document_id, [])),
is_public=is_public_anywhere,
external_user_emails=ext_u_emails,
external_user_group_ids=ext_u_groups,
)
return access_map
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
@ -56,7 +91,20 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
NOTE: is imported in danswer.access.access by `fetch_versioned_implementation`
DO NOT REMOVE."""
user_groups = fetch_user_groups_for_user(db_session, user.id) if user else []
return set(
[prefix_user_group(user_group.name) for user_group in user_groups]
).union(get_acl_for_user_without_groups(user, db_session))
db_user_groups = fetch_user_groups_for_user(db_session, user.id) if user else []
prefixed_user_groups = [
prefix_user_group(db_user_group.name) for db_user_group in db_user_groups
]
db_external_groups = (
fetch_external_groups_for_user(db_session, user.id) if user else []
)
prefixed_external_groups = [
prefix_external_group(db_external_group.external_user_group_id)
for db_external_group in db_external_groups
]
user_acl = set(prefixed_user_groups + prefixed_external_groups)
user_acl.update(get_acl_for_user_without_groups(user, db_session))
return user_acl

View File

@ -11,7 +11,13 @@ from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
from ee.danswer.background.celery_utils import should_perform_external_permissions_check
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.external_permissions.permission_sync import (
run_permission_sync_entrypoint,
)
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
logger = setup_logger()
@ -20,6 +26,13 @@ logger = setup_logger()
global_version.set_ee()
@build_celery_task_wrapper(name_sync_external_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int) -> None:
@ -30,6 +43,23 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_sync_external_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_permissions_task() -> None:
"""Runs periodically to sync external permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
)
@celery_app.task(
name="check_ttl_management_task",
soft_time_limit=JOB_TIMEOUT,
@ -64,7 +94,11 @@ def autogenerate_usage_report_task() -> None:
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"autogenerate-usage-report": {
"sync-external-permissions": {
"task": "check_sync_external_permissions_task",
"schedule": timedelta(seconds=60), # TODO: optimize this
},
"autogenerate_usage_report": {
"task": "autogenerate_usage_report_task",
"schedule": timedelta(days=30), # TODO: change this to config flag
},

View File

@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import AccessType
from danswer.db.models import ConnectorCredentialPair
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
from ee.danswer.db.user_group import delete_user_group
from ee.danswer.db.user_group import fetch_user_group
from ee.danswer.db.user_group import mark_user_group_as_synced
@ -30,11 +33,30 @@ def should_perform_chat_ttl_check(
return True
if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session):
logger.info("TTL check is already being performed. Skipping.")
logger.debug(f"{task_name} is already being performed. Skipping.")
return False
return True
def should_perform_external_permissions_check(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
task_name = name_sync_external_permissions_task(cc_pair_id=cc_pair.id)
latest_task = get_latest_task(task_name, db_session)
if not latest_task:
return True
if check_task_is_live_and_not_timed_out(latest_task, db_session):
logger.debug(f"{task_name} is already being performed. Skipping.")
return False
return True
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis) -> None:
"""This function is likely to move in the worker refactor happening next."""
key = key_bytes.decode("utf-8")

View File

@ -1,224 +0,0 @@
import logging
import time
from datetime import datetime
import dask
from dask.distributed import Client
from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy.orm import Session
from danswer.background.indexing.dask_utils import ResourceLogger
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_PERMISSIONS_APP_NAME
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.models import PermissionSyncStatus
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import NUM_PERMISSION_WORKERS
from ee.danswer.connectors.factory import CONNECTOR_PERMISSION_FUNC_MAP
from ee.danswer.db.connector import fetch_sources_with_connectors
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
from ee.danswer.db.permission_sync import create_perm_sync
from ee.danswer.db.permission_sync import expire_perm_sync_timed_out
from ee.danswer.db.permission_sync import get_perm_sync_attempt
from ee.danswer.db.permission_sync import mark_all_inprogress_permission_sync_failed
from shared_configs.configs import LOG_LEVEL
logger = setup_logger()
# If the indexing dies, it's most likely due to resource constraints,
# restarting just delays the eventual failure, not useful to the user
dask.config.set({"distributed.scheduler.allowed-failures": 0})
def cleanup_perm_sync_jobs(
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
# Just reusing the same timeout, fine for now
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
with Session(get_sqlalchemy_engine()) as db_session:
# clean up completed jobs
for (attempt_id, details), job in existing_jobs.items():
perm_sync_attempt = get_perm_sync_attempt(
attempt_id=attempt_id, db_session=db_session
)
# do nothing for ongoing jobs that haven't been stopped
if (
not job.done()
and perm_sync_attempt.status == PermissionSyncStatus.IN_PROGRESS
):
continue
if job.status == "error":
logger.error(job.exception())
job.release()
del existing_jobs_copy[(attempt_id, details)]
# clean up in-progress jobs that were never completed
expire_perm_sync_timed_out(
timeout_hours=timeout_hours,
db_session=db_session,
)
return existing_jobs_copy
def create_group_sync_jobs(
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
client: Client | SimpleJobClient,
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
"""Creates new relational DB group permission sync job for each source that:
- has permission sync enabled
- has at least 1 connector (enabled or paused)
- has no sync already running
"""
existing_jobs_copy = existing_jobs.copy()
sources_w_runs = [
key[1]
for key in existing_jobs_copy.keys()
if isinstance(key[1], DocumentSource)
]
with Session(get_sqlalchemy_engine()) as db_session:
sources_w_connector = fetch_sources_with_connectors(db_session)
for source_type in sources_w_connector:
if source_type not in CONNECTOR_PERMISSION_FUNC_MAP:
continue
if source_type in sources_w_runs:
continue
db_group_fnc, _ = CONNECTOR_PERMISSION_FUNC_MAP[source_type]
perm_sync = create_perm_sync(
source_type=source_type,
group_update=True,
cc_pair_id=None,
db_session=db_session,
)
run = client.submit(db_group_fnc, pure=False)
logger.info(
f"Kicked off group permission sync for source type {source_type}"
)
if run:
existing_jobs_copy[(perm_sync.id, source_type)] = run
return existing_jobs_copy
def create_connector_perm_sync_jobs(
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
client: Client | SimpleJobClient,
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
"""Update Document Index ACL sync job for each cc-pair where:
- source type has permission sync enabled
- has no sync already running
"""
existing_jobs_copy = existing_jobs.copy()
cc_pairs_w_runs = [
key[1]
for key in existing_jobs_copy.keys()
if isinstance(key[1], DocumentSource)
]
with Session(get_sqlalchemy_engine()) as db_session:
sources_w_connector = fetch_sources_with_connectors(db_session)
for source_type in sources_w_connector:
if source_type not in CONNECTOR_PERMISSION_FUNC_MAP:
continue
_, index_sync_fnc = CONNECTOR_PERMISSION_FUNC_MAP[source_type]
cc_pairs = get_cc_pairs_by_source(source_type, db_session)
for cc_pair in cc_pairs:
if cc_pair.id in cc_pairs_w_runs:
continue
perm_sync = create_perm_sync(
source_type=source_type,
group_update=False,
cc_pair_id=cc_pair.id,
db_session=db_session,
)
run = client.submit(index_sync_fnc, cc_pair.id, pure=False)
logger.info(f"Kicked off ACL sync for cc-pair {cc_pair.id}")
if run:
existing_jobs_copy[(perm_sync.id, cc_pair.id)] = run
return existing_jobs_copy
def permission_loop(delay: int = 60, num_workers: int = NUM_PERMISSION_WORKERS) -> None:
client: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED:
cluster_primary = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
# there are warning about high memory usage + "Event loop unresponsive"
# which are not relevant to us since our workers are expected to use a
# lot of memory + involve CPU intensive tasks that will not relinquish
# the event loop
silence_logs=logging.ERROR,
)
client = Client(cluster_primary)
if LOG_LEVEL.lower() == "debug":
client.register_worker_plugin(ResourceLogger())
else:
client = SimpleJobClient(n_workers=num_workers)
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob] = {}
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# Any jobs still in progress on restart must have died
mark_all_inprogress_permission_sync_failed(db_session)
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Running Permission Sync, current UTC time: {start_time_utc}")
if existing_jobs:
logger.debug(
"Found existing permission sync jobs: "
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
)
try:
# TODO turn this on when it works
"""
existing_jobs = cleanup_perm_sync_jobs(existing_jobs=existing_jobs)
existing_jobs = create_group_sync_jobs(
existing_jobs=existing_jobs, client=client
)
existing_jobs = create_connector_perm_sync_jobs(
existing_jobs=existing_jobs, client=client
)
"""
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)
def update__main() -> None:
logger.notice("Starting Permission Syncing Loop")
init_sqlalchemy_engine(POSTGRES_PERMISSIONS_APP_NAME)
permission_loop()
if __name__ == "__main__":
update__main()

View File

@ -1,6 +1,6 @@
def name_user_group_sync_task(user_group_id: int) -> str:
return f"user_group_sync_task__{user_group_id}"
def name_chat_ttl_task(retention_limit_days: int) -> str:
return f"chat_ttl_{retention_limit_days}_days"
def name_sync_external_permissions_task(cc_pair_id: int) -> str:
return f"sync_external_permissions_task__{cc_pair_id}"

View File

@ -1,12 +0,0 @@
from danswer.utils.logger import setup_logger
logger = setup_logger()
def confluence_update_db_group() -> None:
logger.debug("Not yet implemented group sync for confluence, no-op")
def confluence_update_index_acl(cc_pair_id: int) -> None:
logger.debug("Not yet implemented ACL sync for confluence, no-op")

View File

@ -1,8 +0,0 @@
from danswer.configs.constants import DocumentSource
from ee.danswer.connectors.confluence.perm_sync import confluence_update_db_group
from ee.danswer.connectors.confluence.perm_sync import confluence_update_index_acl
CONNECTOR_PERMISSION_FUNC_MAP = {
DocumentSource.CONFLUENCE: (confluence_update_db_group, confluence_update_index_acl)
}

View File

@ -3,6 +3,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import AccessType
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import UserGroup__ConnectorCredentialPair
@ -32,14 +33,30 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
def get_cc_pairs_by_source(
source_type: DocumentSource,
db_session: Session,
source_type: DocumentSource,
only_sync: bool,
) -> list[ConnectorCredentialPair]:
cc_pairs = (
query = (
db_session.query(ConnectorCredentialPair)
.join(ConnectorCredentialPair.connector)
.filter(Connector.source == source_type)
.all()
)
if only_sync:
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
cc_pairs = query.all()
return cc_pairs
def get_all_auto_sync_cc_pairs(
db_session: Session,
) -> list[ConnectorCredentialPair]:
return (
db_session.query(ConnectorCredentialPair)
.where(
ConnectorCredentialPair.access_type == AccessType.SYNC,
)
.all()
)

View File

@ -1,14 +1,47 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import Document
from danswer.access.models import ExternalAccess
from danswer.access.utils import prefix_group_w_source
from danswer.configs.constants import DocumentSource
from danswer.db.models import Document as DbDocument
def fetch_documents_from_ids(
db_session: Session, document_ids: list[str]
) -> Sequence[Document]:
return db_session.scalars(
select(Document).where(Document.id.in_(document_ids))
).all()
def upsert_document_external_perms__no_commit(
db_session: Session,
doc_id: str,
external_access: ExternalAccess,
source_type: DocumentSource,
) -> None:
"""
This sets the permissions for a document in postgres.
NOTE: this will replace any existing external access, it will not do a union
"""
document = db_session.scalars(
select(DbDocument).where(DbDocument.id == doc_id)
).first()
prefixed_external_groups = [
prefix_group_w_source(
ext_group_name=group_id,
source=source_type,
)
for group_id in external_access.external_user_group_ids
]
if not document:
# If the document does not exist, still store the external access
# So that if the document is added later, the external access is already stored
document = DbDocument(
id=doc_id,
semantic_id="",
external_user_emails=external_access.external_user_emails,
external_user_group_ids=prefixed_external_groups,
is_public=external_access.is_public,
)
db_session.add(document)
return
document.external_user_emails = list(external_access.external_user_emails)
document.external_user_group_ids = prefixed_external_groups
document.is_public = external_access.is_public

View File

@ -0,0 +1,77 @@
from collections.abc import Sequence
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.access.utils import prefix_group_w_source
from danswer.configs.constants import DocumentSource
from danswer.db.models import User__ExternalUserGroupId
class ExternalUserGroup(BaseModel):
id: str
user_ids: list[UUID]
def delete_user__ext_group_for_user__no_commit(
db_session: Session,
user_id: UUID,
) -> None:
db_session.execute(
delete(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.user_id == user_id
)
)
def delete_user__ext_group_for_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
) -> None:
db_session.execute(
delete(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.cc_pair_id == cc_pair_id
)
)
def replace_user__ext_group_for_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
group_defs: list[ExternalUserGroup],
source: DocumentSource,
) -> None:
"""
This function clears all existing external user group relations for a given cc_pair_id
and replaces them with the new group definitions.
"""
delete_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
new_external_permissions = [
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=prefix_group_w_source(external_group.id, source),
cc_pair_id=cc_pair_id,
)
for external_group in group_defs
for user_id in external_group.user_ids
]
db_session.add_all(new_external_permissions)
def fetch_external_groups_for_user(
db_session: Session,
user_id: UUID,
) -> Sequence[User__ExternalUserGroupId]:
return db_session.scalars(
select(User__ExternalUserGroupId).where(
User__ExternalUserGroupId.user_id == user_id
)
).all()

View File

@ -1,72 +0,0 @@
from datetime import timedelta
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.models import PermissionSyncRun
from danswer.db.models import PermissionSyncStatus
from danswer.utils.logger import setup_logger
logger = setup_logger()
def mark_all_inprogress_permission_sync_failed(
db_session: Session,
) -> None:
stmt = (
update(PermissionSyncRun)
.where(PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS)
.values(status=PermissionSyncStatus.FAILED)
)
db_session.execute(stmt)
db_session.commit()
def get_perm_sync_attempt(attempt_id: int, db_session: Session) -> PermissionSyncRun:
stmt = select(PermissionSyncRun).where(PermissionSyncRun.id == attempt_id)
try:
return db_session.scalars(stmt).one()
except NoResultFound:
raise ValueError(f"No PermissionSyncRun found with id {attempt_id}")
def expire_perm_sync_timed_out(
timeout_hours: int,
db_session: Session,
) -> None:
cutoff_time = func.now() - timedelta(hours=timeout_hours)
update_stmt = (
update(PermissionSyncRun)
.where(
PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS,
PermissionSyncRun.updated_at < cutoff_time,
)
.values(status=PermissionSyncStatus.FAILED, error_msg="timed out")
)
db_session.execute(update_stmt)
db_session.commit()
def create_perm_sync(
source_type: DocumentSource,
group_update: bool,
cc_pair_id: int | None,
db_session: Session,
) -> PermissionSyncRun:
new_run = PermissionSyncRun(
source_type=source_type,
status=PermissionSyncStatus.IN_PROGRESS,
group_update=group_update,
cc_pair_id=cc_pair_id,
)
db_session.add(new_run)
db_session.commit()
return new_run

View File

@ -199,7 +199,7 @@ def fetch_documents_for_user_group_paginated(
def fetch_user_groups_for_documents(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[int, list[str]]]:
) -> Sequence[tuple[str, list[str]]]:
stmt = (
select(Document.id, func.array_agg(UserGroup.name))
.join(

View File

@ -0,0 +1,19 @@
from typing import Any
from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
logger = setup_logger()
def confluence_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
logger.debug("Not yet implemented ACL sync for confluence, no-op")

View File

@ -0,0 +1,19 @@
from typing import Any
from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
logger = setup_logger()
def confluence_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
logger.debug("Not yet implemented group sync for confluence, no-op")

View File

@ -0,0 +1,148 @@
from collections.abc import Iterator
from typing import Any
from typing import cast
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds,
)
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=5, delay=5, max_delay=30)
logger = setup_logger()
def _fetch_permissions_paginated(
drive_service: Any, drive_file_id: str
) -> Iterator[dict[str, Any]]:
next_token = None
# Check if the file is trashed
# Returning nothing here will cause the external permissions to
# be empty which will get written to vespa (failing shut)
try:
file_metadata = add_retries(
lambda: drive_service.files()
.get(fileId=drive_file_id, fields="id, trashed")
.execute()
)()
except HttpError as e:
if e.resp.status == 404 or e.resp.status == 403:
return
logger.error(f"Failed to fetch permissions: {e}")
raise
if file_metadata.get("trashed", False):
logger.debug(f"File with ID {drive_file_id} is trashed")
return
# Get paginated permissions for the file id
while True:
try:
permissions_resp: dict[str, Any] = add_retries(
lambda: (
drive_service.permissions()
.list(
fileId=drive_file_id,
fields="permissions(id, emailAddress, role, type, domain)",
supportsAllDrives=True,
pageToken=next_token,
)
.execute()
)
)()
except HttpError as e:
if e.resp.status == 404 or e.resp.status == 403:
break
logger.error(f"Failed to fetch permissions: {e}")
raise
for permission in permissions_resp.get("permissions", []):
yield permission
next_token = permissions_resp.get("nextPageToken")
if not next_token:
break
def _fetch_google_permissions_for_document_id(
db_session: Session,
drive_file_id: str,
raw_credentials_json: dict[str, str],
company_google_domains: list[str],
) -> ExternalAccess:
# Authenticate and construct service
google_drive_creds, _ = get_google_drive_creds(
raw_credentials_json, scopes=FETCH_PERMISSIONS_SCOPES
)
if not google_drive_creds.valid:
raise ValueError("Invalid Google Drive credentials")
drive_service = build("drive", "v3", credentials=google_drive_creds)
user_emails: set[str] = set()
group_emails: set[str] = set()
public = False
for permission in _fetch_permissions_paginated(drive_service, drive_file_id):
permission_type = permission["type"]
if permission_type == "user":
user_emails.add(permission["emailAddress"])
elif permission_type == "group":
group_emails.add(permission["emailAddress"])
elif permission_type == "domain":
if permission["domain"] in company_google_domains:
public = True
elif permission_type == "anyone":
public = True
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_emails,
is_public=public,
)
def gdrive_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
it in postgres so that when it gets created later, the permissions are
already populated
"""
for doc in docs_with_additional_info:
ext_access = _fetch_google_permissions_for_document_id(
db_session=db_session,
drive_file_id=doc.additional_info,
raw_credentials_json=cc_pair.credential.credential_json,
company_google_domains=[
cast(dict[str, str], sync_details)["company_domain"]
],
)
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=doc.id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)

View File

@ -0,0 +1,147 @@
from collections.abc import Iterator
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from sqlalchemy.orm import Session
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds,
)
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
logger = setup_logger()
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=5, delay=5, max_delay=30)
def _fetch_groups_paginated(
google_drive_creds: ServiceAccountCredentials | OAuthCredentials,
identity_source: str | None = None,
customer_id: str | None = None,
) -> Iterator[dict[str, Any]]:
# Note that Google Drive does not use of update the user_cache as the user email
# comes directly with the call to fetch the groups, therefore this is not a valid
# place to save on requests
if identity_source is None and customer_id is None:
raise ValueError(
"Either identity_source or customer_id must be provided to fetch groups"
)
cloud_identity_service = build(
"cloudidentity", "v1", credentials=google_drive_creds
)
parent = (
f"identitysources/{identity_source}"
if identity_source
else f"customers/{customer_id}"
)
while True:
try:
groups_resp: dict[str, Any] = add_retries(
lambda: (cloud_identity_service.groups().list(parent=parent).execute())
)()
for group in groups_resp.get("groups", []):
yield group
next_token = groups_resp.get("nextPageToken")
if not next_token:
break
except HttpError as e:
if e.resp.status == 404 or e.resp.status == 403:
break
logger.error(f"Error fetching groups: {e}")
raise
def _fetch_group_members_paginated(
google_drive_creds: ServiceAccountCredentials | OAuthCredentials,
group_name: str,
) -> Iterator[dict[str, Any]]:
cloud_identity_service = build(
"cloudidentity", "v1", credentials=google_drive_creds
)
next_token = None
while True:
try:
membership_info = add_retries(
lambda: (
cloud_identity_service.groups()
.memberships()
.searchTransitiveMemberships(
parent=group_name, pageToken=next_token
)
.execute()
)
)()
for member in membership_info.get("memberships", []):
yield member
next_token = membership_info.get("nextPageToken")
if not next_token:
break
except HttpError as e:
if e.resp.status == 404 or e.resp.status == 403:
break
logger.error(f"Error fetching group members: {e}")
raise
def gdrive_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
docs_with_additional_info: list[DocsWithAdditionalInfo],
sync_details: dict[str, Any],
) -> None:
google_drive_creds, _ = get_google_drive_creds(
cc_pair.credential.credential_json,
scopes=FETCH_GROUPS_SCOPES,
)
danswer_groups: list[ExternalUserGroup] = []
for group in _fetch_groups_paginated(
google_drive_creds,
identity_source=sync_details.get("identity_source"),
customer_id=sync_details.get("customer_id"),
):
# The id is the group email
group_email = group["groupKey"]["id"]
group_member_emails: list[str] = []
for member in _fetch_group_members_paginated(google_drive_creds, group["name"]):
member_keys = member["preferredMemberKey"]
member_emails = [member_key["id"] for member_key in member_keys]
for member_email in member_emails:
group_member_emails.append(member_email)
group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=group_member_emails
)
if group_members:
danswer_groups.append(
ExternalUserGroup(
id=group_email, user_ids=[user.id for user in group_members]
)
)
replace_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=danswer_groups,
source=cc_pair.connector.source,
)

View File

@ -0,0 +1,126 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.permission_sync_function_map import (
DOC_PERMISSIONS_FUNC_MAP,
)
from ee.danswer.external_permissions.permission_sync_function_map import (
FULL_FETCH_PERIOD_IN_SECONDS,
)
from ee.danswer.external_permissions.permission_sync_function_map import (
GROUP_PERMISSIONS_FUNC_MAP,
)
from ee.danswer.external_permissions.permission_sync_utils import (
get_docs_with_additional_info,
)
logger = setup_logger()
def run_permission_sync_entrypoint(
db_session: Session,
cc_pair_id: int,
) -> None:
# TODO: seperate out group and doc sync
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(
f"No permission sync function found for source type: {source_type}"
)
sync_details = cc_pair.auto_sync_options
if sync_details is None:
raise ValueError(f"No auto sync options found for source type: {source_type}")
# If the source type is not polling, we only fetch the permissions every
# _FULL_FETCH_PERIOD_IN_SECONDS seconds
full_fetch_period = FULL_FETCH_PERIOD_IN_SECONDS[source_type]
if full_fetch_period is not None:
last_sync = cc_pair.last_time_perm_sync
if (
last_sync
and (
datetime.now(timezone.utc) - last_sync.replace(tzinfo=timezone.utc)
).total_seconds()
< full_fetch_period
):
return
# Here we run the connector to grab all the ids
# this may grab ids before they are indexed but that is fine because
# we create a document in postgres to hold the permissions info
# until the indexing job has a chance to run
docs_with_additional_info = get_docs_with_additional_info(
db_session=db_session,
cc_pair=cc_pair,
)
# This function updates:
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
logger.debug(f"Syncing groups for {source_type}")
if group_sync_func is not None:
group_sync_func(
db_session,
cc_pair,
docs_with_additional_info,
sync_details,
)
# This function updates:
# - the user_email <-> document mapping
# - the external_user_group_id <-> document mapping
# in postgres without committing
logger.debug(f"Syncing docs for {source_type}")
doc_sync_func(
db_session,
cc_pair,
docs_with_additional_info,
sync_details,
)
# This function fetches the updated access for the documents
# and returns a dictionary of document_ids and access
# This is the access we want to update vespa with
docs_access = get_access_for_documents(
document_ids=[doc.id for doc in docs_with_additional_info],
db_session=db_session,
)
# Then we build the update requests to update vespa
update_reqs = [
UpdateRequest(document_ids=[doc_id], access=doc_access)
for doc_id, doc_access in docs_access.items()
]
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=None,
)
try:
# update vespa
document_index.update(update_reqs)
# update postgres
db_session.commit()
except Exception as e:
logger.error(f"Error updating document index: {e}")
db_session.rollback()

View File

@ -0,0 +1,54 @@
from collections.abc import Callable
from typing import Any
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.models import ConnectorCredentialPair
from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
GroupSyncFuncType = Callable[
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
None,
]
DocSyncFuncType = Callable[
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
None,
]
# These functions update:
# - the user_email <-> document mapping
# - the external_user_group_id <-> document mapping
# in postgres without committing
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
DocumentSource.CONFLUENCE: confluence_doc_sync,
}
# These functions update:
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
DocumentSource.CONFLUENCE: confluence_group_sync,
}
# None means that the connector supports polling from last_time_perm_sync to now
FULL_FETCH_PERIOD_IN_SECONDS: dict[DocumentSource, int | None] = {
# Polling is supported
DocumentSource.GOOGLE_DRIVE: None,
# Polling is not supported so we fetch all doc permissions every 10 minutes
DocumentSource.CONFLUENCE: 10 * 60,
}
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
return source_type in DOC_PERMISSIONS_FUNC_MAP

View File

@ -0,0 +1,56 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import InputType
from danswer.db.models import ConnectorCredentialPair
from danswer.utils.logger import setup_logger
logger = setup_logger()
class DocsWithAdditionalInfo(BaseModel):
id: str
additional_info: Any
def get_docs_with_additional_info(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> list[DocsWithAdditionalInfo]:
# Get all document ids that need their permissions updated
runnable_connector = instantiate_connector(
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.POLL,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
assert isinstance(runnable_connector, PollConnector)
current_time = datetime.now(timezone.utc)
start_time = (
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
if cc_pair.last_time_perm_sync
else 0
)
cc_pair.last_time_perm_sync = current_time
doc_batch_generator = runnable_connector.poll_source(
start=start_time, end=current_time.timestamp()
)
docs_with_additional_info = [
DocsWithAdditionalInfo(id=doc.id, additional_info=doc.additional_info)
for doc_batch in doc_batch_generator
for doc in doc_batch
]
logger.debug(f"Docs with additional info: {len(docs_with_additional_info)}")
return docs_with_additional_info

View File

@ -0,0 +1,166 @@
"""
launch:
- api server
- postgres
- vespa
- model server (this is only needed so the api server can startup, no embedding is done)
Run this script to seed the database with dummy documents.
Then run test_query_times.py to test query times.
"""
import random
from datetime import datetime
from danswer.access.models import DocumentAccess
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.db.engine import get_session_context_manager
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.vespa.index import VespaIndex
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.indexing.models import IndexChunk
from danswer.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
TOTAL_DOC_SETS = 8
TOTAL_ACL_ENTRIES_PER_CATEGORY = 80
def generate_random_embedding(dim: int) -> Embedding:
return [random.uniform(-1, 1) for _ in range(dim)]
def generate_random_identifier() -> str:
return f"dummy_doc_{random.randint(1, 1000)}"
def generate_dummy_chunk(
doc_id: str,
chunk_id: int,
embedding_dim: int,
number_of_acl_entries: int,
number_of_document_sets: int,
) -> DocMetadataAwareIndexChunk:
document = Document(
id=doc_id,
source=DocumentSource.GOOGLE_DRIVE,
sections=[],
metadata={},
semantic_identifier=generate_random_identifier(),
)
chunk = IndexChunk(
chunk_id=chunk_id,
blurb=f"Blurb for chunk {chunk_id} of document {doc_id}.",
content=f"Content for chunk {chunk_id} of document {doc_id}. This is dummy text for testing purposes.",
source_links={},
section_continuation=False,
source_document=document,
title_prefix=f"Title prefix for doc {doc_id}",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
embeddings=ChunkEmbedding(
full_embedding=generate_random_embedding(embedding_dim),
mini_chunk_embeddings=[],
),
title_embedding=generate_random_embedding(embedding_dim),
)
document_set_names = []
for i in range(number_of_document_sets):
document_set_names.append(f"Document Set {i}")
user_emails: set[str | None] = set()
user_groups: set[str] = set()
external_user_emails: set[str] = set()
external_user_group_ids: set[str] = set()
for i in range(number_of_acl_entries):
user_emails.add(f"user_{i}@example.com")
user_groups.add(f"group_{i}")
external_user_emails.add(f"external_user_{i}@example.com")
external_user_group_ids.add(f"external_group_{i}")
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=DocumentAccess(
user_emails=user_emails,
user_groups=user_groups,
external_user_emails=external_user_emails,
external_user_group_ids=external_user_group_ids,
is_public=random.choice([True, False]),
),
document_sets={document_set for document_set in document_set_names},
boost=random.randint(-1, 1),
)
@log_function_time()
def do_insertion(
vespa_index: VespaIndex, all_chunks: list[DocMetadataAwareIndexChunk]
) -> None:
insertion_records = vespa_index.index(all_chunks)
print(f"Indexed {len(insertion_records)} documents.")
print(
f"New documents: {sum(1 for record in insertion_records if not record.already_existed)}"
)
print(
f"Existing documents updated: {sum(1 for record in insertion_records if record.already_existed)}"
)
@log_function_time()
def seed_dummy_docs(
number_of_document_sets: int,
number_of_acl_entries: int,
num_docs: int = 1000,
chunks_per_doc: int = 5,
batch_size: int = 100,
) -> None:
with get_session_context_manager() as db_session:
search_settings = get_current_search_settings(db_session)
index_name = search_settings.index_name
embedding_dim = search_settings.model_dim
vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None)
print(index_name)
all_chunks = []
chunk_count = 0
for doc_num in range(num_docs):
doc_id = f"dummy_doc_{doc_num}_{datetime.now().isoformat()}"
for chunk_num in range(chunks_per_doc):
chunk = generate_dummy_chunk(
doc_id=doc_id,
chunk_id=chunk_num,
embedding_dim=embedding_dim,
number_of_acl_entries=number_of_acl_entries,
number_of_document_sets=number_of_document_sets,
)
all_chunks.append(chunk)
chunk_count += 1
if len(all_chunks) >= chunks_per_doc * batch_size:
do_insertion(vespa_index, all_chunks)
print(
f"Indexed {chunk_count} chunks out of {num_docs * chunks_per_doc}."
)
print(
f"percentage: {chunk_count / (num_docs * chunks_per_doc) * 100:.2f}% \n"
)
all_chunks = []
if all_chunks:
do_insertion(vespa_index, all_chunks)
if __name__ == "__main__":
seed_dummy_docs(
number_of_document_sets=TOTAL_DOC_SETS,
number_of_acl_entries=TOTAL_ACL_ENTRIES_PER_CATEGORY,
num_docs=100000,
chunks_per_doc=5,
batch_size=1000,
)

View File

@ -0,0 +1,122 @@
"""
RUN THIS AFTER SEED_DUMMY_DOCS.PY
"""
import random
import time
from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.db.engine import get_session_context_manager
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.vespa.index import VespaIndex
from danswer.search.models import IndexFilters
from scripts.query_time_check.seed_dummy_docs import TOTAL_ACL_ENTRIES_PER_CATEGORY
from scripts.query_time_check.seed_dummy_docs import TOTAL_DOC_SETS
from shared_configs.model_server_models import Embedding
# make sure these are smaller than TOTAL_ACL_ENTRIES_PER_CATEGORY and TOTAL_DOC_SETS, respectively
NUMBER_OF_ACL_ENTRIES_PER_QUERY = 6
NUMBER_OF_DOC_SETS_PER_QUERY = 2
def get_slowest_99th_percentile(results: list[float]) -> float:
return sorted(results)[int(0.99 * len(results))]
# Generate random filters
def _random_filters() -> IndexFilters:
"""
Generate random filters for the query containing:
- NUMBER_OF_ACL_ENTRIES_PER_QUERY user emails
- NUMBER_OF_ACL_ENTRIES_PER_QUERY groups
- NUMBER_OF_ACL_ENTRIES_PER_QUERY external groups
- NUMBER_OF_DOC_SETS_PER_QUERY document sets
"""
access_control_list = [
f"user_email:user_{random.randint(0, TOTAL_ACL_ENTRIES_PER_CATEGORY - 1)}@example.com",
]
acl_indices = random.sample(
range(TOTAL_ACL_ENTRIES_PER_CATEGORY), NUMBER_OF_ACL_ENTRIES_PER_QUERY
)
for i in acl_indices:
access_control_list.append(f"group:group_{acl_indices[i]}")
access_control_list.append(f"external_group:external_group_{acl_indices[i]}")
doc_sets = []
doc_set_indices = random.sample(
range(TOTAL_DOC_SETS), NUMBER_OF_ACL_ENTRIES_PER_QUERY
)
for i in doc_set_indices:
doc_sets.append(f"document_set:Document Set {doc_set_indices[i]}")
return IndexFilters(
source_type=[DocumentSource.GOOGLE_DRIVE],
document_set=doc_sets,
tags=[],
access_control_list=access_control_list,
)
def test_hybrid_retrieval_times(
number_of_queries: int,
) -> None:
with get_session_context_manager() as db_session:
search_settings = get_current_search_settings(db_session)
index_name = search_settings.index_name
vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None)
# Generate random queries
queries = [f"Random Query {i}" for i in range(number_of_queries)]
# Generate random embeddings
embeddings = [
Embedding([random.random() for _ in range(DOC_EMBEDDING_DIM)])
for _ in range(number_of_queries)
]
total_time = 0.0
results = []
for i in range(number_of_queries):
start_time = time.time()
vespa_index.hybrid_retrieval(
query=queries[i],
query_embedding=embeddings[i],
final_keywords=None,
filters=_random_filters(),
hybrid_alpha=0.5,
time_decay_multiplier=1.0,
num_to_retrieve=50,
offset=0,
title_content_ratio=0.5,
)
end_time = time.time()
query_time = end_time - start_time
total_time += query_time
results.append(query_time)
print(f"Query {i+1}: {query_time:.4f} seconds")
avg_time = total_time / number_of_queries
fast_time = min(results)
slow_time = max(results)
ninety_ninth_percentile = get_slowest_99th_percentile(results)
# Write results to a file
_OUTPUT_PATH = "query_times_results_large_more.txt"
with open(_OUTPUT_PATH, "w") as f:
f.write(f"Average query time: {avg_time:.4f} seconds\n")
f.write(f"Fastest query: {fast_time:.4f} seconds\n")
f.write(f"Slowest query: {slow_time:.4f} seconds\n")
f.write(f"99th percentile: {ninety_ninth_percentile:.4f} seconds\n")
print(f"Results written to {_OUTPUT_PATH}")
print(f"\nAverage query time: {avg_time:.4f} seconds")
print(f"Fastest query: {fast_time:.4f} seconds")
print(f"Slowest query: {max(results):.4f} seconds")
print(f"99th percentile: {get_slowest_99th_percentile(results):.4f} seconds")
if __name__ == "__main__":
test_hybrid_retrieval_times(number_of_queries=1000)

View File

@ -5,6 +5,7 @@ from uuid import uuid4
import requests
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorIndexingStatus
@ -22,7 +23,7 @@ def _cc_pair_creator(
connector_id: int,
credential_id: int,
name: str | None = None,
is_public: bool = True,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
@ -30,7 +31,7 @@ def _cc_pair_creator(
request = {
"name": name,
"is_public": is_public,
"access_type": access_type,
"groups": groups or [],
}
@ -47,7 +48,7 @@ def _cc_pair_creator(
name=name,
connector_id=connector_id,
credential_id=credential_id,
is_public=is_public,
access_type=access_type,
groups=groups or [],
)
@ -56,7 +57,7 @@ class CCPairManager:
@staticmethod
def create_from_scratch(
name: str | None = None,
is_public: bool = True,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
source: DocumentSource = DocumentSource.FILE,
input_type: InputType = InputType.LOAD_STATE,
@ -69,7 +70,7 @@ class CCPairManager:
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config,
is_public=is_public,
is_public=(access_type == AccessType.PUBLIC),
groups=groups,
user_performing_action=user_performing_action,
)
@ -77,7 +78,7 @@ class CCPairManager:
credential_json=credential_json,
name=name,
source=source,
curator_public=is_public,
curator_public=(access_type == AccessType.PUBLIC),
groups=groups,
user_performing_action=user_performing_action,
)
@ -85,7 +86,7 @@ class CCPairManager:
connector_id=connector.id,
credential_id=credential.id,
name=name,
is_public=is_public,
access_type=access_type,
groups=groups,
user_performing_action=user_performing_action,
)
@ -95,7 +96,7 @@ class CCPairManager:
connector_id: int,
credential_id: int,
name: str | None = None,
is_public: bool = True,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
@ -103,7 +104,7 @@ class CCPairManager:
connector_id=connector_id,
credential_id=credential_id,
name=name,
is_public=is_public,
access_type=access_type,
groups=groups,
user_performing_action=user_performing_action,
)
@ -172,7 +173,7 @@ class CCPairManager:
retrieved_cc_pair.name == cc_pair.name
and retrieved_cc_pair.connector.id == cc_pair.connector_id
and retrieved_cc_pair.credential.id == cc_pair.credential_id
and retrieved_cc_pair.public_doc == cc_pair.is_public
and retrieved_cc_pair.access_type == cc_pair.access_type
and set(retrieved_cc_pair.groups) == set(cc_pair.groups)
):
return

View File

@ -3,6 +3,7 @@ from uuid import uuid4
import requests
from danswer.configs.constants import DocumentSource
from danswer.db.enums import AccessType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
@ -22,7 +23,7 @@ def _verify_document_permissions(
) -> None:
acl_keys = set(retrieved_doc["access_control_list"].keys())
print(f"ACL keys: {acl_keys}")
if cc_pair.is_public:
if cc_pair.access_type == AccessType.PUBLIC:
if "PUBLIC" not in acl_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} is public but"
@ -30,10 +31,10 @@ def _verify_document_permissions(
)
if doc_creating_user is not None:
if f"user_id:{doc_creating_user.id}" not in acl_keys:
if f"user_email:{doc_creating_user.email}" not in acl_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} was created by user"
f" {doc_creating_user.id} but does not have the user_id:{doc_creating_user.id} ACL key"
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
)
if group_names is not None:

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel
from pydantic import Field
from danswer.auth.schemas import UserRole
from danswer.db.enums import AccessType
from danswer.search.enums import RecencyBiasSetting
from danswer.server.documents.models import DocumentSource
from danswer.server.documents.models import InputType
@ -67,7 +68,7 @@ class TestCCPair(BaseModel):
name: str
connector_id: int
credential_id: int
is_public: bool
access_type: AccessType
groups: list[int]
documents: list[SimpleTestDocument] = Field(default_factory=list)

View File

@ -5,6 +5,7 @@ the permissions of the curator manipulating connector-credential pairs.
import pytest
from requests.exceptions import HTTPError
from danswer.db.enums import AccessType
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
@ -91,8 +92,8 @@ def test_cc_pair_permissions(reset: None) -> None:
connector_id=connector_1.id,
credential_id=credential_1.id,
name="invalid_cc_pair_1",
access_type=AccessType.PUBLIC,
groups=[user_group_1.id],
is_public=True,
user_performing_action=curator,
)
@ -103,8 +104,8 @@ def test_cc_pair_permissions(reset: None) -> None:
connector_id=connector_1.id,
credential_id=credential_1.id,
name="invalid_cc_pair_2",
access_type=AccessType.PRIVATE,
groups=[user_group_1.id, user_group_2.id],
is_public=False,
user_performing_action=curator,
)
@ -115,8 +116,8 @@ def test_cc_pair_permissions(reset: None) -> None:
connector_id=connector_1.id,
credential_id=credential_1.id,
name="invalid_cc_pair_2",
access_type=AccessType.PRIVATE,
groups=[],
is_public=False,
user_performing_action=curator,
)
@ -129,8 +130,8 @@ def test_cc_pair_permissions(reset: None) -> None:
# connector_id=connector_2.id,
# credential_id=credential_1.id,
# name="invalid_cc_pair_3",
# access_type=AccessType.PRIVATE,
# groups=[user_group_1.id],
# is_public=False,
# user_performing_action=curator,
# )
@ -141,8 +142,8 @@ def test_cc_pair_permissions(reset: None) -> None:
connector_id=connector_1.id,
credential_id=credential_2.id,
name="invalid_cc_pair_4",
access_type=AccessType.PRIVATE,
groups=[user_group_1.id],
is_public=False,
user_performing_action=curator,
)
@ -154,8 +155,8 @@ def test_cc_pair_permissions(reset: None) -> None:
name="valid_cc_pair",
connector_id=connector_1.id,
credential_id=credential_1.id,
access_type=AccessType.PRIVATE,
groups=[user_group_1.id],
is_public=False,
user_performing_action=curator,
)

View File

@ -1,6 +1,7 @@
import pytest
from requests.exceptions import HTTPError
from danswer.db.enums import AccessType
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document_set import DocumentSetManager
@ -47,14 +48,14 @@ def test_doc_set_permissions_setup(reset: None) -> None:
# Admin creates a cc_pair
private_cc_pair = CCPairManager.create_from_scratch(
is_public=False,
access_type=AccessType.PRIVATE,
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
# Admin creates a public cc_pair
public_cc_pair = CCPairManager.create_from_scratch(
is_public=True,
access_type=AccessType.PUBLIC,
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)

View File

@ -1,6 +1,7 @@
"""
This test tests the happy path for curator permissions
"""
from danswer.db.enums import AccessType
from danswer.db.models import UserRole
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
@ -64,8 +65,8 @@ def test_whole_curator_flow(reset: None) -> None:
connector_id=test_connector.id,
credential_id=test_credential.id,
name="curator_test_cc_pair",
access_type=AccessType.PRIVATE,
groups=[user_group_1.id],
is_public=False,
user_performing_action=curator,
)

View File

@ -1,6 +1,6 @@
"use client";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { FetchError, errorHandlingFetcher } from "@/lib/fetcher";
import useSWR, { mutate } from "swr";
import { HealthCheckBanner } from "@/components/health/healthcheck";
@ -12,7 +12,6 @@ import { useFormContext } from "@/components/context/FormContext";
import { getSourceDisplayName } from "@/lib/sources";
import { SourceIcon } from "@/components/SourceIcon";
import { useState } from "react";
import { submitConnector } from "@/components/admin/connectors/ConnectorForm";
import { deleteCredential, linkCredential } from "@/lib/credential";
import { submitFiles } from "./pages/utils/files";
import { submitGoogleSite } from "./pages/utils/google_site";
@ -38,7 +37,8 @@ import {
useGoogleDriveCredentials,
} from "./pages/utils/hooks";
import { Formik } from "formik";
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
import { AccessTypeForm } from "@/components/admin/connectors/AccessTypeForm";
import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTypeGroupSelector";
import NavigationRow from "./NavigationRow";
export interface AdvancedConfig {
@ -46,6 +46,64 @@ export interface AdvancedConfig {
pruneFreq: number;
indexingStart: string;
}
import { Connector, ConnectorBase } from "@/lib/connectors/connectors";
const BASE_CONNECTOR_URL = "/api/manage/admin/connector";
export async function submitConnector<T>(
connector: ConnectorBase<T>,
connectorId?: number,
fakeCredential?: boolean,
isPublicCcpair?: boolean // exclusively for mock credentials, when also need to specify ccpair details
): Promise<{ message: string; isSuccess: boolean; response?: Connector<T> }> {
const isUpdate = connectorId !== undefined;
if (!connector.connector_specific_config) {
connector.connector_specific_config = {} as T;
}
try {
if (fakeCredential) {
const response = await fetch(
"/api/manage/admin/connector-with-mock-credential",
{
method: isUpdate ? "PATCH" : "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ ...connector, is_public: isPublicCcpair }),
}
);
if (response.ok) {
const responseJson = await response.json();
return { message: "Success!", isSuccess: true, response: responseJson };
} else {
const errorData = await response.json();
return { message: `Error: ${errorData.detail}`, isSuccess: false };
}
} else {
const response = await fetch(
BASE_CONNECTOR_URL + (isUpdate ? `/${connectorId}` : ""),
{
method: isUpdate ? "PATCH" : "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(connector),
}
);
if (response.ok) {
const responseJson = await response.json();
return { message: "Success!", isSuccess: true, response: responseJson };
} else {
const errorData = await response.json();
return { message: `Error: ${errorData.detail}`, isSuccess: false };
}
}
} catch (error) {
return { message: `Error: ${error}`, isSuccess: false };
}
}
export default function AddConnector({
connector,
@ -84,10 +142,35 @@ export default function AddConnector({
const { liveGDriveCredential } = useGoogleDriveCredentials();
const { liveGmailCredential } = useGmailCredentials();
const {
data: appCredentialData,
isLoading: isAppCredentialLoading,
error: isAppCredentialError,
} = useSWR<{ client_id: string }, FetchError>(
"/api/manage/admin/connector/google-drive/app-credential",
errorHandlingFetcher
);
const {
data: serviceAccountKeyData,
isLoading: isServiceAccountKeyLoading,
error: isServiceAccountKeyError,
} = useSWR<{ service_account_email: string }, FetchError>(
"/api/manage/admin/connector/google-drive/service-account-key",
errorHandlingFetcher
);
// Check if credential is activated
const credentialActivated =
(connector === "google_drive" && liveGDriveCredential) ||
(connector === "gmail" && liveGmailCredential) ||
(connector === "google_drive" &&
(liveGDriveCredential ||
appCredentialData ||
serviceAccountKeyData ||
currentCredential)) ||
(connector === "gmail" &&
(liveGmailCredential ||
appCredentialData ||
serviceAccountKeyData ||
currentCredential)) ||
currentCredential;
// Check if there are no credentials
@ -159,10 +242,12 @@ export default function AddConnector({
const {
name,
groups,
is_public: isPublic,
access_type,
pruneFreq,
indexingStart,
refreshFreq,
auto_sync_options,
is_public,
...connector_specific_config
} = values;
@ -204,6 +289,7 @@ export default function AddConnector({
advancedConfiguration.refreshFreq,
advancedConfiguration.pruneFreq,
advancedConfiguration.indexingStart,
values.access_type == "public",
name
);
if (response) {
@ -219,7 +305,7 @@ export default function AddConnector({
setPopup,
setSelectedFiles,
name,
isPublic,
access_type == "public",
groups
);
if (response) {
@ -234,15 +320,15 @@ export default function AddConnector({
input_type: connector == "web" ? "load_state" : "poll", // single case
name: name,
source: connector,
is_public: access_type == "public",
refresh_freq: advancedConfiguration.refreshFreq || null,
prune_freq: advancedConfiguration.pruneFreq || null,
indexing_start: advancedConfiguration.indexingStart || null,
is_public: isPublic,
groups: groups,
},
undefined,
credentialActivated ? false : true,
isPublic
access_type == "public"
);
// If no credential
if (!credentialActivated) {
@ -261,8 +347,9 @@ export default function AddConnector({
response.id,
credential?.id!,
name,
isPublic,
groups
access_type,
groups,
auto_sync_options
);
if (linkCredentialResponse.ok) {
onSuccess();
@ -366,11 +453,8 @@ export default function AddConnector({
selectedFiles={selectedFiles}
/>
<IsPublicGroupSelector
removeIndent
formikProps={formikProps}
objectName="Connector"
/>
<AccessTypeForm connector={connector} />
<AccessTypeGroupSelector />
</Card>
)}

View File

@ -6,12 +6,14 @@ import { Logo } from "@/components/Logo";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { credentialTemplates } from "@/lib/connectors/credentials";
import Link from "next/link";
import { useUser } from "@/components/user/UserProvider";
import { useContext } from "react";
export default function Sidebar() {
const { formStep, setFormStep, connector, allowAdvanced, allowCreate } =
useFormContext();
const combinedSettings = useContext(SettingsContext);
const { isLoadingUser, isAdmin } = useUser();
if (!combinedSettings) {
return null;
}
@ -59,7 +61,9 @@ export default function Sidebar() {
className="w-full p-2 bg-white border-border border rounded items-center hover:bg-background-200 cursor-pointer transition-all duration-150 flex gap-x-2"
>
<SettingsIcon className="flex-none " />
<p className="my-auto flex items-center text-sm">Admin Page</p>
<p className="my-auto flex items-center text-sm">
{isAdmin ? "Admin Page" : "Curator Page"}
</p>
</Link>
</div>

View File

@ -222,36 +222,46 @@ export const DriveJsonUploadSection = ({
Found existing app credentials with the following <b>Client ID:</b>
<p className="italic mt-1">{appCredentialData.client_id}</p>
</div>
<div className="mt-4 mb-1">
If you want to update these credentials, delete the existing
credentials through the button below, and then upload a new
credentials JSON.
</div>
<Button
onClick={async () => {
const response = await fetch(
"/api/manage/admin/connector/google-drive/app-credential",
{
method: "DELETE",
}
);
if (response.ok) {
mutate("/api/manage/admin/connector/google-drive/app-credential");
setPopup({
message: "Successfully deleted service account key",
type: "success",
});
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete app credential - ${errorMsg}`,
type: "error",
});
}
}}
>
Delete
</Button>
{isAdmin ? (
<>
<div className="mt-4 mb-1">
If you want to update these credentials, delete the existing
credentials through the button below, and then upload a new
credentials JSON.
</div>
<Button
onClick={async () => {
const response = await fetch(
"/api/manage/admin/connector/google-drive/app-credential",
{
method: "DELETE",
}
);
if (response.ok) {
mutate(
"/api/manage/admin/connector/google-drive/app-credential"
);
setPopup({
message: "Successfully deleted app credentials",
type: "success",
});
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete app credential - ${errorMsg}`,
type: "error",
});
}
}}
>
Delete
</Button>
</>
) : (
<div className="mt-4 mb-1">
To change these credentials, please contact an administrator.
</div>
)}
</div>
);
}

View File

@ -80,7 +80,7 @@ export const submitFiles = async (
connector.id,
credentialId,
name,
isPublic,
isPublic ? "public" : "private",
groups
);
if (!credentialResponse.ok) {

View File

@ -10,6 +10,7 @@ export const submitGoogleSite = async (
refreshFreq: number,
pruneFreq: number,
indexingStart: Date,
is_public: boolean,
name?: string
) => {
const uploadCreateAndTriggerConnector = async () => {
@ -42,6 +43,7 @@ export const submitGoogleSite = async (
base_url: base_url,
zip_path: filePaths[0],
},
is_public: is_public,
refresh_freq: refreshFreq,
prune_freq: pruneFreq,
indexing_start: indexingStart,

View File

@ -155,7 +155,7 @@ export const DocumentSetCreationForm = ({
// Filter visible cc pairs
const visibleCcPairs = localCcPairs.filter(
(ccPair) =>
ccPair.public_doc ||
ccPair.access_type === "public" ||
(ccPair.groups.length > 0 &&
props.values.groups.every((group) =>
ccPair.groups.includes(group)
@ -228,7 +228,7 @@ export const DocumentSetCreationForm = ({
// Filter non-visible cc pairs
const nonVisibleCcPairs = localCcPairs.filter(
(ccPair) =>
!ccPair.public_doc &&
!(ccPair.access_type === "public") &&
(ccPair.groups.length === 0 ||
!props.values.groups.every((group) =>
ccPair.groups.includes(group)

View File

@ -232,7 +232,7 @@ function ConnectorRow({
<TableCell>{getActivityBadge()}</TableCell>
{isPaidEnterpriseFeaturesEnabled && (
<TableCell>
{ccPairsIndexingStatus.public_doc ? (
{ccPairsIndexingStatus.access_type === "public" ? (
<Badge
size="md"
color={isEditable ? "green" : "gray"}
@ -334,7 +334,8 @@ export function CCPairIndexingStatusTable({
(status) =>
status.cc_pair_status === ConnectorCredentialPairStatus.ACTIVE
).length,
public: statuses.filter((status) => status.public_doc).length,
public: statuses.filter((status) => status.access_type === "public")
.length,
totalDocsIndexed: statuses.reduce(
(sum, status) => sum + status.docs_indexed,
0
@ -420,7 +421,7 @@ export function CCPairIndexingStatusTable({
credential_json: {},
admin_public: false,
},
public_doc: true,
access_type: "public",
docs_indexed: 1000,
last_success: "2023-07-01T12:00:00Z",
last_finished_status: "success",

View File

@ -16,7 +16,7 @@ export const ConnectorEditor = ({
<div className="mb-3 flex gap-2 flex-wrap">
{allCCPairs
// remove public docs, since they don't make sense as part of a group
.filter((ccPair) => !ccPair.public_doc)
.filter((ccPair) => !(ccPair.access_type === "public"))
.map((ccPair) => {
const ind = selectedCCPairIds.indexOf(ccPair.cc_pair_id);
let isSelected = ind !== -1;

View File

@ -28,6 +28,11 @@ export const UserGroupCreationForm = ({
}: UserGroupCreationFormProps) => {
const isUpdate = existingUserGroup !== undefined;
// Filter out ccPairs that aren't access_type "private"
const privateCcPairs = ccPairs.filter(
(ccPair) => ccPair.access_type === "private"
);
return (
<Modal className="w-fit" onOutsideClick={onClose}>
<div className="px-8 py-6 bg-background">
@ -96,7 +101,7 @@ export const UserGroupCreationForm = ({
<Divider />
<h2 className="mb-1 font-medium">
Select which connectors this group has access to:
Select which private connectors this group has access to:
</h2>
<p className="mb-3 text-xs">
All documents indexed by the selected connectors will be
@ -104,7 +109,7 @@ export const UserGroupCreationForm = ({
</p>
<ConnectorEditor
allCCPairs={ccPairs}
allCCPairs={privateCcPairs}
selectedCCPairIds={values.cc_pair_ids}
setSetCCPairIds={(ccPairsIds) =>
setFieldValue("cc_pair_ids", ccPairsIds)

View File

@ -76,7 +76,7 @@ export const AddConnectorForm: React.FC<AddConnectorFormProps> = ({
.includes(ccPair.cc_pair_id)
)
// remove public docs, since they don't make sense as part of a group
.filter((ccPair) => !ccPair.public_doc)
.filter((ccPair) => !(ccPair.access_type === "public"))
.map((ccPair) => {
return {
name: ccPair.name?.toString() || "",

View File

@ -0,0 +1,88 @@
import { DefaultDropdown } from "@/components/Dropdown";
import {
AccessType,
ValidAutoSyncSources,
ConfigurableSources,
validAutoSyncSources,
} from "@/lib/types";
import { Text, Title } from "@tremor/react";
import { useUser } from "@/components/user/UserProvider";
import { useField } from "formik";
import { AutoSyncOptions } from "./AutoSyncOptions";
function isValidAutoSyncSource(
value: ConfigurableSources
): value is ValidAutoSyncSources {
return validAutoSyncSources.includes(value as ValidAutoSyncSources);
}
export function AccessTypeForm({
connector,
}: {
connector: ConfigurableSources;
}) {
const [access_type, meta, access_type_helpers] =
useField<AccessType>("access_type");
const isAutoSyncSupported = isValidAutoSyncSource(connector);
const { isLoadingUser, isAdmin } = useUser();
const options = [
{
name: "Private",
value: "private",
description:
"Only users who have expliticly been given access to this connector (through the User Groups page) can access the documents pulled in by this connector",
},
];
if (isAdmin) {
options.push({
name: "Public",
value: "public",
description:
"Everyone with an account on Danswer can access the documents pulled in by this connector",
});
}
if (isAutoSyncSupported && isAdmin) {
options.push({
name: "Auto Sync",
value: "sync",
description:
"We will automatically sync permissions from the source. A document will be searchable in Danswer if and only if the user performing the search has permission to access the document in the source.",
});
}
return (
<div>
<div className="flex gap-x-2 items-center">
<label className="text-text-950 font-medium">Document Access</label>
</div>
<p className="text-sm text-text-500 mb-2">
Control who has access to the documents indexed by this connector.
</p>
{isAdmin && (
<>
<DefaultDropdown
options={options}
selected={access_type.value}
onSelect={(selected) =>
access_type_helpers.setValue(selected as AccessType)
}
includeDefault={false}
/>
{access_type.value === "sync" && isAutoSyncSupported && (
<div className="mt-6">
<AutoSyncOptions
connectorType={connector as ValidAutoSyncSources}
/>
</div>
)}
</>
)}
</div>
);
}

View File

@ -0,0 +1,147 @@
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import React, { useState, useEffect } from "react";
import { FieldArray, ArrayHelpers, ErrorMessage, useField } from "formik";
import { Text, Divider } from "@tremor/react";
import { FiUsers } from "react-icons/fi";
import { UserGroup, User, UserRole } from "@/lib/types";
import { useUserGroups } from "@/lib/hooks";
import { AccessType } from "@/lib/types";
import { useUser } from "@/components/user/UserProvider";
// This should be included for all forms that require groups / public access
// to be set, and access to this / permissioning should be handled within this component itself.
export function AccessTypeGroupSelector({}: {}) {
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const { isAdmin, user, isLoadingUser, isCurator } = useUser();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const [shouldHideContent, setShouldHideContent] = useState(false);
const [access_type, meta, access_type_helpers] =
useField<AccessType>("access_type");
const [groups, groups_meta, groups_helpers] = useField<number[]>("groups");
useEffect(() => {
if (user && userGroups && isPaidEnterpriseFeaturesEnabled) {
const isUserAdmin = user.role === UserRole.ADMIN;
if (!isPaidEnterpriseFeaturesEnabled) {
access_type_helpers.setValue("public");
return;
}
if (!isUserAdmin) {
access_type_helpers.setValue("private");
}
if (userGroups.length === 1 && !isUserAdmin) {
groups_helpers.setValue([userGroups[0].id]);
setShouldHideContent(true);
} else if (access_type.value !== "private") {
groups_helpers.setValue([]);
setShouldHideContent(false);
} else {
setShouldHideContent(false);
}
}
}, [user, userGroups, access_type.value]);
if (isLoadingUser || userGroupsIsLoading) {
return <div>Loading...</div>;
}
if (!isPaidEnterpriseFeaturesEnabled) {
return null;
}
if (shouldHideContent) {
return (
<>
{userGroups && (
<div className="mb-1 font-medium text-base">
This Connector will be assigned to group <b>{userGroups[0].name}</b>
.
</div>
)}
</>
);
}
return (
<div>
{(access_type.value === "private" || isCurator) &&
userGroups &&
userGroups?.length > 0 && (
<>
<Divider />
<div className="flex mt-4 gap-x-2 items-center">
<div className="block font-medium text-base">
Assign group access for this Connector
</div>
</div>
{userGroupsIsLoading ? (
<div className="animate-pulse bg-gray-200 h-8 w-32 rounded"></div>
) : (
<Text className="mb-3">
{isAdmin ? (
<>
This Connector will be visible/accessible by the groups
selected below
</>
) : (
<>
Curators must select one or more groups to give access to
this Connector
</>
)}
</Text>
)}
<FieldArray
name="groups"
render={(arrayHelpers: ArrayHelpers) => (
<div className="flex gap-2 flex-wrap mb-4">
{userGroupsIsLoading ? (
<div className="animate-pulse bg-gray-200 h-8 w-32 rounded"></div>
) : (
userGroups &&
userGroups.map((userGroup: UserGroup) => {
const ind = groups.value.indexOf(userGroup.id);
let isSelected = ind !== -1;
return (
<div
key={userGroup.id}
className={`
px-3
py-1
rounded-lg
border
border-border
w-fit
flex
cursor-pointer
${isSelected ? "bg-background-strong" : "hover:bg-hover"}
`}
onClick={() => {
if (isSelected) {
arrayHelpers.remove(ind);
} else {
arrayHelpers.push(userGroup.id);
}
}}
>
<div className="my-auto flex">
<FiUsers className="my-auto mr-2" />{" "}
{userGroup.name}
</div>
</div>
);
})
)}
</div>
)}
/>
<ErrorMessage
name="groups"
component="div"
className="text-error text-sm mt-1"
/>
</>
)}
</div>
);
}

View File

@ -0,0 +1,30 @@
import { TextFormField } from "@/components/admin/connectors/Field";
import { useFormikContext } from "formik";
import { ValidAutoSyncSources } from "@/lib/types";
import { Divider } from "@tremor/react";
import { autoSyncConfigBySource } from "@/lib/connectors/AutoSyncOptionFields";
export function AutoSyncOptions({
connectorType,
}: {
connectorType: ValidAutoSyncSources;
}) {
return (
<div>
<Divider />
<>
{Object.entries(autoSyncConfigBySource[connectorType]).map(
([key, config]) => (
<div key={key} className="mb-4">
<TextFormField
name={`auto_sync_options.${key}`}
label={config.label}
subtext={config.subtext}
/>
</div>
)
)}
</>
</div>
);
}

View File

@ -1,382 +0,0 @@
"use client";
import React, { useState } from "react";
import { Formik, Form } from "formik";
import * as Yup from "yup";
import { Popup, usePopup } from "./Popup";
import { ValidInputTypes, ValidSources } from "@/lib/types";
import { deleteConnectorIfExistsAndIsUnlinked } from "@/lib/connector";
import { FormBodyBuilder, RequireAtLeastOne } from "./types";
import { BooleanFormField, TextFormField } from "./Field";
import { createCredential, linkCredential } from "@/lib/credential";
import { useSWRConfig } from "swr";
import { Button, Divider } from "@tremor/react";
import IsPublicField from "./IsPublicField";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { Connector, ConnectorBase } from "@/lib/connectors/connectors";
const BASE_CONNECTOR_URL = "/api/manage/admin/connector";
export async function submitConnector<T>(
connector: ConnectorBase<T>,
connectorId?: number,
fakeCredential?: boolean,
isPublicCcpair?: boolean // exclusively for mock credentials, when also need to specify ccpair details
): Promise<{ message: string; isSuccess: boolean; response?: Connector<T> }> {
const isUpdate = connectorId !== undefined;
if (!connector.connector_specific_config) {
connector.connector_specific_config = {} as T;
}
try {
if (fakeCredential) {
const response = await fetch(
"/api/manage/admin/connector-with-mock-credential",
{
method: isUpdate ? "PATCH" : "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ ...connector, is_public: isPublicCcpair }),
}
);
if (response.ok) {
const responseJson = await response.json();
return { message: "Success!", isSuccess: true, response: responseJson };
} else {
const errorData = await response.json();
return { message: `Error: ${errorData.detail}`, isSuccess: false };
}
} else {
const response = await fetch(
BASE_CONNECTOR_URL + (isUpdate ? `/${connectorId}` : ""),
{
method: isUpdate ? "PATCH" : "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(connector),
}
);
if (response.ok) {
const responseJson = await response.json();
return { message: "Success!", isSuccess: true, response: responseJson };
} else {
const errorData = await response.json();
return { message: `Error: ${errorData.detail}`, isSuccess: false };
}
}
} catch (error) {
return { message: `Error: ${error}`, isSuccess: false };
}
}
const CCPairNameHaver = Yup.object().shape({
cc_pair_name: Yup.string().required("Please enter a name for the connector"),
});
interface BaseProps<T extends Yup.AnyObject> {
nameBuilder: (values: T) => string;
ccPairNameBuilder?: (values: T) => string | null;
source: ValidSources;
inputType: ValidInputTypes;
// if specified, will automatically try and link the credential
credentialId?: number;
// If both are specified, will render formBody and then formBodyBuilder
formBody?: JSX.Element | null;
formBodyBuilder?: FormBodyBuilder<T>;
validationSchema: Yup.ObjectSchema<T>;
validate?: (values: T) => Record<string, string>;
initialValues: T;
onSubmit?: (
isSuccess: boolean,
responseJson: Connector<T> | undefined
) => void;
refreshFreq?: number;
pruneFreq?: number;
indexingStart?: Date;
// If specified, then we will create an empty credential and associate
// the connector with it. If credentialId is specified, then this will be ignored
shouldCreateEmptyCredentialForConnector?: boolean;
}
type ConnectorFormProps<T extends Yup.AnyObject> = RequireAtLeastOne<
BaseProps<T>,
"formBody" | "formBodyBuilder"
>;
export function ConnectorForm<T extends Yup.AnyObject>({
nameBuilder,
ccPairNameBuilder,
source,
inputType,
credentialId,
formBody,
formBodyBuilder,
validationSchema,
validate,
initialValues,
refreshFreq,
pruneFreq,
indexingStart,
onSubmit,
shouldCreateEmptyCredentialForConnector,
}: ConnectorFormProps<T>): JSX.Element {
const { mutate } = useSWRConfig();
const { popup, setPopup } = usePopup();
// only show this option for EE, since groups are not supported in CE
const showNonPublicOption = usePaidEnterpriseFeaturesEnabled();
const shouldHaveNameInput = credentialId !== undefined && !ccPairNameBuilder;
const ccPairNameInitialValue = shouldHaveNameInput
? { cc_pair_name: "" }
: {};
const publicOptionInitialValue = showNonPublicOption
? { is_public: false }
: {};
let finalValidationSchema =
validationSchema as Yup.ObjectSchema<Yup.AnyObject>;
if (shouldHaveNameInput) {
finalValidationSchema = finalValidationSchema.concat(CCPairNameHaver);
}
if (showNonPublicOption) {
finalValidationSchema = finalValidationSchema.concat(
Yup.object().shape({
is_public: Yup.boolean(),
})
);
}
return (
<>
{popup}
<Formik
initialValues={{
...publicOptionInitialValue,
...ccPairNameInitialValue,
...initialValues,
}}
validationSchema={finalValidationSchema}
validate={validate}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
const connectorName = nameBuilder(values);
const connectorConfig = Object.fromEntries(
Object.keys(initialValues).map((key) => [key, values[key]])
) as T;
// best effort check to see if existing connector exists
// delete it if:
// 1. it exists
// 2. AND it has no credentials linked to it
// If the ^ are true, that means things have gotten into a bad
// state, and we should delete the connector to recover
const errorMsg = await deleteConnectorIfExistsAndIsUnlinked({
source,
name: connectorName,
});
if (errorMsg) {
setPopup({
message: `Unable to delete existing connector - ${errorMsg}`,
type: "error",
});
return;
}
const { message, isSuccess, response } = await submitConnector<T>({
name: connectorName,
source,
input_type: inputType,
connector_specific_config: connectorConfig,
refresh_freq: refreshFreq || 0,
prune_freq: pruneFreq ?? null,
indexing_start: indexingStart || null,
});
if (!isSuccess || !response) {
setPopup({ message, type: "error" });
formikHelpers.setSubmitting(false);
return;
}
let credentialIdToLinkTo = credentialId;
// create empty credential if specified
if (
shouldCreateEmptyCredentialForConnector &&
credentialIdToLinkTo === undefined
) {
const createCredentialResponse = await createCredential({
credential_json: {},
admin_public: true,
source: source,
});
if (!createCredentialResponse.ok) {
const errorMsg = await createCredentialResponse.text();
setPopup({
message: `Error creating credential for CC Pair - ${errorMsg}`,
type: "error",
});
formikHelpers.setSubmitting(false);
return;
}
credentialIdToLinkTo = (await createCredentialResponse.json()).id;
}
if (credentialIdToLinkTo !== undefined) {
const ccPairName = ccPairNameBuilder
? ccPairNameBuilder(values)
: values.cc_pair_name;
const linkCredentialResponse = await linkCredential(
response.id,
credentialIdToLinkTo,
ccPairName as string,
values.is_public
);
if (!linkCredentialResponse.ok) {
const linkCredentialErrorMsg =
await linkCredentialResponse.text();
setPopup({
message: `Error linking credential - ${linkCredentialErrorMsg}`,
type: "error",
});
formikHelpers.setSubmitting(false);
return;
}
}
mutate("/api/manage/admin/connector/indexing-status");
setPopup({ message, type: isSuccess ? "success" : "error" });
formikHelpers.setSubmitting(false);
if (isSuccess) {
formikHelpers.resetForm();
}
if (onSubmit) {
onSubmit(isSuccess, response);
}
}}
>
{({ isSubmitting, values }) => (
<Form>
{shouldHaveNameInput && (
<TextFormField
name="cc_pair_name"
label="Connector Name"
autoCompleteDisabled={true}
subtext={`A descriptive name for the connector. This will be used to identify the connector in the Admin UI.`}
/>
)}
{formBody && formBody}
{formBodyBuilder && formBodyBuilder(values)}
{showNonPublicOption && (
<>
<Divider />
<IsPublicField />
<Divider />
</>
)}
<div className="flex">
<Button
type="submit"
size="xs"
color="green"
disabled={isSubmitting}
className="mx-auto w-64"
>
Connect
</Button>
</div>
</Form>
)}
</Formik>
</>
);
}
interface UpdateConnectorBaseProps<T extends Yup.AnyObject> {
nameBuilder?: (values: T) => string;
existingConnector: Connector<T>;
// If both are specified, uses formBody
formBody?: JSX.Element | null;
formBodyBuilder?: FormBodyBuilder<T>;
validationSchema: Yup.ObjectSchema<T>;
onSubmit?: (isSuccess: boolean, responseJson?: Connector<T>) => void;
}
type UpdateConnectorFormProps<T extends Yup.AnyObject> = RequireAtLeastOne<
UpdateConnectorBaseProps<T>,
"formBody" | "formBodyBuilder"
>;
export function UpdateConnectorForm<T extends Yup.AnyObject>({
nameBuilder,
existingConnector,
formBody,
formBodyBuilder,
validationSchema,
onSubmit,
}: UpdateConnectorFormProps<T>): JSX.Element {
const [popup, setPopup] = useState<{
message: string;
type: "success" | "error";
} | null>(null);
return (
<>
{popup && <Popup message={popup.message} type={popup.type} />}
<Formik
initialValues={existingConnector.connector_specific_config}
validationSchema={validationSchema}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
const { message, isSuccess, response } = await submitConnector<T>(
{
name: nameBuilder ? nameBuilder(values) : existingConnector.name,
source: existingConnector.source,
input_type: existingConnector.input_type,
connector_specific_config: values,
refresh_freq: existingConnector.refresh_freq,
prune_freq: existingConnector.prune_freq,
indexing_start: existingConnector.indexing_start,
},
existingConnector.id
);
setPopup({ message, type: isSuccess ? "success" : "error" });
formikHelpers.setSubmitting(false);
if (isSuccess) {
formikHelpers.resetForm();
}
setTimeout(() => {
setPopup(null);
}, 4000);
if (onSubmit) {
onSubmit(isSuccess, response);
}
}}
>
{({ isSubmitting, values }) => (
<Form>
{formBody ? formBody : formBodyBuilder && formBodyBuilder(values)}
<div className="flex">
<Button
type="submit"
color="green"
size="xs"
disabled={isSubmitting}
className="mx-auto w-64"
>
Update
</Button>
</div>
</Form>
)}
</Formik>
</>
);
}

View File

@ -28,7 +28,6 @@ import {
ConfluenceCredentialJson,
Credential,
} from "@/lib/connectors/credentials";
import { UserGroup } from "@/lib/types"; // Added this import
export default function CredentialSection({
ccPair,

View File

@ -232,7 +232,7 @@ export default function CreateCredential({
setShowAdvancedOptions={setShowAdvancedOptions}
/>
)}
{showAdvancedOptions && (
{(showAdvancedOptions || !isAdmin) && (
<IsPublicGroupSelector
formikProps={formikProps}
objectName="credential"

View File

@ -0,0 +1,46 @@
import { ValidAutoSyncSources } from "@/lib/types";
// The first key is the connector type, and the second key is the field name
export const autoSyncConfigBySource: Record<
ValidAutoSyncSources,
Record<
string,
{
label: string;
subtext: JSX.Element;
}
>
> = {
google_drive: {
customer_id: {
label: "Google Workspace Customer ID",
subtext: (
<>
The unique identifier for your Google Workspace account. To find this,
checkout the{" "}
<a
href="https://support.google.com/cloudidentity/answer/10070793"
target="_blank"
className="text-link"
>
guide from Google
</a>
.
</>
),
},
company_domain: {
label: "Google Workspace Company Domain",
subtext: (
<>
The email domain for your Google Workspace account.
<br />
<br />
For example, if your email provided through Google Workspace looks
something like chris@danswer.ai, then your company domain is{" "}
<b>danswer.ai</b>
</>
),
},
},
};

View File

@ -29,6 +29,7 @@ export interface Option {
export interface SelectOption extends Option {
type: "select";
options?: StringWithDescription[];
default?: string;
}
export interface ListOption extends Option {
@ -599,7 +600,7 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
{ name: "articles", value: "articles" },
{ name: "tickets", value: "tickets" },
],
default: 0,
default: "articles",
},
],
},

View File

@ -1,4 +1,5 @@
import { CredentialBase } from "./connectors/credentials";
import { AccessType } from "@/lib/types";
export async function createCredential(credential: CredentialBase<any>) {
return await fetch(`/api/manage/credential`, {
@ -47,8 +48,9 @@ export function linkCredential(
connectorId: number,
credentialId: number,
name?: string,
isPublic?: boolean,
groups?: number[]
accessType?: AccessType,
groups?: number[],
autoSyncOptions?: Record<string, any>
) {
return fetch(
`/api/manage/connector/${connectorId}/credential/${credentialId}`,
@ -59,8 +61,9 @@ export function linkCredential(
},
body: JSON.stringify({
name: name || null,
is_public: isPublic !== undefined ? isPublic : true,
access_type: accessType !== undefined ? accessType : "public",
groups: groups || null,
auto_sync_options: autoSyncOptions || null,
}),
}
);

View File

@ -49,6 +49,7 @@ export type ValidStatuses =
| "not_started";
export type TaskStatus = "PENDING" | "STARTED" | "SUCCESS" | "FAILURE";
export type Feedback = "like" | "dislike";
export type AccessType = "public" | "private" | "sync";
export type SessionType = "Chat" | "Search" | "Slack";
export interface DocumentBoostStatus {
@ -90,7 +91,7 @@ export interface ConnectorIndexingStatus<
cc_pair_status: ConnectorCredentialPairStatus;
connector: Connector<ConnectorConfigType>;
credential: Credential<ConnectorCredentialType>;
public_doc: boolean;
access_type: AccessType;
owner: string;
groups: number[];
last_finished_status: ValidStatuses | null;
@ -258,3 +259,7 @@ export type ConfigurableSources = Exclude<
ValidSources,
"not_applicable" | "ingestion_api"
>;
// The sources that have auto-sync support on the backend
export const validAutoSyncSources = ["google_drive"] as const;
export type ValidAutoSyncSources = (typeof validAutoSyncSources)[number];