diff --git a/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py b/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py index f284c7b4b..068342095 100644 --- a/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py +++ b/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py @@ -1,7 +1,7 @@ """Add last synced and last modified to document table Revision ID: 52a219fb5233 -Revises: f17bf3b0d9f1 +Revises: f7e58d357687 Create Date: 2024-08-28 17:40:46.077470 """ diff --git a/backend/alembic/versions/61ff3651add4_add_permission_syncing.py b/backend/alembic/versions/61ff3651add4_add_permission_syncing.py new file mode 100644 index 000000000..697e1060e --- /dev/null +++ b/backend/alembic/versions/61ff3651add4_add_permission_syncing.py @@ -0,0 +1,162 @@ +"""Add Permission Syncing + +Revision ID: 61ff3651add4 +Revises: 1b8206b29c5d +Create Date: 2024-09-05 13:57:11.770413 + +""" +import fastapi_users_db_sqlalchemy + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "61ff3651add4" +down_revision = "1b8206b29c5d" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Admin user who set up connectors will lose access to the docs temporarily + # only way currently to give back access is to rerun from beginning + op.add_column( + "connector_credential_pair", + sa.Column( + "access_type", + sa.String(), + nullable=True, + ), + ) + op.execute( + "UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true" + ) + op.execute( + "UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false" + ) + op.alter_column("connector_credential_pair", "access_type", nullable=False) + + op.add_column( + "connector_credential_pair", + sa.Column( + "auto_sync_options", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "connector_credential_pair", + sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True), + ) + op.drop_column("connector_credential_pair", "is_public") + + op.add_column( + "document", + sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True), + ) + op.add_column( + "document", + sa.Column( + "external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True + ), + ) + op.add_column( + "document", + sa.Column("is_public", sa.Boolean(), nullable=True), + ) + + op.create_table( + "user__external_user_group_id", + sa.Column( + "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False + ), + sa.Column("external_user_group_id", sa.String(), nullable=False), + sa.Column("cc_pair_id", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("user_id"), + ) + + op.drop_column("external_permission", "user_id") + op.drop_column("email_to_external_user_cache", "user_id") + op.drop_table("permission_sync_run") + op.drop_table("external_permission") + op.drop_table("email_to_external_user_cache") + + +def downgrade() -> None: + op.add_column( + "connector_credential_pair", + sa.Column("is_public", sa.BOOLEAN(), nullable=True), + ) + op.execute( + "UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')" + ) + op.alter_column("connector_credential_pair", "is_public", nullable=False) + + op.drop_column("connector_credential_pair", "auto_sync_options") + op.drop_column("connector_credential_pair", "access_type") + op.drop_column("connector_credential_pair", "last_time_perm_sync") + op.drop_column("document", "external_user_emails") + op.drop_column("document", "external_user_group_ids") + op.drop_column("document", "is_public") + + op.drop_table("user__external_user_group_id") + + # Drop the enum type at the end of the downgrade + op.create_table( + "permission_sync_run", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "source_type", + sa.String(), + nullable=False, + ), + sa.Column("update_type", sa.String(), nullable=False), + sa.Column("cc_pair_id", sa.Integer(), nullable=True), + sa.Column( + "status", + sa.String(), + nullable=False, + ), + sa.Column("error_msg", sa.Text(), nullable=True), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["cc_pair_id"], + ["connector_credential_pair.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "external_permission", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("user_email", sa.String(), nullable=False), + sa.Column( + "source_type", + sa.String(), + nullable=False, + ), + sa.Column("external_permission_group", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "email_to_external_user_cache", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("external_user_id", sa.String(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("user_email", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) diff --git a/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py b/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py index c85bb68a3..e67d31b81 100644 --- a/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py +++ b/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py @@ -1,7 +1,7 @@ """standard answer match_regex flag Revision ID: efb35676026c -Revises: 52a219fb5233 +Revises: 0ebb1d516877 Create Date: 2024-09-11 13:55:46.101149 """ diff --git a/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py b/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py index 2d8e7402e..c2a131d60 100644 --- a/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py +++ b/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py @@ -1,7 +1,7 @@ """add has_web_login column to user Revision ID: f7e58d357687 -Revises: bceb1e139447 +Revises: ba98eba0f66a Create Date: 2024-09-07 20:20:54.522620 """ diff --git a/backend/danswer/access/access.py b/backend/danswer/access/access.py index 9088ddf84..7c8790995 100644 --- a/backend/danswer/access/access.py +++ b/backend/danswer/access/access.py @@ -1,7 +1,7 @@ from sqlalchemy.orm import Session from danswer.access.models import DocumentAccess -from danswer.access.utils import prefix_user +from danswer.access.utils import prefix_user_email from danswer.configs.constants import PUBLIC_DOC_PAT from danswer.db.document import get_access_info_for_document from danswer.db.document import get_access_info_for_documents @@ -18,10 +18,13 @@ def _get_access_for_document( document_id=document_id, ) - if not info: - return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) - - return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2]) + return DocumentAccess.build( + user_emails=info[1] if info and info[1] else [], + user_groups=[], + external_user_emails=[], + external_user_group_ids=[], + is_public=info[2] if info else False, + ) def get_access_for_document( @@ -34,6 +37,16 @@ def get_access_for_document( return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore +def get_null_document_access() -> DocumentAccess: + return DocumentAccess( + user_emails=set(), + user_groups=set(), + is_public=False, + external_user_emails=set(), + external_user_group_ids=set(), + ) + + def _get_access_for_documents( document_ids: list[str], db_session: Session, @@ -42,13 +55,27 @@ def _get_access_for_documents( db_session=db_session, document_ids=document_ids, ) - return { - document_id: DocumentAccess.build( - user_ids=user_ids, user_groups=[], is_public=is_public + doc_access = { + document_id: DocumentAccess( + user_emails=set([email for email in user_emails if email]), + # MIT version will wipe all groups and external groups on update + user_groups=set(), + is_public=is_public, + external_user_emails=set(), + external_user_group_ids=set(), ) - for document_id, user_ids, is_public in document_access_info + for document_id, user_emails, is_public in document_access_info } + # Sometimes the document has not be indexed by the indexing job yet, in those cases + # the document does not exist and so we use least permissive. Specifically the EE version + # checks the MIT version permissions and creates a superset. This ensures that this flow + # does not fail even if the Document has not yet been indexed. + for doc_id in document_ids: + if doc_id not in doc_access: + doc_access[doc_id] = get_null_document_access() + return doc_access + def get_access_for_documents( document_ids: list[str], @@ -70,7 +97,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: matches one entry in the returned set. """ if user: - return {prefix_user(str(user.id)), PUBLIC_DOC_PAT} + return {prefix_user_email(user.email), PUBLIC_DOC_PAT} return {PUBLIC_DOC_PAT} diff --git a/backend/danswer/access/models.py b/backend/danswer/access/models.py index a87e2d94f..af5a021ca 100644 --- a/backend/danswer/access/models.py +++ b/backend/danswer/access/models.py @@ -1,30 +1,72 @@ from dataclasses import dataclass -from uuid import UUID -from danswer.access.utils import prefix_user +from danswer.access.utils import prefix_external_group +from danswer.access.utils import prefix_user_email from danswer.access.utils import prefix_user_group from danswer.configs.constants import PUBLIC_DOC_PAT @dataclass(frozen=True) -class DocumentAccess: - user_ids: set[str] # stringified UUIDs - user_groups: set[str] # names of user groups associated with this document +class ExternalAccess: + # Emails of external users with access to the doc externally + external_user_emails: set[str] + # Names or external IDs of groups with access to the doc + external_user_group_ids: set[str] + # Whether the document is public in the external system or Danswer is_public: bool - def to_acl(self) -> list[str]: - return ( - [prefix_user(user_id) for user_id in self.user_ids] + +@dataclass(frozen=True) +class DocumentAccess(ExternalAccess): + # User emails for Danswer users, None indicates admin + user_emails: set[str | None] + # Names of user groups associated with this document + user_groups: set[str] + + def to_acl(self) -> set[str]: + return set( + [ + prefix_user_email(user_email) + for user_email in self.user_emails + if user_email + ] + [prefix_user_group(group_name) for group_name in self.user_groups] + + [ + prefix_user_email(user_email) + for user_email in self.external_user_emails + ] + + [ + # The group names are already prefixed by the source type + # This adds an additional prefix of "external_group:" + prefix_external_group(group_name) + for group_name in self.external_user_group_ids + ] + ([PUBLIC_DOC_PAT] if self.is_public else []) ) @classmethod def build( - cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool + cls, + user_emails: list[str | None], + user_groups: list[str], + external_user_emails: list[str], + external_user_group_ids: list[str], + is_public: bool, ) -> "DocumentAccess": return cls( - user_ids={str(user_id) for user_id in user_ids if user_id}, + external_user_emails={ + prefix_user_email(external_email) + for external_email in external_user_emails + }, + external_user_group_ids={ + prefix_external_group(external_group_id) + for external_group_id in external_user_group_ids + }, + user_emails={ + prefix_user_email(user_email) + for user_email in user_emails + if user_email + }, user_groups=set(user_groups), is_public=is_public, ) diff --git a/backend/danswer/access/utils.py b/backend/danswer/access/utils.py index 060560eae..82abf9785 100644 --- a/backend/danswer/access/utils.py +++ b/backend/danswer/access/utils.py @@ -1,10 +1,24 @@ -def prefix_user(user_id: str) -> str: - """Prefixes a user ID to eliminate collision with group names. - This assumes that groups are prefixed with a different prefix.""" - return f"user_id:{user_id}" +from danswer.configs.constants import DocumentSource + + +def prefix_user_email(user_email: str) -> str: + """Prefixes a user email to eliminate collision with group names. + This applies to both a Danswer user and an External user, this is to make the query time + more efficient""" + return f"user_email:{user_email}" def prefix_user_group(user_group_name: str) -> str: - """Prefixes a user group name to eliminate collision with user IDs. + """Prefixes a user group name to eliminate collision with user emails. This assumes that user ids are prefixed with a different prefix.""" return f"group:{user_group_name}" + + +def prefix_external_group(ext_group_name: str) -> str: + """Prefixes an external group name to eliminate collision with user emails / Danswer groups.""" + return f"external_group:{ext_group_name}" + + +def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str: + """External groups may collide across sources, every source needs its own prefix.""" + return f"{source.value.upper()}_{ext_group_name}" diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index aa7c53aa3..cdd9b1d5c 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -128,11 +128,11 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: return runnable_connector = instantiate_connector( - cc_pair.connector.source, - InputType.PRUNE, - cc_pair.connector.connector_specific_config, - cc_pair.credential, - db_session, + db_session=db_session, + source=cc_pair.connector.source, + input_type=InputType.PRUNE, + connector_specific_config=cc_pair.connector.connector_specific_config, + credential=cc_pair.credential, ) all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector( diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 86b428536..a29ddd76c 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -56,11 +56,11 @@ def _get_connector_runner( try: runnable_connector = instantiate_connector( - attempt.connector_credential_pair.connector.source, - task, - attempt.connector_credential_pair.connector.connector_specific_config, - attempt.connector_credential_pair.credential, - db_session, + db_session=db_session, + source=attempt.connector_credential_pair.connector.source, + input_type=task, + connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config, + credential=attempt.connector_credential_pair.credential, ) except Exception as e: logger.exception(f"Unable to instantiate connector due to {e}") diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index 1a3d605d3..42d3b0bd9 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -124,11 +124,11 @@ def identify_connector_class( def instantiate_connector( + db_session: Session, source: DocumentSource, input_type: InputType, connector_specific_config: dict[str, Any], credential: Credential, - db_session: Session, ) -> BaseConnector: connector_class = identify_connector_class(source, input_type) connector = connector_class(**connector_specific_config) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 80674b5a3..bf267ab77 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -6,7 +6,6 @@ from datetime import timezone from enum import Enum from itertools import chain from typing import Any -from typing import cast from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore @@ -21,19 +20,13 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource from danswer.configs.constants import IGNORE_FOR_QA from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder -from danswer.connectors.google_drive.connector_auth import ( - get_google_drive_creds_for_authorized_user, -) -from danswer.connectors.google_drive.connector_auth import ( - get_google_drive_creds_for_service_account, -) +from danswer.connectors.google_drive.connector_auth import get_google_drive_creds from danswer.connectors.google_drive.constants import ( DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, ) from danswer.connectors.google_drive.constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) -from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector @@ -407,42 +400,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector): (2) A credential which holds a service account key JSON file, which can then be used to impersonate any user in the workspace. """ - creds: OAuthCredentials | ServiceAccountCredentials | None = None - new_creds_dict = None - if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials: - access_token_json_str = cast( - str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY] - ) - creds = get_google_drive_creds_for_authorized_user( - token_json_str=access_token_json_str - ) - - # tell caller to update token stored in DB if it has changed - # (e.g. the token has been refreshed) - new_creds_json_str = creds.to_json() if creds else "" - if new_creds_json_str != access_token_json_str: - new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} - - if DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials: - service_account_key_json_str = credentials[ - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY - ] - creds = get_google_drive_creds_for_service_account( - service_account_key_json_str=service_account_key_json_str - ) - - # "Impersonate" a user if one is specified - delegated_user_email = cast( - str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY) - ) - if delegated_user_email: - creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore - - if creds is None: - raise PermissionError( - "Unable to access Google Drive - unknown credential structure." - ) - + creds, new_creds_dict = get_google_drive_creds(credentials) self.creds = creds return new_creds_dict @@ -509,6 +467,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector): file["modifiedTime"] ).astimezone(timezone.utc), metadata={} if text_contents else {IGNORE_FOR_QA: "True"}, + additional_info=file.get("id"), ) ) except Exception as e: diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index 0f47727e6..cc68fec54 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -10,11 +10,13 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore from sqlalchemy.orm import Session +from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import DocumentSource from danswer.configs.constants import KV_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY +from danswer.connectors.google_drive.constants import BASE_SCOPES from danswer.connectors.google_drive.constants import ( DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, ) @@ -22,7 +24,8 @@ from danswer.connectors.google_drive.constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY -from danswer.connectors.google_drive.constants import SCOPES +from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES +from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES from danswer.db.credentials import update_credential_json from danswer.db.models import User from danswer.dynamic_configs.factory import get_dynamic_config_store @@ -34,15 +37,25 @@ from danswer.utils.logger import setup_logger logger = setup_logger() +def build_gdrive_scopes() -> list[str]: + base_scopes: list[str] = BASE_SCOPES + permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES + groups_scopes: list[str] = FETCH_GROUPS_SCOPES + + if ENTERPRISE_EDITION_ENABLED: + return base_scopes + permissions_scopes + groups_scopes + return base_scopes + permissions_scopes + + def _build_frontend_google_drive_redirect() -> str: return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback" def get_google_drive_creds_for_authorized_user( - token_json_str: str, + token_json_str: str, scopes: list[str] = build_gdrive_scopes() ) -> OAuthCredentials | None: creds_json = json.loads(token_json_str) - creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES) + creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes) if creds.valid: return creds @@ -59,18 +72,67 @@ def get_google_drive_creds_for_authorized_user( return None -def get_google_drive_creds_for_service_account( - service_account_key_json_str: str, +def _get_google_drive_creds_for_service_account( + service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes() ) -> ServiceAccountCredentials | None: service_account_key = json.loads(service_account_key_json_str) creds = ServiceAccountCredentials.from_service_account_info( - service_account_key, scopes=SCOPES + service_account_key, scopes=scopes ) if not creds.valid or not creds.expired: creds.refresh(Request()) return creds if creds.valid else None +def get_google_drive_creds( + credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes() +) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: + oauth_creds = None + service_creds = None + new_creds_dict = None + if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials: + access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]) + oauth_creds = get_google_drive_creds_for_authorized_user( + token_json_str=access_token_json_str, scopes=scopes + ) + + # tell caller to update token stored in DB if it has changed + # (e.g. the token has been refreshed) + new_creds_json_str = oauth_creds.to_json() if oauth_creds else "" + if new_creds_json_str != access_token_json_str: + new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} + + elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials: + service_account_key_json_str = credentials[ + DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY + ] + service_creds = _get_google_drive_creds_for_service_account( + service_account_key_json_str=service_account_key_json_str, + scopes=scopes, + ) + + # "Impersonate" a user if one is specified + delegated_user_email = cast( + str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY) + ) + if delegated_user_email: + service_creds = ( + service_creds.with_subject(delegated_user_email) + if service_creds + else None + ) + + creds: ServiceAccountCredentials | OAuthCredentials | None = ( + oauth_creds or service_creds + ) + if creds is None: + raise PermissionError( + "Unable to access Google Drive - unknown credential structure." + ) + + return creds, new_creds_dict + + def verify_csrf(credential_id: int, state: str) -> None: csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id))) if csrf != state: @@ -84,7 +146,7 @@ def get_auth_url(credential_id: int) -> str: credential_json = json.loads(creds_str) flow = InstalledAppFlow.from_client_config( credential_json, - scopes=SCOPES, + scopes=build_gdrive_scopes(), redirect_uri=_build_frontend_google_drive_redirect(), ) auth_url, _ = flow.authorization_url(prompt="consent") @@ -107,7 +169,7 @@ def update_credential_access_tokens( app_credentials = get_google_app_cred() flow = InstalledAppFlow.from_client_config( app_credentials.model_dump(), - scopes=SCOPES, + scopes=build_gdrive_scopes(), redirect_uri=_build_frontend_google_drive_redirect(), ) flow.fetch_token(code=auth_code) diff --git a/backend/danswer/connectors/google_drive/constants.py b/backend/danswer/connectors/google_drive/constants.py index 214bfd5cb..0cca65c13 100644 --- a/backend/danswer/connectors/google_drive/constants.py +++ b/backend/danswer/connectors/google_drive/constants.py @@ -1,7 +1,7 @@ DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens" DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key" DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user" -SCOPES = [ - "https://www.googleapis.com/auth/drive.readonly", - "https://www.googleapis.com/auth/drive.metadata.readonly", -] + +BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"] +FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"] +FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"] diff --git a/backend/danswer/connectors/models.py b/backend/danswer/connectors/models.py index 192aa1b20..7d86d2198 100644 --- a/backend/danswer/connectors/models.py +++ b/backend/danswer/connectors/models.py @@ -113,6 +113,9 @@ class DocumentBase(BaseModel): # The default title is semantic_identifier though unless otherwise specified title: str | None = None from_ingestion_api: bool = False + # Anything else that may be useful that is specific to this particular connector type that other + # parts of the code may need. If you're unsure, this can be left as None + additional_info: Any = None def get_title_for_document_index( self, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index cce45331e..088279620 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -211,7 +211,7 @@ def handle_message( with Session(get_sqlalchemy_engine()) as db_session: if message_info.email: - add_non_web_user_if_not_exists(message_info.email, db_session) + add_non_web_user_if_not_exists(db_session, message_info.email) # first check if we need to respond with a standard answer used_standard_answer = handle_standard_answers( diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index 004b5a754..ec9a3a08e 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import Session from danswer.configs.constants import DocumentSource from danswer.db.connector import fetch_connector_by_id from danswer.db.credentials import fetch_credential_by_id +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.db.models import IndexAttempt @@ -24,6 +25,10 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair from danswer.db.models import UserRole from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit +from ee.danswer.external_permissions.permission_sync_function_map import ( + check_if_valid_sync_source, +) logger = setup_logger() @@ -74,7 +79,7 @@ def _add_user_filters( .correlate(ConnectorCredentialPair) ) else: - where_clause |= ConnectorCredentialPair.is_public == True # noqa: E712 + where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC return stmt.where(where_clause) @@ -94,8 +99,7 @@ def get_connector_credential_pairs( ) # noqa if ids: stmt = stmt.where(ConnectorCredentialPair.id.in_(ids)) - results = db_session.scalars(stmt) - return list(results.all()) + return list(db_session.scalars(stmt).all()) def add_deletion_failure_message( @@ -309,9 +313,9 @@ def associate_default_cc_pair(db_session: Session) -> None: association = ConnectorCredentialPair( connector_id=0, credential_id=0, + access_type=AccessType.PUBLIC, name="DefaultCCPair", status=ConnectorCredentialPairStatus.ACTIVE, - is_public=True, ) db_session.add(association) db_session.commit() @@ -336,8 +340,9 @@ def add_credential_to_connector( connector_id: int, credential_id: int, cc_pair_name: str | None, - is_public: bool, + access_type: AccessType, groups: list[int] | None, + auto_sync_options: dict | None = None, ) -> StatusResponse: connector = fetch_connector_by_id(connector_id, db_session) credential = fetch_credential_by_id(credential_id, user, db_session) @@ -345,6 +350,13 @@ def add_credential_to_connector( if connector is None: raise HTTPException(status_code=404, detail="Connector does not exist") + if access_type == AccessType.SYNC: + if not check_if_valid_sync_source(connector.source): + raise HTTPException( + status_code=400, + detail=f"Connector of type {connector.source} does not support SYNC access type", + ) + if credential is None: error_msg = ( f"Credential {credential_id} does not exist or does not belong to user" @@ -375,12 +387,13 @@ def add_credential_to_connector( credential_id=credential_id, name=cc_pair_name, status=ConnectorCredentialPairStatus.ACTIVE, - is_public=is_public, + access_type=access_type, + auto_sync_options=auto_sync_options, ) db_session.add(association) db_session.flush() # make sure the association has an id - if groups: + if groups and access_type != AccessType.SYNC: _relate_groups_to_cc_pair__no_commit( db_session=db_session, cc_pair_id=association.id, @@ -423,6 +436,10 @@ def remove_credential_from_connector( ) if association is not None: + delete_user__ext_group_for_cc_pair__no_commit( + db_session=db_session, + cc_pair_id=association.id, + ) db_session.delete(association) db_session.commit() return StatusResponse( diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index 0d5bc276b..bffffddd7 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -4,7 +4,6 @@ from collections.abc import Generator from collections.abc import Sequence from datetime import datetime from datetime import timezone -from uuid import UUID from sqlalchemy import and_ from sqlalchemy import delete @@ -17,14 +16,17 @@ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.engine.util import TransactionalContext from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import null from danswer.configs.constants import DEFAULT_BOOST +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.feedback import delete_document_feedback_for_documents__no_commit from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential from danswer.db.models import Document as DbDocument from danswer.db.models import DocumentByConnectorCredentialPair +from danswer.db.models import User from danswer.db.tag import delete_document_tags_for_documents__no_commit from danswer.db.utils import model_to_dict from danswer.document_index.interfaces import DocumentMetadata @@ -186,16 +188,14 @@ def get_document_counts_for_cc_pairs( def get_access_info_for_document( db_session: Session, document_id: str, -) -> tuple[str, list[UUID | None], bool] | None: +) -> tuple[str, list[str | None], bool] | None: """Gets access info for a single document by calling the get_access_info_for_documents function and passing a list with a single document ID. - Args: db_session (Session): The database session to use. document_id (str): The document ID to fetch access info for. - Returns: - Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs, + Optional[Tuple[str, List[str | None], bool]]: A tuple containing the document ID, a list of user emails, and a boolean indicating if the document is globally public, or None if no results are found. """ results = get_access_info_for_documents(db_session, [document_id]) @@ -208,19 +208,27 @@ def get_access_info_for_document( def get_access_info_for_documents( db_session: Session, document_ids: list[str], -) -> Sequence[tuple[str, list[UUID | None], bool]]: +) -> Sequence[tuple[str, list[str | None], bool]]: """Gets back all relevant access info for the given documents. This includes the user_ids for cc pairs that the document is associated with + whether any of the associated cc pairs are intending to make the document globally public. + Returns the list where each element contains: + - Document ID (which is also the ID of the DocumentByConnectorCredentialPair) + - List of emails of Danswer users with direct access to the doc (includes a "None" element if + the connector was set up by an admin when auth was off + - bool for whether the document is public (the document later can also be marked public by + automatic permission sync step) """ + stmt = select( + DocumentByConnectorCredentialPair.id, + func.array_agg(func.coalesce(User.email, null())).label("user_emails"), + func.bool_or(ConnectorCredentialPair.access_type == AccessType.PUBLIC).label( + "public_doc" + ), + ).where(DocumentByConnectorCredentialPair.id.in_(document_ids)) + stmt = ( - select( - DocumentByConnectorCredentialPair.id, - func.array_agg(Credential.user_id).label("user_ids"), - func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"), - ) - .where(DocumentByConnectorCredentialPair.id.in_(document_ids)) - .join( + stmt.join( Credential, DocumentByConnectorCredentialPair.credential_id == Credential.id, ) @@ -233,6 +241,13 @@ def get_access_info_for_documents( == ConnectorCredentialPair.credential_id, ), ) + .outerjoin( + User, + and_( + Credential.user_id == User.id, + ConnectorCredentialPair.access_type != AccessType.SYNC, + ), + ) # don't include CC pairs that are being deleted # NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them .where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING) @@ -278,9 +293,19 @@ def upsert_documents( for doc in seen_documents.values() ] ) - # for now, there are no columns to update. If more metadata is added, then this - # needs to change to an `on_conflict_do_update` - on_conflict_stmt = insert_stmt.on_conflict_do_nothing() + + on_conflict_stmt = insert_stmt.on_conflict_do_update( + index_elements=["id"], # Conflict target + set_={ + "from_ingestion_api": insert_stmt.excluded.from_ingestion_api, + "boost": insert_stmt.excluded.boost, + "hidden": insert_stmt.excluded.hidden, + "semantic_id": insert_stmt.excluded.semantic_id, + "link": insert_stmt.excluded.link, + "primary_owners": insert_stmt.excluded.primary_owners, + "secondary_owners": insert_stmt.excluded.secondary_owners, + }, + ) db_session.execute(on_conflict_stmt) db_session.commit() diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index a8c1e4ebb..0ba6c4e9a 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import Session from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids from danswer.db.connector_credential_pair import get_connector_credential_pairs +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Document @@ -180,7 +181,7 @@ def _check_if_cc_pairs_are_owned_by_groups( ids=missing_cc_pair_ids, ) for cc_pair in cc_pairs: - if not cc_pair.is_public: + if cc_pair.access_type != AccessType.PUBLIC: raise ValueError( f"Connector Credential Pair with ID: '{cc_pair.id}'" " is not owned by the specified groups" @@ -704,7 +705,7 @@ def check_document_sets_are_public( ConnectorCredentialPair.id.in_( connector_credential_pair_ids # type:ignore ), - ConnectorCredentialPair.is_public.is_(False), + ConnectorCredentialPair.access_type != AccessType.PUBLIC, ) .limit(1) .first() diff --git a/backend/danswer/db/enums.py b/backend/danswer/db/enums.py index eac048e10..8d9515d38 100644 --- a/backend/danswer/db/enums.py +++ b/backend/danswer/db/enums.py @@ -51,3 +51,9 @@ class ConnectorCredentialPairStatus(str, PyEnum): def is_active(self) -> bool: return self == ConnectorCredentialPairStatus.ACTIVE + + +class AccessType(str, PyEnum): + PUBLIC = "public" + PRIVATE = "private" + SYNC = "sync" diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py index 6df1f1f50..219e24747 100644 --- a/backend/danswer/db/feedback.py +++ b/backend/danswer/db/feedback.py @@ -16,6 +16,7 @@ from sqlalchemy.orm import Session from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.db.chat import get_chat_message +from danswer.db.enums import AccessType from danswer.db.models import ChatMessageFeedback from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Document as DbDocument @@ -94,7 +95,7 @@ def _add_user_filters( .correlate(CCPair) ) else: - where_clause |= CCPair.is_public == True # noqa: E712 + where_clause |= CCPair.access_type == AccessType.PUBLIC return stmt.where(where_clause) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 33d8d8bc0..fd2d1344a 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -39,6 +39,7 @@ from danswer.configs.constants import DEFAULT_BOOST from danswer.configs.constants import DocumentSource from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType +from danswer.db.enums import AccessType from danswer.configs.constants import NotificationType from danswer.configs.constants import SearchFeedbackType from danswer.configs.constants import TokenRateLimitScope @@ -388,10 +389,20 @@ class ConnectorCredentialPair(Base): # controls whether the documents indexed by this CC pair are visible to all # or if they are only visible to those with that are given explicit access # (e.g. via owning the credential or being a part of a group that is given access) - is_public: Mapped[bool] = mapped_column( - Boolean, - default=True, - nullable=False, + access_type: Mapped[AccessType] = mapped_column( + Enum(AccessType, native_enum=False), nullable=False + ) + + # special info needed for the auto-sync feature. The exact structure depends on the + + # source type (defined in the connector's `source` field) + # E.g. for google_drive perm sync: + # {"customer_id": "123567", "company_domain": "@danswer.ai"} + auto_sync_options: Mapped[dict[str, Any] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) + last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True ) # Time finished, not used for calculating backend jobs which uses time started (created) last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column( @@ -422,6 +433,7 @@ class ConnectorCredentialPair(Base): class Document(Base): __tablename__ = "document" + # NOTE: if more sensitive data is added here for display, make sure to add user/group permission # this should correspond to the ID of the document # (as is passed around in Danswer) @@ -465,7 +477,18 @@ class Document(Base): secondary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) - # TODO if more sensitive data is added here for display, make sure to add user/group permission + # Permission sync columns + # Email addresses are saved at the document level for externally synced permissions + # This is becuase the normal flow of assigning permissions is through the cc_pair + # doesn't apply here + external_user_emails: Mapped[list[str] | None] = mapped_column( + postgresql.ARRAY(String), nullable=True + ) + # These group ids have been prefixed by the source type + external_user_group_ids: Mapped[list[str] | None] = mapped_column( + postgresql.ARRAY(String), nullable=True + ) + is_public: Mapped[bool] = mapped_column(Boolean, default=False) retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship( "DocumentRetrievalFeedback", back_populates="document" @@ -1674,95 +1697,18 @@ class StandardAnswer(Base): """Tables related to Permission Sync""" -class PermissionSyncStatus(str, PyEnum): - IN_PROGRESS = "in_progress" - SUCCESS = "success" - FAILED = "failed" - - -class PermissionSyncJobType(str, PyEnum): - USER_LEVEL = "user_level" - GROUP_LEVEL = "group_level" - - -class PermissionSyncRun(Base): - """Represents one run of a permission sync job. For some given cc_pair, it is either sync-ing - the users or it is sync-ing the groups""" - - __tablename__ = "permission_sync_run" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - # Not strictly needed but makes it easy to use without fetching from cc_pair - source_type: Mapped[DocumentSource] = mapped_column( - Enum(DocumentSource, native_enum=False) - ) - # Currently all sync jobs are handled as a group permission sync or a user permission sync - update_type: Mapped[PermissionSyncJobType] = mapped_column( - Enum(PermissionSyncJobType) - ) - cc_pair_id: Mapped[int | None] = mapped_column( - ForeignKey("connector_credential_pair.id"), nullable=True - ) - status: Mapped[PermissionSyncStatus] = mapped_column(Enum(PermissionSyncStatus)) - error_msg: Mapped[str | None] = mapped_column(Text, default=None) - updated_at: Mapped[datetime.datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), onupdate=func.now() - ) - - cc_pair: Mapped[ConnectorCredentialPair] = relationship("ConnectorCredentialPair") - - -class ExternalPermission(Base): +class User__ExternalUserGroupId(Base): """Maps user info both internal and external to the name of the external group This maps the user to all of their external groups so that the external group name can be attached to the ACL list matching during query time. User level permissions can be handled by directly adding the Danswer user to the doc ACL list""" - __tablename__ = "external_permission" + __tablename__ = "user__external_user_group_id" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("user.id", ondelete="CASCADE"), nullable=True - ) - # Email is needed because we want to keep track of users not in Danswer to simplify process - # when the user joins - user_email: Mapped[str] = mapped_column(String) - source_type: Mapped[DocumentSource] = mapped_column( - Enum(DocumentSource, native_enum=False) - ) - external_permission_group: Mapped[str] = mapped_column(String) - user = relationship("User") - - -class EmailToExternalUserCache(Base): - """A way to map users IDs in the external tool to a user in Danswer or at least an email for - when the user joins. Used as a cache for when fetching external groups which have their own - user ids, this can easily be mapped back to users already known in Danswer without needing - to call external APIs to get the user emails. - - This way when groups are updated in the external tool and we need to update the mapping of - internal users to the groups, we can sync the internal users to the external groups they are - part of using this. - - Ie. User Chris is part of groups alpha, beta, and we can update this if Chris is no longer - part of alpha in some external tool. - """ - - __tablename__ = "email_to_external_user_cache" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - external_user_id: Mapped[str] = mapped_column(String) - user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("user.id", ondelete="CASCADE"), nullable=True - ) - # Email is needed because we want to keep track of users not in Danswer to simplify process - # when the user joins - user_email: Mapped[str] = mapped_column(String) - source_type: Mapped[DocumentSource] = mapped_column( - Enum(DocumentSource, native_enum=False) - ) - - user = relationship("User") + user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + # These group ids have been prefixed by the source type + external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True) + cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id")) class UsageReport(Base): diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index 61ba6e475..a6319481b 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -22,6 +22,17 @@ def list_users( return db_session.scalars(stmt).unique().all() +def get_users_by_emails( + db_session: Session, emails: list[str] +) -> tuple[list[User], list[str]]: + # Use distinct to avoid duplicates + stmt = select(User).filter(User.email.in_(emails)) # type: ignore + found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list + found_users_emails = [user.email for user in found_users] + missing_user_emails = [email for email in emails if email not in found_users_emails] + return found_users, missing_user_emails + + def get_user_by_email(email: str, db_session: Session) -> User | None: user = db_session.query(User).filter(User.email == email).first() # type: ignore @@ -34,20 +45,50 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None: return user -def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User: - user = get_user_by_email(email, db_session) - if user is not None: - return user - +def _generate_non_web_user(email: str) -> User: fastapi_users_pw_helper = PasswordHelper() password = fastapi_users_pw_helper.generate() hashed_pass = fastapi_users_pw_helper.hash(password) - user = User( + return User( email=email, hashed_password=hashed_pass, has_web_login=False, role=UserRole.BASIC, ) + + +def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User: + user = get_user_by_email(email, db_session) + if user is not None: + return user + + user = _generate_non_web_user(email=email) db_session.add(user) db_session.commit() return user + + +def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User: + user = get_user_by_email(email, db_session) + if user is not None: + return user + + user = _generate_non_web_user(email=email) + db_session.add(user) + db_session.flush() # generate id + return user + + +def batch_add_non_web_user_if_not_exists__no_commit( + db_session: Session, emails: list[str] +) -> list[User]: + found_users, missing_user_emails = get_users_by_emails(db_session, emails) + + new_users: list[User] = [] + for email in missing_user_emails: + new_users.append(_generate_non_web_user(email=email)) + + db_session.add_all(new_users) + db_session.flush() # generate ids + + return found_users + new_users diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 51cd23e74..5d7412ea9 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -265,7 +265,13 @@ def index_doc_batch( Note that the documents should already be batched at this point so that it does not inflate the memory requirements""" - no_access = DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) + no_access = DocumentAccess.build( + user_emails=[], + user_groups=[], + external_user_emails=[], + external_user_group_ids=[], + is_public=False, + ) ctx = index_doc_batch_prepare( document_batch=document_batch, diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 876886ca2..b06aacc1e 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -18,6 +18,7 @@ from danswer.db.connector_credential_pair import ( ) from danswer.db.document import get_document_counts_for_cc_pairs from danswer.db.engine import get_session +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair from danswer.db.index_attempt import cancel_indexing_attempts_past_model @@ -201,7 +202,7 @@ def associate_credential_to_connector( db_session=db_session, user=user, target_group_ids=metadata.groups, - object_is_public=metadata.is_public, + object_is_public=metadata.access_type == AccessType.PUBLIC, ) try: @@ -211,7 +212,8 @@ def associate_credential_to_connector( connector_id=connector_id, credential_id=credential_id, cc_pair_name=metadata.name, - is_public=True if metadata.is_public is None else metadata.is_public, + access_type=metadata.access_type, + auto_sync_options=metadata.auto_sync_options, groups=metadata.groups, ) diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 73e28b8fb..58dcf7e76 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -64,6 +64,7 @@ from danswer.db.credentials import fetch_credential_by_id from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.document import get_document_counts_for_cc_pairs from danswer.db.engine import get_session +from danswer.db.enums import AccessType from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempts_for_cc_pair from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id @@ -559,7 +560,7 @@ def get_connector_indexing_status( cc_pair_status=cc_pair.status, connector=ConnectorSnapshot.from_connector_db_model(connector), credential=CredentialSnapshot.from_credential_db_model(credential), - public_doc=cc_pair.is_public, + access_type=cc_pair.access_type, owner=credential.user.email if credential.user else "", groups=group_cc_pair_relationships_dict.get(cc_pair.id, []), last_finished_status=( @@ -668,12 +669,15 @@ def create_connector_with_mock_credential( credential = create_credential( mock_credential, user=user, db_session=db_session ) + access_type = ( + AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE + ) response = add_credential_to_connector( db_session=db_session, user=user, connector_id=cast(int, connector_response.id), # will aways be an int credential_id=credential.id, - is_public=connector_data.is_public or False, + access_type=access_type, cc_pair_name=connector_data.name, groups=connector_data.groups, ) diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index 517813892..b4052303b 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -10,6 +10,7 @@ from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.constants import DocumentSource from danswer.connectors.models import DocumentErrorSummary from danswer.connectors.models import InputType +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import Connector from danswer.db.models import ConnectorCredentialPair @@ -218,7 +219,7 @@ class CCPairFullInfo(BaseModel): number_of_index_attempts: int last_index_attempt_status: IndexingStatus | None latest_deletion_attempt: DeletionAttemptSnapshot | None - is_public: bool + access_type: AccessType is_editable_for_current_user: bool deletion_failure_message: str | None @@ -261,7 +262,7 @@ class CCPairFullInfo(BaseModel): number_of_index_attempts=number_of_index_attempts, last_index_attempt_status=last_indexing_status, latest_deletion_attempt=latest_deletion_attempt, - is_public=cc_pair_model.is_public, + access_type=cc_pair_model.access_type, is_editable_for_current_user=is_editable_for_current_user, deletion_failure_message=cc_pair_model.deletion_failure_message, ) @@ -288,7 +289,7 @@ class ConnectorIndexingStatus(BaseModel): credential: CredentialSnapshot owner: str groups: list[int] - public_doc: bool + access_type: AccessType last_finished_status: IndexingStatus | None last_status: IndexingStatus | None last_success: datetime | None @@ -306,7 +307,8 @@ class ConnectorCredentialPairIdentifier(BaseModel): class ConnectorCredentialPairMetadata(BaseModel): name: str | None = None - is_public: bool | None = None + access_type: AccessType + auto_sync_options: dict[str, Any] | None = None groups: list[int] = Field(default_factory=list) diff --git a/backend/danswer/server/features/document_set/models.py b/backend/danswer/server/features/document_set/models.py index 55f337654..740cb6906 100644 --- a/backend/danswer/server/features/document_set/models.py +++ b/backend/danswer/server/features/document_set/models.py @@ -47,7 +47,6 @@ class DocumentSet(BaseModel): description: str cc_pair_descriptors: list[ConnectorCredentialPairDescriptor] is_up_to_date: bool - contains_non_public: bool is_public: bool # For Private Document Sets, who should be able to access these users: list[UUID] @@ -59,12 +58,6 @@ class DocumentSet(BaseModel): id=document_set_model.id, name=document_set_model.name, description=document_set_model.description, - contains_non_public=any( - [ - not cc_pair.is_public - for cc_pair in document_set_model.connector_credential_pairs - ] - ), cc_pair_descriptors=[ ConnectorCredentialPairDescriptor( id=cc_pair.id, diff --git a/backend/danswer/server/manage/embedding/models.py b/backend/danswer/server/manage/embedding/models.py index b4ca7862b..d6210118d 100644 --- a/backend/danswer/server/manage/embedding/models.py +++ b/backend/danswer/server/manage/embedding/models.py @@ -18,6 +18,9 @@ class TestEmbeddingRequest(BaseModel): api_url: str | None = None model_name: str | None = None + # This disables the "model_" protected namespace for pydantic + model_config = {"protected_namespaces": ()} + class CloudEmbeddingProvider(BaseModel): provider_type: EmbeddingProvider diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 46beee856..04d2c1244 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -48,6 +48,7 @@ from danswer.server.models import InvitedUserSnapshot from danswer.server.models import MinimalUserSnapshot from danswer.utils.logger import setup_logger from ee.danswer.db.api_key import is_api_key_email_address +from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit from ee.danswer.db.user_group import remove_curator_status__no_commit logger = setup_logger() @@ -243,6 +244,11 @@ async def delete_user( for oauth_account in user_to_delete.oauth_accounts: db_session.delete(oauth_account) + delete_user__ext_group_for_user__no_commit( + db_session=db_session, + user_id=user_to_delete.id, + ) + db_session.query(DocumentSet__User).filter( DocumentSet__User.user_id == user_to_delete.id ).delete() diff --git a/backend/ee/danswer/access/access.py b/backend/ee/danswer/access/access.py index 2b3cdb7a9..094298677 100644 --- a/backend/ee/danswer/access/access.py +++ b/backend/ee/danswer/access/access.py @@ -5,8 +5,11 @@ from danswer.access.access import ( ) from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups from danswer.access.models import DocumentAccess +from danswer.access.utils import prefix_external_group from danswer.access.utils import prefix_user_group +from danswer.db.document import get_documents_by_ids from danswer.db.models import User +from ee.danswer.db.external_perm import fetch_external_groups_for_user from ee.danswer.db.user_group import fetch_user_groups_for_documents from ee.danswer.db.user_group import fetch_user_groups_for_user @@ -17,7 +20,13 @@ def _get_access_for_document( ) -> DocumentAccess: id_to_access = _get_access_for_documents([document_id], db_session) if len(id_to_access) == 0: - return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) + return DocumentAccess.build( + user_emails=[], + user_groups=[], + external_user_emails=[], + external_user_group_ids=[], + is_public=False, + ) return next(iter(id_to_access.values())) @@ -30,22 +39,48 @@ def _get_access_for_documents( document_ids=document_ids, db_session=db_session, ) - user_group_info = { + user_group_info: dict[str, list[str]] = { document_id: group_names for document_id, group_names in fetch_user_groups_for_documents( db_session=db_session, document_ids=document_ids, ) } + documents = get_documents_by_ids( + db_session=db_session, + document_ids=document_ids, + ) + doc_id_map = {doc.id: doc for doc in documents} - return { - document_id: DocumentAccess( - user_ids=non_ee_access.user_ids, - user_groups=user_group_info.get(document_id, []), # type: ignore - is_public=non_ee_access.is_public, + access_map = {} + for document_id, non_ee_access in non_ee_access_dict.items(): + document = doc_id_map[document_id] + + ext_u_emails = ( + set(document.external_user_emails) + if document.external_user_emails + else set() ) - for document_id, non_ee_access in non_ee_access_dict.items() - } + + ext_u_groups = ( + set(document.external_user_group_ids) + if document.external_user_group_ids + else set() + ) + + # If the document is determined to be "public" externally (through a SYNC connector) + # then it's given the same access level as if it were marked public within Danswer + is_public_anywhere = document.is_public or non_ee_access.is_public + + # To avoid collisions of group namings between connectors, they need to be prefixed + access_map[document_id] = DocumentAccess( + user_emails=non_ee_access.user_emails, + user_groups=set(user_group_info.get(document_id, [])), + is_public=is_public_anywhere, + external_user_emails=ext_u_emails, + external_user_group_ids=ext_u_groups, + ) + return access_map def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: @@ -56,7 +91,20 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: NOTE: is imported in danswer.access.access by `fetch_versioned_implementation` DO NOT REMOVE.""" - user_groups = fetch_user_groups_for_user(db_session, user.id) if user else [] - return set( - [prefix_user_group(user_group.name) for user_group in user_groups] - ).union(get_acl_for_user_without_groups(user, db_session)) + db_user_groups = fetch_user_groups_for_user(db_session, user.id) if user else [] + prefixed_user_groups = [ + prefix_user_group(db_user_group.name) for db_user_group in db_user_groups + ] + + db_external_groups = ( + fetch_external_groups_for_user(db_session, user.id) if user else [] + ) + prefixed_external_groups = [ + prefix_external_group(db_external_group.external_user_group_id) + for db_external_group in db_external_groups + ] + + user_acl = set(prefixed_user_groups + prefixed_external_groups) + user_acl.update(get_acl_for_user_without_groups(user, db_session)) + + return user_acl diff --git a/backend/ee/danswer/background/celery/celery_app.py b/backend/ee/danswer/background/celery/celery_app.py index 2b4c96ccb..2b3475bee 100644 --- a/backend/ee/danswer/background/celery/celery_app.py +++ b/backend/ee/danswer/background/celery/celery_app.py @@ -11,7 +11,13 @@ from danswer.server.settings.store import load_settings from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version from ee.danswer.background.celery_utils import should_perform_chat_ttl_check +from ee.danswer.background.celery_utils import should_perform_external_permissions_check from ee.danswer.background.task_name_builders import name_chat_ttl_task +from ee.danswer.background.task_name_builders import name_sync_external_permissions_task +from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs +from ee.danswer.external_permissions.permission_sync import ( + run_permission_sync_entrypoint, +) from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report logger = setup_logger() @@ -20,6 +26,13 @@ logger = setup_logger() global_version.set_ee() +@build_celery_task_wrapper(name_sync_external_permissions_task) +@celery_app.task(soft_time_limit=JOB_TIMEOUT) +def sync_external_permissions_task(cc_pair_id: int) -> None: + with Session(get_sqlalchemy_engine()) as db_session: + run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id) + + @build_celery_task_wrapper(name_chat_ttl_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def perform_ttl_management_task(retention_limit_days: int) -> None: @@ -30,6 +43,23 @@ def perform_ttl_management_task(retention_limit_days: int) -> None: ##### # Periodic Tasks ##### +@celery_app.task( + name="check_sync_external_permissions_task", + soft_time_limit=JOB_TIMEOUT, +) +def check_sync_external_permissions_task() -> None: + """Runs periodically to sync external permissions""" + with Session(get_sqlalchemy_engine()) as db_session: + cc_pairs = get_all_auto_sync_cc_pairs(db_session) + for cc_pair in cc_pairs: + if should_perform_external_permissions_check( + cc_pair=cc_pair, db_session=db_session + ): + sync_external_permissions_task.apply_async( + kwargs=dict(cc_pair_id=cc_pair.id), + ) + + @celery_app.task( name="check_ttl_management_task", soft_time_limit=JOB_TIMEOUT, @@ -64,7 +94,11 @@ def autogenerate_usage_report_task() -> None: # Celery Beat (Periodic Tasks) Settings ##### celery_app.conf.beat_schedule = { - "autogenerate-usage-report": { + "sync-external-permissions": { + "task": "check_sync_external_permissions_task", + "schedule": timedelta(seconds=60), # TODO: optimize this + }, + "autogenerate_usage_report": { "task": "autogenerate_usage_report_task", "schedule": timedelta(days=30), # TODO: change this to config flag }, diff --git a/backend/ee/danswer/background/celery_utils.py b/backend/ee/danswer/background/celery_utils.py index 879487180..bd1c5d874 100644 --- a/backend/ee/danswer/background/celery_utils.py +++ b/backend/ee/danswer/background/celery_utils.py @@ -6,10 +6,13 @@ from sqlalchemy.orm import Session from danswer.background.celery.celery_app import task_logger from danswer.background.celery.celery_redis import RedisUserGroup from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.enums import AccessType +from danswer.db.models import ConnectorCredentialPair from danswer.db.tasks import check_task_is_live_and_not_timed_out from danswer.db.tasks import get_latest_task from danswer.utils.logger import setup_logger from ee.danswer.background.task_name_builders import name_chat_ttl_task +from ee.danswer.background.task_name_builders import name_sync_external_permissions_task from ee.danswer.db.user_group import delete_user_group from ee.danswer.db.user_group import fetch_user_group from ee.danswer.db.user_group import mark_user_group_as_synced @@ -30,11 +33,30 @@ def should_perform_chat_ttl_check( return True if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session): - logger.info("TTL check is already being performed. Skipping.") + logger.debug(f"{task_name} is already being performed. Skipping.") return False return True +def should_perform_external_permissions_check( + cc_pair: ConnectorCredentialPair, db_session: Session +) -> bool: + if cc_pair.access_type != AccessType.SYNC: + return False + + task_name = name_sync_external_permissions_task(cc_pair_id=cc_pair.id) + + latest_task = get_latest_task(task_name, db_session) + if not latest_task: + return True + + if check_task_is_live_and_not_timed_out(latest_task, db_session): + logger.debug(f"{task_name} is already being performed. Skipping.") + return False + + return True + + def monitor_usergroup_taskset(key_bytes: bytes, r: Redis) -> None: """This function is likely to move in the worker refactor happening next.""" key = key_bytes.decode("utf-8") diff --git a/backend/ee/danswer/background/permission_sync.py b/backend/ee/danswer/background/permission_sync.py deleted file mode 100644 index c14094b60..000000000 --- a/backend/ee/danswer/background/permission_sync.py +++ /dev/null @@ -1,224 +0,0 @@ -import logging -import time -from datetime import datetime - -import dask -from dask.distributed import Client -from dask.distributed import Future -from distributed import LocalCluster -from sqlalchemy.orm import Session - -from danswer.background.indexing.dask_utils import ResourceLogger -from danswer.background.indexing.job_client import SimpleJob -from danswer.background.indexing.job_client import SimpleJobClient -from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT -from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED -from danswer.configs.constants import DocumentSource -from danswer.configs.constants import POSTGRES_PERMISSIONS_APP_NAME -from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.engine import init_sqlalchemy_engine -from danswer.db.models import PermissionSyncStatus -from danswer.utils.logger import setup_logger -from ee.danswer.configs.app_configs import NUM_PERMISSION_WORKERS -from ee.danswer.connectors.factory import CONNECTOR_PERMISSION_FUNC_MAP -from ee.danswer.db.connector import fetch_sources_with_connectors -from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source -from ee.danswer.db.permission_sync import create_perm_sync -from ee.danswer.db.permission_sync import expire_perm_sync_timed_out -from ee.danswer.db.permission_sync import get_perm_sync_attempt -from ee.danswer.db.permission_sync import mark_all_inprogress_permission_sync_failed -from shared_configs.configs import LOG_LEVEL - -logger = setup_logger() - -# If the indexing dies, it's most likely due to resource constraints, -# restarting just delays the eventual failure, not useful to the user -dask.config.set({"distributed.scheduler.allowed-failures": 0}) - - -def cleanup_perm_sync_jobs( - existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], - # Just reusing the same timeout, fine for now - timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT, -) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: - existing_jobs_copy = existing_jobs.copy() - - with Session(get_sqlalchemy_engine()) as db_session: - # clean up completed jobs - for (attempt_id, details), job in existing_jobs.items(): - perm_sync_attempt = get_perm_sync_attempt( - attempt_id=attempt_id, db_session=db_session - ) - - # do nothing for ongoing jobs that haven't been stopped - if ( - not job.done() - and perm_sync_attempt.status == PermissionSyncStatus.IN_PROGRESS - ): - continue - - if job.status == "error": - logger.error(job.exception()) - - job.release() - del existing_jobs_copy[(attempt_id, details)] - - # clean up in-progress jobs that were never completed - expire_perm_sync_timed_out( - timeout_hours=timeout_hours, - db_session=db_session, - ) - - return existing_jobs_copy - - -def create_group_sync_jobs( - existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], - client: Client | SimpleJobClient, -) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: - """Creates new relational DB group permission sync job for each source that: - - has permission sync enabled - - has at least 1 connector (enabled or paused) - - has no sync already running - """ - existing_jobs_copy = existing_jobs.copy() - sources_w_runs = [ - key[1] - for key in existing_jobs_copy.keys() - if isinstance(key[1], DocumentSource) - ] - with Session(get_sqlalchemy_engine()) as db_session: - sources_w_connector = fetch_sources_with_connectors(db_session) - for source_type in sources_w_connector: - if source_type not in CONNECTOR_PERMISSION_FUNC_MAP: - continue - if source_type in sources_w_runs: - continue - - db_group_fnc, _ = CONNECTOR_PERMISSION_FUNC_MAP[source_type] - perm_sync = create_perm_sync( - source_type=source_type, - group_update=True, - cc_pair_id=None, - db_session=db_session, - ) - - run = client.submit(db_group_fnc, pure=False) - - logger.info( - f"Kicked off group permission sync for source type {source_type}" - ) - - if run: - existing_jobs_copy[(perm_sync.id, source_type)] = run - - return existing_jobs_copy - - -def create_connector_perm_sync_jobs( - existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], - client: Client | SimpleJobClient, -) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: - """Update Document Index ACL sync job for each cc-pair where: - - source type has permission sync enabled - - has no sync already running - """ - existing_jobs_copy = existing_jobs.copy() - cc_pairs_w_runs = [ - key[1] - for key in existing_jobs_copy.keys() - if isinstance(key[1], DocumentSource) - ] - with Session(get_sqlalchemy_engine()) as db_session: - sources_w_connector = fetch_sources_with_connectors(db_session) - for source_type in sources_w_connector: - if source_type not in CONNECTOR_PERMISSION_FUNC_MAP: - continue - - _, index_sync_fnc = CONNECTOR_PERMISSION_FUNC_MAP[source_type] - - cc_pairs = get_cc_pairs_by_source(source_type, db_session) - - for cc_pair in cc_pairs: - if cc_pair.id in cc_pairs_w_runs: - continue - - perm_sync = create_perm_sync( - source_type=source_type, - group_update=False, - cc_pair_id=cc_pair.id, - db_session=db_session, - ) - - run = client.submit(index_sync_fnc, cc_pair.id, pure=False) - - logger.info(f"Kicked off ACL sync for cc-pair {cc_pair.id}") - - if run: - existing_jobs_copy[(perm_sync.id, cc_pair.id)] = run - - return existing_jobs_copy - - -def permission_loop(delay: int = 60, num_workers: int = NUM_PERMISSION_WORKERS) -> None: - client: Client | SimpleJobClient - if DASK_JOB_CLIENT_ENABLED: - cluster_primary = LocalCluster( - n_workers=num_workers, - threads_per_worker=1, - # there are warning about high memory usage + "Event loop unresponsive" - # which are not relevant to us since our workers are expected to use a - # lot of memory + involve CPU intensive tasks that will not relinquish - # the event loop - silence_logs=logging.ERROR, - ) - client = Client(cluster_primary) - if LOG_LEVEL.lower() == "debug": - client.register_worker_plugin(ResourceLogger()) - else: - client = SimpleJobClient(n_workers=num_workers) - - existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob] = {} - engine = get_sqlalchemy_engine() - - with Session(engine) as db_session: - # Any jobs still in progress on restart must have died - mark_all_inprogress_permission_sync_failed(db_session) - - while True: - start = time.time() - start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S") - logger.info(f"Running Permission Sync, current UTC time: {start_time_utc}") - - if existing_jobs: - logger.debug( - "Found existing permission sync jobs: " - f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}" - ) - - try: - # TODO turn this on when it works - """ - existing_jobs = cleanup_perm_sync_jobs(existing_jobs=existing_jobs) - existing_jobs = create_group_sync_jobs( - existing_jobs=existing_jobs, client=client - ) - existing_jobs = create_connector_perm_sync_jobs( - existing_jobs=existing_jobs, client=client - ) - """ - except Exception as e: - logger.exception(f"Failed to run update due to {e}") - sleep_time = delay - (time.time() - start) - if sleep_time > 0: - time.sleep(sleep_time) - - -def update__main() -> None: - logger.notice("Starting Permission Syncing Loop") - init_sqlalchemy_engine(POSTGRES_PERMISSIONS_APP_NAME) - permission_loop() - - -if __name__ == "__main__": - update__main() diff --git a/backend/ee/danswer/background/task_name_builders.py b/backend/ee/danswer/background/task_name_builders.py index 4f1046adb..75a5aa36b 100644 --- a/backend/ee/danswer/background/task_name_builders.py +++ b/backend/ee/danswer/background/task_name_builders.py @@ -1,6 +1,6 @@ -def name_user_group_sync_task(user_group_id: int) -> str: - return f"user_group_sync_task__{user_group_id}" - - def name_chat_ttl_task(retention_limit_days: int) -> str: return f"chat_ttl_{retention_limit_days}_days" + + +def name_sync_external_permissions_task(cc_pair_id: int) -> str: + return f"sync_external_permissions_task__{cc_pair_id}" diff --git a/backend/ee/danswer/connectors/confluence/perm_sync.py b/backend/ee/danswer/connectors/confluence/perm_sync.py deleted file mode 100644 index 2985b47b0..000000000 --- a/backend/ee/danswer/connectors/confluence/perm_sync.py +++ /dev/null @@ -1,12 +0,0 @@ -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -def confluence_update_db_group() -> None: - logger.debug("Not yet implemented group sync for confluence, no-op") - - -def confluence_update_index_acl(cc_pair_id: int) -> None: - logger.debug("Not yet implemented ACL sync for confluence, no-op") diff --git a/backend/ee/danswer/connectors/factory.py b/backend/ee/danswer/connectors/factory.py deleted file mode 100644 index 52f932494..000000000 --- a/backend/ee/danswer/connectors/factory.py +++ /dev/null @@ -1,8 +0,0 @@ -from danswer.configs.constants import DocumentSource -from ee.danswer.connectors.confluence.perm_sync import confluence_update_db_group -from ee.danswer.connectors.confluence.perm_sync import confluence_update_index_acl - - -CONNECTOR_PERMISSION_FUNC_MAP = { - DocumentSource.CONFLUENCE: (confluence_update_db_group, confluence_update_index_acl) -} diff --git a/backend/ee/danswer/db/connector_credential_pair.py b/backend/ee/danswer/db/connector_credential_pair.py index a21729134..bb91c0de7 100644 --- a/backend/ee/danswer/db/connector_credential_pair.py +++ b/backend/ee/danswer/db/connector_credential_pair.py @@ -3,6 +3,7 @@ from sqlalchemy.orm import Session from danswer.configs.constants import DocumentSource from danswer.db.connector_credential_pair import get_connector_credential_pair +from danswer.db.enums import AccessType from danswer.db.models import Connector from danswer.db.models import ConnectorCredentialPair from danswer.db.models import UserGroup__ConnectorCredentialPair @@ -32,14 +33,30 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit( def get_cc_pairs_by_source( - source_type: DocumentSource, db_session: Session, + source_type: DocumentSource, + only_sync: bool, ) -> list[ConnectorCredentialPair]: - cc_pairs = ( + query = ( db_session.query(ConnectorCredentialPair) .join(ConnectorCredentialPair.connector) .filter(Connector.source == source_type) - .all() ) + if only_sync: + query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC) + + cc_pairs = query.all() return cc_pairs + + +def get_all_auto_sync_cc_pairs( + db_session: Session, +) -> list[ConnectorCredentialPair]: + return ( + db_session.query(ConnectorCredentialPair) + .where( + ConnectorCredentialPair.access_type == AccessType.SYNC, + ) + .all() + ) diff --git a/backend/ee/danswer/db/document.py b/backend/ee/danswer/db/document.py index 5a368ea17..d67bc0e57 100644 --- a/backend/ee/danswer/db/document.py +++ b/backend/ee/danswer/db/document.py @@ -1,14 +1,47 @@ -from collections.abc import Sequence - from sqlalchemy import select from sqlalchemy.orm import Session -from danswer.db.models import Document +from danswer.access.models import ExternalAccess +from danswer.access.utils import prefix_group_w_source +from danswer.configs.constants import DocumentSource +from danswer.db.models import Document as DbDocument -def fetch_documents_from_ids( - db_session: Session, document_ids: list[str] -) -> Sequence[Document]: - return db_session.scalars( - select(Document).where(Document.id.in_(document_ids)) - ).all() +def upsert_document_external_perms__no_commit( + db_session: Session, + doc_id: str, + external_access: ExternalAccess, + source_type: DocumentSource, +) -> None: + """ + This sets the permissions for a document in postgres. + NOTE: this will replace any existing external access, it will not do a union + """ + document = db_session.scalars( + select(DbDocument).where(DbDocument.id == doc_id) + ).first() + + prefixed_external_groups = [ + prefix_group_w_source( + ext_group_name=group_id, + source=source_type, + ) + for group_id in external_access.external_user_group_ids + ] + + if not document: + # If the document does not exist, still store the external access + # So that if the document is added later, the external access is already stored + document = DbDocument( + id=doc_id, + semantic_id="", + external_user_emails=external_access.external_user_emails, + external_user_group_ids=prefixed_external_groups, + is_public=external_access.is_public, + ) + db_session.add(document) + return + + document.external_user_emails = list(external_access.external_user_emails) + document.external_user_group_ids = prefixed_external_groups + document.is_public = external_access.is_public diff --git a/backend/ee/danswer/db/external_perm.py b/backend/ee/danswer/db/external_perm.py new file mode 100644 index 000000000..25881df55 --- /dev/null +++ b/backend/ee/danswer/db/external_perm.py @@ -0,0 +1,77 @@ +from collections.abc import Sequence +from uuid import UUID + +from pydantic import BaseModel +from sqlalchemy import delete +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.access.utils import prefix_group_w_source +from danswer.configs.constants import DocumentSource +from danswer.db.models import User__ExternalUserGroupId + + +class ExternalUserGroup(BaseModel): + id: str + user_ids: list[UUID] + + +def delete_user__ext_group_for_user__no_commit( + db_session: Session, + user_id: UUID, +) -> None: + db_session.execute( + delete(User__ExternalUserGroupId).where( + User__ExternalUserGroupId.user_id == user_id + ) + ) + + +def delete_user__ext_group_for_cc_pair__no_commit( + db_session: Session, + cc_pair_id: int, +) -> None: + db_session.execute( + delete(User__ExternalUserGroupId).where( + User__ExternalUserGroupId.cc_pair_id == cc_pair_id + ) + ) + + +def replace_user__ext_group_for_cc_pair__no_commit( + db_session: Session, + cc_pair_id: int, + group_defs: list[ExternalUserGroup], + source: DocumentSource, +) -> None: + """ + This function clears all existing external user group relations for a given cc_pair_id + and replaces them with the new group definitions. + """ + delete_user__ext_group_for_cc_pair__no_commit( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) + + new_external_permissions = [ + User__ExternalUserGroupId( + user_id=user_id, + external_user_group_id=prefix_group_w_source(external_group.id, source), + cc_pair_id=cc_pair_id, + ) + for external_group in group_defs + for user_id in external_group.user_ids + ] + + db_session.add_all(new_external_permissions) + + +def fetch_external_groups_for_user( + db_session: Session, + user_id: UUID, +) -> Sequence[User__ExternalUserGroupId]: + return db_session.scalars( + select(User__ExternalUserGroupId).where( + User__ExternalUserGroupId.user_id == user_id + ) + ).all() diff --git a/backend/ee/danswer/db/permission_sync.py b/backend/ee/danswer/db/permission_sync.py deleted file mode 100644 index 7642bb653..000000000 --- a/backend/ee/danswer/db/permission_sync.py +++ /dev/null @@ -1,72 +0,0 @@ -from datetime import timedelta - -from sqlalchemy import func -from sqlalchemy import select -from sqlalchemy import update -from sqlalchemy.exc import NoResultFound -from sqlalchemy.orm import Session - -from danswer.configs.constants import DocumentSource -from danswer.db.models import PermissionSyncRun -from danswer.db.models import PermissionSyncStatus -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -def mark_all_inprogress_permission_sync_failed( - db_session: Session, -) -> None: - stmt = ( - update(PermissionSyncRun) - .where(PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS) - .values(status=PermissionSyncStatus.FAILED) - ) - db_session.execute(stmt) - db_session.commit() - - -def get_perm_sync_attempt(attempt_id: int, db_session: Session) -> PermissionSyncRun: - stmt = select(PermissionSyncRun).where(PermissionSyncRun.id == attempt_id) - try: - return db_session.scalars(stmt).one() - except NoResultFound: - raise ValueError(f"No PermissionSyncRun found with id {attempt_id}") - - -def expire_perm_sync_timed_out( - timeout_hours: int, - db_session: Session, -) -> None: - cutoff_time = func.now() - timedelta(hours=timeout_hours) - - update_stmt = ( - update(PermissionSyncRun) - .where( - PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS, - PermissionSyncRun.updated_at < cutoff_time, - ) - .values(status=PermissionSyncStatus.FAILED, error_msg="timed out") - ) - - db_session.execute(update_stmt) - db_session.commit() - - -def create_perm_sync( - source_type: DocumentSource, - group_update: bool, - cc_pair_id: int | None, - db_session: Session, -) -> PermissionSyncRun: - new_run = PermissionSyncRun( - source_type=source_type, - status=PermissionSyncStatus.IN_PROGRESS, - group_update=group_update, - cc_pair_id=cc_pair_id, - ) - - db_session.add(new_run) - db_session.commit() - - return new_run diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index ab666f747..f62fe19a7 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -199,7 +199,7 @@ def fetch_documents_for_user_group_paginated( def fetch_user_groups_for_documents( db_session: Session, document_ids: list[str], -) -> Sequence[tuple[int, list[str]]]: +) -> Sequence[tuple[str, list[str]]]: stmt = ( select(Document.id, func.array_agg(UserGroup.name)) .join( diff --git a/backend/ee/danswer/connectors/__init__.py b/backend/ee/danswer/external_permissions/__init__.py similarity index 100% rename from backend/ee/danswer/connectors/__init__.py rename to backend/ee/danswer/external_permissions/__init__.py diff --git a/backend/ee/danswer/connectors/confluence/__init__.py b/backend/ee/danswer/external_permissions/confluence/__init__.py similarity index 100% rename from backend/ee/danswer/connectors/confluence/__init__.py rename to backend/ee/danswer/external_permissions/confluence/__init__.py diff --git a/backend/ee/danswer/external_permissions/confluence/doc_sync.py b/backend/ee/danswer/external_permissions/confluence/doc_sync.py new file mode 100644 index 000000000..2044a2e58 --- /dev/null +++ b/backend/ee/danswer/external_permissions/confluence/doc_sync.py @@ -0,0 +1,19 @@ +from typing import Any + +from sqlalchemy.orm import Session + +from danswer.db.models import ConnectorCredentialPair +from danswer.utils.logger import setup_logger +from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo + + +logger = setup_logger() + + +def confluence_doc_sync( + db_session: Session, + cc_pair: ConnectorCredentialPair, + docs_with_additional_info: list[DocsWithAdditionalInfo], + sync_details: dict[str, Any], +) -> None: + logger.debug("Not yet implemented ACL sync for confluence, no-op") diff --git a/backend/ee/danswer/external_permissions/confluence/group_sync.py b/backend/ee/danswer/external_permissions/confluence/group_sync.py new file mode 100644 index 000000000..6e9e6777d --- /dev/null +++ b/backend/ee/danswer/external_permissions/confluence/group_sync.py @@ -0,0 +1,19 @@ +from typing import Any + +from sqlalchemy.orm import Session + +from danswer.db.models import ConnectorCredentialPair +from danswer.utils.logger import setup_logger +from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo + + +logger = setup_logger() + + +def confluence_group_sync( + db_session: Session, + cc_pair: ConnectorCredentialPair, + docs_with_additional_info: list[DocsWithAdditionalInfo], + sync_details: dict[str, Any], +) -> None: + logger.debug("Not yet implemented group sync for confluence, no-op") diff --git a/backend/ee/danswer/external_permissions/google_drive/__init__.py b/backend/ee/danswer/external_permissions/google_drive/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py new file mode 100644 index 000000000..80ffe471e --- /dev/null +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -0,0 +1,148 @@ +from collections.abc import Iterator +from typing import Any +from typing import cast + +from googleapiclient.discovery import build # type: ignore +from googleapiclient.errors import HttpError # type: ignore +from sqlalchemy.orm import Session + +from danswer.access.models import ExternalAccess +from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder +from danswer.connectors.google_drive.connector_auth import ( + get_google_drive_creds, +) +from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES +from danswer.db.models import ConnectorCredentialPair +from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit +from danswer.utils.logger import setup_logger +from ee.danswer.db.document import upsert_document_external_perms__no_commit +from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo + +# Google Drive APIs are quite flakey and may 500 for an +# extended period of time. Trying to combat here by adding a very +# long retry period (~20 minutes of trying every minute) +add_retries = retry_builder(tries=5, delay=5, max_delay=30) + + +logger = setup_logger() + + +def _fetch_permissions_paginated( + drive_service: Any, drive_file_id: str +) -> Iterator[dict[str, Any]]: + next_token = None + + # Check if the file is trashed + # Returning nothing here will cause the external permissions to + # be empty which will get written to vespa (failing shut) + try: + file_metadata = add_retries( + lambda: drive_service.files() + .get(fileId=drive_file_id, fields="id, trashed") + .execute() + )() + except HttpError as e: + if e.resp.status == 404 or e.resp.status == 403: + return + logger.error(f"Failed to fetch permissions: {e}") + raise + + if file_metadata.get("trashed", False): + logger.debug(f"File with ID {drive_file_id} is trashed") + return + + # Get paginated permissions for the file id + while True: + try: + permissions_resp: dict[str, Any] = add_retries( + lambda: ( + drive_service.permissions() + .list( + fileId=drive_file_id, + fields="permissions(id, emailAddress, role, type, domain)", + supportsAllDrives=True, + pageToken=next_token, + ) + .execute() + ) + )() + except HttpError as e: + if e.resp.status == 404 or e.resp.status == 403: + break + logger.error(f"Failed to fetch permissions: {e}") + raise + + for permission in permissions_resp.get("permissions", []): + yield permission + + next_token = permissions_resp.get("nextPageToken") + if not next_token: + break + + +def _fetch_google_permissions_for_document_id( + db_session: Session, + drive_file_id: str, + raw_credentials_json: dict[str, str], + company_google_domains: list[str], +) -> ExternalAccess: + # Authenticate and construct service + google_drive_creds, _ = get_google_drive_creds( + raw_credentials_json, scopes=FETCH_PERMISSIONS_SCOPES + ) + if not google_drive_creds.valid: + raise ValueError("Invalid Google Drive credentials") + + drive_service = build("drive", "v3", credentials=google_drive_creds) + + user_emails: set[str] = set() + group_emails: set[str] = set() + public = False + for permission in _fetch_permissions_paginated(drive_service, drive_file_id): + permission_type = permission["type"] + if permission_type == "user": + user_emails.add(permission["emailAddress"]) + elif permission_type == "group": + group_emails.add(permission["emailAddress"]) + elif permission_type == "domain": + if permission["domain"] in company_google_domains: + public = True + elif permission_type == "anyone": + public = True + + batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails)) + + return ExternalAccess( + external_user_emails=user_emails, + external_user_group_ids=group_emails, + is_public=public, + ) + + +def gdrive_doc_sync( + db_session: Session, + cc_pair: ConnectorCredentialPair, + docs_with_additional_info: list[DocsWithAdditionalInfo], + sync_details: dict[str, Any], +) -> None: + """ + Adds the external permissions to the documents in postgres + if the document doesn't already exists in postgres, we create + it in postgres so that when it gets created later, the permissions are + already populated + """ + for doc in docs_with_additional_info: + ext_access = _fetch_google_permissions_for_document_id( + db_session=db_session, + drive_file_id=doc.additional_info, + raw_credentials_json=cc_pair.credential.credential_json, + company_google_domains=[ + cast(dict[str, str], sync_details)["company_domain"] + ], + ) + upsert_document_external_perms__no_commit( + db_session=db_session, + doc_id=doc.id, + external_access=ext_access, + source_type=cc_pair.connector.source, + ) diff --git a/backend/ee/danswer/external_permissions/google_drive/group_sync.py b/backend/ee/danswer/external_permissions/google_drive/group_sync.py new file mode 100644 index 000000000..a0fe068be --- /dev/null +++ b/backend/ee/danswer/external_permissions/google_drive/group_sync.py @@ -0,0 +1,147 @@ +from collections.abc import Iterator +from typing import Any + +from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore +from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore +from googleapiclient.discovery import build # type: ignore +from googleapiclient.errors import HttpError # type: ignore +from sqlalchemy.orm import Session + +from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder +from danswer.connectors.google_drive.connector_auth import ( + get_google_drive_creds, +) +from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES +from danswer.db.models import ConnectorCredentialPair +from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit +from danswer.utils.logger import setup_logger +from ee.danswer.db.external_perm import ExternalUserGroup +from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit +from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo + +logger = setup_logger() + + +# Google Drive APIs are quite flakey and may 500 for an +# extended period of time. Trying to combat here by adding a very +# long retry period (~20 minutes of trying every minute) +add_retries = retry_builder(tries=5, delay=5, max_delay=30) + + +def _fetch_groups_paginated( + google_drive_creds: ServiceAccountCredentials | OAuthCredentials, + identity_source: str | None = None, + customer_id: str | None = None, +) -> Iterator[dict[str, Any]]: + # Note that Google Drive does not use of update the user_cache as the user email + # comes directly with the call to fetch the groups, therefore this is not a valid + # place to save on requests + if identity_source is None and customer_id is None: + raise ValueError( + "Either identity_source or customer_id must be provided to fetch groups" + ) + + cloud_identity_service = build( + "cloudidentity", "v1", credentials=google_drive_creds + ) + parent = ( + f"identitysources/{identity_source}" + if identity_source + else f"customers/{customer_id}" + ) + + while True: + try: + groups_resp: dict[str, Any] = add_retries( + lambda: (cloud_identity_service.groups().list(parent=parent).execute()) + )() + for group in groups_resp.get("groups", []): + yield group + + next_token = groups_resp.get("nextPageToken") + if not next_token: + break + except HttpError as e: + if e.resp.status == 404 or e.resp.status == 403: + break + logger.error(f"Error fetching groups: {e}") + raise + + +def _fetch_group_members_paginated( + google_drive_creds: ServiceAccountCredentials | OAuthCredentials, + group_name: str, +) -> Iterator[dict[str, Any]]: + cloud_identity_service = build( + "cloudidentity", "v1", credentials=google_drive_creds + ) + next_token = None + while True: + try: + membership_info = add_retries( + lambda: ( + cloud_identity_service.groups() + .memberships() + .searchTransitiveMemberships( + parent=group_name, pageToken=next_token + ) + .execute() + ) + )() + + for member in membership_info.get("memberships", []): + yield member + + next_token = membership_info.get("nextPageToken") + if not next_token: + break + except HttpError as e: + if e.resp.status == 404 or e.resp.status == 403: + break + logger.error(f"Error fetching group members: {e}") + raise + + +def gdrive_group_sync( + db_session: Session, + cc_pair: ConnectorCredentialPair, + docs_with_additional_info: list[DocsWithAdditionalInfo], + sync_details: dict[str, Any], +) -> None: + google_drive_creds, _ = get_google_drive_creds( + cc_pair.credential.credential_json, + scopes=FETCH_GROUPS_SCOPES, + ) + + danswer_groups: list[ExternalUserGroup] = [] + for group in _fetch_groups_paginated( + google_drive_creds, + identity_source=sync_details.get("identity_source"), + customer_id=sync_details.get("customer_id"), + ): + # The id is the group email + group_email = group["groupKey"]["id"] + + group_member_emails: list[str] = [] + for member in _fetch_group_members_paginated(google_drive_creds, group["name"]): + member_keys = member["preferredMemberKey"] + member_emails = [member_key["id"] for member_key in member_keys] + for member_email in member_emails: + group_member_emails.append(member_email) + + group_members = batch_add_non_web_user_if_not_exists__no_commit( + db_session=db_session, emails=group_member_emails + ) + if group_members: + danswer_groups.append( + ExternalUserGroup( + id=group_email, user_ids=[user.id for user in group_members] + ) + ) + + replace_user__ext_group_for_cc_pair__no_commit( + db_session=db_session, + cc_pair_id=cc_pair.id, + group_defs=danswer_groups, + source=cc_pair.connector.source, + ) diff --git a/backend/ee/danswer/external_permissions/permission_sync.py b/backend/ee/danswer/external_permissions/permission_sync.py new file mode 100644 index 000000000..a3829b64e --- /dev/null +++ b/backend/ee/danswer/external_permissions/permission_sync.py @@ -0,0 +1,126 @@ +from datetime import datetime +from datetime import timezone + +from sqlalchemy.orm import Session + +from danswer.access.access import get_access_for_documents +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.search_settings import get_current_search_settings +from danswer.document_index.factory import get_default_document_index +from danswer.document_index.interfaces import UpdateRequest +from danswer.utils.logger import setup_logger +from ee.danswer.external_permissions.permission_sync_function_map import ( + DOC_PERMISSIONS_FUNC_MAP, +) +from ee.danswer.external_permissions.permission_sync_function_map import ( + FULL_FETCH_PERIOD_IN_SECONDS, +) +from ee.danswer.external_permissions.permission_sync_function_map import ( + GROUP_PERMISSIONS_FUNC_MAP, +) +from ee.danswer.external_permissions.permission_sync_utils import ( + get_docs_with_additional_info, +) + +logger = setup_logger() + + +def run_permission_sync_entrypoint( + db_session: Session, + cc_pair_id: int, +) -> None: + # TODO: seperate out group and doc sync + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + if cc_pair is None: + raise ValueError(f"No connector credential pair found for id: {cc_pair_id}") + + source_type = cc_pair.connector.source + + doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) + group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type) + + if doc_sync_func is None: + raise ValueError( + f"No permission sync function found for source type: {source_type}" + ) + + sync_details = cc_pair.auto_sync_options + if sync_details is None: + raise ValueError(f"No auto sync options found for source type: {source_type}") + + # If the source type is not polling, we only fetch the permissions every + # _FULL_FETCH_PERIOD_IN_SECONDS seconds + full_fetch_period = FULL_FETCH_PERIOD_IN_SECONDS[source_type] + if full_fetch_period is not None: + last_sync = cc_pair.last_time_perm_sync + if ( + last_sync + and ( + datetime.now(timezone.utc) - last_sync.replace(tzinfo=timezone.utc) + ).total_seconds() + < full_fetch_period + ): + return + + # Here we run the connector to grab all the ids + # this may grab ids before they are indexed but that is fine because + # we create a document in postgres to hold the permissions info + # until the indexing job has a chance to run + docs_with_additional_info = get_docs_with_additional_info( + db_session=db_session, + cc_pair=cc_pair, + ) + + # This function updates: + # - the user_email <-> external_user_group_id mapping + # in postgres without committing + logger.debug(f"Syncing groups for {source_type}") + if group_sync_func is not None: + group_sync_func( + db_session, + cc_pair, + docs_with_additional_info, + sync_details, + ) + + # This function updates: + # - the user_email <-> document mapping + # - the external_user_group_id <-> document mapping + # in postgres without committing + logger.debug(f"Syncing docs for {source_type}") + doc_sync_func( + db_session, + cc_pair, + docs_with_additional_info, + sync_details, + ) + + # This function fetches the updated access for the documents + # and returns a dictionary of document_ids and access + # This is the access we want to update vespa with + docs_access = get_access_for_documents( + document_ids=[doc.id for doc in docs_with_additional_info], + db_session=db_session, + ) + + # Then we build the update requests to update vespa + update_reqs = [ + UpdateRequest(document_ids=[doc_id], access=doc_access) + for doc_id, doc_access in docs_access.items() + ] + + # Don't bother sync-ing secondary, it will be sync-ed after switch anyway + search_settings = get_current_search_settings(db_session) + document_index = get_default_document_index( + primary_index_name=search_settings.index_name, + secondary_index_name=None, + ) + + try: + # update vespa + document_index.update(update_reqs) + # update postgres + db_session.commit() + except Exception as e: + logger.error(f"Error updating document index: {e}") + db_session.rollback() diff --git a/backend/ee/danswer/external_permissions/permission_sync_function_map.py b/backend/ee/danswer/external_permissions/permission_sync_function_map.py new file mode 100644 index 000000000..9bedbbc63 --- /dev/null +++ b/backend/ee/danswer/external_permissions/permission_sync_function_map.py @@ -0,0 +1,54 @@ +from collections.abc import Callable +from typing import Any + +from sqlalchemy.orm import Session + +from danswer.configs.constants import DocumentSource +from danswer.db.models import ConnectorCredentialPair +from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync +from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync +from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync +from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync +from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo + +GroupSyncFuncType = Callable[ + [Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]], + None, +] + +DocSyncFuncType = Callable[ + [Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]], + None, +] + +# These functions update: +# - the user_email <-> document mapping +# - the external_user_group_id <-> document mapping +# in postgres without committing +# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK +DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = { + DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync, + DocumentSource.CONFLUENCE: confluence_doc_sync, +} + +# These functions update: +# - the user_email <-> external_user_group_id mapping +# in postgres without committing +# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS +GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = { + DocumentSource.GOOGLE_DRIVE: gdrive_group_sync, + DocumentSource.CONFLUENCE: confluence_group_sync, +} + + +# None means that the connector supports polling from last_time_perm_sync to now +FULL_FETCH_PERIOD_IN_SECONDS: dict[DocumentSource, int | None] = { + # Polling is supported + DocumentSource.GOOGLE_DRIVE: None, + # Polling is not supported so we fetch all doc permissions every 10 minutes + DocumentSource.CONFLUENCE: 10 * 60, +} + + +def check_if_valid_sync_source(source_type: DocumentSource) -> bool: + return source_type in DOC_PERMISSIONS_FUNC_MAP diff --git a/backend/ee/danswer/external_permissions/permission_sync_utils.py b/backend/ee/danswer/external_permissions/permission_sync_utils.py new file mode 100644 index 000000000..7be817bdb --- /dev/null +++ b/backend/ee/danswer/external_permissions/permission_sync_utils.py @@ -0,0 +1,56 @@ +from datetime import datetime +from datetime import timezone +from typing import Any + +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.connectors.factory import instantiate_connector +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.models import InputType +from danswer.db.models import ConnectorCredentialPair +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class DocsWithAdditionalInfo(BaseModel): + id: str + additional_info: Any + + +def get_docs_with_additional_info( + db_session: Session, + cc_pair: ConnectorCredentialPair, +) -> list[DocsWithAdditionalInfo]: + # Get all document ids that need their permissions updated + runnable_connector = instantiate_connector( + db_session=db_session, + source=cc_pair.connector.source, + input_type=InputType.POLL, + connector_specific_config=cc_pair.connector.connector_specific_config, + credential=cc_pair.credential, + ) + + assert isinstance(runnable_connector, PollConnector) + + current_time = datetime.now(timezone.utc) + start_time = ( + cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp() + if cc_pair.last_time_perm_sync + else 0 + ) + cc_pair.last_time_perm_sync = current_time + + doc_batch_generator = runnable_connector.poll_source( + start=start_time, end=current_time.timestamp() + ) + + docs_with_additional_info = [ + DocsWithAdditionalInfo(id=doc.id, additional_info=doc.additional_info) + for doc_batch in doc_batch_generator + for doc in doc_batch + ] + logger.debug(f"Docs with additional info: {len(docs_with_additional_info)}") + + return docs_with_additional_info diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py new file mode 100644 index 000000000..96b6b4a01 --- /dev/null +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -0,0 +1,166 @@ +""" +launch: +- api server +- postgres +- vespa +- model server (this is only needed so the api server can startup, no embedding is done) + +Run this script to seed the database with dummy documents. +Then run test_query_times.py to test query times. +""" +import random +from datetime import datetime + +from danswer.access.models import DocumentAccess +from danswer.configs.constants import DocumentSource +from danswer.connectors.models import Document +from danswer.db.engine import get_session_context_manager +from danswer.db.search_settings import get_current_search_settings +from danswer.document_index.vespa.index import VespaIndex +from danswer.indexing.models import ChunkEmbedding +from danswer.indexing.models import DocMetadataAwareIndexChunk +from danswer.indexing.models import IndexChunk +from danswer.utils.timing import log_function_time +from shared_configs.model_server_models import Embedding + + +TOTAL_DOC_SETS = 8 +TOTAL_ACL_ENTRIES_PER_CATEGORY = 80 + + +def generate_random_embedding(dim: int) -> Embedding: + return [random.uniform(-1, 1) for _ in range(dim)] + + +def generate_random_identifier() -> str: + return f"dummy_doc_{random.randint(1, 1000)}" + + +def generate_dummy_chunk( + doc_id: str, + chunk_id: int, + embedding_dim: int, + number_of_acl_entries: int, + number_of_document_sets: int, +) -> DocMetadataAwareIndexChunk: + document = Document( + id=doc_id, + source=DocumentSource.GOOGLE_DRIVE, + sections=[], + metadata={}, + semantic_identifier=generate_random_identifier(), + ) + + chunk = IndexChunk( + chunk_id=chunk_id, + blurb=f"Blurb for chunk {chunk_id} of document {doc_id}.", + content=f"Content for chunk {chunk_id} of document {doc_id}. This is dummy text for testing purposes.", + source_links={}, + section_continuation=False, + source_document=document, + title_prefix=f"Title prefix for doc {doc_id}", + metadata_suffix_semantic="", + metadata_suffix_keyword="", + mini_chunk_texts=None, + embeddings=ChunkEmbedding( + full_embedding=generate_random_embedding(embedding_dim), + mini_chunk_embeddings=[], + ), + title_embedding=generate_random_embedding(embedding_dim), + ) + + document_set_names = [] + for i in range(number_of_document_sets): + document_set_names.append(f"Document Set {i}") + + user_emails: set[str | None] = set() + user_groups: set[str] = set() + external_user_emails: set[str] = set() + external_user_group_ids: set[str] = set() + for i in range(number_of_acl_entries): + user_emails.add(f"user_{i}@example.com") + user_groups.add(f"group_{i}") + external_user_emails.add(f"external_user_{i}@example.com") + external_user_group_ids.add(f"external_group_{i}") + + return DocMetadataAwareIndexChunk.from_index_chunk( + index_chunk=chunk, + access=DocumentAccess( + user_emails=user_emails, + user_groups=user_groups, + external_user_emails=external_user_emails, + external_user_group_ids=external_user_group_ids, + is_public=random.choice([True, False]), + ), + document_sets={document_set for document_set in document_set_names}, + boost=random.randint(-1, 1), + ) + + +@log_function_time() +def do_insertion( + vespa_index: VespaIndex, all_chunks: list[DocMetadataAwareIndexChunk] +) -> None: + insertion_records = vespa_index.index(all_chunks) + print(f"Indexed {len(insertion_records)} documents.") + print( + f"New documents: {sum(1 for record in insertion_records if not record.already_existed)}" + ) + print( + f"Existing documents updated: {sum(1 for record in insertion_records if record.already_existed)}" + ) + + +@log_function_time() +def seed_dummy_docs( + number_of_document_sets: int, + number_of_acl_entries: int, + num_docs: int = 1000, + chunks_per_doc: int = 5, + batch_size: int = 100, +) -> None: + with get_session_context_manager() as db_session: + search_settings = get_current_search_settings(db_session) + index_name = search_settings.index_name + embedding_dim = search_settings.model_dim + + vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None) + print(index_name) + + all_chunks = [] + chunk_count = 0 + for doc_num in range(num_docs): + doc_id = f"dummy_doc_{doc_num}_{datetime.now().isoformat()}" + for chunk_num in range(chunks_per_doc): + chunk = generate_dummy_chunk( + doc_id=doc_id, + chunk_id=chunk_num, + embedding_dim=embedding_dim, + number_of_acl_entries=number_of_acl_entries, + number_of_document_sets=number_of_document_sets, + ) + all_chunks.append(chunk) + chunk_count += 1 + + if len(all_chunks) >= chunks_per_doc * batch_size: + do_insertion(vespa_index, all_chunks) + print( + f"Indexed {chunk_count} chunks out of {num_docs * chunks_per_doc}." + ) + print( + f"percentage: {chunk_count / (num_docs * chunks_per_doc) * 100:.2f}% \n" + ) + all_chunks = [] + + if all_chunks: + do_insertion(vespa_index, all_chunks) + + +if __name__ == "__main__": + seed_dummy_docs( + number_of_document_sets=TOTAL_DOC_SETS, + number_of_acl_entries=TOTAL_ACL_ENTRIES_PER_CATEGORY, + num_docs=100000, + chunks_per_doc=5, + batch_size=1000, + ) diff --git a/backend/scripts/query_time_check/test_query_times.py b/backend/scripts/query_time_check/test_query_times.py new file mode 100644 index 000000000..c839fc610 --- /dev/null +++ b/backend/scripts/query_time_check/test_query_times.py @@ -0,0 +1,122 @@ +""" +RUN THIS AFTER SEED_DUMMY_DOCS.PY +""" +import random +import time + +from danswer.configs.constants import DocumentSource +from danswer.configs.model_configs import DOC_EMBEDDING_DIM +from danswer.db.engine import get_session_context_manager +from danswer.db.search_settings import get_current_search_settings +from danswer.document_index.vespa.index import VespaIndex +from danswer.search.models import IndexFilters +from scripts.query_time_check.seed_dummy_docs import TOTAL_ACL_ENTRIES_PER_CATEGORY +from scripts.query_time_check.seed_dummy_docs import TOTAL_DOC_SETS +from shared_configs.model_server_models import Embedding + +# make sure these are smaller than TOTAL_ACL_ENTRIES_PER_CATEGORY and TOTAL_DOC_SETS, respectively +NUMBER_OF_ACL_ENTRIES_PER_QUERY = 6 +NUMBER_OF_DOC_SETS_PER_QUERY = 2 + + +def get_slowest_99th_percentile(results: list[float]) -> float: + return sorted(results)[int(0.99 * len(results))] + + +# Generate random filters +def _random_filters() -> IndexFilters: + """ + Generate random filters for the query containing: + - NUMBER_OF_ACL_ENTRIES_PER_QUERY user emails + - NUMBER_OF_ACL_ENTRIES_PER_QUERY groups + - NUMBER_OF_ACL_ENTRIES_PER_QUERY external groups + - NUMBER_OF_DOC_SETS_PER_QUERY document sets + """ + access_control_list = [ + f"user_email:user_{random.randint(0, TOTAL_ACL_ENTRIES_PER_CATEGORY - 1)}@example.com", + ] + acl_indices = random.sample( + range(TOTAL_ACL_ENTRIES_PER_CATEGORY), NUMBER_OF_ACL_ENTRIES_PER_QUERY + ) + for i in acl_indices: + access_control_list.append(f"group:group_{acl_indices[i]}") + access_control_list.append(f"external_group:external_group_{acl_indices[i]}") + + doc_sets = [] + doc_set_indices = random.sample( + range(TOTAL_DOC_SETS), NUMBER_OF_ACL_ENTRIES_PER_QUERY + ) + for i in doc_set_indices: + doc_sets.append(f"document_set:Document Set {doc_set_indices[i]}") + + return IndexFilters( + source_type=[DocumentSource.GOOGLE_DRIVE], + document_set=doc_sets, + tags=[], + access_control_list=access_control_list, + ) + + +def test_hybrid_retrieval_times( + number_of_queries: int, +) -> None: + with get_session_context_manager() as db_session: + search_settings = get_current_search_settings(db_session) + index_name = search_settings.index_name + + vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None) + + # Generate random queries + queries = [f"Random Query {i}" for i in range(number_of_queries)] + + # Generate random embeddings + embeddings = [ + Embedding([random.random() for _ in range(DOC_EMBEDDING_DIM)]) + for _ in range(number_of_queries) + ] + + total_time = 0.0 + results = [] + for i in range(number_of_queries): + start_time = time.time() + + vespa_index.hybrid_retrieval( + query=queries[i], + query_embedding=embeddings[i], + final_keywords=None, + filters=_random_filters(), + hybrid_alpha=0.5, + time_decay_multiplier=1.0, + num_to_retrieve=50, + offset=0, + title_content_ratio=0.5, + ) + + end_time = time.time() + query_time = end_time - start_time + total_time += query_time + results.append(query_time) + + print(f"Query {i+1}: {query_time:.4f} seconds") + + avg_time = total_time / number_of_queries + fast_time = min(results) + slow_time = max(results) + ninety_ninth_percentile = get_slowest_99th_percentile(results) + # Write results to a file + _OUTPUT_PATH = "query_times_results_large_more.txt" + with open(_OUTPUT_PATH, "w") as f: + f.write(f"Average query time: {avg_time:.4f} seconds\n") + f.write(f"Fastest query: {fast_time:.4f} seconds\n") + f.write(f"Slowest query: {slow_time:.4f} seconds\n") + f.write(f"99th percentile: {ninety_ninth_percentile:.4f} seconds\n") + print(f"Results written to {_OUTPUT_PATH}") + + print(f"\nAverage query time: {avg_time:.4f} seconds") + print(f"Fastest query: {fast_time:.4f} seconds") + print(f"Slowest query: {max(results):.4f} seconds") + print(f"99th percentile: {get_slowest_99th_percentile(results):.4f} seconds") + + +if __name__ == "__main__": + test_hybrid_retrieval_times(number_of_queries=1000) diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py index 6498252bb..5512eca71 100644 --- a/backend/tests/integration/common_utils/managers/cc_pair.py +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -5,6 +5,7 @@ from uuid import uuid4 import requests from danswer.connectors.models import InputType +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorIndexingStatus @@ -22,7 +23,7 @@ def _cc_pair_creator( connector_id: int, credential_id: int, name: str | None = None, - is_public: bool = True, + access_type: AccessType = AccessType.PUBLIC, groups: list[int] | None = None, user_performing_action: TestUser | None = None, ) -> TestCCPair: @@ -30,7 +31,7 @@ def _cc_pair_creator( request = { "name": name, - "is_public": is_public, + "access_type": access_type, "groups": groups or [], } @@ -47,7 +48,7 @@ def _cc_pair_creator( name=name, connector_id=connector_id, credential_id=credential_id, - is_public=is_public, + access_type=access_type, groups=groups or [], ) @@ -56,7 +57,7 @@ class CCPairManager: @staticmethod def create_from_scratch( name: str | None = None, - is_public: bool = True, + access_type: AccessType = AccessType.PUBLIC, groups: list[int] | None = None, source: DocumentSource = DocumentSource.FILE, input_type: InputType = InputType.LOAD_STATE, @@ -69,7 +70,7 @@ class CCPairManager: source=source, input_type=input_type, connector_specific_config=connector_specific_config, - is_public=is_public, + is_public=(access_type == AccessType.PUBLIC), groups=groups, user_performing_action=user_performing_action, ) @@ -77,7 +78,7 @@ class CCPairManager: credential_json=credential_json, name=name, source=source, - curator_public=is_public, + curator_public=(access_type == AccessType.PUBLIC), groups=groups, user_performing_action=user_performing_action, ) @@ -85,7 +86,7 @@ class CCPairManager: connector_id=connector.id, credential_id=credential.id, name=name, - is_public=is_public, + access_type=access_type, groups=groups, user_performing_action=user_performing_action, ) @@ -95,7 +96,7 @@ class CCPairManager: connector_id: int, credential_id: int, name: str | None = None, - is_public: bool = True, + access_type: AccessType = AccessType.PUBLIC, groups: list[int] | None = None, user_performing_action: TestUser | None = None, ) -> TestCCPair: @@ -103,7 +104,7 @@ class CCPairManager: connector_id=connector_id, credential_id=credential_id, name=name, - is_public=is_public, + access_type=access_type, groups=groups, user_performing_action=user_performing_action, ) @@ -172,7 +173,7 @@ class CCPairManager: retrieved_cc_pair.name == cc_pair.name and retrieved_cc_pair.connector.id == cc_pair.connector_id and retrieved_cc_pair.credential.id == cc_pair.credential_id - and retrieved_cc_pair.public_doc == cc_pair.is_public + and retrieved_cc_pair.access_type == cc_pair.access_type and set(retrieved_cc_pair.groups) == set(cc_pair.groups) ): return diff --git a/backend/tests/integration/common_utils/managers/document.py b/backend/tests/integration/common_utils/managers/document.py index 3f691eca8..dcd8def5c 100644 --- a/backend/tests/integration/common_utils/managers/document.py +++ b/backend/tests/integration/common_utils/managers/document.py @@ -3,6 +3,7 @@ from uuid import uuid4 import requests from danswer.configs.constants import DocumentSource +from danswer.db.enums import AccessType from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.constants import NUM_DOCS @@ -22,7 +23,7 @@ def _verify_document_permissions( ) -> None: acl_keys = set(retrieved_doc["access_control_list"].keys()) print(f"ACL keys: {acl_keys}") - if cc_pair.is_public: + if cc_pair.access_type == AccessType.PUBLIC: if "PUBLIC" not in acl_keys: raise ValueError( f"Document {retrieved_doc['document_id']} is public but" @@ -30,10 +31,10 @@ def _verify_document_permissions( ) if doc_creating_user is not None: - if f"user_id:{doc_creating_user.id}" not in acl_keys: + if f"user_email:{doc_creating_user.email}" not in acl_keys: raise ValueError( f"Document {retrieved_doc['document_id']} was created by user" - f" {doc_creating_user.id} but does not have the user_id:{doc_creating_user.id} ACL key" + f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key" ) if group_names is not None: diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py index 2d8744327..db46409ab 100644 --- a/backend/tests/integration/common_utils/test_models.py +++ b/backend/tests/integration/common_utils/test_models.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from pydantic import Field from danswer.auth.schemas import UserRole +from danswer.db.enums import AccessType from danswer.search.enums import RecencyBiasSetting from danswer.server.documents.models import DocumentSource from danswer.server.documents.models import InputType @@ -67,7 +68,7 @@ class TestCCPair(BaseModel): name: str connector_id: int credential_id: int - is_public: bool + access_type: AccessType groups: list[int] documents: list[SimpleTestDocument] = Field(default_factory=list) diff --git a/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py b/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py index c52c5826e..c29a9faa5 100644 --- a/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py +++ b/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py @@ -5,6 +5,7 @@ the permissions of the curator manipulating connector-credential pairs. import pytest from requests.exceptions import HTTPError +from danswer.db.enums import AccessType from danswer.server.documents.models import DocumentSource from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.connector import ConnectorManager @@ -91,8 +92,8 @@ def test_cc_pair_permissions(reset: None) -> None: connector_id=connector_1.id, credential_id=credential_1.id, name="invalid_cc_pair_1", + access_type=AccessType.PUBLIC, groups=[user_group_1.id], - is_public=True, user_performing_action=curator, ) @@ -103,8 +104,8 @@ def test_cc_pair_permissions(reset: None) -> None: connector_id=connector_1.id, credential_id=credential_1.id, name="invalid_cc_pair_2", + access_type=AccessType.PRIVATE, groups=[user_group_1.id, user_group_2.id], - is_public=False, user_performing_action=curator, ) @@ -115,8 +116,8 @@ def test_cc_pair_permissions(reset: None) -> None: connector_id=connector_1.id, credential_id=credential_1.id, name="invalid_cc_pair_2", + access_type=AccessType.PRIVATE, groups=[], - is_public=False, user_performing_action=curator, ) @@ -129,8 +130,8 @@ def test_cc_pair_permissions(reset: None) -> None: # connector_id=connector_2.id, # credential_id=credential_1.id, # name="invalid_cc_pair_3", + # access_type=AccessType.PRIVATE, # groups=[user_group_1.id], - # is_public=False, # user_performing_action=curator, # ) @@ -141,8 +142,8 @@ def test_cc_pair_permissions(reset: None) -> None: connector_id=connector_1.id, credential_id=credential_2.id, name="invalid_cc_pair_4", + access_type=AccessType.PRIVATE, groups=[user_group_1.id], - is_public=False, user_performing_action=curator, ) @@ -154,8 +155,8 @@ def test_cc_pair_permissions(reset: None) -> None: name="valid_cc_pair", connector_id=connector_1.id, credential_id=credential_1.id, + access_type=AccessType.PRIVATE, groups=[user_group_1.id], - is_public=False, user_performing_action=curator, ) diff --git a/backend/tests/integration/tests/permissions/test_doc_set_permissions.py b/backend/tests/integration/tests/permissions/test_doc_set_permissions.py index 412b5d41f..b8913485e 100644 --- a/backend/tests/integration/tests/permissions/test_doc_set_permissions.py +++ b/backend/tests/integration/tests/permissions/test_doc_set_permissions.py @@ -1,6 +1,7 @@ import pytest from requests.exceptions import HTTPError +from danswer.db.enums import AccessType from danswer.server.documents.models import DocumentSource from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.document_set import DocumentSetManager @@ -47,14 +48,14 @@ def test_doc_set_permissions_setup(reset: None) -> None: # Admin creates a cc_pair private_cc_pair = CCPairManager.create_from_scratch( - is_public=False, + access_type=AccessType.PRIVATE, source=DocumentSource.INGESTION_API, user_performing_action=admin_user, ) # Admin creates a public cc_pair public_cc_pair = CCPairManager.create_from_scratch( - is_public=True, + access_type=AccessType.PUBLIC, source=DocumentSource.INGESTION_API, user_performing_action=admin_user, ) diff --git a/backend/tests/integration/tests/permissions/test_whole_curator_flow.py b/backend/tests/integration/tests/permissions/test_whole_curator_flow.py index 878ba1e17..00b154571 100644 --- a/backend/tests/integration/tests/permissions/test_whole_curator_flow.py +++ b/backend/tests/integration/tests/permissions/test_whole_curator_flow.py @@ -1,6 +1,7 @@ """ This test tests the happy path for curator permissions """ +from danswer.db.enums import AccessType from danswer.db.models import UserRole from danswer.server.documents.models import DocumentSource from tests.integration.common_utils.managers.cc_pair import CCPairManager @@ -64,8 +65,8 @@ def test_whole_curator_flow(reset: None) -> None: connector_id=test_connector.id, credential_id=test_credential.id, name="curator_test_cc_pair", + access_type=AccessType.PRIVATE, groups=[user_group_1.id], - is_public=False, user_performing_action=curator, ) diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index 98d460161..614ae6c92 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -1,6 +1,6 @@ "use client"; -import { errorHandlingFetcher } from "@/lib/fetcher"; +import { FetchError, errorHandlingFetcher } from "@/lib/fetcher"; import useSWR, { mutate } from "swr"; import { HealthCheckBanner } from "@/components/health/healthcheck"; @@ -12,7 +12,6 @@ import { useFormContext } from "@/components/context/FormContext"; import { getSourceDisplayName } from "@/lib/sources"; import { SourceIcon } from "@/components/SourceIcon"; import { useState } from "react"; -import { submitConnector } from "@/components/admin/connectors/ConnectorForm"; import { deleteCredential, linkCredential } from "@/lib/credential"; import { submitFiles } from "./pages/utils/files"; import { submitGoogleSite } from "./pages/utils/google_site"; @@ -38,7 +37,8 @@ import { useGoogleDriveCredentials, } from "./pages/utils/hooks"; import { Formik } from "formik"; -import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; +import { AccessTypeForm } from "@/components/admin/connectors/AccessTypeForm"; +import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTypeGroupSelector"; import NavigationRow from "./NavigationRow"; export interface AdvancedConfig { @@ -46,6 +46,64 @@ export interface AdvancedConfig { pruneFreq: number; indexingStart: string; } +import { Connector, ConnectorBase } from "@/lib/connectors/connectors"; + +const BASE_CONNECTOR_URL = "/api/manage/admin/connector"; + +export async function submitConnector( + connector: ConnectorBase, + connectorId?: number, + fakeCredential?: boolean, + isPublicCcpair?: boolean // exclusively for mock credentials, when also need to specify ccpair details +): Promise<{ message: string; isSuccess: boolean; response?: Connector }> { + const isUpdate = connectorId !== undefined; + if (!connector.connector_specific_config) { + connector.connector_specific_config = {} as T; + } + + try { + if (fakeCredential) { + const response = await fetch( + "/api/manage/admin/connector-with-mock-credential", + { + method: isUpdate ? "PATCH" : "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ ...connector, is_public: isPublicCcpair }), + } + ); + if (response.ok) { + const responseJson = await response.json(); + return { message: "Success!", isSuccess: true, response: responseJson }; + } else { + const errorData = await response.json(); + return { message: `Error: ${errorData.detail}`, isSuccess: false }; + } + } else { + const response = await fetch( + BASE_CONNECTOR_URL + (isUpdate ? `/${connectorId}` : ""), + { + method: isUpdate ? "PATCH" : "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(connector), + } + ); + + if (response.ok) { + const responseJson = await response.json(); + return { message: "Success!", isSuccess: true, response: responseJson }; + } else { + const errorData = await response.json(); + return { message: `Error: ${errorData.detail}`, isSuccess: false }; + } + } + } catch (error) { + return { message: `Error: ${error}`, isSuccess: false }; + } +} export default function AddConnector({ connector, @@ -84,10 +142,35 @@ export default function AddConnector({ const { liveGDriveCredential } = useGoogleDriveCredentials(); const { liveGmailCredential } = useGmailCredentials(); + const { + data: appCredentialData, + isLoading: isAppCredentialLoading, + error: isAppCredentialError, + } = useSWR<{ client_id: string }, FetchError>( + "/api/manage/admin/connector/google-drive/app-credential", + errorHandlingFetcher + ); + const { + data: serviceAccountKeyData, + isLoading: isServiceAccountKeyLoading, + error: isServiceAccountKeyError, + } = useSWR<{ service_account_email: string }, FetchError>( + "/api/manage/admin/connector/google-drive/service-account-key", + errorHandlingFetcher + ); + // Check if credential is activated const credentialActivated = - (connector === "google_drive" && liveGDriveCredential) || - (connector === "gmail" && liveGmailCredential) || + (connector === "google_drive" && + (liveGDriveCredential || + appCredentialData || + serviceAccountKeyData || + currentCredential)) || + (connector === "gmail" && + (liveGmailCredential || + appCredentialData || + serviceAccountKeyData || + currentCredential)) || currentCredential; // Check if there are no credentials @@ -159,10 +242,12 @@ export default function AddConnector({ const { name, groups, - is_public: isPublic, + access_type, pruneFreq, indexingStart, refreshFreq, + auto_sync_options, + is_public, ...connector_specific_config } = values; @@ -204,6 +289,7 @@ export default function AddConnector({ advancedConfiguration.refreshFreq, advancedConfiguration.pruneFreq, advancedConfiguration.indexingStart, + values.access_type == "public", name ); if (response) { @@ -219,7 +305,7 @@ export default function AddConnector({ setPopup, setSelectedFiles, name, - isPublic, + access_type == "public", groups ); if (response) { @@ -234,15 +320,15 @@ export default function AddConnector({ input_type: connector == "web" ? "load_state" : "poll", // single case name: name, source: connector, + is_public: access_type == "public", refresh_freq: advancedConfiguration.refreshFreq || null, prune_freq: advancedConfiguration.pruneFreq || null, indexing_start: advancedConfiguration.indexingStart || null, - is_public: isPublic, groups: groups, }, undefined, credentialActivated ? false : true, - isPublic + access_type == "public" ); // If no credential if (!credentialActivated) { @@ -261,8 +347,9 @@ export default function AddConnector({ response.id, credential?.id!, name, - isPublic, - groups + access_type, + groups, + auto_sync_options ); if (linkCredentialResponse.ok) { onSuccess(); @@ -366,11 +453,8 @@ export default function AddConnector({ selectedFiles={selectedFiles} /> - + + )} diff --git a/web/src/app/admin/connectors/[connector]/Sidebar.tsx b/web/src/app/admin/connectors/[connector]/Sidebar.tsx index e0c85029f..4b7f2970d 100644 --- a/web/src/app/admin/connectors/[connector]/Sidebar.tsx +++ b/web/src/app/admin/connectors/[connector]/Sidebar.tsx @@ -6,12 +6,14 @@ import { Logo } from "@/components/Logo"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import { credentialTemplates } from "@/lib/connectors/credentials"; import Link from "next/link"; +import { useUser } from "@/components/user/UserProvider"; import { useContext } from "react"; export default function Sidebar() { const { formStep, setFormStep, connector, allowAdvanced, allowCreate } = useFormContext(); const combinedSettings = useContext(SettingsContext); + const { isLoadingUser, isAdmin } = useUser(); if (!combinedSettings) { return null; } @@ -59,7 +61,9 @@ export default function Sidebar() { className="w-full p-2 bg-white border-border border rounded items-center hover:bg-background-200 cursor-pointer transition-all duration-150 flex gap-x-2" > -

