mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 19:38:19 +02:00
Added permission syncing (#2340)
* Added permission syncing on the backend * Rewored to work with celery alembic fix fixed test * frontend changes * got groups working * added comments and fixed public docs * fixed merge issues * frontend complete! * frontend cleanup and mypy fixes * refactored connector access_type selection * mypy fixes * minor refactor and frontend improvements * get to fetch * renames and comments * minor change to var names * got curator stuff working * addressed pablo's comments * refactored user_external_group to reference users table * implemented polling * small refactor * fixed a whoopsies on the frontend * added scripts to seed dummy docs and test query times * fixed frontend build issue * alembic fix * handled is_public overlap * yuhong feedback * added more checks for sync * black * mypy * fixed circular import * todos * alembic fix * alembic
This commit is contained in:
parent
ef104e9a82
commit
2274cab554
@ -1,7 +1,7 @@
|
||||
"""Add last synced and last modified to document table
|
||||
|
||||
Revision ID: 52a219fb5233
|
||||
Revises: f17bf3b0d9f1
|
||||
Revises: f7e58d357687
|
||||
Create Date: 2024-08-28 17:40:46.077470
|
||||
|
||||
"""
|
||||
|
162
backend/alembic/versions/61ff3651add4_add_permission_syncing.py
Normal file
162
backend/alembic/versions/61ff3651add4_add_permission_syncing.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""Add Permission Syncing
|
||||
|
||||
Revision ID: 61ff3651add4
|
||||
Revises: 1b8206b29c5d
|
||||
Create Date: 2024-09-05 13:57:11.770413
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "61ff3651add4"
|
||||
down_revision = "1b8206b29c5d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Admin user who set up connectors will lose access to the docs temporarily
|
||||
# only way currently to give back access is to rerun from beginning
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"access_type",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false"
|
||||
)
|
||||
op.alter_column("connector_credential_pair", "access_type", nullable=False)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"auto_sync_options",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.drop_column("connector_credential_pair", "is_public")
|
||||
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"user__external_user_group_id",
|
||||
sa.Column(
|
||||
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
|
||||
),
|
||||
sa.Column("external_user_group_id", sa.String(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("user_id"),
|
||||
)
|
||||
|
||||
op.drop_column("external_permission", "user_id")
|
||||
op.drop_column("email_to_external_user_cache", "user_id")
|
||||
op.drop_table("permission_sync_run")
|
||||
op.drop_table("external_permission")
|
||||
op.drop_table("email_to_external_user_cache")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_public", sa.BOOLEAN(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')"
|
||||
)
|
||||
op.alter_column("connector_credential_pair", "is_public", nullable=False)
|
||||
|
||||
op.drop_column("connector_credential_pair", "auto_sync_options")
|
||||
op.drop_column("connector_credential_pair", "access_type")
|
||||
op.drop_column("connector_credential_pair", "last_time_perm_sync")
|
||||
op.drop_column("document", "external_user_emails")
|
||||
op.drop_column("document", "external_user_group_ids")
|
||||
op.drop_column("document", "is_public")
|
||||
|
||||
op.drop_table("user__external_user_group_id")
|
||||
|
||||
# Drop the enum type at the end of the downgrade
|
||||
op.create_table(
|
||||
"permission_sync_run",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("update_type", sa.String(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"external_permission",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("external_permission_group", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"email_to_external_user_cache",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_user_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
@ -1,7 +1,7 @@
|
||||
"""standard answer match_regex flag
|
||||
|
||||
Revision ID: efb35676026c
|
||||
Revises: 52a219fb5233
|
||||
Revises: 0ebb1d516877
|
||||
Create Date: 2024-09-11 13:55:46.101149
|
||||
|
||||
"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""add has_web_login column to user
|
||||
|
||||
Revision ID: f7e58d357687
|
||||
Revises: bceb1e139447
|
||||
Revises: ba98eba0f66a
|
||||
Create Date: 2024-09-07 20:20:54.522620
|
||||
|
||||
"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_user_email
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_access_info_for_document
|
||||
from danswer.db.document import get_access_info_for_documents
|
||||
@ -18,10 +18,13 @@ def _get_access_for_document(
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
if not info:
|
||||
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
|
||||
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
|
||||
return DocumentAccess.build(
|
||||
user_emails=info[1] if info and info[1] else [],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=info[2] if info else False,
|
||||
)
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
@ -34,6 +37,16 @@ def get_access_for_document(
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
return DocumentAccess(
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
is_public=False,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
)
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
@ -42,13 +55,27 @@ def _get_access_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(
|
||||
user_ids=user_ids, user_groups=[], is_public=is_public
|
||||
doc_access = {
|
||||
document_id: DocumentAccess(
|
||||
user_emails=set([email for email in user_emails if email]),
|
||||
# MIT version will wipe all groups and external groups on update
|
||||
user_groups=set(),
|
||||
is_public=is_public,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
for document_id, user_emails, is_public in document_access_info
|
||||
}
|
||||
|
||||
# Sometimes the document has not be indexed by the indexing job yet, in those cases
|
||||
# the document does not exist and so we use least permissive. Specifically the EE version
|
||||
# checks the MIT version permissions and creates a superset. This ensures that this flow
|
||||
# does not fail even if the Document has not yet been indexed.
|
||||
for doc_id in document_ids:
|
||||
if doc_id not in doc_access:
|
||||
doc_access[doc_id] = get_null_document_access()
|
||||
return doc_access
|
||||
|
||||
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
@ -70,7 +97,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
matches one entry in the returned set.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
return {PUBLIC_DOC_PAT}
|
||||
|
||||
|
||||
|
@ -1,30 +1,72 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_external_group
|
||||
from danswer.access.utils import prefix_user_email
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess:
|
||||
user_ids: set[str] # stringified UUIDs
|
||||
user_groups: set[str] # names of user groups associated with this document
|
||||
class ExternalAccess:
|
||||
# Emails of external users with access to the doc externally
|
||||
external_user_emails: set[str]
|
||||
# Names or external IDs of groups with access to the doc
|
||||
external_user_group_ids: set[str]
|
||||
# Whether the document is public in the external system or Danswer
|
||||
is_public: bool
|
||||
|
||||
def to_acl(self) -> list[str]:
|
||||
return (
|
||||
[prefix_user(user_id) for user_id in self.user_ids]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Danswer users, None indicates admin
|
||||
user_emails: set[str | None]
|
||||
# Names of user groups associated with this document
|
||||
user_groups: set[str]
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
return set(
|
||||
[
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.user_emails
|
||||
if user_email
|
||||
]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ [
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.external_user_emails
|
||||
]
|
||||
+ [
|
||||
# The group names are already prefixed by the source type
|
||||
# This adds an additional prefix of "external_group:"
|
||||
prefix_external_group(group_name)
|
||||
for group_name in self.external_user_group_ids
|
||||
]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
|
||||
cls,
|
||||
user_emails: list[str | None],
|
||||
user_groups: list[str],
|
||||
external_user_emails: list[str],
|
||||
external_user_group_ids: list[str],
|
||||
is_public: bool,
|
||||
) -> "DocumentAccess":
|
||||
return cls(
|
||||
user_ids={str(user_id) for user_id in user_ids if user_id},
|
||||
external_user_emails={
|
||||
prefix_user_email(external_email)
|
||||
for external_email in external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
prefix_external_group(external_group_id)
|
||||
for external_group_id in external_user_group_ids
|
||||
},
|
||||
user_emails={
|
||||
prefix_user_email(user_email)
|
||||
for user_email in user_emails
|
||||
if user_email
|
||||
},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
@ -1,10 +1,24 @@
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
def prefix_user_email(user_email: str) -> str:
|
||||
"""Prefixes a user email to eliminate collision with group names.
|
||||
This applies to both a Danswer user and an External user, this is to make the query time
|
||||
more efficient"""
|
||||
return f"user_email:{user_email}"
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
"""Prefixes a user group name to eliminate collision with user emails.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
||||
|
||||
|
||||
def prefix_external_group(ext_group_name: str) -> str:
|
||||
"""Prefixes an external group name to eliminate collision with user emails / Danswer groups."""
|
||||
return f"external_group:{ext_group_name}"
|
||||
|
||||
|
||||
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""External groups may collide across sources, every source needs its own prefix."""
|
||||
return f"{source.value.upper()}_{ext_group_name}"
|
||||
|
@ -128,11 +128,11 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
return
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
cc_pair.connector.source,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
db_session,
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.PRUNE,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
|
@ -56,11 +56,11 @@ def _get_connector_runner(
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector_credential_pair.connector.source,
|
||||
task,
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
db_session,
|
||||
db_session=db_session,
|
||||
source=attempt.connector_credential_pair.connector.source,
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
@ -124,11 +124,11 @@ def identify_connector_class(
|
||||
|
||||
|
||||
def instantiate_connector(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
db_session: Session,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
connector = connector_class(**connector_specific_config)
|
||||
|
@ -6,7 +6,6 @@ from datetime import timezone
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
@ -21,19 +20,13 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds_for_authorized_user,
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds_for_service_account,
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@ -407,42 +400,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(
|
||||
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||
)
|
||||
creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = creds.to_json() if creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
if DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
creds = get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
|
||||
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - unknown credential structure."
|
||||
)
|
||||
|
||||
creds, new_creds_dict = get_google_drive_creds(credentials)
|
||||
self.creds = creds
|
||||
return new_creds_dict
|
||||
|
||||
@ -509,6 +467,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
file["modifiedTime"]
|
||||
).astimezone(timezone.utc),
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -10,11 +10,13 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import BASE_SCOPES
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
@ -22,7 +24,8 @@ from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.google_drive.constants import SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
@ -34,15 +37,25 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_gdrive_scopes() -> list[str]:
|
||||
base_scopes: list[str] = BASE_SCOPES
|
||||
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
|
||||
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
return base_scopes + permissions_scopes + groups_scopes
|
||||
return base_scopes + permissions_scopes
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect() -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
|
||||
|
||||
def get_google_drive_creds_for_authorized_user(
|
||||
token_json_str: str,
|
||||
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
@ -59,18 +72,67 @@ def get_google_drive_creds_for_authorized_user(
|
||||
return None
|
||||
|
||||
|
||||
def get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str,
|
||||
def _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> ServiceAccountCredentials | None:
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=SCOPES
|
||||
service_account_key, scopes=scopes
|
||||
)
|
||||
if not creds.valid or not creds.expired:
|
||||
creds.refresh(Request())
|
||||
return creds if creds.valid else None
|
||||
|
||||
|
||||
def get_google_drive_creds(
|
||||
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
oauth_creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str, scopes=scopes
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
service_creds = _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
service_creds = (
|
||||
service_creds.with_subject(delegated_user_email)
|
||||
if service_creds
|
||||
else None
|
||||
)
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = (
|
||||
oauth_creds or service_creds
|
||||
)
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - unknown credential structure."
|
||||
)
|
||||
|
||||
return creds, new_creds_dict
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
@ -84,7 +146,7 @@ def get_auth_url(credential_id: int) -> str:
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=SCOPES,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
@ -107,7 +169,7 @@ def update_credential_access_tokens(
|
||||
app_credentials = get_google_app_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
scopes=SCOPES,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
|
@ -1,7 +1,7 @@
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
]
|
||||
|
||||
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
|
||||
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]
|
||||
|
@ -113,6 +113,9 @@ class DocumentBase(BaseModel):
|
||||
# The default title is semantic_identifier though unless otherwise specified
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
# Anything else that may be useful that is specific to this particular connector type that other
|
||||
# parts of the code may need. If you're unsure, this can be left as None
|
||||
additional_info: Any = None
|
||||
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
|
@ -211,7 +211,7 @@ def handle_message(
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if message_info.email:
|
||||
add_non_web_user_if_not_exists(message_info.email, db_session)
|
||||
add_non_web_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
used_standard_answer = handle_standard_answers(
|
||||
|
@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
@ -24,6 +25,10 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
check_if_valid_sync_source,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -74,7 +79,7 @@ def _add_user_filters(
|
||||
.correlate(ConnectorCredentialPair)
|
||||
)
|
||||
else:
|
||||
where_clause |= ConnectorCredentialPair.is_public == True # noqa: E712
|
||||
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@ -94,8 +99,7 @@ def get_connector_credential_pairs(
|
||||
) # noqa
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
results = db_session.scalars(stmt)
|
||||
return list(results.all())
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def add_deletion_failure_message(
|
||||
@ -309,9 +313,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
association = ConnectorCredentialPair(
|
||||
connector_id=0,
|
||||
credential_id=0,
|
||||
access_type=AccessType.PUBLIC,
|
||||
name="DefaultCCPair",
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
is_public=True,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.commit()
|
||||
@ -336,8 +340,9 @@ def add_credential_to_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
cc_pair_name: str | None,
|
||||
is_public: bool,
|
||||
access_type: AccessType,
|
||||
groups: list[int] | None,
|
||||
auto_sync_options: dict | None = None,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
@ -345,6 +350,13 @@ def add_credential_to_connector(
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if access_type == AccessType.SYNC:
|
||||
if not check_if_valid_sync_source(connector.source):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Connector of type {connector.source} does not support SYNC access type",
|
||||
)
|
||||
|
||||
if credential is None:
|
||||
error_msg = (
|
||||
f"Credential {credential_id} does not exist or does not belong to user"
|
||||
@ -375,12 +387,13 @@ def add_credential_to_connector(
|
||||
credential_id=credential_id,
|
||||
name=cc_pair_name,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
is_public=is_public,
|
||||
access_type=access_type,
|
||||
auto_sync_options=auto_sync_options,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.flush() # make sure the association has an id
|
||||
|
||||
if groups:
|
||||
if groups and access_type != AccessType.SYNC:
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
@ -423,6 +436,10 @@ def remove_credential_from_connector(
|
||||
)
|
||||
|
||||
if association is not None:
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
)
|
||||
db_session.delete(association)
|
||||
db_session.commit()
|
||||
return StatusResponse(
|
||||
|
@ -4,7 +4,6 @@ from collections.abc import Generator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
@ -17,14 +16,17 @@ from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.engine.util import TransactionalContext
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import null
|
||||
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.db.tag import delete_document_tags_for_documents__no_commit
|
||||
from danswer.db.utils import model_to_dict
|
||||
from danswer.document_index.interfaces import DocumentMetadata
|
||||
@ -186,16 +188,14 @@ def get_document_counts_for_cc_pairs(
|
||||
def get_access_info_for_document(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> tuple[str, list[UUID | None], bool] | None:
|
||||
) -> tuple[str, list[str | None], bool] | None:
|
||||
"""Gets access info for a single document by calling the get_access_info_for_documents function
|
||||
and passing a list with a single document ID.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use.
|
||||
document_id (str): The document ID to fetch access info for.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs,
|
||||
Optional[Tuple[str, List[str | None], bool]]: A tuple containing the document ID, a list of user emails,
|
||||
and a boolean indicating if the document is globally public, or None if no results are found.
|
||||
"""
|
||||
results = get_access_info_for_documents(db_session, [document_id])
|
||||
@ -208,19 +208,27 @@ def get_access_info_for_document(
|
||||
def get_access_info_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, list[UUID | None], bool]]:
|
||||
) -> Sequence[tuple[str, list[str | None], bool]]:
|
||||
"""Gets back all relevant access info for the given documents. This includes
|
||||
the user_ids for cc pairs that the document is associated with + whether any
|
||||
of the associated cc pairs are intending to make the document globally public.
|
||||
Returns the list where each element contains:
|
||||
- Document ID (which is also the ID of the DocumentByConnectorCredentialPair)
|
||||
- List of emails of Danswer users with direct access to the doc (includes a "None" element if
|
||||
the connector was set up by an admin when auth was off
|
||||
- bool for whether the document is public (the document later can also be marked public by
|
||||
automatic permission sync step)
|
||||
"""
|
||||
stmt = select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.array_agg(func.coalesce(User.email, null())).label("user_emails"),
|
||||
func.bool_or(ConnectorCredentialPair.access_type == AccessType.PUBLIC).label(
|
||||
"public_doc"
|
||||
),
|
||||
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.array_agg(Credential.user_id).label("user_ids"),
|
||||
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
|
||||
)
|
||||
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
.join(
|
||||
stmt.join(
|
||||
Credential,
|
||||
DocumentByConnectorCredentialPair.credential_id == Credential.id,
|
||||
)
|
||||
@ -233,6 +241,13 @@ def get_access_info_for_documents(
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.outerjoin(
|
||||
User,
|
||||
and_(
|
||||
Credential.user_id == User.id,
|
||||
ConnectorCredentialPair.access_type != AccessType.SYNC,
|
||||
),
|
||||
)
|
||||
# don't include CC pairs that are being deleted
|
||||
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
|
||||
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
|
||||
@ -278,9 +293,19 @@ def upsert_documents(
|
||||
for doc in seen_documents.values()
|
||||
]
|
||||
)
|
||||
# for now, there are no columns to update. If more metadata is added, then this
|
||||
# needs to change to an `on_conflict_do_update`
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
|
||||
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["id"], # Conflict target
|
||||
set_={
|
||||
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
|
||||
"boost": insert_stmt.excluded.boost,
|
||||
"hidden": insert_stmt.excluded.hidden,
|
||||
"semantic_id": insert_stmt.excluded.semantic_id,
|
||||
"link": insert_stmt.excluded.link,
|
||||
"primary_owners": insert_stmt.excluded.primary_owners,
|
||||
"secondary_owners": insert_stmt.excluded.secondary_owners,
|
||||
},
|
||||
)
|
||||
db_session.execute(on_conflict_stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
@ -180,7 +181,7 @@ def _check_if_cc_pairs_are_owned_by_groups(
|
||||
ids=missing_cc_pair_ids,
|
||||
)
|
||||
for cc_pair in cc_pairs:
|
||||
if not cc_pair.is_public:
|
||||
if cc_pair.access_type != AccessType.PUBLIC:
|
||||
raise ValueError(
|
||||
f"Connector Credential Pair with ID: '{cc_pair.id}'"
|
||||
" is not owned by the specified groups"
|
||||
@ -704,7 +705,7 @@ def check_document_sets_are_public(
|
||||
ConnectorCredentialPair.id.in_(
|
||||
connector_credential_pair_ids # type:ignore
|
||||
),
|
||||
ConnectorCredentialPair.is_public.is_(False),
|
||||
ConnectorCredentialPair.access_type != AccessType.PUBLIC,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
|
@ -51,3 +51,9 @@ class ConnectorCredentialPairStatus(str, PyEnum):
|
||||
|
||||
def is_active(self) -> bool:
|
||||
return self == ConnectorCredentialPairStatus.ACTIVE
|
||||
|
||||
|
||||
class AccessType(str, PyEnum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
SYNC = "sync"
|
||||
|
@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import ChatMessageFeedback
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document as DbDocument
|
||||
@ -94,7 +95,7 @@ def _add_user_filters(
|
||||
.correlate(CCPair)
|
||||
)
|
||||
else:
|
||||
where_clause |= CCPair.is_public == True # noqa: E712
|
||||
where_clause |= CCPair.access_type == AccessType.PUBLIC
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
@ -39,6 +39,7 @@ from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.constants import TokenRateLimitScope
|
||||
@ -388,10 +389,20 @@ class ConnectorCredentialPair(Base):
|
||||
# controls whether the documents indexed by this CC pair are visible to all
|
||||
# or if they are only visible to those with that are given explicit access
|
||||
# (e.g. via owning the credential or being a part of a group that is given access)
|
||||
is_public: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=True,
|
||||
nullable=False,
|
||||
access_type: Mapped[AccessType] = mapped_column(
|
||||
Enum(AccessType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
# special info needed for the auto-sync feature. The exact structure depends on the
|
||||
|
||||
# source type (defined in the connector's `source` field)
|
||||
# E.g. for google_drive perm sync:
|
||||
# {"customer_id": "123567", "company_domain": "@danswer.ai"}
|
||||
auto_sync_options: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
# Time finished, not used for calculating backend jobs which uses time started (created)
|
||||
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
@ -422,6 +433,7 @@ class ConnectorCredentialPair(Base):
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Danswer)
|
||||
@ -465,7 +477,18 @@ class Document(Base):
|
||||
secondary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# TODO if more sensitive data is added here for display, make sure to add user/group permission
|
||||
# Permission sync columns
|
||||
# Email addresses are saved at the document level for externally synced permissions
|
||||
# This is becuase the normal flow of assigning permissions is through the cc_pair
|
||||
# doesn't apply here
|
||||
external_user_emails: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# These group ids have been prefixed by the source type
|
||||
external_user_group_ids: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
@ -1674,95 +1697,18 @@ class StandardAnswer(Base):
|
||||
"""Tables related to Permission Sync"""
|
||||
|
||||
|
||||
class PermissionSyncStatus(str, PyEnum):
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PermissionSyncJobType(str, PyEnum):
|
||||
USER_LEVEL = "user_level"
|
||||
GROUP_LEVEL = "group_level"
|
||||
|
||||
|
||||
class PermissionSyncRun(Base):
|
||||
"""Represents one run of a permission sync job. For some given cc_pair, it is either sync-ing
|
||||
the users or it is sync-ing the groups"""
|
||||
|
||||
__tablename__ = "permission_sync_run"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
# Not strictly needed but makes it easy to use without fetching from cc_pair
|
||||
source_type: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
# Currently all sync jobs are handled as a group permission sync or a user permission sync
|
||||
update_type: Mapped[PermissionSyncJobType] = mapped_column(
|
||||
Enum(PermissionSyncJobType)
|
||||
)
|
||||
cc_pair_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), nullable=True
|
||||
)
|
||||
status: Mapped[PermissionSyncStatus] = mapped_column(Enum(PermissionSyncStatus))
|
||||
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
cc_pair: Mapped[ConnectorCredentialPair] = relationship("ConnectorCredentialPair")
|
||||
|
||||
|
||||
class ExternalPermission(Base):
|
||||
class User__ExternalUserGroupId(Base):
|
||||
"""Maps user info both internal and external to the name of the external group
|
||||
This maps the user to all of their external groups so that the external group name can be
|
||||
attached to the ACL list matching during query time. User level permissions can be handled by
|
||||
directly adding the Danswer user to the doc ACL list"""
|
||||
|
||||
__tablename__ = "external_permission"
|
||||
__tablename__ = "user__external_user_group_id"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# Email is needed because we want to keep track of users not in Danswer to simplify process
|
||||
# when the user joins
|
||||
user_email: Mapped[str] = mapped_column(String)
|
||||
source_type: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
external_permission_group: Mapped[str] = mapped_column(String)
|
||||
user = relationship("User")
|
||||
|
||||
|
||||
class EmailToExternalUserCache(Base):
|
||||
"""A way to map users IDs in the external tool to a user in Danswer or at least an email for
|
||||
when the user joins. Used as a cache for when fetching external groups which have their own
|
||||
user ids, this can easily be mapped back to users already known in Danswer without needing
|
||||
to call external APIs to get the user emails.
|
||||
|
||||
This way when groups are updated in the external tool and we need to update the mapping of
|
||||
internal users to the groups, we can sync the internal users to the external groups they are
|
||||
part of using this.
|
||||
|
||||
Ie. User Chris is part of groups alpha, beta, and we can update this if Chris is no longer
|
||||
part of alpha in some external tool.
|
||||
"""
|
||||
|
||||
__tablename__ = "email_to_external_user_cache"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_user_id: Mapped[str] = mapped_column(String)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# Email is needed because we want to keep track of users not in Danswer to simplify process
|
||||
# when the user joins
|
||||
user_email: Mapped[str] = mapped_column(String)
|
||||
source_type: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
|
||||
user = relationship("User")
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
# These group ids have been prefixed by the source type
|
||||
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
|
||||
|
||||
|
||||
class UsageReport(Base):
|
||||
|
@ -22,6 +22,17 @@ def list_users(
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def get_users_by_emails(
|
||||
db_session: Session, emails: list[str]
|
||||
) -> tuple[list[User], list[str]]:
|
||||
# Use distinct to avoid duplicates
|
||||
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
|
||||
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
|
||||
found_users_emails = [user.email for user in found_users]
|
||||
missing_user_emails = [email for email in emails if email not in found_users_emails]
|
||||
return found_users, missing_user_emails
|
||||
|
||||
|
||||
def get_user_by_email(email: str, db_session: Session) -> User | None:
|
||||
user = db_session.query(User).filter(User.email == email).first() # type: ignore
|
||||
|
||||
@ -34,20 +45,50 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
|
||||
return user
|
||||
|
||||
|
||||
def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
def _generate_non_web_user(email: str) -> User:
|
||||
fastapi_users_pw_helper = PasswordHelper()
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
user = User(
|
||||
return User(
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
has_web_login=False,
|
||||
role=UserRole.BASIC,
|
||||
)
|
||||
|
||||
|
||||
def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
user = _generate_non_web_user(email=email)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
user = _generate_non_web_user(email=email)
|
||||
db_session.add(user)
|
||||
db_session.flush() # generate id
|
||||
return user
|
||||
|
||||
|
||||
def batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session: Session, emails: list[str]
|
||||
) -> list[User]:
|
||||
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
|
||||
|
||||
new_users: list[User] = []
|
||||
for email in missing_user_emails:
|
||||
new_users.append(_generate_non_web_user(email=email))
|
||||
|
||||
db_session.add_all(new_users)
|
||||
db_session.flush() # generate ids
|
||||
|
||||
return found_users + new_users
|
||||
|
@ -265,7 +265,13 @@ def index_doc_batch(
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
|
||||
no_access = DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
ctx = index_doc_batch_prepare(
|
||||
document_batch=document_batch,
|
||||
|
@ -18,6 +18,7 @@ from danswer.db.connector_credential_pair import (
|
||||
)
|
||||
from danswer.db.document import get_document_counts_for_cc_pairs
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
@ -201,7 +202,7 @@ def associate_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=metadata.groups,
|
||||
object_is_public=metadata.is_public,
|
||||
object_is_public=metadata.access_type == AccessType.PUBLIC,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -211,7 +212,8 @@ def associate_credential_to_connector(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
cc_pair_name=metadata.name,
|
||||
is_public=True if metadata.is_public is None else metadata.is_public,
|
||||
access_type=metadata.access_type,
|
||||
auto_sync_options=metadata.auto_sync_options,
|
||||
groups=metadata.groups,
|
||||
)
|
||||
|
||||
|
@ -64,6 +64,7 @@ from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import get_document_counts_for_cc_pairs
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
@ -559,7 +560,7 @@ def get_connector_indexing_status(
|
||||
cc_pair_status=cc_pair.status,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(connector),
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
public_doc=cc_pair.is_public,
|
||||
access_type=cc_pair.access_type,
|
||||
owner=credential.user.email if credential.user else "",
|
||||
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
|
||||
last_finished_status=(
|
||||
@ -668,12 +669,15 @@ def create_connector_with_mock_credential(
|
||||
credential = create_credential(
|
||||
mock_credential, user=user, db_session=db_session
|
||||
)
|
||||
access_type = (
|
||||
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
|
||||
)
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=cast(int, connector_response.id), # will aways be an int
|
||||
credential_id=credential.id,
|
||||
is_public=connector_data.is_public or False,
|
||||
access_type=access_type,
|
||||
cc_pair_name=connector_data.name,
|
||||
groups=connector_data.groups,
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import DocumentErrorSummary
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
@ -218,7 +219,7 @@ class CCPairFullInfo(BaseModel):
|
||||
number_of_index_attempts: int
|
||||
last_index_attempt_status: IndexingStatus | None
|
||||
latest_deletion_attempt: DeletionAttemptSnapshot | None
|
||||
is_public: bool
|
||||
access_type: AccessType
|
||||
is_editable_for_current_user: bool
|
||||
deletion_failure_message: str | None
|
||||
|
||||
@ -261,7 +262,7 @@ class CCPairFullInfo(BaseModel):
|
||||
number_of_index_attempts=number_of_index_attempts,
|
||||
last_index_attempt_status=last_indexing_status,
|
||||
latest_deletion_attempt=latest_deletion_attempt,
|
||||
is_public=cc_pair_model.is_public,
|
||||
access_type=cc_pair_model.access_type,
|
||||
is_editable_for_current_user=is_editable_for_current_user,
|
||||
deletion_failure_message=cc_pair_model.deletion_failure_message,
|
||||
)
|
||||
@ -288,7 +289,7 @@ class ConnectorIndexingStatus(BaseModel):
|
||||
credential: CredentialSnapshot
|
||||
owner: str
|
||||
groups: list[int]
|
||||
public_doc: bool
|
||||
access_type: AccessType
|
||||
last_finished_status: IndexingStatus | None
|
||||
last_status: IndexingStatus | None
|
||||
last_success: datetime | None
|
||||
@ -306,7 +307,8 @@ class ConnectorCredentialPairIdentifier(BaseModel):
|
||||
|
||||
class ConnectorCredentialPairMetadata(BaseModel):
|
||||
name: str | None = None
|
||||
is_public: bool | None = None
|
||||
access_type: AccessType
|
||||
auto_sync_options: dict[str, Any] | None = None
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
@ -47,7 +47,6 @@ class DocumentSet(BaseModel):
|
||||
description: str
|
||||
cc_pair_descriptors: list[ConnectorCredentialPairDescriptor]
|
||||
is_up_to_date: bool
|
||||
contains_non_public: bool
|
||||
is_public: bool
|
||||
# For Private Document Sets, who should be able to access these
|
||||
users: list[UUID]
|
||||
@ -59,12 +58,6 @@ class DocumentSet(BaseModel):
|
||||
id=document_set_model.id,
|
||||
name=document_set_model.name,
|
||||
description=document_set_model.description,
|
||||
contains_non_public=any(
|
||||
[
|
||||
not cc_pair.is_public
|
||||
for cc_pair in document_set_model.connector_credential_pairs
|
||||
]
|
||||
),
|
||||
cc_pair_descriptors=[
|
||||
ConnectorCredentialPairDescriptor(
|
||||
id=cc_pair.id,
|
||||
|
@ -18,6 +18,9 @@ class TestEmbeddingRequest(BaseModel):
|
||||
api_url: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(BaseModel):
|
||||
provider_type: EmbeddingProvider
|
||||
|
@ -48,6 +48,7 @@ from danswer.server.models import InvitedUserSnapshot
|
||||
from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.api_key import is_api_key_email_address
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
|
||||
from ee.danswer.db.user_group import remove_curator_status__no_commit
|
||||
|
||||
logger = setup_logger()
|
||||
@ -243,6 +244,11 @@ async def delete_user(
|
||||
for oauth_account in user_to_delete.oauth_accounts:
|
||||
db_session.delete(oauth_account)
|
||||
|
||||
delete_user__ext_group_for_user__no_commit(
|
||||
db_session=db_session,
|
||||
user_id=user_to_delete.id,
|
||||
)
|
||||
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
|
@ -5,8 +5,11 @@ from danswer.access.access import (
|
||||
)
|
||||
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_external_group
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.db.document import get_documents_by_ids
|
||||
from danswer.db.models import User
|
||||
from ee.danswer.db.external_perm import fetch_external_groups_for_user
|
||||
from ee.danswer.db.user_group import fetch_user_groups_for_documents
|
||||
from ee.danswer.db.user_group import fetch_user_groups_for_user
|
||||
|
||||
@ -17,7 +20,13 @@ def _get_access_for_document(
|
||||
) -> DocumentAccess:
|
||||
id_to_access = _get_access_for_documents([document_id], db_session)
|
||||
if len(id_to_access) == 0:
|
||||
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
return DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
return next(iter(id_to_access.values()))
|
||||
|
||||
@ -30,22 +39,48 @@ def _get_access_for_documents(
|
||||
document_ids=document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
user_group_info = {
|
||||
user_group_info: dict[str, list[str]] = {
|
||||
document_id: group_names
|
||||
for document_id, group_names in fetch_user_groups_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
}
|
||||
documents = get_documents_by_ids(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
doc_id_map = {doc.id: doc for doc in documents}
|
||||
|
||||
return {
|
||||
document_id: DocumentAccess(
|
||||
user_ids=non_ee_access.user_ids,
|
||||
user_groups=user_group_info.get(document_id, []), # type: ignore
|
||||
is_public=non_ee_access.is_public,
|
||||
access_map = {}
|
||||
for document_id, non_ee_access in non_ee_access_dict.items():
|
||||
document = doc_id_map[document_id]
|
||||
|
||||
ext_u_emails = (
|
||||
set(document.external_user_emails)
|
||||
if document.external_user_emails
|
||||
else set()
|
||||
)
|
||||
for document_id, non_ee_access in non_ee_access_dict.items()
|
||||
}
|
||||
|
||||
ext_u_groups = (
|
||||
set(document.external_user_group_ids)
|
||||
if document.external_user_group_ids
|
||||
else set()
|
||||
)
|
||||
|
||||
# If the document is determined to be "public" externally (through a SYNC connector)
|
||||
# then it's given the same access level as if it were marked public within Danswer
|
||||
is_public_anywhere = document.is_public or non_ee_access.is_public
|
||||
|
||||
# To avoid collisions of group namings between connectors, they need to be prefixed
|
||||
access_map[document_id] = DocumentAccess(
|
||||
user_emails=non_ee_access.user_emails,
|
||||
user_groups=set(user_group_info.get(document_id, [])),
|
||||
is_public=is_public_anywhere,
|
||||
external_user_emails=ext_u_emails,
|
||||
external_user_group_ids=ext_u_groups,
|
||||
)
|
||||
return access_map
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
@ -56,7 +91,20 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
|
||||
NOTE: is imported in danswer.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
user_groups = fetch_user_groups_for_user(db_session, user.id) if user else []
|
||||
return set(
|
||||
[prefix_user_group(user_group.name) for user_group in user_groups]
|
||||
).union(get_acl_for_user_without_groups(user, db_session))
|
||||
db_user_groups = fetch_user_groups_for_user(db_session, user.id) if user else []
|
||||
prefixed_user_groups = [
|
||||
prefix_user_group(db_user_group.name) for db_user_group in db_user_groups
|
||||
]
|
||||
|
||||
db_external_groups = (
|
||||
fetch_external_groups_for_user(db_session, user.id) if user else []
|
||||
)
|
||||
prefixed_external_groups = [
|
||||
prefix_external_group(db_external_group.external_user_group_id)
|
||||
for db_external_group in db_external_groups
|
||||
]
|
||||
|
||||
user_acl = set(prefixed_user_groups + prefixed_external_groups)
|
||||
user_acl.update(get_acl_for_user_without_groups(user, db_session))
|
||||
|
||||
return user_acl
|
||||
|
@ -11,7 +11,13 @@ from danswer.server.settings.store import load_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.danswer.background.celery_utils import should_perform_external_permissions_check
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_permission_sync_entrypoint,
|
||||
)
|
||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||
|
||||
logger = setup_logger()
|
||||
@ -20,6 +26,13 @@ logger = setup_logger()
|
||||
global_version.set_ee()
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(retention_limit_days: int) -> None:
|
||||
@ -30,6 +43,23 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_sync_external_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_permissions_task() -> None:
|
||||
"""Runs periodically to sync external permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
@ -64,7 +94,11 @@ def autogenerate_usage_report_task() -> None:
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"autogenerate-usage-report": {
|
||||
"sync-external-permissions": {
|
||||
"task": "check_sync_external_permissions_task",
|
||||
"schedule": timedelta(seconds=60), # TODO: optimize this
|
||||
},
|
||||
"autogenerate_usage_report": {
|
||||
"task": "autogenerate_usage_report_task",
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
|
@ -6,10 +6,13 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import name_sync_external_permissions_task
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
from ee.danswer.db.user_group import mark_user_group_as_synced
|
||||
@ -30,11 +33,30 @@ def should_perform_chat_ttl_check(
|
||||
return True
|
||||
|
||||
if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.info("TTL check is already being performed. Skipping.")
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_external_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
task_name = name_sync_external_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
return True
|
||||
|
||||
if check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
key = key_bytes.decode("utf-8")
|
||||
|
@ -1,224 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import dask
|
||||
from dask.distributed import Client
|
||||
from dask.distributed import Future
|
||||
from distributed import LocalCluster
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
from danswer.background.indexing.job_client import SimpleJob
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_PERMISSIONS_APP_NAME
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.models import PermissionSyncStatus
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import NUM_PERMISSION_WORKERS
|
||||
from ee.danswer.connectors.factory import CONNECTOR_PERMISSION_FUNC_MAP
|
||||
from ee.danswer.db.connector import fetch_sources_with_connectors
|
||||
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.danswer.db.permission_sync import create_perm_sync
|
||||
from ee.danswer.db.permission_sync import expire_perm_sync_timed_out
|
||||
from ee.danswer.db.permission_sync import get_perm_sync_attempt
|
||||
from ee.danswer.db.permission_sync import mark_all_inprogress_permission_sync_failed
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If the indexing dies, it's most likely due to resource constraints,
|
||||
# restarting just delays the eventual failure, not useful to the user
|
||||
dask.config.set({"distributed.scheduler.allowed-failures": 0})
|
||||
|
||||
|
||||
def cleanup_perm_sync_jobs(
|
||||
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
|
||||
# Just reusing the same timeout, fine for now
|
||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# clean up completed jobs
|
||||
for (attempt_id, details), job in existing_jobs.items():
|
||||
perm_sync_attempt = get_perm_sync_attempt(
|
||||
attempt_id=attempt_id, db_session=db_session
|
||||
)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if (
|
||||
not job.done()
|
||||
and perm_sync_attempt.status == PermissionSyncStatus.IN_PROGRESS
|
||||
):
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
logger.error(job.exception())
|
||||
|
||||
job.release()
|
||||
del existing_jobs_copy[(attempt_id, details)]
|
||||
|
||||
# clean up in-progress jobs that were never completed
|
||||
expire_perm_sync_timed_out(
|
||||
timeout_hours=timeout_hours,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def create_group_sync_jobs(
|
||||
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
|
||||
client: Client | SimpleJobClient,
|
||||
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
|
||||
"""Creates new relational DB group permission sync job for each source that:
|
||||
- has permission sync enabled
|
||||
- has at least 1 connector (enabled or paused)
|
||||
- has no sync already running
|
||||
"""
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
sources_w_runs = [
|
||||
key[1]
|
||||
for key in existing_jobs_copy.keys()
|
||||
if isinstance(key[1], DocumentSource)
|
||||
]
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
sources_w_connector = fetch_sources_with_connectors(db_session)
|
||||
for source_type in sources_w_connector:
|
||||
if source_type not in CONNECTOR_PERMISSION_FUNC_MAP:
|
||||
continue
|
||||
if source_type in sources_w_runs:
|
||||
continue
|
||||
|
||||
db_group_fnc, _ = CONNECTOR_PERMISSION_FUNC_MAP[source_type]
|
||||
perm_sync = create_perm_sync(
|
||||
source_type=source_type,
|
||||
group_update=True,
|
||||
cc_pair_id=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
run = client.submit(db_group_fnc, pure=False)
|
||||
|
||||
logger.info(
|
||||
f"Kicked off group permission sync for source type {source_type}"
|
||||
)
|
||||
|
||||
if run:
|
||||
existing_jobs_copy[(perm_sync.id, source_type)] = run
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def create_connector_perm_sync_jobs(
|
||||
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob],
|
||||
client: Client | SimpleJobClient,
|
||||
) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]:
|
||||
"""Update Document Index ACL sync job for each cc-pair where:
|
||||
- source type has permission sync enabled
|
||||
- has no sync already running
|
||||
"""
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
cc_pairs_w_runs = [
|
||||
key[1]
|
||||
for key in existing_jobs_copy.keys()
|
||||
if isinstance(key[1], DocumentSource)
|
||||
]
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
sources_w_connector = fetch_sources_with_connectors(db_session)
|
||||
for source_type in sources_w_connector:
|
||||
if source_type not in CONNECTOR_PERMISSION_FUNC_MAP:
|
||||
continue
|
||||
|
||||
_, index_sync_fnc = CONNECTOR_PERMISSION_FUNC_MAP[source_type]
|
||||
|
||||
cc_pairs = get_cc_pairs_by_source(source_type, db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if cc_pair.id in cc_pairs_w_runs:
|
||||
continue
|
||||
|
||||
perm_sync = create_perm_sync(
|
||||
source_type=source_type,
|
||||
group_update=False,
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
run = client.submit(index_sync_fnc, cc_pair.id, pure=False)
|
||||
|
||||
logger.info(f"Kicked off ACL sync for cc-pair {cc_pair.id}")
|
||||
|
||||
if run:
|
||||
existing_jobs_copy[(perm_sync.id, cc_pair.id)] = run
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def permission_loop(delay: int = 60, num_workers: int = NUM_PERMISSION_WORKERS) -> None:
|
||||
client: Client | SimpleJobClient
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
cluster_primary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
# there are warning about high memory usage + "Event loop unresponsive"
|
||||
# which are not relevant to us since our workers are expected to use a
|
||||
# lot of memory + involve CPU intensive tasks that will not relinquish
|
||||
# the event loop
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
client = Client(cluster_primary)
|
||||
if LOG_LEVEL.lower() == "debug":
|
||||
client.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client = SimpleJobClient(n_workers=num_workers)
|
||||
|
||||
existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob] = {}
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
with Session(engine) as db_session:
|
||||
# Any jobs still in progress on restart must have died
|
||||
mark_all_inprogress_permission_sync_failed(db_session)
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info(f"Running Permission Sync, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
logger.debug(
|
||||
"Found existing permission sync jobs: "
|
||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||
)
|
||||
|
||||
try:
|
||||
# TODO turn this on when it works
|
||||
"""
|
||||
existing_jobs = cleanup_perm_sync_jobs(existing_jobs=existing_jobs)
|
||||
existing_jobs = create_group_sync_jobs(
|
||||
existing_jobs=existing_jobs, client=client
|
||||
)
|
||||
existing_jobs = create_connector_perm_sync_jobs(
|
||||
existing_jobs=existing_jobs, client=client
|
||||
)
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
sleep_time = delay - (time.time() - start)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
logger.notice("Starting Permission Syncing Loop")
|
||||
init_sqlalchemy_engine(POSTGRES_PERMISSIONS_APP_NAME)
|
||||
permission_loop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
update__main()
|
@ -1,6 +1,6 @@
|
||||
def name_user_group_sync_task(user_group_id: int) -> str:
|
||||
return f"user_group_sync_task__{user_group_id}"
|
||||
|
||||
|
||||
def name_chat_ttl_task(retention_limit_days: int) -> str:
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
def name_sync_external_permissions_task(cc_pair_id: int) -> str:
|
||||
return f"sync_external_permissions_task__{cc_pair_id}"
|
||||
|
@ -1,12 +0,0 @@
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def confluence_update_db_group() -> None:
|
||||
logger.debug("Not yet implemented group sync for confluence, no-op")
|
||||
|
||||
|
||||
def confluence_update_index_acl(cc_pair_id: int) -> None:
|
||||
logger.debug("Not yet implemented ACL sync for confluence, no-op")
|
@ -1,8 +0,0 @@
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from ee.danswer.connectors.confluence.perm_sync import confluence_update_db_group
|
||||
from ee.danswer.connectors.confluence.perm_sync import confluence_update_index_acl
|
||||
|
||||
|
||||
CONNECTOR_PERMISSION_FUNC_MAP = {
|
||||
DocumentSource.CONFLUENCE: (confluence_update_db_group, confluence_update_index_acl)
|
||||
}
|
@ -3,6 +3,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
@ -32,14 +33,30 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
|
||||
|
||||
|
||||
def get_cc_pairs_by_source(
|
||||
source_type: DocumentSource,
|
||||
db_session: Session,
|
||||
source_type: DocumentSource,
|
||||
only_sync: bool,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
cc_pairs = (
|
||||
query = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.join(ConnectorCredentialPair.connector)
|
||||
.filter(Connector.source == source_type)
|
||||
.all()
|
||||
)
|
||||
|
||||
if only_sync:
|
||||
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
|
||||
|
||||
cc_pairs = query.all()
|
||||
return cc_pairs
|
||||
|
||||
|
||||
def get_all_auto_sync_cc_pairs(
|
||||
db_session: Session,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
return (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.where(
|
||||
ConnectorCredentialPair.access_type == AccessType.SYNC,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
@ -1,14 +1,47 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import Document
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.access.utils import prefix_group_w_source
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import Document as DbDocument
|
||||
|
||||
|
||||
def fetch_documents_from_ids(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> Sequence[Document]:
|
||||
return db_session.scalars(
|
||||
select(Document).where(Document.id.in_(document_ids))
|
||||
).all()
|
||||
def upsert_document_external_perms__no_commit(
|
||||
db_session: Session,
|
||||
doc_id: str,
|
||||
external_access: ExternalAccess,
|
||||
source_type: DocumentSource,
|
||||
) -> None:
|
||||
"""
|
||||
This sets the permissions for a document in postgres.
|
||||
NOTE: this will replace any existing external access, it will not do a union
|
||||
"""
|
||||
document = db_session.scalars(
|
||||
select(DbDocument).where(DbDocument.id == doc_id)
|
||||
).first()
|
||||
|
||||
prefixed_external_groups = [
|
||||
prefix_group_w_source(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
for group_id in external_access.external_user_group_ids
|
||||
]
|
||||
|
||||
if not document:
|
||||
# If the document does not exist, still store the external access
|
||||
# So that if the document is added later, the external access is already stored
|
||||
document = DbDocument(
|
||||
id=doc_id,
|
||||
semantic_id="",
|
||||
external_user_emails=external_access.external_user_emails,
|
||||
external_user_group_ids=prefixed_external_groups,
|
||||
is_public=external_access.is_public,
|
||||
)
|
||||
db_session.add(document)
|
||||
return
|
||||
|
||||
document.external_user_emails = list(external_access.external_user_emails)
|
||||
document.external_user_group_ids = prefixed_external_groups
|
||||
document.is_public = external_access.is_public
|
||||
|
77
backend/ee/danswer/db/external_perm.py
Normal file
77
backend/ee/danswer/db/external_perm.py
Normal file
@ -0,0 +1,77 @@
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.utils import prefix_group_w_source
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import User__ExternalUserGroupId
|
||||
|
||||
|
||||
class ExternalUserGroup(BaseModel):
|
||||
id: str
|
||||
user_ids: list[UUID]
|
||||
|
||||
|
||||
def delete_user__ext_group_for_user__no_commit(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.user_id == user_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.cc_pair_id == cc_pair_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
group_defs: list[ExternalUserGroup],
|
||||
source: DocumentSource,
|
||||
) -> None:
|
||||
"""
|
||||
This function clears all existing external user group relations for a given cc_pair_id
|
||||
and replaces them with the new group definitions.
|
||||
"""
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
new_external_permissions = [
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=prefix_group_w_source(external_group.id, source),
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
for external_group in group_defs
|
||||
for user_id in external_group.user_ids
|
||||
]
|
||||
|
||||
db_session.add_all(new_external_permissions)
|
||||
|
||||
|
||||
def fetch_external_groups_for_user(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
) -> Sequence[User__ExternalUserGroupId]:
|
||||
return db_session.scalars(
|
||||
select(User__ExternalUserGroupId).where(
|
||||
User__ExternalUserGroupId.user_id == user_id
|
||||
)
|
||||
).all()
|
@ -1,72 +0,0 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import PermissionSyncRun
|
||||
from danswer.db.models import PermissionSyncStatus
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def mark_all_inprogress_permission_sync_failed(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(PermissionSyncRun)
|
||||
.where(PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS)
|
||||
.values(status=PermissionSyncStatus.FAILED)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_perm_sync_attempt(attempt_id: int, db_session: Session) -> PermissionSyncRun:
|
||||
stmt = select(PermissionSyncRun).where(PermissionSyncRun.id == attempt_id)
|
||||
try:
|
||||
return db_session.scalars(stmt).one()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"No PermissionSyncRun found with id {attempt_id}")
|
||||
|
||||
|
||||
def expire_perm_sync_timed_out(
|
||||
timeout_hours: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
cutoff_time = func.now() - timedelta(hours=timeout_hours)
|
||||
|
||||
update_stmt = (
|
||||
update(PermissionSyncRun)
|
||||
.where(
|
||||
PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS,
|
||||
PermissionSyncRun.updated_at < cutoff_time,
|
||||
)
|
||||
.values(status=PermissionSyncStatus.FAILED, error_msg="timed out")
|
||||
)
|
||||
|
||||
db_session.execute(update_stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_perm_sync(
|
||||
source_type: DocumentSource,
|
||||
group_update: bool,
|
||||
cc_pair_id: int | None,
|
||||
db_session: Session,
|
||||
) -> PermissionSyncRun:
|
||||
new_run = PermissionSyncRun(
|
||||
source_type=source_type,
|
||||
status=PermissionSyncStatus.IN_PROGRESS,
|
||||
group_update=group_update,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
db_session.add(new_run)
|
||||
db_session.commit()
|
||||
|
||||
return new_run
|
@ -199,7 +199,7 @@ def fetch_documents_for_user_group_paginated(
|
||||
def fetch_user_groups_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[int, list[str]]]:
|
||||
) -> Sequence[tuple[str, list[str]]]:
|
||||
stmt = (
|
||||
select(Document.id, func.array_agg(UserGroup.name))
|
||||
.join(
|
||||
|
@ -0,0 +1,19 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug("Not yet implemented ACL sync for confluence, no-op")
|
@ -0,0 +1,19 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
logger.debug("Not yet implemented group sync for confluence, no-op")
|
148
backend/ee/danswer/external_permissions/google_drive/doc_sync.py
Normal file
148
backend/ee/danswer/external_permissions/google_drive/doc_sync.py
Normal file
@ -0,0 +1,148 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=5, delay=5, max_delay=30)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _fetch_permissions_paginated(
|
||||
drive_service: Any, drive_file_id: str
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
next_token = None
|
||||
|
||||
# Check if the file is trashed
|
||||
# Returning nothing here will cause the external permissions to
|
||||
# be empty which will get written to vespa (failing shut)
|
||||
try:
|
||||
file_metadata = add_retries(
|
||||
lambda: drive_service.files()
|
||||
.get(fileId=drive_file_id, fields="id, trashed")
|
||||
.execute()
|
||||
)()
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404 or e.resp.status == 403:
|
||||
return
|
||||
logger.error(f"Failed to fetch permissions: {e}")
|
||||
raise
|
||||
|
||||
if file_metadata.get("trashed", False):
|
||||
logger.debug(f"File with ID {drive_file_id} is trashed")
|
||||
return
|
||||
|
||||
# Get paginated permissions for the file id
|
||||
while True:
|
||||
try:
|
||||
permissions_resp: dict[str, Any] = add_retries(
|
||||
lambda: (
|
||||
drive_service.permissions()
|
||||
.list(
|
||||
fileId=drive_file_id,
|
||||
fields="permissions(id, emailAddress, role, type, domain)",
|
||||
supportsAllDrives=True,
|
||||
pageToken=next_token,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404 or e.resp.status == 403:
|
||||
break
|
||||
logger.error(f"Failed to fetch permissions: {e}")
|
||||
raise
|
||||
|
||||
for permission in permissions_resp.get("permissions", []):
|
||||
yield permission
|
||||
|
||||
next_token = permissions_resp.get("nextPageToken")
|
||||
if not next_token:
|
||||
break
|
||||
|
||||
|
||||
def _fetch_google_permissions_for_document_id(
|
||||
db_session: Session,
|
||||
drive_file_id: str,
|
||||
raw_credentials_json: dict[str, str],
|
||||
company_google_domains: list[str],
|
||||
) -> ExternalAccess:
|
||||
# Authenticate and construct service
|
||||
google_drive_creds, _ = get_google_drive_creds(
|
||||
raw_credentials_json, scopes=FETCH_PERMISSIONS_SCOPES
|
||||
)
|
||||
if not google_drive_creds.valid:
|
||||
raise ValueError("Invalid Google Drive credentials")
|
||||
|
||||
drive_service = build("drive", "v3", credentials=google_drive_creds)
|
||||
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
public = False
|
||||
for permission in _fetch_permissions_paginated(drive_service, drive_file_id):
|
||||
permission_type = permission["type"]
|
||||
if permission_type == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
elif permission_type == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission_type == "domain":
|
||||
if permission["domain"] in company_google_domains:
|
||||
public = True
|
||||
elif permission_type == "anyone":
|
||||
public = True
|
||||
|
||||
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_emails,
|
||||
is_public=public,
|
||||
)
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
for doc in docs_with_additional_info:
|
||||
ext_access = _fetch_google_permissions_for_document_id(
|
||||
db_session=db_session,
|
||||
drive_file_id=doc.additional_info,
|
||||
raw_credentials_json=cc_pair.credential.credential_json,
|
||||
company_google_domains=[
|
||||
cast(dict[str, str], sync_details)["company_domain"]
|
||||
],
|
||||
)
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc.id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
@ -0,0 +1,147 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=5, delay=5, max_delay=30)
|
||||
|
||||
|
||||
def _fetch_groups_paginated(
|
||||
google_drive_creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
identity_source: str | None = None,
|
||||
customer_id: str | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
# Note that Google Drive does not use of update the user_cache as the user email
|
||||
# comes directly with the call to fetch the groups, therefore this is not a valid
|
||||
# place to save on requests
|
||||
if identity_source is None and customer_id is None:
|
||||
raise ValueError(
|
||||
"Either identity_source or customer_id must be provided to fetch groups"
|
||||
)
|
||||
|
||||
cloud_identity_service = build(
|
||||
"cloudidentity", "v1", credentials=google_drive_creds
|
||||
)
|
||||
parent = (
|
||||
f"identitysources/{identity_source}"
|
||||
if identity_source
|
||||
else f"customers/{customer_id}"
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
groups_resp: dict[str, Any] = add_retries(
|
||||
lambda: (cloud_identity_service.groups().list(parent=parent).execute())
|
||||
)()
|
||||
for group in groups_resp.get("groups", []):
|
||||
yield group
|
||||
|
||||
next_token = groups_resp.get("nextPageToken")
|
||||
if not next_token:
|
||||
break
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404 or e.resp.status == 403:
|
||||
break
|
||||
logger.error(f"Error fetching groups: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _fetch_group_members_paginated(
|
||||
google_drive_creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
group_name: str,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
cloud_identity_service = build(
|
||||
"cloudidentity", "v1", credentials=google_drive_creds
|
||||
)
|
||||
next_token = None
|
||||
while True:
|
||||
try:
|
||||
membership_info = add_retries(
|
||||
lambda: (
|
||||
cloud_identity_service.groups()
|
||||
.memberships()
|
||||
.searchTransitiveMemberships(
|
||||
parent=group_name, pageToken=next_token
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
)()
|
||||
|
||||
for member in membership_info.get("memberships", []):
|
||||
yield member
|
||||
|
||||
next_token = membership_info.get("nextPageToken")
|
||||
if not next_token:
|
||||
break
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404 or e.resp.status == 403:
|
||||
break
|
||||
logger.error(f"Error fetching group members: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
docs_with_additional_info: list[DocsWithAdditionalInfo],
|
||||
sync_details: dict[str, Any],
|
||||
) -> None:
|
||||
google_drive_creds, _ = get_google_drive_creds(
|
||||
cc_pair.credential.credential_json,
|
||||
scopes=FETCH_GROUPS_SCOPES,
|
||||
)
|
||||
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
for group in _fetch_groups_paginated(
|
||||
google_drive_creds,
|
||||
identity_source=sync_details.get("identity_source"),
|
||||
customer_id=sync_details.get("customer_id"),
|
||||
):
|
||||
# The id is the group email
|
||||
group_email = group["groupKey"]["id"]
|
||||
|
||||
group_member_emails: list[str] = []
|
||||
for member in _fetch_group_members_paginated(google_drive_creds, group["name"]):
|
||||
member_keys = member["preferredMemberKey"]
|
||||
member_emails = [member_key["id"] for member_key in member_keys]
|
||||
for member_email in member_emails:
|
||||
group_member_emails.append(member_email)
|
||||
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=group_member_emails
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email, user_ids=[user.id for user in group_members]
|
||||
)
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
126
backend/ee/danswer/external_permissions/permission_sync.py
Normal file
126
backend/ee/danswer/external_permissions/permission_sync.py
Normal file
@ -0,0 +1,126 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
DOC_PERMISSIONS_FUNC_MAP,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
FULL_FETCH_PERIOD_IN_SECONDS,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
GROUP_PERMISSIONS_FUNC_MAP,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync_utils import (
|
||||
get_docs_with_additional_info,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_permission_sync_entrypoint(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
# TODO: seperate out group and doc sync
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No permission sync function found for source type: {source_type}"
|
||||
)
|
||||
|
||||
sync_details = cc_pair.auto_sync_options
|
||||
if sync_details is None:
|
||||
raise ValueError(f"No auto sync options found for source type: {source_type}")
|
||||
|
||||
# If the source type is not polling, we only fetch the permissions every
|
||||
# _FULL_FETCH_PERIOD_IN_SECONDS seconds
|
||||
full_fetch_period = FULL_FETCH_PERIOD_IN_SECONDS[source_type]
|
||||
if full_fetch_period is not None:
|
||||
last_sync = cc_pair.last_time_perm_sync
|
||||
if (
|
||||
last_sync
|
||||
and (
|
||||
datetime.now(timezone.utc) - last_sync.replace(tzinfo=timezone.utc)
|
||||
).total_seconds()
|
||||
< full_fetch_period
|
||||
):
|
||||
return
|
||||
|
||||
# Here we run the connector to grab all the ids
|
||||
# this may grab ids before they are indexed but that is fine because
|
||||
# we create a document in postgres to hold the permissions info
|
||||
# until the indexing job has a chance to run
|
||||
docs_with_additional_info = get_docs_with_additional_info(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing groups for {source_type}")
|
||||
if group_sync_func is not None:
|
||||
group_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
docs_with_additional_info,
|
||||
sync_details,
|
||||
)
|
||||
|
||||
# This function updates:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing docs for {source_type}")
|
||||
doc_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
docs_with_additional_info,
|
||||
sync_details,
|
||||
)
|
||||
|
||||
# This function fetches the updated access for the documents
|
||||
# and returns a dictionary of document_ids and access
|
||||
# This is the access we want to update vespa with
|
||||
docs_access = get_access_for_documents(
|
||||
document_ids=[doc.id for doc in docs_with_additional_info],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Then we build the update requests to update vespa
|
||||
update_reqs = [
|
||||
UpdateRequest(document_ids=[doc_id], access=doc_access)
|
||||
for doc_id, doc_access in docs_access.items()
|
||||
]
|
||||
|
||||
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
|
||||
try:
|
||||
# update vespa
|
||||
document_index.update(update_reqs)
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating document index: {e}")
|
||||
db_session.rollback()
|
@ -0,0 +1,54 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
|
||||
None,
|
||||
]
|
||||
|
||||
DocSyncFuncType = Callable[
|
||||
[Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]],
|
||||
None,
|
||||
]
|
||||
|
||||
# These functions update:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
||||
}
|
||||
|
||||
# These functions update:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_group_sync,
|
||||
}
|
||||
|
||||
|
||||
# None means that the connector supports polling from last_time_perm_sync to now
|
||||
FULL_FETCH_PERIOD_IN_SECONDS: dict[DocumentSource, int | None] = {
|
||||
# Polling is supported
|
||||
DocumentSource.GOOGLE_DRIVE: None,
|
||||
# Polling is not supported so we fetch all doc permissions every 10 minutes
|
||||
DocumentSource.CONFLUENCE: 10 * 60,
|
||||
}
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocsWithAdditionalInfo(BaseModel):
|
||||
id: str
|
||||
additional_info: Any
|
||||
|
||||
|
||||
def get_docs_with_additional_info(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocsWithAdditionalInfo]:
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp()
|
||||
if cc_pair.last_time_perm_sync
|
||||
else 0
|
||||
)
|
||||
cc_pair.last_time_perm_sync = current_time
|
||||
|
||||
doc_batch_generator = runnable_connector.poll_source(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
)
|
||||
|
||||
docs_with_additional_info = [
|
||||
DocsWithAdditionalInfo(id=doc.id, additional_info=doc.additional_info)
|
||||
for doc_batch in doc_batch_generator
|
||||
for doc in doc_batch
|
||||
]
|
||||
logger.debug(f"Docs with additional info: {len(docs_with_additional_info)}")
|
||||
|
||||
return docs_with_additional_info
|
166
backend/scripts/query_time_check/seed_dummy_docs.py
Normal file
166
backend/scripts/query_time_check/seed_dummy_docs.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""
|
||||
launch:
|
||||
- api server
|
||||
- postgres
|
||||
- vespa
|
||||
- model server (this is only needed so the api server can startup, no embedding is done)
|
||||
|
||||
Run this script to seed the database with dummy documents.
|
||||
Then run test_query_times.py to test query times.
|
||||
"""
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
TOTAL_DOC_SETS = 8
|
||||
TOTAL_ACL_ENTRIES_PER_CATEGORY = 80
|
||||
|
||||
|
||||
def generate_random_embedding(dim: int) -> Embedding:
|
||||
return [random.uniform(-1, 1) for _ in range(dim)]
|
||||
|
||||
|
||||
def generate_random_identifier() -> str:
|
||||
return f"dummy_doc_{random.randint(1, 1000)}"
|
||||
|
||||
|
||||
def generate_dummy_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int,
|
||||
embedding_dim: int,
|
||||
number_of_acl_entries: int,
|
||||
number_of_document_sets: int,
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
document = Document(
|
||||
id=doc_id,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
sections=[],
|
||||
metadata={},
|
||||
semantic_identifier=generate_random_identifier(),
|
||||
)
|
||||
|
||||
chunk = IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb=f"Blurb for chunk {chunk_id} of document {doc_id}.",
|
||||
content=f"Content for chunk {chunk_id} of document {doc_id}. This is dummy text for testing purposes.",
|
||||
source_links={},
|
||||
section_continuation=False,
|
||||
source_document=document,
|
||||
title_prefix=f"Title prefix for doc {doc_id}",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=generate_random_embedding(embedding_dim),
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=generate_random_embedding(embedding_dim),
|
||||
)
|
||||
|
||||
document_set_names = []
|
||||
for i in range(number_of_document_sets):
|
||||
document_set_names.append(f"Document Set {i}")
|
||||
|
||||
user_emails: set[str | None] = set()
|
||||
user_groups: set[str] = set()
|
||||
external_user_emails: set[str] = set()
|
||||
external_user_group_ids: set[str] = set()
|
||||
for i in range(number_of_acl_entries):
|
||||
user_emails.add(f"user_{i}@example.com")
|
||||
user_groups.add(f"group_{i}")
|
||||
external_user_emails.add(f"external_user_{i}@example.com")
|
||||
external_user_group_ids.add(f"external_group_{i}")
|
||||
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=DocumentAccess(
|
||||
user_emails=user_emails,
|
||||
user_groups=user_groups,
|
||||
external_user_emails=external_user_emails,
|
||||
external_user_group_ids=external_user_group_ids,
|
||||
is_public=random.choice([True, False]),
|
||||
),
|
||||
document_sets={document_set for document_set in document_set_names},
|
||||
boost=random.randint(-1, 1),
|
||||
)
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def do_insertion(
|
||||
vespa_index: VespaIndex, all_chunks: list[DocMetadataAwareIndexChunk]
|
||||
) -> None:
|
||||
insertion_records = vespa_index.index(all_chunks)
|
||||
print(f"Indexed {len(insertion_records)} documents.")
|
||||
print(
|
||||
f"New documents: {sum(1 for record in insertion_records if not record.already_existed)}"
|
||||
)
|
||||
print(
|
||||
f"Existing documents updated: {sum(1 for record in insertion_records if record.already_existed)}"
|
||||
)
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def seed_dummy_docs(
|
||||
number_of_document_sets: int,
|
||||
number_of_acl_entries: int,
|
||||
num_docs: int = 1000,
|
||||
chunks_per_doc: int = 5,
|
||||
batch_size: int = 100,
|
||||
) -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_name = search_settings.index_name
|
||||
embedding_dim = search_settings.model_dim
|
||||
|
||||
vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None)
|
||||
print(index_name)
|
||||
|
||||
all_chunks = []
|
||||
chunk_count = 0
|
||||
for doc_num in range(num_docs):
|
||||
doc_id = f"dummy_doc_{doc_num}_{datetime.now().isoformat()}"
|
||||
for chunk_num in range(chunks_per_doc):
|
||||
chunk = generate_dummy_chunk(
|
||||
doc_id=doc_id,
|
||||
chunk_id=chunk_num,
|
||||
embedding_dim=embedding_dim,
|
||||
number_of_acl_entries=number_of_acl_entries,
|
||||
number_of_document_sets=number_of_document_sets,
|
||||
)
|
||||
all_chunks.append(chunk)
|
||||
chunk_count += 1
|
||||
|
||||
if len(all_chunks) >= chunks_per_doc * batch_size:
|
||||
do_insertion(vespa_index, all_chunks)
|
||||
print(
|
||||
f"Indexed {chunk_count} chunks out of {num_docs * chunks_per_doc}."
|
||||
)
|
||||
print(
|
||||
f"percentage: {chunk_count / (num_docs * chunks_per_doc) * 100:.2f}% \n"
|
||||
)
|
||||
all_chunks = []
|
||||
|
||||
if all_chunks:
|
||||
do_insertion(vespa_index, all_chunks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed_dummy_docs(
|
||||
number_of_document_sets=TOTAL_DOC_SETS,
|
||||
number_of_acl_entries=TOTAL_ACL_ENTRIES_PER_CATEGORY,
|
||||
num_docs=100000,
|
||||
chunks_per_doc=5,
|
||||
batch_size=1000,
|
||||
)
|
122
backend/scripts/query_time_check/test_query_times.py
Normal file
122
backend/scripts/query_time_check/test_query_times.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""
|
||||
RUN THIS AFTER SEED_DUMMY_DOCS.PY
|
||||
"""
|
||||
import random
|
||||
import time
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.search.models import IndexFilters
|
||||
from scripts.query_time_check.seed_dummy_docs import TOTAL_ACL_ENTRIES_PER_CATEGORY
|
||||
from scripts.query_time_check.seed_dummy_docs import TOTAL_DOC_SETS
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
# make sure these are smaller than TOTAL_ACL_ENTRIES_PER_CATEGORY and TOTAL_DOC_SETS, respectively
|
||||
NUMBER_OF_ACL_ENTRIES_PER_QUERY = 6
|
||||
NUMBER_OF_DOC_SETS_PER_QUERY = 2
|
||||
|
||||
|
||||
def get_slowest_99th_percentile(results: list[float]) -> float:
|
||||
return sorted(results)[int(0.99 * len(results))]
|
||||
|
||||
|
||||
# Generate random filters
|
||||
def _random_filters() -> IndexFilters:
|
||||
"""
|
||||
Generate random filters for the query containing:
|
||||
- NUMBER_OF_ACL_ENTRIES_PER_QUERY user emails
|
||||
- NUMBER_OF_ACL_ENTRIES_PER_QUERY groups
|
||||
- NUMBER_OF_ACL_ENTRIES_PER_QUERY external groups
|
||||
- NUMBER_OF_DOC_SETS_PER_QUERY document sets
|
||||
"""
|
||||
access_control_list = [
|
||||
f"user_email:user_{random.randint(0, TOTAL_ACL_ENTRIES_PER_CATEGORY - 1)}@example.com",
|
||||
]
|
||||
acl_indices = random.sample(
|
||||
range(TOTAL_ACL_ENTRIES_PER_CATEGORY), NUMBER_OF_ACL_ENTRIES_PER_QUERY
|
||||
)
|
||||
for i in acl_indices:
|
||||
access_control_list.append(f"group:group_{acl_indices[i]}")
|
||||
access_control_list.append(f"external_group:external_group_{acl_indices[i]}")
|
||||
|
||||
doc_sets = []
|
||||
doc_set_indices = random.sample(
|
||||
range(TOTAL_DOC_SETS), NUMBER_OF_ACL_ENTRIES_PER_QUERY
|
||||
)
|
||||
for i in doc_set_indices:
|
||||
doc_sets.append(f"document_set:Document Set {doc_set_indices[i]}")
|
||||
|
||||
return IndexFilters(
|
||||
source_type=[DocumentSource.GOOGLE_DRIVE],
|
||||
document_set=doc_sets,
|
||||
tags=[],
|
||||
access_control_list=access_control_list,
|
||||
)
|
||||
|
||||
|
||||
def test_hybrid_retrieval_times(
|
||||
number_of_queries: int,
|
||||
) -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_name = search_settings.index_name
|
||||
|
||||
vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None)
|
||||
|
||||
# Generate random queries
|
||||
queries = [f"Random Query {i}" for i in range(number_of_queries)]
|
||||
|
||||
# Generate random embeddings
|
||||
embeddings = [
|
||||
Embedding([random.random() for _ in range(DOC_EMBEDDING_DIM)])
|
||||
for _ in range(number_of_queries)
|
||||
]
|
||||
|
||||
total_time = 0.0
|
||||
results = []
|
||||
for i in range(number_of_queries):
|
||||
start_time = time.time()
|
||||
|
||||
vespa_index.hybrid_retrieval(
|
||||
query=queries[i],
|
||||
query_embedding=embeddings[i],
|
||||
final_keywords=None,
|
||||
filters=_random_filters(),
|
||||
hybrid_alpha=0.5,
|
||||
time_decay_multiplier=1.0,
|
||||
num_to_retrieve=50,
|
||||
offset=0,
|
||||
title_content_ratio=0.5,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
query_time = end_time - start_time
|
||||
total_time += query_time
|
||||
results.append(query_time)
|
||||
|
||||
print(f"Query {i+1}: {query_time:.4f} seconds")
|
||||
|
||||
avg_time = total_time / number_of_queries
|
||||
fast_time = min(results)
|
||||
slow_time = max(results)
|
||||
ninety_ninth_percentile = get_slowest_99th_percentile(results)
|
||||
# Write results to a file
|
||||
_OUTPUT_PATH = "query_times_results_large_more.txt"
|
||||
with open(_OUTPUT_PATH, "w") as f:
|
||||
f.write(f"Average query time: {avg_time:.4f} seconds\n")
|
||||
f.write(f"Fastest query: {fast_time:.4f} seconds\n")
|
||||
f.write(f"Slowest query: {slow_time:.4f} seconds\n")
|
||||
f.write(f"99th percentile: {ninety_ninth_percentile:.4f} seconds\n")
|
||||
print(f"Results written to {_OUTPUT_PATH}")
|
||||
|
||||
print(f"\nAverage query time: {avg_time:.4f} seconds")
|
||||
print(f"Fastest query: {fast_time:.4f} seconds")
|
||||
print(f"Slowest query: {max(results):.4f} seconds")
|
||||
print(f"99th percentile: {get_slowest_99th_percentile(results):.4f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hybrid_retrieval_times(number_of_queries=1000)
|
@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorIndexingStatus
|
||||
@ -22,7 +23,7 @@ def _cc_pair_creator(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
name: str | None = None,
|
||||
is_public: bool = True,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCCPair:
|
||||
@ -30,7 +31,7 @@ def _cc_pair_creator(
|
||||
|
||||
request = {
|
||||
"name": name,
|
||||
"is_public": is_public,
|
||||
"access_type": access_type,
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
@ -47,7 +48,7 @@ def _cc_pair_creator(
|
||||
name=name,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
is_public=is_public,
|
||||
access_type=access_type,
|
||||
groups=groups or [],
|
||||
)
|
||||
|
||||
@ -56,7 +57,7 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def create_from_scratch(
|
||||
name: str | None = None,
|
||||
is_public: bool = True,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
source: DocumentSource = DocumentSource.FILE,
|
||||
input_type: InputType = InputType.LOAD_STATE,
|
||||
@ -69,7 +70,7 @@ class CCPairManager:
|
||||
source=source,
|
||||
input_type=input_type,
|
||||
connector_specific_config=connector_specific_config,
|
||||
is_public=is_public,
|
||||
is_public=(access_type == AccessType.PUBLIC),
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
@ -77,7 +78,7 @@ class CCPairManager:
|
||||
credential_json=credential_json,
|
||||
name=name,
|
||||
source=source,
|
||||
curator_public=is_public,
|
||||
curator_public=(access_type == AccessType.PUBLIC),
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
@ -85,7 +86,7 @@ class CCPairManager:
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
name=name,
|
||||
is_public=is_public,
|
||||
access_type=access_type,
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
@ -95,7 +96,7 @@ class CCPairManager:
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
name: str | None = None,
|
||||
is_public: bool = True,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCCPair:
|
||||
@ -103,7 +104,7 @@ class CCPairManager:
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
name=name,
|
||||
is_public=is_public,
|
||||
access_type=access_type,
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
@ -172,7 +173,7 @@ class CCPairManager:
|
||||
retrieved_cc_pair.name == cc_pair.name
|
||||
and retrieved_cc_pair.connector.id == cc_pair.connector_id
|
||||
and retrieved_cc_pair.credential.id == cc_pair.credential_id
|
||||
and retrieved_cc_pair.public_doc == cc_pair.is_public
|
||||
and retrieved_cc_pair.access_type == cc_pair.access_type
|
||||
and set(retrieved_cc_pair.groups) == set(cc_pair.groups)
|
||||
):
|
||||
return
|
||||
|
@ -3,6 +3,7 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.enums import AccessType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
@ -22,7 +23,7 @@ def _verify_document_permissions(
|
||||
) -> None:
|
||||
acl_keys = set(retrieved_doc["access_control_list"].keys())
|
||||
print(f"ACL keys: {acl_keys}")
|
||||
if cc_pair.is_public:
|
||||
if cc_pair.access_type == AccessType.PUBLIC:
|
||||
if "PUBLIC" not in acl_keys:
|
||||
raise ValueError(
|
||||
f"Document {retrieved_doc['document_id']} is public but"
|
||||
@ -30,10 +31,10 @@ def _verify_document_permissions(
|
||||
)
|
||||
|
||||
if doc_creating_user is not None:
|
||||
if f"user_id:{doc_creating_user.id}" not in acl_keys:
|
||||
if f"user_email:{doc_creating_user.email}" not in acl_keys:
|
||||
raise ValueError(
|
||||
f"Document {retrieved_doc['document_id']} was created by user"
|
||||
f" {doc_creating_user.id} but does not have the user_id:{doc_creating_user.id} ACL key"
|
||||
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
|
||||
)
|
||||
|
||||
if group_names is not None:
|
||||
|
@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from danswer.server.documents.models import InputType
|
||||
@ -67,7 +68,7 @@ class TestCCPair(BaseModel):
|
||||
name: str
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
is_public: bool
|
||||
access_type: AccessType
|
||||
groups: list[int]
|
||||
documents: list[SimpleTestDocument] = Field(default_factory=list)
|
||||
|
||||
|
@ -5,6 +5,7 @@ the permissions of the curator manipulating connector-credential pairs.
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
@ -91,8 +92,8 @@ def test_cc_pair_permissions(reset: None) -> None:
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
name="invalid_cc_pair_1",
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[user_group_1.id],
|
||||
is_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@ -103,8 +104,8 @@ def test_cc_pair_permissions(reset: None) -> None:
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
name="invalid_cc_pair_2",
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[user_group_1.id, user_group_2.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@ -115,8 +116,8 @@ def test_cc_pair_permissions(reset: None) -> None:
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
name="invalid_cc_pair_2",
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@ -129,8 +130,8 @@ def test_cc_pair_permissions(reset: None) -> None:
|
||||
# connector_id=connector_2.id,
|
||||
# credential_id=credential_1.id,
|
||||
# name="invalid_cc_pair_3",
|
||||
# access_type=AccessType.PRIVATE,
|
||||
# groups=[user_group_1.id],
|
||||
# is_public=False,
|
||||
# user_performing_action=curator,
|
||||
# )
|
||||
|
||||
@ -141,8 +142,8 @@ def test_cc_pair_permissions(reset: None) -> None:
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_2.id,
|
||||
name="invalid_cc_pair_4",
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[user_group_1.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@ -154,8 +155,8 @@ def test_cc_pair_permissions(reset: None) -> None:
|
||||
name="valid_cc_pair",
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[user_group_1.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document_set import DocumentSetManager
|
||||
@ -47,14 +48,14 @@ def test_doc_set_permissions_setup(reset: None) -> None:
|
||||
|
||||
# Admin creates a cc_pair
|
||||
private_cc_pair = CCPairManager.create_from_scratch(
|
||||
is_public=False,
|
||||
access_type=AccessType.PRIVATE,
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Admin creates a public cc_pair
|
||||
public_cc_pair = CCPairManager.create_from_scratch(
|
||||
is_public=True,
|
||||
access_type=AccessType.PUBLIC,
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""
|
||||
This test tests the happy path for curator permissions
|
||||
"""
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
@ -64,8 +65,8 @@ def test_whole_curator_flow(reset: None) -> None:
|
||||
connector_id=test_connector.id,
|
||||
credential_id=test_credential.id,
|
||||
name="curator_test_cc_pair",
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[user_group_1.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { FetchError, errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
|
||||
@ -12,7 +12,6 @@ import { useFormContext } from "@/components/context/FormContext";
|
||||
import { getSourceDisplayName } from "@/lib/sources";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { useState } from "react";
|
||||
import { submitConnector } from "@/components/admin/connectors/ConnectorForm";
|
||||
import { deleteCredential, linkCredential } from "@/lib/credential";
|
||||
import { submitFiles } from "./pages/utils/files";
|
||||
import { submitGoogleSite } from "./pages/utils/google_site";
|
||||
@ -38,7 +37,8 @@ import {
|
||||
useGoogleDriveCredentials,
|
||||
} from "./pages/utils/hooks";
|
||||
import { Formik } from "formik";
|
||||
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
|
||||
import { AccessTypeForm } from "@/components/admin/connectors/AccessTypeForm";
|
||||
import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTypeGroupSelector";
|
||||
import NavigationRow from "./NavigationRow";
|
||||
|
||||
export interface AdvancedConfig {
|
||||
@ -46,6 +46,64 @@ export interface AdvancedConfig {
|
||||
pruneFreq: number;
|
||||
indexingStart: string;
|
||||
}
|
||||
import { Connector, ConnectorBase } from "@/lib/connectors/connectors";
|
||||
|
||||
const BASE_CONNECTOR_URL = "/api/manage/admin/connector";
|
||||
|
||||
export async function submitConnector<T>(
|
||||
connector: ConnectorBase<T>,
|
||||
connectorId?: number,
|
||||
fakeCredential?: boolean,
|
||||
isPublicCcpair?: boolean // exclusively for mock credentials, when also need to specify ccpair details
|
||||
): Promise<{ message: string; isSuccess: boolean; response?: Connector<T> }> {
|
||||
const isUpdate = connectorId !== undefined;
|
||||
if (!connector.connector_specific_config) {
|
||||
connector.connector_specific_config = {} as T;
|
||||
}
|
||||
|
||||
try {
|
||||
if (fakeCredential) {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector-with-mock-credential",
|
||||
{
|
||||
method: isUpdate ? "PATCH" : "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ ...connector, is_public: isPublicCcpair }),
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
const responseJson = await response.json();
|
||||
return { message: "Success!", isSuccess: true, response: responseJson };
|
||||
} else {
|
||||
const errorData = await response.json();
|
||||
return { message: `Error: ${errorData.detail}`, isSuccess: false };
|
||||
}
|
||||
} else {
|
||||
const response = await fetch(
|
||||
BASE_CONNECTOR_URL + (isUpdate ? `/${connectorId}` : ""),
|
||||
{
|
||||
method: isUpdate ? "PATCH" : "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(connector),
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const responseJson = await response.json();
|
||||
return { message: "Success!", isSuccess: true, response: responseJson };
|
||||
} else {
|
||||
const errorData = await response.json();
|
||||
return { message: `Error: ${errorData.detail}`, isSuccess: false };
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
return { message: `Error: ${error}`, isSuccess: false };
|
||||
}
|
||||
}
|
||||
|
||||
export default function AddConnector({
|
||||
connector,
|
||||
@ -84,10 +142,35 @@ export default function AddConnector({
|
||||
const { liveGDriveCredential } = useGoogleDriveCredentials();
|
||||
const { liveGmailCredential } = useGmailCredentials();
|
||||
|
||||
const {
|
||||
data: appCredentialData,
|
||||
isLoading: isAppCredentialLoading,
|
||||
error: isAppCredentialError,
|
||||
} = useSWR<{ client_id: string }, FetchError>(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
const {
|
||||
data: serviceAccountKeyData,
|
||||
isLoading: isServiceAccountKeyLoading,
|
||||
error: isServiceAccountKeyError,
|
||||
} = useSWR<{ service_account_email: string }, FetchError>(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
// Check if credential is activated
|
||||
const credentialActivated =
|
||||
(connector === "google_drive" && liveGDriveCredential) ||
|
||||
(connector === "gmail" && liveGmailCredential) ||
|
||||
(connector === "google_drive" &&
|
||||
(liveGDriveCredential ||
|
||||
appCredentialData ||
|
||||
serviceAccountKeyData ||
|
||||
currentCredential)) ||
|
||||
(connector === "gmail" &&
|
||||
(liveGmailCredential ||
|
||||
appCredentialData ||
|
||||
serviceAccountKeyData ||
|
||||
currentCredential)) ||
|
||||
currentCredential;
|
||||
|
||||
// Check if there are no credentials
|
||||
@ -159,10 +242,12 @@ export default function AddConnector({
|
||||
const {
|
||||
name,
|
||||
groups,
|
||||
is_public: isPublic,
|
||||
access_type,
|
||||
pruneFreq,
|
||||
indexingStart,
|
||||
refreshFreq,
|
||||
auto_sync_options,
|
||||
is_public,
|
||||
...connector_specific_config
|
||||
} = values;
|
||||
|
||||
@ -204,6 +289,7 @@ export default function AddConnector({
|
||||
advancedConfiguration.refreshFreq,
|
||||
advancedConfiguration.pruneFreq,
|
||||
advancedConfiguration.indexingStart,
|
||||
values.access_type == "public",
|
||||
name
|
||||
);
|
||||
if (response) {
|
||||
@ -219,7 +305,7 @@ export default function AddConnector({
|
||||
setPopup,
|
||||
setSelectedFiles,
|
||||
name,
|
||||
isPublic,
|
||||
access_type == "public",
|
||||
groups
|
||||
);
|
||||
if (response) {
|
||||
@ -234,15 +320,15 @@ export default function AddConnector({
|
||||
input_type: connector == "web" ? "load_state" : "poll", // single case
|
||||
name: name,
|
||||
source: connector,
|
||||
is_public: access_type == "public",
|
||||
refresh_freq: advancedConfiguration.refreshFreq || null,
|
||||
prune_freq: advancedConfiguration.pruneFreq || null,
|
||||
indexing_start: advancedConfiguration.indexingStart || null,
|
||||
is_public: isPublic,
|
||||
groups: groups,
|
||||
},
|
||||
undefined,
|
||||
credentialActivated ? false : true,
|
||||
isPublic
|
||||
access_type == "public"
|
||||
);
|
||||
// If no credential
|
||||
if (!credentialActivated) {
|
||||
@ -261,8 +347,9 @@ export default function AddConnector({
|
||||
response.id,
|
||||
credential?.id!,
|
||||
name,
|
||||
isPublic,
|
||||
groups
|
||||
access_type,
|
||||
groups,
|
||||
auto_sync_options
|
||||
);
|
||||
if (linkCredentialResponse.ok) {
|
||||
onSuccess();
|
||||
@ -366,11 +453,8 @@ export default function AddConnector({
|
||||
selectedFiles={selectedFiles}
|
||||
/>
|
||||
|
||||
<IsPublicGroupSelector
|
||||
removeIndent
|
||||
formikProps={formikProps}
|
||||
objectName="Connector"
|
||||
/>
|
||||
<AccessTypeForm connector={connector} />
|
||||
<AccessTypeGroupSelector />
|
||||
</Card>
|
||||
)}
|
||||
|
||||
|
@ -6,12 +6,14 @@ import { Logo } from "@/components/Logo";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { credentialTemplates } from "@/lib/connectors/credentials";
|
||||
import Link from "next/link";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { useContext } from "react";
|
||||
|
||||
export default function Sidebar() {
|
||||
const { formStep, setFormStep, connector, allowAdvanced, allowCreate } =
|
||||
useFormContext();
|
||||
const combinedSettings = useContext(SettingsContext);
|
||||
const { isLoadingUser, isAdmin } = useUser();
|
||||
if (!combinedSettings) {
|
||||
return null;
|
||||
}
|
||||
@ -59,7 +61,9 @@ export default function Sidebar() {
|
||||
className="w-full p-2 bg-white border-border border rounded items-center hover:bg-background-200 cursor-pointer transition-all duration-150 flex gap-x-2"
|
||||
>
|
||||
<SettingsIcon className="flex-none " />
|
||||
<p className="my-auto flex items-center text-sm">Admin Page</p>
|
||||
<p className="my-auto flex items-center text-sm">
|
||||
{isAdmin ? "Admin Page" : "Curator Page"}
|
||||
</p>
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
|
@ -222,36 +222,46 @@ export const DriveJsonUploadSection = ({
|
||||
Found existing app credentials with the following <b>Client ID:</b>
|
||||
<p className="italic mt-1">{appCredentialData.client_id}</p>
|
||||
</div>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate("/api/manage/admin/connector/google-drive/app-credential");
|
||||
setPopup({
|
||||
message: "Successfully deleted service account key",
|
||||
type: "success",
|
||||
});
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete app credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/app-credential"
|
||||
);
|
||||
setPopup({
|
||||
message: "Successfully deleted app credentials",
|
||||
type: "success",
|
||||
});
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete app credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ export const submitFiles = async (
|
||||
connector.id,
|
||||
credentialId,
|
||||
name,
|
||||
isPublic,
|
||||
isPublic ? "public" : "private",
|
||||
groups
|
||||
);
|
||||
if (!credentialResponse.ok) {
|
||||
|
@ -10,6 +10,7 @@ export const submitGoogleSite = async (
|
||||
refreshFreq: number,
|
||||
pruneFreq: number,
|
||||
indexingStart: Date,
|
||||
is_public: boolean,
|
||||
name?: string
|
||||
) => {
|
||||
const uploadCreateAndTriggerConnector = async () => {
|
||||
@ -42,6 +43,7 @@ export const submitGoogleSite = async (
|
||||
base_url: base_url,
|
||||
zip_path: filePaths[0],
|
||||
},
|
||||
is_public: is_public,
|
||||
refresh_freq: refreshFreq,
|
||||
prune_freq: pruneFreq,
|
||||
indexing_start: indexingStart,
|
||||
|
@ -155,7 +155,7 @@ export const DocumentSetCreationForm = ({
|
||||
// Filter visible cc pairs
|
||||
const visibleCcPairs = localCcPairs.filter(
|
||||
(ccPair) =>
|
||||
ccPair.public_doc ||
|
||||
ccPair.access_type === "public" ||
|
||||
(ccPair.groups.length > 0 &&
|
||||
props.values.groups.every((group) =>
|
||||
ccPair.groups.includes(group)
|
||||
@ -228,7 +228,7 @@ export const DocumentSetCreationForm = ({
|
||||
// Filter non-visible cc pairs
|
||||
const nonVisibleCcPairs = localCcPairs.filter(
|
||||
(ccPair) =>
|
||||
!ccPair.public_doc &&
|
||||
!(ccPair.access_type === "public") &&
|
||||
(ccPair.groups.length === 0 ||
|
||||
!props.values.groups.every((group) =>
|
||||
ccPair.groups.includes(group)
|
||||
|
@ -232,7 +232,7 @@ function ConnectorRow({
|
||||
<TableCell>{getActivityBadge()}</TableCell>
|
||||
{isPaidEnterpriseFeaturesEnabled && (
|
||||
<TableCell>
|
||||
{ccPairsIndexingStatus.public_doc ? (
|
||||
{ccPairsIndexingStatus.access_type === "public" ? (
|
||||
<Badge
|
||||
size="md"
|
||||
color={isEditable ? "green" : "gray"}
|
||||
@ -334,7 +334,8 @@ export function CCPairIndexingStatusTable({
|
||||
(status) =>
|
||||
status.cc_pair_status === ConnectorCredentialPairStatus.ACTIVE
|
||||
).length,
|
||||
public: statuses.filter((status) => status.public_doc).length,
|
||||
public: statuses.filter((status) => status.access_type === "public")
|
||||
.length,
|
||||
totalDocsIndexed: statuses.reduce(
|
||||
(sum, status) => sum + status.docs_indexed,
|
||||
0
|
||||
@ -420,7 +421,7 @@ export function CCPairIndexingStatusTable({
|
||||
credential_json: {},
|
||||
admin_public: false,
|
||||
},
|
||||
public_doc: true,
|
||||
access_type: "public",
|
||||
docs_indexed: 1000,
|
||||
last_success: "2023-07-01T12:00:00Z",
|
||||
last_finished_status: "success",
|
||||
|
@ -16,7 +16,7 @@ export const ConnectorEditor = ({
|
||||
<div className="mb-3 flex gap-2 flex-wrap">
|
||||
{allCCPairs
|
||||
// remove public docs, since they don't make sense as part of a group
|
||||
.filter((ccPair) => !ccPair.public_doc)
|
||||
.filter((ccPair) => !(ccPair.access_type === "public"))
|
||||
.map((ccPair) => {
|
||||
const ind = selectedCCPairIds.indexOf(ccPair.cc_pair_id);
|
||||
let isSelected = ind !== -1;
|
||||
|
@ -28,6 +28,11 @@ export const UserGroupCreationForm = ({
|
||||
}: UserGroupCreationFormProps) => {
|
||||
const isUpdate = existingUserGroup !== undefined;
|
||||
|
||||
// Filter out ccPairs that aren't access_type "private"
|
||||
const privateCcPairs = ccPairs.filter(
|
||||
(ccPair) => ccPair.access_type === "private"
|
||||
);
|
||||
|
||||
return (
|
||||
<Modal className="w-fit" onOutsideClick={onClose}>
|
||||
<div className="px-8 py-6 bg-background">
|
||||
@ -96,7 +101,7 @@ export const UserGroupCreationForm = ({
|
||||
<Divider />
|
||||
|
||||
<h2 className="mb-1 font-medium">
|
||||
Select which connectors this group has access to:
|
||||
Select which private connectors this group has access to:
|
||||
</h2>
|
||||
<p className="mb-3 text-xs">
|
||||
All documents indexed by the selected connectors will be
|
||||
@ -104,7 +109,7 @@ export const UserGroupCreationForm = ({
|
||||
</p>
|
||||
|
||||
<ConnectorEditor
|
||||
allCCPairs={ccPairs}
|
||||
allCCPairs={privateCcPairs}
|
||||
selectedCCPairIds={values.cc_pair_ids}
|
||||
setSetCCPairIds={(ccPairsIds) =>
|
||||
setFieldValue("cc_pair_ids", ccPairsIds)
|
||||
|
@ -76,7 +76,7 @@ export const AddConnectorForm: React.FC<AddConnectorFormProps> = ({
|
||||
.includes(ccPair.cc_pair_id)
|
||||
)
|
||||
// remove public docs, since they don't make sense as part of a group
|
||||
.filter((ccPair) => !ccPair.public_doc)
|
||||
.filter((ccPair) => !(ccPair.access_type === "public"))
|
||||
.map((ccPair) => {
|
||||
return {
|
||||
name: ccPair.name?.toString() || "",
|
||||
|
88
web/src/components/admin/connectors/AccessTypeForm.tsx
Normal file
88
web/src/components/admin/connectors/AccessTypeForm.tsx
Normal file
@ -0,0 +1,88 @@
|
||||
import { DefaultDropdown } from "@/components/Dropdown";
|
||||
import {
|
||||
AccessType,
|
||||
ValidAutoSyncSources,
|
||||
ConfigurableSources,
|
||||
validAutoSyncSources,
|
||||
} from "@/lib/types";
|
||||
import { Text, Title } from "@tremor/react";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { useField } from "formik";
|
||||
import { AutoSyncOptions } from "./AutoSyncOptions";
|
||||
|
||||
function isValidAutoSyncSource(
|
||||
value: ConfigurableSources
|
||||
): value is ValidAutoSyncSources {
|
||||
return validAutoSyncSources.includes(value as ValidAutoSyncSources);
|
||||
}
|
||||
|
||||
export function AccessTypeForm({
|
||||
connector,
|
||||
}: {
|
||||
connector: ConfigurableSources;
|
||||
}) {
|
||||
const [access_type, meta, access_type_helpers] =
|
||||
useField<AccessType>("access_type");
|
||||
|
||||
const isAutoSyncSupported = isValidAutoSyncSource(connector);
|
||||
const { isLoadingUser, isAdmin } = useUser();
|
||||
|
||||
const options = [
|
||||
{
|
||||
name: "Private",
|
||||
value: "private",
|
||||
description:
|
||||
"Only users who have expliticly been given access to this connector (through the User Groups page) can access the documents pulled in by this connector",
|
||||
},
|
||||
];
|
||||
|
||||
if (isAdmin) {
|
||||
options.push({
|
||||
name: "Public",
|
||||
value: "public",
|
||||
description:
|
||||
"Everyone with an account on Danswer can access the documents pulled in by this connector",
|
||||
});
|
||||
}
|
||||
|
||||
if (isAutoSyncSupported && isAdmin) {
|
||||
options.push({
|
||||
name: "Auto Sync",
|
||||
value: "sync",
|
||||
description:
|
||||
"We will automatically sync permissions from the source. A document will be searchable in Danswer if and only if the user performing the search has permission to access the document in the source.",
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex gap-x-2 items-center">
|
||||
<label className="text-text-950 font-medium">Document Access</label>
|
||||
</div>
|
||||
<p className="text-sm text-text-500 mb-2">
|
||||
Control who has access to the documents indexed by this connector.
|
||||
</p>
|
||||
|
||||
{isAdmin && (
|
||||
<>
|
||||
<DefaultDropdown
|
||||
options={options}
|
||||
selected={access_type.value}
|
||||
onSelect={(selected) =>
|
||||
access_type_helpers.setValue(selected as AccessType)
|
||||
}
|
||||
includeDefault={false}
|
||||
/>
|
||||
|
||||
{access_type.value === "sync" && isAutoSyncSupported && (
|
||||
<div className="mt-6">
|
||||
<AutoSyncOptions
|
||||
connectorType={connector as ValidAutoSyncSources}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
147
web/src/components/admin/connectors/AccessTypeGroupSelector.tsx
Normal file
147
web/src/components/admin/connectors/AccessTypeGroupSelector.tsx
Normal file
@ -0,0 +1,147 @@
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { FieldArray, ArrayHelpers, ErrorMessage, useField } from "formik";
|
||||
import { Text, Divider } from "@tremor/react";
|
||||
import { FiUsers } from "react-icons/fi";
|
||||
import { UserGroup, User, UserRole } from "@/lib/types";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { AccessType } from "@/lib/types";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
|
||||
// This should be included for all forms that require groups / public access
|
||||
// to be set, and access to this / permissioning should be handled within this component itself.
|
||||
export function AccessTypeGroupSelector({}: {}) {
|
||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||
const { isAdmin, user, isLoadingUser, isCurator } = useUser();
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const [shouldHideContent, setShouldHideContent] = useState(false);
|
||||
|
||||
const [access_type, meta, access_type_helpers] =
|
||||
useField<AccessType>("access_type");
|
||||
const [groups, groups_meta, groups_helpers] = useField<number[]>("groups");
|
||||
|
||||
useEffect(() => {
|
||||
if (user && userGroups && isPaidEnterpriseFeaturesEnabled) {
|
||||
const isUserAdmin = user.role === UserRole.ADMIN;
|
||||
if (!isPaidEnterpriseFeaturesEnabled) {
|
||||
access_type_helpers.setValue("public");
|
||||
return;
|
||||
}
|
||||
if (!isUserAdmin) {
|
||||
access_type_helpers.setValue("private");
|
||||
}
|
||||
if (userGroups.length === 1 && !isUserAdmin) {
|
||||
groups_helpers.setValue([userGroups[0].id]);
|
||||
setShouldHideContent(true);
|
||||
} else if (access_type.value !== "private") {
|
||||
groups_helpers.setValue([]);
|
||||
setShouldHideContent(false);
|
||||
} else {
|
||||
setShouldHideContent(false);
|
||||
}
|
||||
}
|
||||
}, [user, userGroups, access_type.value]);
|
||||
|
||||
if (isLoadingUser || userGroupsIsLoading) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
if (!isPaidEnterpriseFeaturesEnabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (shouldHideContent) {
|
||||
return (
|
||||
<>
|
||||
{userGroups && (
|
||||
<div className="mb-1 font-medium text-base">
|
||||
This Connector will be assigned to group <b>{userGroups[0].name}</b>
|
||||
.
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
{(access_type.value === "private" || isCurator) &&
|
||||
userGroups &&
|
||||
userGroups?.length > 0 && (
|
||||
<>
|
||||
<Divider />
|
||||
<div className="flex mt-4 gap-x-2 items-center">
|
||||
<div className="block font-medium text-base">
|
||||
Assign group access for this Connector
|
||||
</div>
|
||||
</div>
|
||||
{userGroupsIsLoading ? (
|
||||
<div className="animate-pulse bg-gray-200 h-8 w-32 rounded"></div>
|
||||
) : (
|
||||
<Text className="mb-3">
|
||||
{isAdmin ? (
|
||||
<>
|
||||
This Connector will be visible/accessible by the groups
|
||||
selected below
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
Curators must select one or more groups to give access to
|
||||
this Connector
|
||||
</>
|
||||
)}
|
||||
</Text>
|
||||
)}
|
||||
<FieldArray
|
||||
name="groups"
|
||||
render={(arrayHelpers: ArrayHelpers) => (
|
||||
<div className="flex gap-2 flex-wrap mb-4">
|
||||
{userGroupsIsLoading ? (
|
||||
<div className="animate-pulse bg-gray-200 h-8 w-32 rounded"></div>
|
||||
) : (
|
||||
userGroups &&
|
||||
userGroups.map((userGroup: UserGroup) => {
|
||||
const ind = groups.value.indexOf(userGroup.id);
|
||||
let isSelected = ind !== -1;
|
||||
return (
|
||||
<div
|
||||
key={userGroup.id}
|
||||
className={`
|
||||
px-3
|
||||
py-1
|
||||
rounded-lg
|
||||
border
|
||||
border-border
|
||||
w-fit
|
||||
flex
|
||||
cursor-pointer
|
||||
${isSelected ? "bg-background-strong" : "hover:bg-hover"}
|
||||
`}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
arrayHelpers.remove(ind);
|
||||
} else {
|
||||
arrayHelpers.push(userGroup.id);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="my-auto flex">
|
||||
<FiUsers className="my-auto mr-2" />{" "}
|
||||
{userGroup.name}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
<ErrorMessage
|
||||
name="groups"
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
30
web/src/components/admin/connectors/AutoSyncOptions.tsx
Normal file
30
web/src/components/admin/connectors/AutoSyncOptions.tsx
Normal file
@ -0,0 +1,30 @@
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { useFormikContext } from "formik";
|
||||
import { ValidAutoSyncSources } from "@/lib/types";
|
||||
import { Divider } from "@tremor/react";
|
||||
import { autoSyncConfigBySource } from "@/lib/connectors/AutoSyncOptionFields";
|
||||
|
||||
export function AutoSyncOptions({
|
||||
connectorType,
|
||||
}: {
|
||||
connectorType: ValidAutoSyncSources;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
||||
<Divider />
|
||||
<>
|
||||
{Object.entries(autoSyncConfigBySource[connectorType]).map(
|
||||
([key, config]) => (
|
||||
<div key={key} className="mb-4">
|
||||
<TextFormField
|
||||
name={`auto_sync_options.${key}`}
|
||||
label={config.label}
|
||||
subtext={config.subtext}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -1,382 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState } from "react";
|
||||
import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { Popup, usePopup } from "./Popup";
|
||||
import { ValidInputTypes, ValidSources } from "@/lib/types";
|
||||
import { deleteConnectorIfExistsAndIsUnlinked } from "@/lib/connector";
|
||||
import { FormBodyBuilder, RequireAtLeastOne } from "./types";
|
||||
import { BooleanFormField, TextFormField } from "./Field";
|
||||
import { createCredential, linkCredential } from "@/lib/credential";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Button, Divider } from "@tremor/react";
|
||||
import IsPublicField from "./IsPublicField";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { Connector, ConnectorBase } from "@/lib/connectors/connectors";
|
||||
|
||||
const BASE_CONNECTOR_URL = "/api/manage/admin/connector";
|
||||
|
||||
export async function submitConnector<T>(
|
||||
connector: ConnectorBase<T>,
|
||||
connectorId?: number,
|
||||
fakeCredential?: boolean,
|
||||
isPublicCcpair?: boolean // exclusively for mock credentials, when also need to specify ccpair details
|
||||
): Promise<{ message: string; isSuccess: boolean; response?: Connector<T> }> {
|
||||
const isUpdate = connectorId !== undefined;
|
||||
if (!connector.connector_specific_config) {
|
||||
connector.connector_specific_config = {} as T;
|
||||
}
|
||||
|
||||
try {
|
||||
if (fakeCredential) {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector-with-mock-credential",
|
||||
{
|
||||
method: isUpdate ? "PATCH" : "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ ...connector, is_public: isPublicCcpair }),
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
const responseJson = await response.json();
|
||||
return { message: "Success!", isSuccess: true, response: responseJson };
|
||||
} else {
|
||||
const errorData = await response.json();
|
||||
return { message: `Error: ${errorData.detail}`, isSuccess: false };
|
||||
}
|
||||
} else {
|
||||
const response = await fetch(
|
||||
BASE_CONNECTOR_URL + (isUpdate ? `/${connectorId}` : ""),
|
||||
{
|
||||
method: isUpdate ? "PATCH" : "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(connector),
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const responseJson = await response.json();
|
||||
return { message: "Success!", isSuccess: true, response: responseJson };
|
||||
} else {
|
||||
const errorData = await response.json();
|
||||
return { message: `Error: ${errorData.detail}`, isSuccess: false };
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
return { message: `Error: ${error}`, isSuccess: false };
|
||||
}
|
||||
}
|
||||
|
||||
const CCPairNameHaver = Yup.object().shape({
|
||||
cc_pair_name: Yup.string().required("Please enter a name for the connector"),
|
||||
});
|
||||
|
||||
interface BaseProps<T extends Yup.AnyObject> {
|
||||
nameBuilder: (values: T) => string;
|
||||
ccPairNameBuilder?: (values: T) => string | null;
|
||||
source: ValidSources;
|
||||
inputType: ValidInputTypes;
|
||||
// if specified, will automatically try and link the credential
|
||||
credentialId?: number;
|
||||
// If both are specified, will render formBody and then formBodyBuilder
|
||||
formBody?: JSX.Element | null;
|
||||
formBodyBuilder?: FormBodyBuilder<T>;
|
||||
validationSchema: Yup.ObjectSchema<T>;
|
||||
validate?: (values: T) => Record<string, string>;
|
||||
initialValues: T;
|
||||
onSubmit?: (
|
||||
isSuccess: boolean,
|
||||
responseJson: Connector<T> | undefined
|
||||
) => void;
|
||||
refreshFreq?: number;
|
||||
pruneFreq?: number;
|
||||
indexingStart?: Date;
|
||||
// If specified, then we will create an empty credential and associate
|
||||
// the connector with it. If credentialId is specified, then this will be ignored
|
||||
shouldCreateEmptyCredentialForConnector?: boolean;
|
||||
}
|
||||
|
||||
type ConnectorFormProps<T extends Yup.AnyObject> = RequireAtLeastOne<
|
||||
BaseProps<T>,
|
||||
"formBody" | "formBodyBuilder"
|
||||
>;
|
||||
|
||||
export function ConnectorForm<T extends Yup.AnyObject>({
|
||||
nameBuilder,
|
||||
ccPairNameBuilder,
|
||||
source,
|
||||
inputType,
|
||||
credentialId,
|
||||
formBody,
|
||||
formBodyBuilder,
|
||||
validationSchema,
|
||||
validate,
|
||||
initialValues,
|
||||
refreshFreq,
|
||||
pruneFreq,
|
||||
indexingStart,
|
||||
onSubmit,
|
||||
shouldCreateEmptyCredentialForConnector,
|
||||
}: ConnectorFormProps<T>): JSX.Element {
|
||||
const { mutate } = useSWRConfig();
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
// only show this option for EE, since groups are not supported in CE
|
||||
const showNonPublicOption = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
const shouldHaveNameInput = credentialId !== undefined && !ccPairNameBuilder;
|
||||
|
||||
const ccPairNameInitialValue = shouldHaveNameInput
|
||||
? { cc_pair_name: "" }
|
||||
: {};
|
||||
const publicOptionInitialValue = showNonPublicOption
|
||||
? { is_public: false }
|
||||
: {};
|
||||
|
||||
let finalValidationSchema =
|
||||
validationSchema as Yup.ObjectSchema<Yup.AnyObject>;
|
||||
if (shouldHaveNameInput) {
|
||||
finalValidationSchema = finalValidationSchema.concat(CCPairNameHaver);
|
||||
}
|
||||
if (showNonPublicOption) {
|
||||
finalValidationSchema = finalValidationSchema.concat(
|
||||
Yup.object().shape({
|
||||
is_public: Yup.boolean(),
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
<Formik
|
||||
initialValues={{
|
||||
...publicOptionInitialValue,
|
||||
...ccPairNameInitialValue,
|
||||
...initialValues,
|
||||
}}
|
||||
validationSchema={finalValidationSchema}
|
||||
validate={validate}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
const connectorName = nameBuilder(values);
|
||||
const connectorConfig = Object.fromEntries(
|
||||
Object.keys(initialValues).map((key) => [key, values[key]])
|
||||
) as T;
|
||||
|
||||
// best effort check to see if existing connector exists
|
||||
// delete it if:
|
||||
// 1. it exists
|
||||
// 2. AND it has no credentials linked to it
|
||||
// If the ^ are true, that means things have gotten into a bad
|
||||
// state, and we should delete the connector to recover
|
||||
const errorMsg = await deleteConnectorIfExistsAndIsUnlinked({
|
||||
source,
|
||||
name: connectorName,
|
||||
});
|
||||
if (errorMsg) {
|
||||
setPopup({
|
||||
message: `Unable to delete existing connector - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { message, isSuccess, response } = await submitConnector<T>({
|
||||
name: connectorName,
|
||||
source,
|
||||
input_type: inputType,
|
||||
connector_specific_config: connectorConfig,
|
||||
refresh_freq: refreshFreq || 0,
|
||||
prune_freq: pruneFreq ?? null,
|
||||
indexing_start: indexingStart || null,
|
||||
});
|
||||
|
||||
if (!isSuccess || !response) {
|
||||
setPopup({ message, type: "error" });
|
||||
formikHelpers.setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
let credentialIdToLinkTo = credentialId;
|
||||
// create empty credential if specified
|
||||
if (
|
||||
shouldCreateEmptyCredentialForConnector &&
|
||||
credentialIdToLinkTo === undefined
|
||||
) {
|
||||
const createCredentialResponse = await createCredential({
|
||||
credential_json: {},
|
||||
admin_public: true,
|
||||
source: source,
|
||||
});
|
||||
|
||||
if (!createCredentialResponse.ok) {
|
||||
const errorMsg = await createCredentialResponse.text();
|
||||
setPopup({
|
||||
message: `Error creating credential for CC Pair - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
formikHelpers.setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
credentialIdToLinkTo = (await createCredentialResponse.json()).id;
|
||||
}
|
||||
|
||||
if (credentialIdToLinkTo !== undefined) {
|
||||
const ccPairName = ccPairNameBuilder
|
||||
? ccPairNameBuilder(values)
|
||||
: values.cc_pair_name;
|
||||
const linkCredentialResponse = await linkCredential(
|
||||
response.id,
|
||||
credentialIdToLinkTo,
|
||||
ccPairName as string,
|
||||
values.is_public
|
||||
);
|
||||
if (!linkCredentialResponse.ok) {
|
||||
const linkCredentialErrorMsg =
|
||||
await linkCredentialResponse.text();
|
||||
setPopup({
|
||||
message: `Error linking credential - ${linkCredentialErrorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
formikHelpers.setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
mutate("/api/manage/admin/connector/indexing-status");
|
||||
setPopup({ message, type: isSuccess ? "success" : "error" });
|
||||
formikHelpers.setSubmitting(false);
|
||||
if (isSuccess) {
|
||||
formikHelpers.resetForm();
|
||||
}
|
||||
if (onSubmit) {
|
||||
onSubmit(isSuccess, response);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values }) => (
|
||||
<Form>
|
||||
{shouldHaveNameInput && (
|
||||
<TextFormField
|
||||
name="cc_pair_name"
|
||||
label="Connector Name"
|
||||
autoCompleteDisabled={true}
|
||||
subtext={`A descriptive name for the connector. This will be used to identify the connector in the Admin UI.`}
|
||||
/>
|
||||
)}
|
||||
{formBody && formBody}
|
||||
{formBodyBuilder && formBodyBuilder(values)}
|
||||
{showNonPublicOption && (
|
||||
<>
|
||||
<Divider />
|
||||
<IsPublicField />
|
||||
<Divider />
|
||||
</>
|
||||
)}
|
||||
<div className="flex">
|
||||
<Button
|
||||
type="submit"
|
||||
size="xs"
|
||||
color="green"
|
||||
disabled={isSubmitting}
|
||||
className="mx-auto w-64"
|
||||
>
|
||||
Connect
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
interface UpdateConnectorBaseProps<T extends Yup.AnyObject> {
|
||||
nameBuilder?: (values: T) => string;
|
||||
existingConnector: Connector<T>;
|
||||
// If both are specified, uses formBody
|
||||
formBody?: JSX.Element | null;
|
||||
formBodyBuilder?: FormBodyBuilder<T>;
|
||||
validationSchema: Yup.ObjectSchema<T>;
|
||||
onSubmit?: (isSuccess: boolean, responseJson?: Connector<T>) => void;
|
||||
}
|
||||
|
||||
type UpdateConnectorFormProps<T extends Yup.AnyObject> = RequireAtLeastOne<
|
||||
UpdateConnectorBaseProps<T>,
|
||||
"formBody" | "formBodyBuilder"
|
||||
>;
|
||||
|
||||
export function UpdateConnectorForm<T extends Yup.AnyObject>({
|
||||
nameBuilder,
|
||||
existingConnector,
|
||||
formBody,
|
||||
formBodyBuilder,
|
||||
validationSchema,
|
||||
onSubmit,
|
||||
}: UpdateConnectorFormProps<T>): JSX.Element {
|
||||
const [popup, setPopup] = useState<{
|
||||
message: string;
|
||||
type: "success" | "error";
|
||||
} | null>(null);
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup && <Popup message={popup.message} type={popup.type} />}
|
||||
<Formik
|
||||
initialValues={existingConnector.connector_specific_config}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
|
||||
const { message, isSuccess, response } = await submitConnector<T>(
|
||||
{
|
||||
name: nameBuilder ? nameBuilder(values) : existingConnector.name,
|
||||
source: existingConnector.source,
|
||||
input_type: existingConnector.input_type,
|
||||
connector_specific_config: values,
|
||||
refresh_freq: existingConnector.refresh_freq,
|
||||
prune_freq: existingConnector.prune_freq,
|
||||
indexing_start: existingConnector.indexing_start,
|
||||
},
|
||||
existingConnector.id
|
||||
);
|
||||
|
||||
setPopup({ message, type: isSuccess ? "success" : "error" });
|
||||
formikHelpers.setSubmitting(false);
|
||||
if (isSuccess) {
|
||||
formikHelpers.resetForm();
|
||||
}
|
||||
setTimeout(() => {
|
||||
setPopup(null);
|
||||
}, 4000);
|
||||
if (onSubmit) {
|
||||
onSubmit(isSuccess, response);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values }) => (
|
||||
<Form>
|
||||
{formBody ? formBody : formBodyBuilder && formBodyBuilder(values)}
|
||||
<div className="flex">
|
||||
<Button
|
||||
type="submit"
|
||||
color="green"
|
||||
size="xs"
|
||||
disabled={isSubmitting}
|
||||
className="mx-auto w-64"
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</>
|
||||
);
|
||||
}
|
@ -28,7 +28,6 @@ import {
|
||||
ConfluenceCredentialJson,
|
||||
Credential,
|
||||
} from "@/lib/connectors/credentials";
|
||||
import { UserGroup } from "@/lib/types"; // Added this import
|
||||
|
||||
export default function CredentialSection({
|
||||
ccPair,
|
||||
|
@ -232,7 +232,7 @@ export default function CreateCredential({
|
||||
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||
/>
|
||||
)}
|
||||
{showAdvancedOptions && (
|
||||
{(showAdvancedOptions || !isAdmin) && (
|
||||
<IsPublicGroupSelector
|
||||
formikProps={formikProps}
|
||||
objectName="credential"
|
||||
|
46
web/src/lib/connectors/AutoSyncOptionFields.tsx
Normal file
46
web/src/lib/connectors/AutoSyncOptionFields.tsx
Normal file
@ -0,0 +1,46 @@
|
||||
import { ValidAutoSyncSources } from "@/lib/types";
|
||||
|
||||
// The first key is the connector type, and the second key is the field name
|
||||
export const autoSyncConfigBySource: Record<
|
||||
ValidAutoSyncSources,
|
||||
Record<
|
||||
string,
|
||||
{
|
||||
label: string;
|
||||
subtext: JSX.Element;
|
||||
}
|
||||
>
|
||||
> = {
|
||||
google_drive: {
|
||||
customer_id: {
|
||||
label: "Google Workspace Customer ID",
|
||||
subtext: (
|
||||
<>
|
||||
The unique identifier for your Google Workspace account. To find this,
|
||||
checkout the{" "}
|
||||
<a
|
||||
href="https://support.google.com/cloudidentity/answer/10070793"
|
||||
target="_blank"
|
||||
className="text-link"
|
||||
>
|
||||
guide from Google
|
||||
</a>
|
||||
.
|
||||
</>
|
||||
),
|
||||
},
|
||||
company_domain: {
|
||||
label: "Google Workspace Company Domain",
|
||||
subtext: (
|
||||
<>
|
||||
The email domain for your Google Workspace account.
|
||||
<br />
|
||||
<br />
|
||||
For example, if your email provided through Google Workspace looks
|
||||
something like chris@danswer.ai, then your company domain is{" "}
|
||||
<b>danswer.ai</b>
|
||||
</>
|
||||
),
|
||||
},
|
||||
},
|
||||
};
|
@ -29,6 +29,7 @@ export interface Option {
|
||||
export interface SelectOption extends Option {
|
||||
type: "select";
|
||||
options?: StringWithDescription[];
|
||||
default?: string;
|
||||
}
|
||||
|
||||
export interface ListOption extends Option {
|
||||
@ -599,7 +600,7 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
|
||||
{ name: "articles", value: "articles" },
|
||||
{ name: "tickets", value: "tickets" },
|
||||
],
|
||||
default: 0,
|
||||
default: "articles",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { CredentialBase } from "./connectors/credentials";
|
||||
import { AccessType } from "@/lib/types";
|
||||
|
||||
export async function createCredential(credential: CredentialBase<any>) {
|
||||
return await fetch(`/api/manage/credential`, {
|
||||
@ -47,8 +48,9 @@ export function linkCredential(
|
||||
connectorId: number,
|
||||
credentialId: number,
|
||||
name?: string,
|
||||
isPublic?: boolean,
|
||||
groups?: number[]
|
||||
accessType?: AccessType,
|
||||
groups?: number[],
|
||||
autoSyncOptions?: Record<string, any>
|
||||
) {
|
||||
return fetch(
|
||||
`/api/manage/connector/${connectorId}/credential/${credentialId}`,
|
||||
@ -59,8 +61,9 @@ export function linkCredential(
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: name || null,
|
||||
is_public: isPublic !== undefined ? isPublic : true,
|
||||
access_type: accessType !== undefined ? accessType : "public",
|
||||
groups: groups || null,
|
||||
auto_sync_options: autoSyncOptions || null,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
@ -49,6 +49,7 @@ export type ValidStatuses =
|
||||
| "not_started";
|
||||
export type TaskStatus = "PENDING" | "STARTED" | "SUCCESS" | "FAILURE";
|
||||
export type Feedback = "like" | "dislike";
|
||||
export type AccessType = "public" | "private" | "sync";
|
||||
export type SessionType = "Chat" | "Search" | "Slack";
|
||||
|
||||
export interface DocumentBoostStatus {
|
||||
@ -90,7 +91,7 @@ export interface ConnectorIndexingStatus<
|
||||
cc_pair_status: ConnectorCredentialPairStatus;
|
||||
connector: Connector<ConnectorConfigType>;
|
||||
credential: Credential<ConnectorCredentialType>;
|
||||
public_doc: boolean;
|
||||
access_type: AccessType;
|
||||
owner: string;
|
||||
groups: number[];
|
||||
last_finished_status: ValidStatuses | null;
|
||||
@ -258,3 +259,7 @@ export type ConfigurableSources = Exclude<
|
||||
ValidSources,
|
||||
"not_applicable" | "ingestion_api"
|
||||
>;
|
||||
|
||||
// The sources that have auto-sync support on the backend
|
||||
export const validAutoSyncSources = ["google_drive"] as const;
|
||||
export type ValidAutoSyncSources = (typeof validAutoSyncSources)[number];
|
||||
|
Loading…
x
Reference in New Issue
Block a user