diff --git a/backend/alembic/versions/3b25685ff73c_move_is_public_to_cc_pair.py b/backend/alembic/versions/3b25685ff73c_move_is_public_to_cc_pair.py new file mode 100644 index 000000000..937d926e4 --- /dev/null +++ b/backend/alembic/versions/3b25685ff73c_move_is_public_to_cc_pair.py @@ -0,0 +1,49 @@ +"""Move is_public to cc_pair + +Revision ID: 3b25685ff73c +Revises: e0a68a81d434 +Create Date: 2023-10-05 18:47:09.582849 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "3b25685ff73c" +down_revision = "e0a68a81d434" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "connector_credential_pair", + sa.Column("is_public", sa.Boolean(), nullable=True), + ) + # fill in is_public for existing rows + op.execute( + "UPDATE connector_credential_pair SET is_public = true WHERE is_public IS NULL" + ) + op.alter_column("connector_credential_pair", "is_public", nullable=False) + + op.add_column( + "credential", + sa.Column("is_admin", sa.Boolean(), nullable=True), + ) + op.execute("UPDATE credential SET is_admin = true WHERE is_admin IS NULL") + op.alter_column("credential", "is_admin", nullable=False) + + op.drop_column("credential", "public_doc") + + +def downgrade() -> None: + op.add_column( + "credential", + sa.Column("public_doc", sa.Boolean(), nullable=True), + ) + # setting public_doc to false for all existing rows to be safe + # NOTE: this is likely not the correct state of the world but it's the best we can do + op.execute("UPDATE credential SET public_doc = false WHERE public_doc IS NULL") + op.alter_column("credential", "public_doc", nullable=False) + op.drop_column("connector_credential_pair", "is_public") + op.drop_column("credential", "is_admin") diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 586ed2d8d..628254eee 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -287,13 +287,9 @@ def _run_indexing( f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}" ) - index_user_id = ( - None if db_credential.public_doc else db_credential.user_id - ) new_docs, total_batch_chunks = indexing_pipeline( documents=doc_batch, index_attempt_metadata=IndexAttemptMetadata( - user_id=index_user_id, connector_id=db_connector.id, credential_id=db_credential.id, ), diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index 02d489c1d..c4924c0cc 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -130,7 +130,7 @@ def build_service_account_creds( return CredentialBase( credential_json=credential_dict, - public_doc=True, + is_admin=True, ) diff --git a/backend/danswer/connectors/models.py b/backend/danswer/connectors/models.py index 7bd0c6834..65213fc95 100644 --- a/backend/danswer/connectors/models.py +++ b/backend/danswer/connectors/models.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from enum import Enum from typing import Any -from uuid import UUID from danswer.configs.constants import DocumentSource @@ -41,6 +40,5 @@ class InputType(str, Enum): @dataclass class IndexAttemptMetadata: - user_id: UUID | None connector_id: int credential_id: int diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index 0c2446fe7..96f491518 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -1,5 +1,6 @@ from typing import Any +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy.orm import Session from sqlalchemy.sql.expression import or_ @@ -19,18 +20,30 @@ from danswer.utils.logger import setup_logger logger = setup_logger() +def _attach_user_filters(stmt: Select[tuple[Credential]], user: User | None) -> Select: + """Attaches filters to the statement to ensure that the user can only + access the appropriate credentials""" + if user: + if user.role == UserRole.ADMIN: + stmt = stmt.where( + or_( + Credential.user_id == user.id, + Credential.user_id.is_(None), + Credential.is_admin == True, # noqa: E712 + ) + ) + else: + stmt = stmt.where(Credential.user_id == user.id) + + return stmt + + def fetch_credentials( db_session: Session, user: User | None = None, - public_only: bool | None = None, ) -> list[Credential]: stmt = select(Credential) - if user: - stmt = stmt.where( - or_(Credential.user_id == user.id, Credential.user_id.is_(None)) - ) - if public_only is not None: - stmt = stmt.where(Credential.public_doc == public_only) + stmt = _attach_user_filters(stmt, user) results = db_session.scalars(stmt) return list(results.all()) @@ -39,20 +52,7 @@ def fetch_credential_by_id( credential_id: int, user: User | None, db_session: Session ) -> Credential | None: stmt = select(Credential).where(Credential.id == credential_id) - if user: - # admins have access to all public credentials + credentials they own - if user.role == UserRole.ADMIN: - stmt = stmt.where( - or_( - Credential.user_id == user.id, - Credential.user_id.is_(None), - Credential.public_doc == True, # noqa: E712 - ) - ) - else: - stmt = stmt.where( - or_(Credential.user_id == user.id, Credential.user_id.is_(None)) - ) + stmt = _attach_user_filters(stmt, user) result = db_session.execute(stmt) credential = result.scalar_one_or_none() return credential @@ -60,13 +60,13 @@ def fetch_credential_by_id( def create_credential( credential_data: CredentialBase, - user: User, + user: User | None, db_session: Session, ) -> ObjectCreationIdResponse: credential = Credential( credential_json=credential_data.credential_json, user_id=user.id if user else None, - public_doc=credential_data.public_doc, + is_admin=credential_data.is_admin, ) db_session.add(credential) db_session.commit() @@ -86,7 +86,6 @@ def update_credential( credential.credential_json = credential_data.credential_json credential.user_id = user.id if user is not None else None - credential.public_doc = credential_data.public_doc db_session.commit() return credential @@ -144,13 +143,15 @@ def create_initial_public_credential() -> None: if first_credential is not None: if ( first_credential.credential_json != {} - or first_credential.public_doc is False + or first_credential.user is not None ): raise ValueError(error_msg) return credential = Credential( - id=public_cred_id, credential_json={}, user_id=None, public_doc=True + id=public_cred_id, + credential_json={}, + user_id=None, ) db_session.add(credential) db_session.commit() diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index c88e9f945..c81a4bd88 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import Session from danswer.configs.constants import DEFAULT_BOOST from danswer.datastores.interfaces import DocumentMetadata from danswer.db.feedback import delete_document_feedback_for_documents +from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential from danswer.db.models import Document as DbDocument from danswer.db.models import DocumentByConnectorCredentialPair @@ -69,7 +70,7 @@ def get_acccess_info_for_documents( stmt = select( DocumentByConnectorCredentialPair.id, func.array_agg(Credential.user_id).label("user_ids"), - func.bool_or(Credential.public_doc).label("public_doc"), + func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"), ).where(DocumentByConnectorCredentialPair.id.in_(document_ids)) # pretend that the specified cc pair doesn't exist @@ -83,10 +84,22 @@ def get_acccess_info_for_documents( ) ) - stmt = stmt.join( - Credential, - DocumentByConnectorCredentialPair.credential_id == Credential.id, - ).group_by(DocumentByConnectorCredentialPair.id) + stmt = ( + stmt.join( + Credential, + DocumentByConnectorCredentialPair.credential_id == Credential.id, + ) + .join( + ConnectorCredentialPair, + and_( + DocumentByConnectorCredentialPair.connector_id + == ConnectorCredentialPair.connector_id, + DocumentByConnectorCredentialPair.credential_id + == ConnectorCredentialPair.credential_id, + ), + ) + .group_by(DocumentByConnectorCredentialPair.id) + ) return db_session.execute(stmt).all() # type: ignore diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 09bdad108..5ad9d8b44 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -144,6 +144,14 @@ class ConnectorCredentialPair(Base): credential_id: Mapped[int] = mapped_column( ForeignKey("credential.id"), primary_key=True ) + # controls whether the documents indexed by this CC pair are visible to all + # or if they are only visible to those with that are given explicit access + # (e.g. via owning the credential or being a part of a group that is given access) + is_public: Mapped[bool] = mapped_column( + Boolean, + default=True, + nullable=False, + ) # Time finished, not used for calculating backend jobs which uses time started (created) last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None @@ -206,7 +214,8 @@ class Credential(Base): id: Mapped[int] = mapped_column(primary_key=True) credential_json: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB()) user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) - public_doc: Mapped[bool] = mapped_column(Boolean, default=False) + # if `true`, then all Admins will have access to the credential + is_admin: Mapped[bool] = mapped_column(Boolean, default=True) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) diff --git a/backend/danswer/server/credential.py b/backend/danswer/server/credential.py index 591ff3882..1421d2fbe 100644 --- a/backend/danswer/server/credential.py +++ b/backend/danswer/server/credential.py @@ -3,6 +3,7 @@ from fastapi import Depends from fastapi import HTTPException from sqlalchemy.orm import Session +from danswer.auth.schemas import UserRole from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.db.credentials import create_credential @@ -26,11 +27,11 @@ router = APIRouter(prefix="/manage") @router.get("/admin/credential") def list_credentials_admin( - _: User = Depends(current_admin_user), + user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[CredentialSnapshot]: """Lists all public credentials""" - credentials = fetch_credentials(db_session=db_session, public_only=True) + credentials = fetch_credentials(db_session=db_session, user=user) return [ CredentialSnapshot.from_credential_db_model(credential) for credential in credentials @@ -65,6 +66,21 @@ def list_credentials( ] +@router.post("/credential") +def create_credential_from_model( + credential_info: CredentialBase, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> ObjectCreationIdResponse: + if user and user.role != UserRole.ADMIN: + raise HTTPException( + status_code=400, + detail="Non-admin cannot create admin credential", + ) + + return create_credential(credential_info, user, db_session) + + @router.get("/credential/{credential_id}") def get_credential_by_id( credential_id: int, @@ -81,15 +97,6 @@ def get_credential_by_id( return CredentialSnapshot.from_credential_db_model(credential) -@router.post("/credential") -def create_credential_from_model( - connector_info: CredentialBase, - user: User = Depends(current_user), - db_session: Session = Depends(get_session), -) -> ObjectCreationIdResponse: - return create_credential(connector_info, user, db_session) - - @router.patch("/credential/{credential_id}") def update_credential_from_model( credential_id: int, @@ -110,7 +117,7 @@ def update_credential_from_model( id=updated_credential.id, credential_json=updated_credential.credential_json, user_id=updated_credential.user_id, - public_doc=updated_credential.public_doc, + is_admin=updated_credential.is_admin, time_created=updated_credential.time_created, time_updated=updated_credential.time_updated, ) diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index 82dc35b9d..d8c1d0595 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -215,7 +215,7 @@ def delete_google_service_account_key( @router.put("/admin/connector/google-drive/service-account-credential") def upsert_service_account_credential( service_account_credential_request: GoogleServiceAccountCredentialRequest, - user: User = Depends(current_admin_user), + user: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: """Special API which allows the creation of a credential for a service account. @@ -225,12 +225,12 @@ def upsert_service_account_credential( credential_base = build_service_account_creds( delegated_user_email=service_account_credential_request.google_drive_delegated_user ) - print(credential_base) except ConfigNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) # first delete all existing service account credentials delete_google_drive_service_account_credentials(user, db_session) + # `user=None` since this credential is not a personal credential return create_credential( credential_data=credential_base, user=user, db_session=db_session ) @@ -322,7 +322,7 @@ def get_connector_indexing_status( name=cc_pair.name, connector=ConnectorSnapshot.from_connector_db_model(connector), credential=CredentialSnapshot.from_credential_db_model(credential), - public_doc=credential.public_doc, + public_doc=cc_pair.is_public, owner=credential.user.email if credential.user else "", last_status=cc_pair.last_attempt_status, last_success=cc_pair.last_successful_index_time, diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index d4e3c762c..ff65590fa 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -336,7 +336,7 @@ class RunConnectorRequest(BaseModel): class CredentialBase(BaseModel): credential_json: dict[str, Any] - public_doc: bool + is_admin: bool class CredentialSnapshot(CredentialBase): @@ -353,7 +353,7 @@ class CredentialSnapshot(CredentialBase): if MASK_CREDENTIAL_PREFIX else credential.credential_json, user_id=credential.user_id, - public_doc=credential.public_doc, + is_admin=credential.is_admin, time_created=credential.time_created, time_updated=credential.time_updated, ) diff --git a/web/src/app/admin/connectors/google-drive/page.tsx b/web/src/app/admin/connectors/google-drive/page.tsx index 6107bcf94..cb2af82df 100644 --- a/web/src/app/admin/connectors/google-drive/page.tsx +++ b/web/src/app/admin/connectors/google-drive/page.tsx @@ -321,7 +321,8 @@ const Main = () => { | Credential | undefined = credentialsData.find( (credential) => - credential.credential_json?.google_drive_tokens && credential.public_doc + credential.credential_json?.google_drive_tokens && + credential.user_id === null ); const googleDriveServiceAccountCredential: | Credential diff --git a/web/src/app/admin/indexing/status/page.tsx b/web/src/app/admin/indexing/status/page.tsx index 07963d0d4..4fb7382af 100644 --- a/web/src/app/admin/indexing/status/page.tsx +++ b/web/src/app/admin/indexing/status/page.tsx @@ -152,7 +152,7 @@ function Main() { ), diff --git a/web/src/app/user/connectors/GoogleDriveCard.tsx b/web/src/app/user/connectors/GoogleDriveCard.tsx index ee304779d..1bc97e4c3 100644 --- a/web/src/app/user/connectors/GoogleDriveCard.tsx +++ b/web/src/app/user/connectors/GoogleDriveCard.tsx @@ -20,8 +20,9 @@ export const GoogleDriveCard = ({ const existingCredential: Credential | undefined = userCredentials?.find( (credential) => + // user_id is set => credential is not a public credential credential.credential_json?.google_drive_tokens !== undefined && - !credential.public_doc + credential.user_id !== null ); const credentialIsLinked = diff --git a/web/src/components/admin/connectors/CredentialForm.tsx b/web/src/components/admin/connectors/CredentialForm.tsx index ebc11be09..77d89b000 100644 --- a/web/src/components/admin/connectors/CredentialForm.tsx +++ b/web/src/components/admin/connectors/CredentialForm.tsx @@ -57,7 +57,7 @@ export function CredentialForm({ formikHelpers.setSubmitting(true); submitCredential({ credential_json: values, - public_doc: true, + is_admin: true, }).then(({ message, isSuccess }) => { setPopup({ message, type: isSuccess ? "success" : "error" }); formikHelpers.setSubmitting(false); diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 0e265a8b9..e107954ec 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -139,7 +139,7 @@ export interface ConnectorIndexingStatus< // CREDENTIALS export interface CredentialBase { credential_json: T; - public_doc: boolean; + is_admin: boolean; } export interface Credential extends CredentialBase {