Admin Page

+

+ {isAdmin ? "Admin Page" : "Curator Page"} +

diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx index 98e517e43..8fc1fa767 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx @@ -222,36 +222,46 @@ export const DriveJsonUploadSection = ({ Found existing app credentials with the following Client ID:

{appCredentialData.client_id}

-
- If you want to update these credentials, delete the existing - credentials through the button below, and then upload a new - credentials JSON. -
- + {isAdmin ? ( + <> +
+ If you want to update these credentials, delete the existing + credentials through the button below, and then upload a new + credentials JSON. +
+ + + ) : ( +
+ To change these credentials, please contact an administrator. +
+ )} ); } diff --git a/web/src/app/admin/connectors/[connector]/pages/utils/files.ts b/web/src/app/admin/connectors/[connector]/pages/utils/files.ts index 7535eec35..c7b36ac0f 100644 --- a/web/src/app/admin/connectors/[connector]/pages/utils/files.ts +++ b/web/src/app/admin/connectors/[connector]/pages/utils/files.ts @@ -80,7 +80,7 @@ export const submitFiles = async ( connector.id, credentialId, name, - isPublic, + isPublic ? "public" : "private", groups ); if (!credentialResponse.ok) { diff --git a/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts b/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts index 11d7f46ec..abc1097cc 100644 --- a/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts +++ b/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts @@ -10,6 +10,7 @@ export const submitGoogleSite = async ( refreshFreq: number, pruneFreq: number, indexingStart: Date, + is_public: boolean, name?: string ) => { const uploadCreateAndTriggerConnector = async () => { @@ -42,6 +43,7 @@ export const submitGoogleSite = async ( base_url: base_url, zip_path: filePaths[0], }, + is_public: is_public, refresh_freq: refreshFreq, prune_freq: pruneFreq, indexing_start: indexingStart, diff --git a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx index 2778103e3..fb7e56cc6 100644 --- a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx +++ b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx @@ -155,7 +155,7 @@ export const DocumentSetCreationForm = ({ // Filter visible cc pairs const visibleCcPairs = localCcPairs.filter( (ccPair) => - ccPair.public_doc || + ccPair.access_type === "public" || (ccPair.groups.length > 0 && props.values.groups.every((group) => ccPair.groups.includes(group) @@ -228,7 +228,7 @@ export const DocumentSetCreationForm = ({ // Filter non-visible cc pairs const nonVisibleCcPairs = localCcPairs.filter( (ccPair) => - !ccPair.public_doc && + !(ccPair.access_type === "public") && (ccPair.groups.length === 0 || !props.values.groups.every((group) => ccPair.groups.includes(group) diff --git a/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx b/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx index 2a38e5edd..2b78e11de 100644 --- a/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx +++ b/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx @@ -232,7 +232,7 @@ function ConnectorRow({ {getActivityBadge()} {isPaidEnterpriseFeaturesEnabled && ( - {ccPairsIndexingStatus.public_doc ? ( + {ccPairsIndexingStatus.access_type === "public" ? ( status.cc_pair_status === ConnectorCredentialPairStatus.ACTIVE ).length, - public: statuses.filter((status) => status.public_doc).length, + public: statuses.filter((status) => status.access_type === "public") + .length, totalDocsIndexed: statuses.reduce( (sum, status) => sum + status.docs_indexed, 0 @@ -420,7 +421,7 @@ export function CCPairIndexingStatusTable({ credential_json: {}, admin_public: false, }, - public_doc: true, + access_type: "public", docs_indexed: 1000, last_success: "2023-07-01T12:00:00Z", last_finished_status: "success", diff --git a/web/src/app/ee/admin/groups/ConnectorEditor.tsx b/web/src/app/ee/admin/groups/ConnectorEditor.tsx index d632e5f6a..ab6bbb0be 100644 --- a/web/src/app/ee/admin/groups/ConnectorEditor.tsx +++ b/web/src/app/ee/admin/groups/ConnectorEditor.tsx @@ -16,7 +16,7 @@ export const ConnectorEditor = ({
{allCCPairs // remove public docs, since they don't make sense as part of a group - .filter((ccPair) => !ccPair.public_doc) + .filter((ccPair) => !(ccPair.access_type === "public")) .map((ccPair) => { const ind = selectedCCPairIds.indexOf(ccPair.cc_pair_id); let isSelected = ind !== -1; diff --git a/web/src/app/ee/admin/groups/UserGroupCreationForm.tsx b/web/src/app/ee/admin/groups/UserGroupCreationForm.tsx index c87bfb5e5..ec1ac52e6 100644 --- a/web/src/app/ee/admin/groups/UserGroupCreationForm.tsx +++ b/web/src/app/ee/admin/groups/UserGroupCreationForm.tsx @@ -28,6 +28,11 @@ export const UserGroupCreationForm = ({ }: UserGroupCreationFormProps) => { const isUpdate = existingUserGroup !== undefined; + // Filter out ccPairs that aren't access_type "private" + const privateCcPairs = ccPairs.filter( + (ccPair) => ccPair.access_type === "private" + ); + return (
@@ -96,7 +101,7 @@ export const UserGroupCreationForm = ({

- Select which connectors this group has access to: + Select which private connectors this group has access to:

All documents indexed by the selected connectors will be @@ -104,7 +109,7 @@ export const UserGroupCreationForm = ({

setFieldValue("cc_pair_ids", ccPairsIds) diff --git a/web/src/app/ee/admin/groups/[groupId]/AddConnectorForm.tsx b/web/src/app/ee/admin/groups/[groupId]/AddConnectorForm.tsx index 3e1896edd..8b5166ed8 100644 --- a/web/src/app/ee/admin/groups/[groupId]/AddConnectorForm.tsx +++ b/web/src/app/ee/admin/groups/[groupId]/AddConnectorForm.tsx @@ -76,7 +76,7 @@ export const AddConnectorForm: React.FC = ({ .includes(ccPair.cc_pair_id) ) // remove public docs, since they don't make sense as part of a group - .filter((ccPair) => !ccPair.public_doc) + .filter((ccPair) => !(ccPair.access_type === "public")) .map((ccPair) => { return { name: ccPair.name?.toString() || "", diff --git a/web/src/components/admin/connectors/AccessTypeForm.tsx b/web/src/components/admin/connectors/AccessTypeForm.tsx new file mode 100644 index 000000000..f3a7eef11 --- /dev/null +++ b/web/src/components/admin/connectors/AccessTypeForm.tsx @@ -0,0 +1,88 @@ +import { DefaultDropdown } from "@/components/Dropdown"; +import { + AccessType, + ValidAutoSyncSources, + ConfigurableSources, + validAutoSyncSources, +} from "@/lib/types"; +import { Text, Title } from "@tremor/react"; +import { useUser } from "@/components/user/UserProvider"; +import { useField } from "formik"; +import { AutoSyncOptions } from "./AutoSyncOptions"; + +function isValidAutoSyncSource( + value: ConfigurableSources +): value is ValidAutoSyncSources { + return validAutoSyncSources.includes(value as ValidAutoSyncSources); +} + +export function AccessTypeForm({ + connector, +}: { + connector: ConfigurableSources; +}) { + const [access_type, meta, access_type_helpers] = + useField("access_type"); + + const isAutoSyncSupported = isValidAutoSyncSource(connector); + const { isLoadingUser, isAdmin } = useUser(); + + const options = [ + { + name: "Private", + value: "private", + description: + "Only users who have expliticly been given access to this connector (through the User Groups page) can access the documents pulled in by this connector", + }, + ]; + + if (isAdmin) { + options.push({ + name: "Public", + value: "public", + description: + "Everyone with an account on Danswer can access the documents pulled in by this connector", + }); + } + + if (isAutoSyncSupported && isAdmin) { + options.push({ + name: "Auto Sync", + value: "sync", + description: + "We will automatically sync permissions from the source. A document will be searchable in Danswer if and only if the user performing the search has permission to access the document in the source.", + }); + } + + return ( +
+
+ +
+

+ Control who has access to the documents indexed by this connector. +

+ + {isAdmin && ( + <> + + access_type_helpers.setValue(selected as AccessType) + } + includeDefault={false} + /> + + {access_type.value === "sync" && isAutoSyncSupported && ( +
+ +
+ )} + + )} +
+ ); +} diff --git a/web/src/components/admin/connectors/AccessTypeGroupSelector.tsx b/web/src/components/admin/connectors/AccessTypeGroupSelector.tsx new file mode 100644 index 000000000..4a165515c --- /dev/null +++ b/web/src/components/admin/connectors/AccessTypeGroupSelector.tsx @@ -0,0 +1,147 @@ +import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; +import React, { useState, useEffect } from "react"; +import { FieldArray, ArrayHelpers, ErrorMessage, useField } from "formik"; +import { Text, Divider } from "@tremor/react"; +import { FiUsers } from "react-icons/fi"; +import { UserGroup, User, UserRole } from "@/lib/types"; +import { useUserGroups } from "@/lib/hooks"; +import { AccessType } from "@/lib/types"; +import { useUser } from "@/components/user/UserProvider"; + +// This should be included for all forms that require groups / public access +// to be set, and access to this / permissioning should be handled within this component itself. +export function AccessTypeGroupSelector({}: {}) { + const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); + const { isAdmin, user, isLoadingUser, isCurator } = useUser(); + const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled(); + const [shouldHideContent, setShouldHideContent] = useState(false); + + const [access_type, meta, access_type_helpers] = + useField("access_type"); + const [groups, groups_meta, groups_helpers] = useField("groups"); + + useEffect(() => { + if (user && userGroups && isPaidEnterpriseFeaturesEnabled) { + const isUserAdmin = user.role === UserRole.ADMIN; + if (!isPaidEnterpriseFeaturesEnabled) { + access_type_helpers.setValue("public"); + return; + } + if (!isUserAdmin) { + access_type_helpers.setValue("private"); + } + if (userGroups.length === 1 && !isUserAdmin) { + groups_helpers.setValue([userGroups[0].id]); + setShouldHideContent(true); + } else if (access_type.value !== "private") { + groups_helpers.setValue([]); + setShouldHideContent(false); + } else { + setShouldHideContent(false); + } + } + }, [user, userGroups, access_type.value]); + + if (isLoadingUser || userGroupsIsLoading) { + return
Loading...
; + } + if (!isPaidEnterpriseFeaturesEnabled) { + return null; + } + + if (shouldHideContent) { + return ( + <> + {userGroups && ( +
+ This Connector will be assigned to group {userGroups[0].name} + . +
+ )} + + ); + } + + return ( +
+ {(access_type.value === "private" || isCurator) && + userGroups && + userGroups?.length > 0 && ( + <> + +
+
+ Assign group access for this Connector +
+
+ {userGroupsIsLoading ? ( +
+ ) : ( + + {isAdmin ? ( + <> + This Connector will be visible/accessible by the groups + selected below + + ) : ( + <> + Curators must select one or more groups to give access to + this Connector + + )} + + )} + ( +
+ {userGroupsIsLoading ? ( +
+ ) : ( + userGroups && + userGroups.map((userGroup: UserGroup) => { + const ind = groups.value.indexOf(userGroup.id); + let isSelected = ind !== -1; + return ( +
{ + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push(userGroup.id); + } + }} + > +
+ {" "} + {userGroup.name} +
+
+ ); + }) + )} +
+ )} + /> + + + )} +
+ ); +} diff --git a/web/src/components/admin/connectors/AutoSyncOptions.tsx b/web/src/components/admin/connectors/AutoSyncOptions.tsx new file mode 100644 index 000000000..f060ad31f --- /dev/null +++ b/web/src/components/admin/connectors/AutoSyncOptions.tsx @@ -0,0 +1,30 @@ +import { TextFormField } from "@/components/admin/connectors/Field"; +import { useFormikContext } from "formik"; +import { ValidAutoSyncSources } from "@/lib/types"; +import { Divider } from "@tremor/react"; +import { autoSyncConfigBySource } from "@/lib/connectors/AutoSyncOptionFields"; + +export function AutoSyncOptions({ + connectorType, +}: { + connectorType: ValidAutoSyncSources; +}) { + return ( +
+ + <> + {Object.entries(autoSyncConfigBySource[connectorType]).map( + ([key, config]) => ( +
+ +
+ ) + )} + +
+ ); +} diff --git a/web/src/components/admin/connectors/ConnectorForm.tsx b/web/src/components/admin/connectors/ConnectorForm.tsx deleted file mode 100644 index 2befa6b2b..000000000 --- a/web/src/components/admin/connectors/ConnectorForm.tsx +++ /dev/null @@ -1,382 +0,0 @@ -"use client"; - -import React, { useState } from "react"; -import { Formik, Form } from "formik"; -import * as Yup from "yup"; -import { Popup, usePopup } from "./Popup"; -import { ValidInputTypes, ValidSources } from "@/lib/types"; -import { deleteConnectorIfExistsAndIsUnlinked } from "@/lib/connector"; -import { FormBodyBuilder, RequireAtLeastOne } from "./types"; -import { BooleanFormField, TextFormField } from "./Field"; -import { createCredential, linkCredential } from "@/lib/credential"; -import { useSWRConfig } from "swr"; -import { Button, Divider } from "@tremor/react"; -import IsPublicField from "./IsPublicField"; -import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; -import { Connector, ConnectorBase } from "@/lib/connectors/connectors"; - -const BASE_CONNECTOR_URL = "/api/manage/admin/connector"; - -export async function submitConnector( - connector: ConnectorBase, - connectorId?: number, - fakeCredential?: boolean, - isPublicCcpair?: boolean // exclusively for mock credentials, when also need to specify ccpair details -): Promise<{ message: string; isSuccess: boolean; response?: Connector }> { - const isUpdate = connectorId !== undefined; - if (!connector.connector_specific_config) { - connector.connector_specific_config = {} as T; - } - - try { - if (fakeCredential) { - const response = await fetch( - "/api/manage/admin/connector-with-mock-credential", - { - method: isUpdate ? "PATCH" : "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ ...connector, is_public: isPublicCcpair }), - } - ); - if (response.ok) { - const responseJson = await response.json(); - return { message: "Success!", isSuccess: true, response: responseJson }; - } else { - const errorData = await response.json(); - return { message: `Error: ${errorData.detail}`, isSuccess: false }; - } - } else { - const response = await fetch( - BASE_CONNECTOR_URL + (isUpdate ? `/${connectorId}` : ""), - { - method: isUpdate ? "PATCH" : "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(connector), - } - ); - - if (response.ok) { - const responseJson = await response.json(); - return { message: "Success!", isSuccess: true, response: responseJson }; - } else { - const errorData = await response.json(); - return { message: `Error: ${errorData.detail}`, isSuccess: false }; - } - } - } catch (error) { - return { message: `Error: ${error}`, isSuccess: false }; - } -} - -const CCPairNameHaver = Yup.object().shape({ - cc_pair_name: Yup.string().required("Please enter a name for the connector"), -}); - -interface BaseProps { - nameBuilder: (values: T) => string; - ccPairNameBuilder?: (values: T) => string | null; - source: ValidSources; - inputType: ValidInputTypes; - // if specified, will automatically try and link the credential - credentialId?: number; - // If both are specified, will render formBody and then formBodyBuilder - formBody?: JSX.Element | null; - formBodyBuilder?: FormBodyBuilder; - validationSchema: Yup.ObjectSchema; - validate?: (values: T) => Record; - initialValues: T; - onSubmit?: ( - isSuccess: boolean, - responseJson: Connector | undefined - ) => void; - refreshFreq?: number; - pruneFreq?: number; - indexingStart?: Date; - // If specified, then we will create an empty credential and associate - // the connector with it. If credentialId is specified, then this will be ignored - shouldCreateEmptyCredentialForConnector?: boolean; -} - -type ConnectorFormProps = RequireAtLeastOne< - BaseProps, - "formBody" | "formBodyBuilder" ->; - -export function ConnectorForm({ - nameBuilder, - ccPairNameBuilder, - source, - inputType, - credentialId, - formBody, - formBodyBuilder, - validationSchema, - validate, - initialValues, - refreshFreq, - pruneFreq, - indexingStart, - onSubmit, - shouldCreateEmptyCredentialForConnector, -}: ConnectorFormProps): JSX.Element { - const { mutate } = useSWRConfig(); - const { popup, setPopup } = usePopup(); - - // only show this option for EE, since groups are not supported in CE - const showNonPublicOption = usePaidEnterpriseFeaturesEnabled(); - - const shouldHaveNameInput = credentialId !== undefined && !ccPairNameBuilder; - - const ccPairNameInitialValue = shouldHaveNameInput - ? { cc_pair_name: "" } - : {}; - const publicOptionInitialValue = showNonPublicOption - ? { is_public: false } - : {}; - - let finalValidationSchema = - validationSchema as Yup.ObjectSchema; - if (shouldHaveNameInput) { - finalValidationSchema = finalValidationSchema.concat(CCPairNameHaver); - } - if (showNonPublicOption) { - finalValidationSchema = finalValidationSchema.concat( - Yup.object().shape({ - is_public: Yup.boolean(), - }) - ); - } - - return ( - <> - {popup} - { - formikHelpers.setSubmitting(true); - const connectorName = nameBuilder(values); - const connectorConfig = Object.fromEntries( - Object.keys(initialValues).map((key) => [key, values[key]]) - ) as T; - - // best effort check to see if existing connector exists - // delete it if: - // 1. it exists - // 2. AND it has no credentials linked to it - // If the ^ are true, that means things have gotten into a bad - // state, and we should delete the connector to recover - const errorMsg = await deleteConnectorIfExistsAndIsUnlinked({ - source, - name: connectorName, - }); - if (errorMsg) { - setPopup({ - message: `Unable to delete existing connector - ${errorMsg}`, - type: "error", - }); - return; - } - - const { message, isSuccess, response } = await submitConnector({ - name: connectorName, - source, - input_type: inputType, - connector_specific_config: connectorConfig, - refresh_freq: refreshFreq || 0, - prune_freq: pruneFreq ?? null, - indexing_start: indexingStart || null, - }); - - if (!isSuccess || !response) { - setPopup({ message, type: "error" }); - formikHelpers.setSubmitting(false); - return; - } - - let credentialIdToLinkTo = credentialId; - // create empty credential if specified - if ( - shouldCreateEmptyCredentialForConnector && - credentialIdToLinkTo === undefined - ) { - const createCredentialResponse = await createCredential({ - credential_json: {}, - admin_public: true, - source: source, - }); - - if (!createCredentialResponse.ok) { - const errorMsg = await createCredentialResponse.text(); - setPopup({ - message: `Error creating credential for CC Pair - ${errorMsg}`, - type: "error", - }); - formikHelpers.setSubmitting(false); - return; - } - credentialIdToLinkTo = (await createCredentialResponse.json()).id; - } - - if (credentialIdToLinkTo !== undefined) { - const ccPairName = ccPairNameBuilder - ? ccPairNameBuilder(values) - : values.cc_pair_name; - const linkCredentialResponse = await linkCredential( - response.id, - credentialIdToLinkTo, - ccPairName as string, - values.is_public - ); - if (!linkCredentialResponse.ok) { - const linkCredentialErrorMsg = - await linkCredentialResponse.text(); - setPopup({ - message: `Error linking credential - ${linkCredentialErrorMsg}`, - type: "error", - }); - formikHelpers.setSubmitting(false); - return; - } - } - - mutate("/api/manage/admin/connector/indexing-status"); - setPopup({ message, type: isSuccess ? "success" : "error" }); - formikHelpers.setSubmitting(false); - if (isSuccess) { - formikHelpers.resetForm(); - } - if (onSubmit) { - onSubmit(isSuccess, response); - } - }} - > - {({ isSubmitting, values }) => ( -
- {shouldHaveNameInput && ( - - )} - {formBody && formBody} - {formBodyBuilder && formBodyBuilder(values)} - {showNonPublicOption && ( - <> - - - - - )} -
- -
- - )} -
- - ); -} - -interface UpdateConnectorBaseProps { - nameBuilder?: (values: T) => string; - existingConnector: Connector; - // If both are specified, uses formBody - formBody?: JSX.Element | null; - formBodyBuilder?: FormBodyBuilder; - validationSchema: Yup.ObjectSchema; - onSubmit?: (isSuccess: boolean, responseJson?: Connector) => void; -} - -type UpdateConnectorFormProps = RequireAtLeastOne< - UpdateConnectorBaseProps, - "formBody" | "formBodyBuilder" ->; - -export function UpdateConnectorForm({ - nameBuilder, - existingConnector, - formBody, - formBodyBuilder, - validationSchema, - onSubmit, -}: UpdateConnectorFormProps): JSX.Element { - const [popup, setPopup] = useState<{ - message: string; - type: "success" | "error"; - } | null>(null); - - return ( - <> - {popup && } - { - formikHelpers.setSubmitting(true); - - const { message, isSuccess, response } = await submitConnector( - { - name: nameBuilder ? nameBuilder(values) : existingConnector.name, - source: existingConnector.source, - input_type: existingConnector.input_type, - connector_specific_config: values, - refresh_freq: existingConnector.refresh_freq, - prune_freq: existingConnector.prune_freq, - indexing_start: existingConnector.indexing_start, - }, - existingConnector.id - ); - - setPopup({ message, type: isSuccess ? "success" : "error" }); - formikHelpers.setSubmitting(false); - if (isSuccess) { - formikHelpers.resetForm(); - } - setTimeout(() => { - setPopup(null); - }, 4000); - if (onSubmit) { - onSubmit(isSuccess, response); - } - }} - > - {({ isSubmitting, values }) => ( -
- {formBody ? formBody : formBodyBuilder && formBodyBuilder(values)} -
- -
-
- )} -
- - ); -} diff --git a/web/src/components/credentials/CredentialSection.tsx b/web/src/components/credentials/CredentialSection.tsx index d26cc78c2..a6002fd9a 100644 --- a/web/src/components/credentials/CredentialSection.tsx +++ b/web/src/components/credentials/CredentialSection.tsx @@ -28,7 +28,6 @@ import { ConfluenceCredentialJson, Credential, } from "@/lib/connectors/credentials"; -import { UserGroup } from "@/lib/types"; // Added this import export default function CredentialSection({ ccPair, diff --git a/web/src/components/credentials/actions/CreateCredential.tsx b/web/src/components/credentials/actions/CreateCredential.tsx index 0a5d3cb23..5188f3d02 100644 --- a/web/src/components/credentials/actions/CreateCredential.tsx +++ b/web/src/components/credentials/actions/CreateCredential.tsx @@ -232,7 +232,7 @@ export default function CreateCredential({ setShowAdvancedOptions={setShowAdvancedOptions} /> )} - {showAdvancedOptions && ( + {(showAdvancedOptions || !isAdmin) && ( +> = { + google_drive: { + customer_id: { + label: "Google Workspace Customer ID", + subtext: ( + <> + The unique identifier for your Google Workspace account. To find this, + checkout the{" "} + + guide from Google + + . + + ), + }, + company_domain: { + label: "Google Workspace Company Domain", + subtext: ( + <> + The email domain for your Google Workspace account. +
+
+ For example, if your email provided through Google Workspace looks + something like chris@danswer.ai, then your company domain is{" "} + danswer.ai + + ), + }, + }, +}; diff --git a/web/src/lib/connectors/connectors.ts b/web/src/lib/connectors/connectors.ts index 7e56a498d..b526818b0 100644 --- a/web/src/lib/connectors/connectors.ts +++ b/web/src/lib/connectors/connectors.ts @@ -29,6 +29,7 @@ export interface Option { export interface SelectOption extends Option { type: "select"; options?: StringWithDescription[]; + default?: string; } export interface ListOption extends Option { @@ -599,7 +600,7 @@ For example, specifying .*-support.* as a "channel" will cause the connector to { name: "articles", value: "articles" }, { name: "tickets", value: "tickets" }, ], - default: 0, + default: "articles", }, ], }, diff --git a/web/src/lib/credential.ts b/web/src/lib/credential.ts index 03f6c6e75..c65968047 100644 --- a/web/src/lib/credential.ts +++ b/web/src/lib/credential.ts @@ -1,4 +1,5 @@ import { CredentialBase } from "./connectors/credentials"; +import { AccessType } from "@/lib/types"; export async function createCredential(credential: CredentialBase) { return await fetch(`/api/manage/credential`, { @@ -47,8 +48,9 @@ export function linkCredential( connectorId: number, credentialId: number, name?: string, - isPublic?: boolean, - groups?: number[] + accessType?: AccessType, + groups?: number[], + autoSyncOptions?: Record ) { return fetch( `/api/manage/connector/${connectorId}/credential/${credentialId}`, @@ -59,8 +61,9 @@ export function linkCredential( }, body: JSON.stringify({ name: name || null, - is_public: isPublic !== undefined ? isPublic : true, + access_type: accessType !== undefined ? accessType : "public", groups: groups || null, + auto_sync_options: autoSyncOptions || null, }), } ); diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 936e4e6c8..bf342ca40 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -49,6 +49,7 @@ export type ValidStatuses = | "not_started"; export type TaskStatus = "PENDING" | "STARTED" | "SUCCESS" | "FAILURE"; export type Feedback = "like" | "dislike"; +export type AccessType = "public" | "private" | "sync"; export type SessionType = "Chat" | "Search" | "Slack"; export interface DocumentBoostStatus { @@ -90,7 +91,7 @@ export interface ConnectorIndexingStatus< cc_pair_status: ConnectorCredentialPairStatus; connector: Connector; credential: Credential; - public_doc: boolean; + access_type: AccessType; owner: string; groups: number[]; last_finished_status: ValidStatuses | null; @@ -258,3 +259,7 @@ export type ConfigurableSources = Exclude< ValidSources, "not_applicable" | "ingestion_api" >; + +// The sources that have auto-sync support on the backend +export const validAutoSyncSources = ["google_drive"] as const; +export type ValidAutoSyncSources = (typeof validAutoSyncSources)[number];