From c042a19c0063aba0cc88efe0fb4b8cb37e91453b Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 22 Aug 2024 18:39:37 -0700 Subject: [PATCH] Curator role (#2166) * Added backend support for curator role * modal refactor * finalized first 2 commits same as before finally what was it for * added credential, cc_pair, and cleanup mypy is super helpful hahahahahahahahahahahaha * curator support for personas * added connector management permission checks * fixed the connector creation flow * added document access to curator * small cleanup added comments and started ui * groups and assistant editor * Persona frontend * Document set frontend * cleaned up the entire frontend * alembic fix * Minor fixes * credentials section * some credential updates * removed logging statements * fixed try catch * fixed model name * made everything happen in one db commit * Final cleanup * cleaned up fast code * mypy/build fixes * polish * more token rate limit polish * fixed weird credential permissions * Addressed chris feedback * addressed pablo feedback * fixed alembic * removed deduping and caching * polish!!!! --- .../351faebd379d_add_curator_fields.py | 90 +++++ backend/danswer/auth/schemas.py | 12 + backend/danswer/auth/users.py | 41 ++- .../slack/handlers/handle_regular_answer.py | 7 +- backend/danswer/db/connector.py | 4 +- .../danswer/db/connector_credential_pair.py | 122 ++++++- backend/danswer/db/credentials.py | 130 ++++++- backend/danswer/db/document_set.py | 62 +++- backend/danswer/db/feedback.py | 113 +++++- backend/danswer/db/llm.py | 8 +- backend/danswer/db/models.py | 23 ++ backend/danswer/db/persona.py | 244 +++++++------ backend/danswer/db/users.py | 32 +- backend/danswer/server/auth_check.py | 2 + backend/danswer/server/documents/cc_pair.py | 65 +++- backend/danswer/server/documents/connector.py | 113 ++++-- .../danswer/server/documents/credential.py | 83 +++-- backend/danswer/server/documents/models.py | 21 +- .../server/features/document_set/api.py | 22 +- .../danswer/server/features/persona/api.py | 15 +- .../danswer/server/manage/administrative.py | 57 +-- backend/danswer/server/manage/models.py | 5 + backend/danswer/server/manage/users.py | 59 ++- .../server/query_and_chat/query_backend.py | 4 +- backend/ee/danswer/db/token_limit.py | 84 ++++- backend/ee/danswer/db/user_group.py | 149 +++++++- .../danswer/server/token_rate_limits/api.py | 9 +- backend/ee/danswer/server/user_group/api.py | 44 ++- .../ee/danswer/server/user_group/models.py | 11 + .../app/admin/assistants/AssistantEditor.tsx | 246 ++++++------- web/src/app/admin/assistants/PersonaTable.tsx | 100 ++++-- web/src/app/admin/assistants/page.tsx | 20 +- web/src/app/admin/connector/[ccPairId]/lib.ts | 8 +- .../app/admin/connector/[ccPairId]/page.tsx | 82 +++-- .../app/admin/connector/[ccPairId]/types.ts | 2 + .../[connector]/AddConnectorPage.tsx | 34 +- .../connectors/[connector]/pages/Advanced.tsx | 6 +- .../connectors/[connector]/pages/Create.tsx | 154 +++++++- .../[connector]/pages/gdrive/Credential.tsx | 87 +++-- .../pages/gdrive/GoogleDrivePage.tsx | 54 ++- .../[connector]/pages/gmail/Credential.tsx | 85 +++-- .../[connector]/pages/gmail/GmailPage.tsx | 50 ++- .../[connector]/pages/utils/files.ts | 10 +- .../sets/DocumentSetCreationForm.tsx | 129 ++----- web/src/app/admin/documents/sets/hooks.tsx | 16 +- web/src/app/admin/documents/sets/lib.ts | 2 +- web/src/app/admin/documents/sets/page.tsx | 146 ++++++-- .../status/CCPairIndexingStatusTable.tsx | 66 +++- web/src/app/admin/indexing/status/page.tsx | 27 +- .../prompt-library/modals/AddPromptModal.tsx | 2 +- .../prompt-library/modals/EditPromptModal.tsx | 2 +- .../TokenRateLimitTables.tsx | 67 +++- web/src/app/chat/modal/DeleteChatModal.tsx | 2 +- web/src/app/chat/modal/FeedbackModal.tsx | 2 +- .../app/chat/modal/SetDefaultModelModal.tsx | 4 +- .../app/chat/modal/ShareChatSessionModal.tsx | 2 +- .../sessionSidebar/ChatSessionDisplay.tsx | 1 - .../ee/admin/api-key/DanswerApiKeyForm.tsx | 6 +- .../admin/groups/[groupId]/GroupDisplay.tsx | 257 ++++++++++---- web/src/app/ee/admin/groups/[groupId]/lib.ts | 16 +- web/src/app/ee/admin/groups/page.tsx | 36 +- web/src/app/ee/admin/groups/types.ts | 5 + web/src/components/IsPublicGroupSelector.tsx | 155 ++++++++ web/src/components/UserDropdown.tsx | 19 +- web/src/components/admin/ClientLayout.tsx | 336 ++++++++++-------- web/src/components/admin/Layout.tsx | 4 +- .../admin/connectors/CredentialForm.tsx | 2 + .../admin/users/SignedUpUserTable.tsx | 149 +++++--- .../credentials/CredentialSection.tsx | 10 +- .../credentials/actions/CreateCredential.tsx | 130 +++++-- .../credentials/actions/ModifyCredential.tsx | 36 +- .../components/modals/GenericConfirmModal.tsx | 42 +++ .../modals}/ModalWrapper.tsx | 0 web/src/components/table/DraggableRow.tsx | 14 +- web/src/components/table/DraggableTable.tsx | 41 ++- .../lib/admin/users/userMutationFetcher.ts | 6 +- web/src/lib/ccPair.ts | 27 +- web/src/lib/connectors/connectors.ts | 2 + web/src/lib/connectors/credentials.ts | 3 +- web/src/lib/credential.ts | 4 +- web/src/lib/types.ts | 8 +- web/src/lib/user.ts | 1 - 82 files changed, 3141 insertions(+), 1205 deletions(-) create mode 100644 backend/alembic/versions/351faebd379d_add_curator_fields.py create mode 100644 web/src/components/IsPublicGroupSelector.tsx create mode 100644 web/src/components/modals/GenericConfirmModal.tsx rename web/src/{app/chat/modal => components/modals}/ModalWrapper.tsx (100%) diff --git a/backend/alembic/versions/351faebd379d_add_curator_fields.py b/backend/alembic/versions/351faebd379d_add_curator_fields.py new file mode 100644 index 000000000..ae4e03bd4 --- /dev/null +++ b/backend/alembic/versions/351faebd379d_add_curator_fields.py @@ -0,0 +1,90 @@ +"""Add curator fields + +Revision ID: 351faebd379d +Revises: ee3f4b47fad5 +Create Date: 2024-08-15 22:37:08.397052 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "351faebd379d" +down_revision = "ee3f4b47fad5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add is_curator column to User__UserGroup table + op.add_column( + "user__user_group", + sa.Column("is_curator", sa.Boolean(), nullable=False, server_default="false"), + ) + + # Use batch mode to modify the enum type + with op.batch_alter_table("user", schema=None) as batch_op: + batch_op.alter_column( # type: ignore[attr-defined] + "role", + type_=sa.Enum( + "BASIC", + "ADMIN", + "CURATOR", + "GLOBAL_CURATOR", + name="userrole", + native_enum=False, + ), + existing_type=sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False), + existing_nullable=False, + ) + # Create the association table + op.create_table( + "credential__user_group", + sa.Column("credential_id", sa.Integer(), nullable=False), + sa.Column("user_group_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["credential_id"], + ["credential.id"], + ), + sa.ForeignKeyConstraint( + ["user_group_id"], + ["user_group.id"], + ), + sa.PrimaryKeyConstraint("credential_id", "user_group_id"), + ) + op.add_column( + "credential", + sa.Column( + "curator_public", sa.Boolean(), nullable=False, server_default="false" + ), + ) + + +def downgrade() -> None: + # Update existing records to ensure they fit within the BASIC/ADMIN roles + op.execute( + "UPDATE \"user\" SET role = 'ADMIN' WHERE role IN ('CURATOR', 'GLOBAL_CURATOR')" + ) + + # Remove is_curator column from User__UserGroup table + op.drop_column("user__user_group", "is_curator") + + with op.batch_alter_table("user", schema=None) as batch_op: + batch_op.alter_column( # type: ignore[attr-defined] + "role", + type_=sa.Enum( + "BASIC", "ADMIN", name="userrole", native_enum=False, length=20 + ), + existing_type=sa.Enum( + "BASIC", + "ADMIN", + "CURATOR", + "GLOBAL_CURATOR", + name="userrole", + native_enum=False, + ), + existing_nullable=False, + ) + # Drop the association table + op.drop_table("credential__user_group") + op.drop_column("credential", "curator_public") diff --git a/backend/danswer/auth/schemas.py b/backend/danswer/auth/schemas.py index 79d9a7f80..9e0553991 100644 --- a/backend/danswer/auth/schemas.py +++ b/backend/danswer/auth/schemas.py @@ -5,8 +5,20 @@ from fastapi_users import schemas class UserRole(str, Enum): + """ + User roles + - Basic can't perform any admin actions + - Admin can perform all admin actions + - Curator can perform admin actions for + groups they are curators of + - Global Curator can perform admin actions + for all groups they are a member of + """ + BASIC = "basic" ADMIN = "admin" + CURATOR = "curator" + GLOBAL_CURATOR = "global_curator" class UserStatus(str, Enum): diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index bda3b3a32..ef6e0be1b 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -67,6 +67,23 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() +def validate_curator_request(groups: list | None, is_public: bool) -> None: + if is_public: + detail = "User does not have permission to create public credentials" + logger.error(detail) + raise HTTPException( + status_code=401, + detail=detail, + ) + if not groups: + detail = "Curators must specify 1+ groups" + logger.error(detail) + raise HTTPException( + status_code=401, + detail="Curators must specify 1+ groups", + ) + + def is_user_admin(user: User | None) -> bool: if AUTH_TYPE == AuthType.DISABLED: return True @@ -395,6 +412,28 @@ async def current_user( return await double_check_user(user) +async def current_curator_or_admin_user( + user: User | None = Depends(current_user), +) -> User | None: + if DISABLE_AUTH: + return None + + if not user or not hasattr(user, "role"): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. User is not authenticated or lacks role information.", + ) + + allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN} + if user.role not in allowed_roles: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Access denied. User is not a curator or admin.", + ) + + return user + + async def current_admin_user(user: User | None = Depends(current_user)) -> User | None: if DISABLE_AUTH: return None @@ -402,7 +441,7 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Access denied. User is not an admin.", + detail="Access denied. User must be an admin to perform this action.", ) return user diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index b7a3b818c..b2eda3661 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -146,7 +146,12 @@ def handle_regular_answer( if len(new_message_request.messages) > 1: persona = cast( Persona, - fetch_persona_by_id(db_session, new_message_request.persona_id), + fetch_persona_by_id( + db_session, + new_message_request.persona_id, + user=None, + get_editable=False, + ), ) llm, _ = get_llms_for_persona(persona) diff --git a/backend/danswer/db/connector.py b/backend/danswer/db/connector.py index 3c06cd69e..89e697710 100644 --- a/backend/danswer/db/connector.py +++ b/backend/danswer/db/connector.py @@ -75,8 +75,8 @@ def fetch_ingestion_connector_by_name( def create_connector( - connector_data: ConnectorBase, db_session: Session, + connector_data: ConnectorBase, ) -> ObjectCreationIdResponse: if connector_by_name_source_exists( connector_data.name, connector_data.source, db_session @@ -132,8 +132,8 @@ def update_connector( def delete_connector( - connector_id: int, db_session: Session, + connector_id: int, ) -> StatusResponse[int]: """Only used in special cases (e.g. a connector is in a bad state and we need to delete it). Be VERY careful using this, as it could lead to a bad state if not used correctly. diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index ed1742b2f..7bf3324a7 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -3,7 +3,10 @@ from datetime import datetime from fastapi import HTTPException from sqlalchemy import delete from sqlalchemy import desc +from sqlalchemy import exists +from sqlalchemy import Select from sqlalchemy import select +from sqlalchemy.orm import aliased from sqlalchemy.orm import Session from danswer.configs.constants import DocumentSource @@ -16,16 +19,74 @@ from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus from danswer.db.models import IndexModelStatus from danswer.db.models import User +from danswer.db.models import User__UserGroup +from danswer.db.models import UserGroup__ConnectorCredentialPair +from danswer.db.models import UserRole from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger logger = setup_logger() +def _add_user_filters( + stmt: Select, user: User | None, get_editable: bool = True +) -> Select: + # If user is None, assume the user is an admin or auth is disabled + if user is None or user.role == UserRole.ADMIN: + return stmt + + UG__CCpair = aliased(UserGroup__ConnectorCredentialPair) + User__UG = aliased(User__UserGroup) + + """ + Here we select cc_pairs by relation: + User -> User__UserGroup -> UserGroup__ConnectorCredentialPair -> + ConnectorCredentialPair + """ + stmt = stmt.outerjoin(UG__CCpair).outerjoin( + User__UG, + User__UG.user_group_id == UG__CCpair.user_group_id, + ) + + """ + Filter cc_pairs by: + - if the user is in the user_group that owns the cc_pair + - if the user is not a global_curator, they must also have a curator relationship + to the user_group + - if editing is being done, we also filter out cc_pairs that are owned by groups + that the user isn't a curator for + - if we are not editing, we show all cc_pairs in the groups the user is a curator + for (as well as public cc_pairs) + """ + where_clause = User__UG.user_id == user.id + if user.role == UserRole.CURATOR and get_editable: + where_clause &= User__UG.is_curator == True # noqa: E712 + if get_editable: + user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id) + if user.role == UserRole.CURATOR: + user_groups = user_groups.where( + User__UserGroup.is_curator == True # noqa: E712 + ) + where_clause &= ( + ~exists() + .where(UG__CCpair.cc_pair_id == ConnectorCredentialPair.id) + .where(~UG__CCpair.user_group_id.in_(user_groups)) + .correlate(ConnectorCredentialPair) + ) + else: + where_clause |= ConnectorCredentialPair.is_public == True # noqa: E712 + + return stmt.where(where_clause) + + def get_connector_credential_pairs( - db_session: Session, include_disabled: bool = True + db_session: Session, + include_disabled: bool = True, + user: User | None = None, + get_editable: bool = True, ) -> list[ConnectorCredentialPair]: stmt = select(ConnectorCredentialPair) + stmt = _add_user_filters(stmt, user, get_editable) if not include_disabled: stmt = stmt.where( ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE @@ -38,8 +99,11 @@ def get_connector_credential_pair( connector_id: int, credential_id: int, db_session: Session, + user: User | None = None, + get_editable: bool = True, ) -> ConnectorCredentialPair | None: stmt = select(ConnectorCredentialPair) + stmt = _add_user_filters(stmt, user, get_editable) stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id) stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id) result = db_session.execute(stmt) @@ -49,8 +113,11 @@ def get_connector_credential_pair( def get_connector_credential_source_from_id( cc_pair_id: int, db_session: Session, + user: User | None = None, + get_editable: bool = True, ) -> DocumentSource | None: stmt = select(ConnectorCredentialPair) + stmt = _add_user_filters(stmt, user, get_editable) stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id) result = db_session.execute(stmt) cc_pair = result.scalar_one_or_none() @@ -60,8 +127,11 @@ def get_connector_credential_source_from_id( def get_connector_credential_pair_from_id( cc_pair_id: int, db_session: Session, + user: User | None = None, + get_editable: bool = True, ) -> ConnectorCredentialPair | None: - stmt = select(ConnectorCredentialPair) + stmt = select(ConnectorCredentialPair).distinct() + stmt = _add_user_filters(stmt, user, get_editable) stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id) result = db_session.execute(stmt) return result.scalar_one_or_none() @@ -217,14 +287,28 @@ def associate_default_cc_pair(db_session: Session) -> None: db_session.commit() +def _relate_groups_to_cc_pair__no_commit( + db_session: Session, + cc_pair_id: int, + user_group_ids: list[int], +) -> None: + for group_id in user_group_ids: + db_session.add( + UserGroup__ConnectorCredentialPair( + user_group_id=group_id, cc_pair_id=cc_pair_id + ) + ) + + def add_credential_to_connector( + db_session: Session, + user: User | None, connector_id: int, credential_id: int, cc_pair_name: str | None, is_public: bool, - user: User | None, - db_session: Session, -) -> StatusResponse[int]: + groups: list[int] | None, +) -> StatusResponse: connector = fetch_connector_by_id(connector_id, db_session) credential = fetch_credential_by_id(credential_id, user, db_session) @@ -260,12 +344,21 @@ def add_credential_to_connector( is_public=is_public, ) db_session.add(association) + db_session.flush() # make sure the association has an id + + if groups: + _relate_groups_to_cc_pair__no_commit( + db_session=db_session, + cc_pair_id=association.id, + user_group_ids=groups, + ) + db_session.commit() return StatusResponse( - success=True, - message=f"New Credential {credential_id} added to Connector", - data=connector_id, + success=False, + message=f"Connector already has Credential {credential_id}", + data=association.id, ) @@ -287,13 +380,12 @@ def remove_credential_from_connector( detail="Credential does not exist or does not belong to user", ) - association = ( - db_session.query(ConnectorCredentialPair) - .filter( - ConnectorCredentialPair.connector_id == connector_id, - ConnectorCredentialPair.credential_id == credential_id, - ) - .one_or_none() + association = get_connector_credential_pair( + connector_id=connector_id, + credential_id=credential_id, + db_session=db_session, + user=user, + get_editable=True, ) if association is not None: diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index 1d27e9c43..cf9af2c2e 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -1,5 +1,6 @@ from typing import Any +from sqlalchemy import exists from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import update @@ -17,8 +18,10 @@ from danswer.connectors.google_drive.constants import ( ) from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential +from danswer.db.models import Credential__UserGroup from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.models import User +from danswer.db.models import User__UserGroup from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import CredentialDataUpdateRequest from danswer.utils.logger import setup_logger @@ -26,42 +29,122 @@ from danswer.utils.logger import setup_logger logger = setup_logger() +# The credentials for these sources are not real so +# permissions are not enforced for them +CREDENTIAL_PERMISSIONS_TO_IGNORE = { + DocumentSource.FILE, + DocumentSource.WEB, + DocumentSource.NOT_APPLICABLE, + DocumentSource.GOOGLE_SITES, + DocumentSource.WIKIPEDIA, + DocumentSource.MEDIAWIKI, +} -def _attach_user_filters( - stmt: Select[tuple[Credential]], + +def _add_user_filters( + stmt: Select, user: User | None, assume_admin: bool = False, # Used with API key + get_editable: bool = True, ) -> Select: """Attaches filters to the statement to ensure that the user can only access the appropriate credentials""" - if user: - if user.role == UserRole.ADMIN: + if not user: + if assume_admin: + # apply admin filters minus the user_id check stmt = stmt.where( or_( - Credential.user_id == user.id, Credential.user_id.is_(None), Credential.admin_public == True, # noqa: E712 + Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE), ) ) - else: - stmt = stmt.where(Credential.user_id == user.id) - elif assume_admin: - stmt = stmt.where( + return stmt + + if user.role == UserRole.ADMIN: + # Admins can access all credentials that are public or owned by them + # or are not associated with any user + return stmt.where( or_( + Credential.user_id == user.id, Credential.user_id.is_(None), Credential.admin_public == True, # noqa: E712 + Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE), ) ) + if user.role == UserRole.BASIC: + # Basic users can only access credentials that are owned by them + return stmt.where(Credential.user_id == user.id) - return stmt + """ + THIS PART IS FOR CURATORS AND GLOBAL CURATORS + Here we select cc_pairs by relation: + User -> User__UserGroup -> Credential__UserGroup -> Credential + """ + stmt = stmt.outerjoin(Credential__UserGroup).outerjoin( + User__UserGroup, + User__UserGroup.user_group_id == Credential__UserGroup.user_group_id, + ) + """ + Filter Credentials by: + - if the user is in the user_group that owns the Credential + - if the user is not a global_curator, they must also have a curator relationship + to the user_group + - if editing is being done, we also filter out Credentials that are owned by groups + that the user isn't a curator for + - if we are not editing, we show all Credentials in the groups the user is a curator + for (as well as public Credentials) + - if we are not editing, we return all Credentials directly connected to the user + """ + where_clause = User__UserGroup.user_id == user.id + if user.role == UserRole.CURATOR: + where_clause &= User__UserGroup.is_curator == True # noqa: E712 + if get_editable: + user_groups = select(User__UserGroup.user_group_id).where( + User__UserGroup.user_id == user.id + ) + if user.role == UserRole.CURATOR: + user_groups = user_groups.where( + User__UserGroup.is_curator == True # noqa: E712 + ) + where_clause &= ( + ~exists() + .where(Credential__UserGroup.credential_id == Credential.id) + .where(~Credential__UserGroup.user_group_id.in_(user_groups)) + .correlate(Credential) + ) + else: + where_clause |= Credential.curator_public == True # noqa: E712 + where_clause |= Credential.user_id == user.id # noqa: E712 + + where_clause |= Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE) + + return stmt.where(where_clause) + + +def _relate_credential_to_user_groups__no_commit( + db_session: Session, + credential_id: int, + user_group_ids: list[int], +) -> None: + credential_user_groups = [] + for group_id in user_group_ids: + credential_user_groups.append( + Credential__UserGroup( + credential_id=credential_id, + user_group_id=group_id, + ) + ) + db_session.add_all(credential_user_groups) def fetch_credentials( db_session: Session, user: User | None = None, + get_editable: bool = True, ) -> list[Credential]: stmt = select(Credential) - stmt = _attach_user_filters(stmt, user) + stmt = _add_user_filters(stmt, user, get_editable=get_editable) results = db_session.scalars(stmt) return list(results.all()) @@ -73,7 +156,7 @@ def fetch_credential_by_id( assume_admin: bool = False, ) -> Credential | None: stmt = select(Credential).where(Credential.id == credential_id) - stmt = _attach_user_filters(stmt, user, assume_admin=assume_admin) + stmt = _add_user_filters(stmt, user, assume_admin=assume_admin) result = db_session.execute(stmt) credential = result.scalar_one_or_none() return credential @@ -83,9 +166,10 @@ def fetch_credentials_by_source( db_session: Session, user: User | None, document_source: DocumentSource | None = None, + get_editable: bool = True, ) -> list[Credential]: base_query = select(Credential).where(Credential.source == document_source) - base_query = _attach_user_filters(base_query, user) + base_query = _add_user_filters(base_query, user, get_editable=get_editable) credentials = db_session.execute(base_query).scalars().all() return list(credentials) @@ -153,19 +237,38 @@ def create_credential( admin_public=credential_data.admin_public, source=credential_data.source, name=credential_data.name, + curator_public=credential_data.curator_public, ) db_session.add(credential) + db_session.flush() # This ensures the credential gets an ID + + _relate_credential_to_user_groups__no_commit( + db_session=db_session, + credential_id=credential.id, + user_group_ids=credential_data.groups, + ) + db_session.commit() return credential +def _cleanup_credential__user_group_relationships__no_commit( + db_session: Session, credential_id: int +) -> None: + """NOTE: does not commit the transaction.""" + db_session.query(Credential__UserGroup).filter( + Credential__UserGroup.credential_id == credential_id + ).delete(synchronize_session=False) + + def alter_credential( credential_id: int, credential_data: CredentialDataUpdateRequest, user: User, db_session: Session, ) -> Credential | None: + # TODO: add user group relationship update credential = fetch_credential_by_id(credential_id, user, db_session) if credential is None: @@ -275,6 +378,7 @@ def delete_credential( else: logger.notice(f"Deleting credential {credential_id}") + _cleanup_credential__user_group_relationships__no_commit(db_session, credential_id) db_session.delete(credential) db_session.commit() diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index 3893ee202..130c4b0a0 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -4,9 +4,12 @@ from uuid import UUID from sqlalchemy import and_ from sqlalchemy import delete +from sqlalchemy import exists from sqlalchemy import func from sqlalchemy import or_ +from sqlalchemy import Select from sqlalchemy import select +from sqlalchemy.orm import aliased from sqlalchemy.orm import Session from danswer.db.enums import ConnectorCredentialPairStatus @@ -15,6 +18,10 @@ from danswer.db.models import Document from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import DocumentSet__ConnectorCredentialPair +from danswer.db.models import DocumentSet__UserGroup +from danswer.db.models import User +from danswer.db.models import User__UserGroup +from danswer.db.models import UserRole from danswer.server.features.document_set.models import DocumentSetCreationRequest from danswer.server.features.document_set.models import DocumentSetUpdateRequest from danswer.utils.variable_functionality import fetch_versioned_implementation @@ -341,9 +348,58 @@ def fetch_document_sets( ] -def fetch_all_document_sets(db_session: Session) -> Sequence[DocumentSetDBModel]: - """Used for Admin UI where they should have visibility into all document sets""" - return db_session.scalars(select(DocumentSetDBModel)).all() +def _add_user_filters( + stmt: Select, user: User | None, get_editable: bool = True +) -> Select: + # If user is None, assume the user is an admin or auth is disabled + if user is None or user.role == UserRole.ADMIN: + return stmt + + DocumentSet__UG = aliased(DocumentSet__UserGroup) + User__UG = aliased(User__UserGroup) + """ + Here we select cc_pairs by relation: + User -> User__UserGroup -> DocumentSet__UserGroup -> DocumentSet + """ + stmt = stmt.outerjoin(DocumentSet__UG).outerjoin( + User__UserGroup, + User__UserGroup.user_group_id == DocumentSet__UG.user_group_id, + ) + """ + Filter DocumentSets by: + - if the user is in the user_group that owns the DocumentSet + - if the user is not a global_curator, they must also have a curator relationship + to the user_group + - if editing is being done, we also filter out DocumentSets that are owned by groups + that the user isn't a curator for + - if we are not editing, we show all DocumentSets in the groups the user is a curator + for (as well as public DocumentSets) + """ + where_clause = User__UserGroup.user_id == user.id + if user.role == UserRole.CURATOR and get_editable: + where_clause &= User__UserGroup.is_curator == True # noqa: E712 + if get_editable: + user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id) + if user.role == UserRole.CURATOR: + user_groups = user_groups.where(User__UG.is_curator == True) # noqa: E712 + where_clause &= ( + ~exists() + .where(DocumentSet__UG.document_set_id == DocumentSetDBModel.id) + .where(~DocumentSet__UG.user_group_id.in_(user_groups)) + .correlate(DocumentSetDBModel) + ) + else: + where_clause |= DocumentSetDBModel.is_public == True # noqa: E712 + + return stmt.where(where_clause) + + +def fetch_all_document_sets_for_user( + db_session: Session, user: User | None = None, get_editable: bool = True +) -> Sequence[DocumentSetDBModel]: + stmt = select(DocumentSetDBModel).distinct() + stmt = _add_user_filters(stmt, user, get_editable=get_editable) + return db_session.scalars(stmt).all() def fetch_user_document_sets( diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py index bb7da0864..79557f209 100644 --- a/backend/danswer/db/feedback.py +++ b/backend/danswer/db/feedback.py @@ -1,22 +1,36 @@ from uuid import UUID +from fastapi import HTTPException +from sqlalchemy import and_ from sqlalchemy import asc from sqlalchemy import delete from sqlalchemy import desc +from sqlalchemy import exists +from sqlalchemy import Select from sqlalchemy import select +from sqlalchemy.orm import aliased from sqlalchemy.orm import Session from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.db.chat import get_chat_message from danswer.db.models import ChatMessageFeedback +from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Document as DbDocument +from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.models import DocumentRetrievalFeedback +from danswer.db.models import User +from danswer.db.models import User__UserGroup +from danswer.db.models import UserGroup__ConnectorCredentialPair +from danswer.db.models import UserRole from danswer.document_index.interfaces import DocumentIndex from danswer.document_index.interfaces import UpdateRequest +from danswer.utils.logger import setup_logger + +logger = setup_logger() -def fetch_db_doc_by_id(doc_id: str, db_session: Session) -> DbDocument: +def _fetch_db_doc_by_id(doc_id: str, db_session: Session) -> DbDocument: stmt = select(DbDocument).where(DbDocument.id == doc_id) result = db_session.execute(stmt) doc = result.scalar_one_or_none() @@ -27,15 +41,78 @@ def fetch_db_doc_by_id(doc_id: str, db_session: Session) -> DbDocument: return doc +def _add_user_filters( + stmt: Select, user: User | None, get_editable: bool = True +) -> Select: + # If user is None, assume the user is an admin or auth is disabled + if user is None or user.role == UserRole.ADMIN: + return stmt + + DocByCC = aliased(DocumentByConnectorCredentialPair) + CCPair = aliased(ConnectorCredentialPair) + UG__CCpair = aliased(UserGroup__ConnectorCredentialPair) + User__UG = aliased(User__UserGroup) + + """ + Here we select documents by relation: + User -> User__UserGroup -> UserGroup__ConnectorCredentialPair -> + ConnectorCredentialPair -> DocumentByConnectorCredentialPair -> Document + """ + stmt = ( + stmt.outerjoin(DocByCC, DocByCC.id == DbDocument.id) + .outerjoin( + CCPair, + and_( + CCPair.connector_id == DocByCC.connector_id, + CCPair.credential_id == DocByCC.credential_id, + ), + ) + .outerjoin(UG__CCpair, UG__CCpair.cc_pair_id == CCPair.id) + .outerjoin(User__UG, User__UG.user_group_id == UG__CCpair.user_group_id) + ) + + """ + Filter Documents by: + - if the user is in the user_group that owns the object + - if the user is not a global_curator, they must also have a curator relationship + to the user_group + - if editing is being done, we also filter out objects that are owned by groups + that the user isn't a curator for + - if we are not editing, we show all objects in the groups the user is a curator + for (as well as public objects as well) + """ + where_clause = User__UG.user_id == user.id + if user.role == UserRole.CURATOR and get_editable: + where_clause &= User__UG.is_curator == True # noqa: E712 + if get_editable: + user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id) + where_clause &= ( + ~exists() + .where(UG__CCpair.cc_pair_id == CCPair.id) + .where(~UG__CCpair.user_group_id.in_(user_groups)) + .correlate(CCPair) + ) + else: + where_clause |= CCPair.is_public == True # noqa: E712 + + return stmt.where(where_clause) + + def fetch_docs_ranked_by_boost( - db_session: Session, ascending: bool = False, limit: int = 100 + db_session: Session, + user: User | None = None, + ascending: bool = False, + limit: int = 100, ) -> list[DbDocument]: order_func = asc if ascending else desc - stmt = ( - select(DbDocument) - .order_by(order_func(DbDocument.boost), order_func(DbDocument.semantic_id)) - .limit(limit) + stmt = select(DbDocument) + + stmt = _add_user_filters(stmt=stmt, user=user, get_editable=False) + + stmt = stmt.order_by( + order_func(DbDocument.boost), order_func(DbDocument.semantic_id) ) + stmt = stmt.limit(limit) result = db_session.execute(stmt) doc_list = result.scalars().all() @@ -43,12 +120,19 @@ def fetch_docs_ranked_by_boost( def update_document_boost( - db_session: Session, document_id: str, boost: int, document_index: DocumentIndex + db_session: Session, + document_id: str, + boost: int, + document_index: DocumentIndex, + user: User | None = None, ) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) + stmt = _add_user_filters(stmt, user, get_editable=True) result = db_session.execute(stmt).scalar_one_or_none() if result is None: - raise ValueError(f"No document found with ID: '{document_id}'") + raise HTTPException( + status_code=400, detail="Document is not editable by this user" + ) result.boost = boost @@ -63,12 +147,19 @@ def update_document_boost( def update_document_hidden( - db_session: Session, document_id: str, hidden: bool, document_index: DocumentIndex + db_session: Session, + document_id: str, + hidden: bool, + document_index: DocumentIndex, + user: User | None = None, ) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) + stmt = _add_user_filters(stmt, user, get_editable=True) result = db_session.execute(stmt).scalar_one_or_none() if result is None: - raise ValueError(f"No document found with ID: '{document_id}'") + raise HTTPException( + status_code=400, detail="Document is not editable by this user" + ) result.hidden = hidden @@ -92,7 +183,7 @@ def create_doc_retrieval_feedback( feedback: SearchFeedbackType | None = None, ) -> None: """Creates a new Document feedback row and updates the boost value in Postgres and Vespa""" - db_doc = fetch_db_doc_by_id(document_id, db_session) + db_doc = _fetch_db_doc_by_id(document_id, db_session) retrieval_feedback = DocumentRetrievalFeedback( chat_message_id=message_id, diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index 8410ed733..5be3273f6 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -107,16 +107,14 @@ def fetch_existing_llm_providers( if not user: return list(db_session.scalars(select(LLMProviderModel)).all()) stmt = select(LLMProviderModel).distinct() - user_groups_subquery = ( - select(User__UserGroup.user_group_id) - .where(User__UserGroup.user_id == user.id) - .subquery() + user_groups_select = select(User__UserGroup.user_group_id).where( + User__UserGroup.user_id == user.id ) access_conditions = or_( LLMProviderModel.is_public, LLMProviderModel.id.in_( # User is part of a group that has access select(LLMProvider__UserGroup.llm_provider_id).where( - LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore + LLMProvider__UserGroup.user_group_id.in_(user_groups_select) # type: ignore ) ), ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index b3d85418f..500bf2b04 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -529,6 +529,8 @@ class Credential(Base): DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) + curator_public: Mapped[bool] = mapped_column(Boolean, default=False) + connectors: Mapped[list["ConnectorCredentialPair"]] = relationship( "ConnectorCredentialPair", back_populates="credential", @@ -1458,6 +1460,8 @@ class SamlAccount(Base): class User__UserGroup(Base): __tablename__ = "user__user_group" + is_curator: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id"), primary_key=True ) @@ -1522,6 +1526,17 @@ class DocumentSet__UserGroup(Base): ) +class Credential__UserGroup(Base): + __tablename__ = "credential__user_group" + + credential_id: Mapped[int] = mapped_column( + ForeignKey("credential.id"), primary_key=True + ) + user_group_id: Mapped[int] = mapped_column( + ForeignKey("user_group.id"), primary_key=True + ) + + class UserGroup(Base): __tablename__ = "user_group" @@ -1538,6 +1553,10 @@ class UserGroup(Base): "User", secondary=User__UserGroup.__table__, ) + user_group_relationships: Mapped[list[User__UserGroup]] = relationship( + "User__UserGroup", + viewonly=True, + ) cc_pairs: Mapped[list[ConnectorCredentialPair]] = relationship( "ConnectorCredentialPair", secondary=UserGroup__ConnectorCredentialPair.__table__, @@ -1559,6 +1578,10 @@ class UserGroup(Base): secondary=DocumentSet__UserGroup.__table__, viewonly=True, ) + credentials: Mapped[list[Credential]] = relationship( + "Credential", + secondary=Credential__UserGroup.__table__, + ) """Tables related to Token Rate Limiting diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index d8f71c3ea..bbf45a1d9 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -4,13 +4,15 @@ from uuid import UUID from fastapi import HTTPException from sqlalchemy import delete +from sqlalchemy import exists from sqlalchemy import func from sqlalchemy import not_ from sqlalchemy import or_ +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import update +from sqlalchemy.orm import aliased from sqlalchemy.orm import joinedload -from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from danswer.auth.schemas import UserRole @@ -38,6 +40,89 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() +def _add_user_filters( + stmt: Select, user: User | None, get_editable: bool = True +) -> Select: + # If user is None, assume the user is an admin or auth is disabled + if user is None or user.role == UserRole.ADMIN: + return stmt + + Persona__UG = aliased(Persona__UserGroup) + User__UG = aliased(User__UserGroup) + """ + Here we select cc_pairs by relation: + User -> User__UserGroup -> Persona__UserGroup -> Persona + """ + stmt = ( + stmt.outerjoin(Persona__UG) + .outerjoin( + User__UserGroup, + User__UserGroup.user_group_id == Persona__UG.user_group_id, + ) + .outerjoin( + Persona__User, + Persona__User.persona_id == Persona.id, + ) + ) + """ + Filter Personas by: + - if the user is in the user_group that owns the Persona + - if the user is not a global_curator, they must also have a curator relationship + to the user_group + - if editing is being done, we also filter out Personas that are owned by groups + that the user isn't a curator for + - if we are not editing, we show all Personas in the groups the user is a curator + for (as well as public Personas) + - if we are not editing, we return all Personas directly connected to the user + """ + where_clause = User__UserGroup.user_id == user.id + if user.role == UserRole.CURATOR and get_editable: + where_clause &= User__UserGroup.is_curator == True # noqa: E712 + if get_editable: + user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id) + if user.role == UserRole.CURATOR: + user_groups = user_groups.where(User__UG.is_curator == True) # noqa: E712 + where_clause &= ( + ~exists() + .where(Persona__UG.persona_id == Persona.id) + .where(~Persona__UG.user_group_id.in_(user_groups)) + .correlate(Persona) + ) + else: + where_clause |= Persona.is_public == True # noqa: E712 + where_clause &= Persona.is_visible == True # noqa: E712 + where_clause |= Persona__User.user_id == user.id + where_clause |= Persona.user_id == user.id + + return stmt.where(where_clause) + + +def fetch_persona_by_id( + db_session: Session, persona_id: int, user: User | None, get_editable: bool = True +) -> Persona: + stmt = select(Persona).where(Persona.id == persona_id).distinct() + stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable) + persona = db_session.scalars(stmt).one_or_none() + if not persona: + raise HTTPException( + status_code=403, + detail=f"Persona with ID {persona_id} does not exist or user is not authorized to access it", + ) + return persona + + +def _get_persona_by_name( + persona_name: str, user: User | None, db_session: Session +) -> Persona | None: + """Admins can see all, regular users can only fetch their own. + If user is None, assume the user is an admin or auth is disabled.""" + stmt = select(Persona).where(Persona.name == persona_name) + if user and user.role != UserRole.ADMIN: + stmt = stmt.where(Persona.user_id == user.id) + result = db_session.execute(stmt).scalar_one_or_none() + return result + + def make_persona_private( persona_id: int, user_ids: list[UUID] | None, @@ -105,13 +190,9 @@ def update_persona_shared_users( """Simplified version of `create_update_persona` which only touches the accessibility rather than any of the logic (e.g. prompt, connected data sources, etc.).""" - persona = fetch_persona_by_id(db_session=db_session, persona_id=persona_id) - if not persona: - raise HTTPException( - status_code=404, detail=f"Persona with ID {persona_id} not found" - ) - - check_user_can_edit_persona(user=user, persona=persona) + persona = fetch_persona_by_id( + db_session=db_session, persona_id=persona_id, user=user, get_editable=True + ) if persona.is_public: raise HTTPException(status_code=400, detail="Cannot share public persona") @@ -129,10 +210,6 @@ def update_persona_shared_users( ) -def fetch_persona_by_id(db_session: Session, persona_id: int) -> Persona | None: - return db_session.scalar(select(Persona).where(Persona.id == persona_id)) - - def get_prompts( user_id: UUID | None, db_session: Session, @@ -152,36 +229,17 @@ def get_prompts( def get_personas( - # if user_id is `None` assume the user is an admin or auth is disabled - user_id: UUID | None, + # if user is `None` assume the user is an admin or auth is disabled + user: User | None, db_session: Session, + get_editable: bool = True, include_default: bool = True, include_slack_bot_personas: bool = False, include_deleted: bool = False, joinedload_all: bool = False, ) -> Sequence[Persona]: stmt = select(Persona).distinct() - if user_id is not None: - # Subquery to find all groups the user belongs to - user_groups_subquery = ( - select(User__UserGroup.user_group_id) - .where(User__UserGroup.user_id == user_id) - .subquery() - ) - - # Include personas where the user is directly related or part of a user group that has access - access_conditions = or_( - Persona.is_public == True, # noqa: E712 - Persona.id.in_( # User has access through list of users with access - select(Persona__User.persona_id).where(Persona__User.user_id == user_id) - ), - Persona.id.in_( # User is part of a group that has access - select(Persona__UserGroup.persona_id).where( - Persona__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore - ) - ), - ) - stmt = stmt.where(access_conditions) + stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable) if not include_default: stmt = stmt.where(Persona.default_persona.is_(False)) @@ -245,7 +303,7 @@ def update_all_personas_display_priority( db_session: Session, ) -> None: """Updates the display priority of all lives Personas""" - personas = get_personas(user_id=None, db_session=db_session) + personas = get_personas(user=None, db_session=db_session) available_persona_ids = {persona.id for persona in personas} if available_persona_ids != set(display_priority_map.keys()): raise ValueError("Invalid persona IDs provided") @@ -346,7 +404,7 @@ def upsert_persona( if persona_id is not None: persona = db_session.query(Persona).filter_by(id=persona_id).first() else: - persona = get_persona_by_name( + persona = _get_persona_by_name( persona_name=name, user=user, db_session=db_session ) @@ -383,7 +441,10 @@ def upsert_persona( if not default_persona and persona.default_persona: raise ValueError("Cannot update default persona with non-default.") - check_user_can_edit_persona(user=user, persona=persona) + # this checks if the user has permission to edit the persona + persona = fetch_persona_by_id( + db_session=db_session, persona_id=persona.id, user=user, get_editable=True + ) persona.name = name persona.description = description @@ -485,8 +546,11 @@ def update_persona_visibility( persona_id: int, is_visible: bool, db_session: Session, + user: User | None = None, ) -> None: - persona = get_persona_by_id(persona_id=persona_id, user=None, db_session=db_session) + persona = fetch_persona_by_id( + db_session=db_session, persona_id=persona_id, user=user, get_editable=True + ) persona.is_visible = is_visible db_session.commit() @@ -499,23 +563,6 @@ def validate_persona_tools(tools: list[Tool]) -> None: ) -def check_user_can_edit_persona(user: User | None, persona: Persona) -> None: - # if user is None, assume that no-auth is turned on - if user is None: - return - - # admins can edit everything - if user.role == UserRole.ADMIN: - return - - # otherwise, make sure user owns persona - if persona.user_id != user.id: - raise HTTPException( - status_code=403, - detail=f"User not authorized to edit persona with ID {persona.id}", - ) - - def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]: """Unsafe, can fetch prompts from all users""" if not prompt_ids: @@ -587,54 +634,53 @@ def get_persona_by_id( include_deleted: bool = False, is_for_edit: bool = True, # NOTE: assume true for safety ) -> Persona: - stmt = ( + persona_stmt = ( select(Persona) - .options(selectinload(Persona.users), selectinload(Persona.groups)) + .distinct() + .outerjoin(Persona.groups) + .outerjoin(Persona.users) + .outerjoin(UserGroup.user_group_relationships) .where(Persona.id == persona_id) ) - or_conditions = [] - - # if user is an admin, they should have access to all Personas - # and will skip the following clause - if user is not None and user.role != UserRole.ADMIN: - # the user is not an admin - isPersonaUnowned = Persona.user_id.is_( - None - ) # allow access if persona user id is None - isUserCreator = ( - Persona.user_id == user.id - ) # allow access if user created the persona - or_conditions.extend([isPersonaUnowned, isUserCreator]) - - # if we aren't editing, also give access if: - # 1. the user is authorized for this persona - # 2. the user is in an authorized group for this persona - # 3. if the persona is public - if not is_for_edit: - isSharedWithUser = Persona.users.any( - id=user.id - ) # allow access if user is in allowed users - isSharedWithGroup = Persona.groups.any( - UserGroup.users.any(id=user.id) - ) # allow access if user is in any allowed group - or_conditions.extend([isSharedWithUser, isSharedWithGroup]) - or_conditions.append(Persona.is_public.is_(True)) - - if or_conditions: - stmt = stmt.where(or_(*or_conditions)) - if not include_deleted: - stmt = stmt.where(Persona.deleted.is_(False)) + persona_stmt = persona_stmt.where(Persona.deleted.is_(False)) - result = db_session.execute(stmt) + if not user or user.role == UserRole.ADMIN: + result = db_session.execute(persona_stmt) + persona = result.scalar_one_or_none() + if persona is None: + raise ValueError( + f"Persona with ID {persona_id} does not exist or does not belong to user" + ) + return persona + + # or check if user owns persona + or_conditions = Persona.user_id == user.id + # allow access if persona user id is None + or_conditions |= Persona.user_id == None # noqa: E711 + if not is_for_edit: + # if the user is in a group related to the persona + or_conditions |= User__UserGroup.user_id == user.id + # if the user is in the .users of the persona + or_conditions |= User.id == user.id + or_conditions |= Persona.is_public == True # noqa: E712 + elif user.role == UserRole.GLOBAL_CURATOR: + # global curators can edit personas for the groups they are in + or_conditions |= User__UserGroup.user_id == user.id + elif user.role == UserRole.CURATOR: + # curators can edit personas for the groups they are curators of + or_conditions |= (User__UserGroup.user_id == user.id) & ( + User__UserGroup.is_curator == True # noqa: E712 + ) + + persona_stmt = persona_stmt.where(or_conditions) + result = db_session.execute(persona_stmt) persona = result.scalar_one_or_none() - if persona is None: raise ValueError( f"Persona with ID {persona_id} does not exist or does not belong to user" ) - return persona @@ -665,18 +711,6 @@ def get_prompt_by_name( return result -def get_persona_by_name( - persona_name: str, user: User | None, db_session: Session -) -> Persona | None: - """Admins can see all, regular users can only fetch their own. - If user is None, assume the user is an admin or auth is disabled.""" - stmt = select(Persona).where(Persona.name == persona_name) - if user and user.role != UserRole.ADMIN: - stmt = stmt.where(Persona.user_id == user.id) - result = db_session.execute(stmt).scalar_one_or_none() - return result - - def delete_persona_by_name( persona_name: str, db_session: Session, is_default: bool = True ) -> None: diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index f8a393802..515cbe070 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -1,21 +1,41 @@ from collections.abc import Sequence +from uuid import UUID +from sqlalchemy import select from sqlalchemy.orm import Session -from sqlalchemy.schema import Column from danswer.db.models import User +from danswer.db.models import User__UserGroup +from danswer.db.models import UserRole -def list_users(db_session: Session, q: str = "") -> Sequence[User]: +def list_users( + db_session: Session, email_filter_string: str = "", user: User | None = None +) -> Sequence[User]: """List all users. No pagination as of now, as the # of users is assumed to be relatively small (<< 1 million)""" - query = db_session.query(User) - if q: - query = query.filter(Column("email").ilike("%{}%".format(q))) - return query.all() + stmt = select(User) + + if email_filter_string: + stmt = stmt.where(User.email.ilike(f"%{email_filter_string}%")) # type: ignore + + if user and user.role != UserRole.ADMIN: + stmt = stmt.join(User__UserGroup) + where_clause = User__UserGroup.user_id == user.id + if user.role == UserRole.CURATOR: + where_clause &= User__UserGroup.is_curator == True # noqa: E712 + stmt = stmt.where(where_clause) + + return db_session.scalars(stmt).unique().all() def get_user_by_email(email: str, db_session: Session) -> User | None: user = db_session.query(User).filter(User.email == email).first() # type: ignore return user + + +def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None: + user = db_session.query(User).filter(User.id == user_id).first() # type: ignore + + return user diff --git a/backend/danswer/server/auth_check.py b/backend/danswer/server/auth_check.py index 53ef572da..12258eba2 100644 --- a/backend/danswer/server/auth_check.py +++ b/backend/danswer/server/auth_check.py @@ -5,6 +5,7 @@ from fastapi.dependencies.models import Dependant from starlette.routing import BaseRoute from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user from danswer.configs.app_configs import APP_API_PREFIX from danswer.server.danswer_api.ingestion import api_key_dep @@ -93,6 +94,7 @@ def check_router_auth( if ( depends_fn == current_user or depends_fn == current_admin_user + or depends_fn == current_curator_or_admin_user or depends_fn == api_key_dep ): found_auth = True diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 7f8dcef57..dd28eee1a 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session -from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot from danswer.db.connector_credential_pair import add_credential_to_connector @@ -21,10 +21,14 @@ from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import get_index_attempts_for_connector from danswer.db.models import User +from danswer.db.models import UserRole from danswer.server.documents.models import CCPairFullInfo from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairMetadata from danswer.server.models import StatusResponse +from danswer.utils.logger import setup_logger + +logger = setup_logger() router = APIRouter(prefix="/manage") @@ -32,18 +36,20 @@ router = APIRouter(prefix="/manage") @router.get("/admin/cc-pair/{cc_pair_id}") def get_cc_pair_full_info( cc_pair_id: int, - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> CCPairFullInfo: cc_pair = get_connector_credential_pair_from_id( - cc_pair_id=cc_pair_id, - db_session=db_session, + cc_pair_id, db_session, user, get_editable=False ) - if cc_pair is None: + if not cc_pair: raise HTTPException( - status_code=400, - detail=f"Connector with ID {cc_pair_id} not found. Has it been deleted?", + status_code=404, detail="CC Pair not found for current user permissions" ) + editable_cc_pair = get_connector_credential_pair_from_id( + cc_pair_id, db_session, user, get_editable=True + ) + is_editable_for_current_user = editable_cc_pair is not None cc_pair_identifier = ConnectorCredentialPairIdentifier( connector_id=cc_pair.connector_id, @@ -74,6 +80,7 @@ def get_cc_pair_full_info( db_session=db_session, ), num_docs_indexed=documents_indexed, + is_editable_for_current_user=is_editable_for_current_user, ) @@ -85,9 +92,21 @@ class CCStatusUpdateRequest(BaseModel): def update_cc_pair_status( cc_pair_id: int, status_update_request: CCStatusUpdateRequest, - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> None: + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id=cc_pair_id, + db_session=db_session, + user=user, + get_editable=True, + ) + if not cc_pair: + raise HTTPException( + status_code=400, + detail="Connection not found for current user's permissions", + ) + if status_update_request.status == ConnectorCredentialPairStatus.PAUSED: cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session) @@ -105,12 +124,19 @@ def update_cc_pair_status( def update_cc_pair_name( cc_pair_id: int, new_name: str, - user: User | None = Depends(current_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id=cc_pair_id, + db_session=db_session, + user=user, + get_editable=True, + ) if not cc_pair: - raise HTTPException(status_code=404, detail="CC Pair not found") + raise HTTPException( + status_code=400, detail="CC Pair not found for current user's permissions" + ) try: cc_pair.name = new_name @@ -128,18 +154,27 @@ def associate_credential_to_connector( connector_id: int, credential_id: int, metadata: ConnectorCredentialPairMetadata, - user: User | None = Depends(current_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: + if user and user.role != UserRole.ADMIN and metadata.is_public: + raise HTTPException( + status_code=400, + detail="Public connections cannot be created by non-admin users", + ) + try: - return add_credential_to_connector( + response = add_credential_to_connector( + db_session=db_session, + user=user, connector_id=connector_id, credential_id=credential_id, cc_pair_name=metadata.name, is_public=metadata.is_public, - user=user, - db_session=db_session, + groups=metadata.groups, ) + + return response except IntegrityError: raise HTTPException(status_code=400, detail="Name must be unique") diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index abc9de1f9..05c18e65f 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -5,6 +5,7 @@ from typing import cast from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from fastapi import Query from fastapi import Request from fastapi import Response from fastapi import UploadFile @@ -12,6 +13,7 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES @@ -67,15 +69,16 @@ from danswer.db.index_attempt import get_index_attempts_for_cc_pair from danswer.db.index_attempt import get_latest_finished_index_attempt_for_cc_pair from danswer.db.index_attempt import get_latest_index_attempts from danswer.db.models import User +from danswer.db.models import UserRole from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.file_store.file_store import get_default_file_store from danswer.server.documents.models import AuthStatus from danswer.server.documents.models import AuthUrl from danswer.server.documents.models import ConnectorBase -from danswer.server.documents.models import ConnectorCredentialBase from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorIndexingStatus from danswer.server.documents.models import ConnectorSnapshot +from danswer.server.documents.models import ConnectorUpdateRequest from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import CredentialSnapshot from danswer.server.documents.models import FileUploadResponse @@ -88,6 +91,9 @@ from danswer.server.documents.models import IndexAttemptSnapshot from danswer.server.documents.models import ObjectCreationIdResponse from danswer.server.documents.models import RunConnectorRequest from danswer.server.models import StatusResponse +from danswer.utils.logger import setup_logger + +logger = setup_logger() _GMAIL_CREDENTIAL_ID_COOKIE_NAME = "gmail_credential_id" _GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME = "google_drive_credential_id" @@ -101,7 +107,7 @@ router = APIRouter(prefix="/manage") @router.get("/admin/connector/gmail/app-credential") def check_google_app_gmail_credentials_exist( - _: User = Depends(current_admin_user), + _: User = Depends(current_curator_or_admin_user), ) -> dict[str, str]: try: return {"client_id": get_google_app_gmail_cred().web.client_id} @@ -139,7 +145,7 @@ def delete_google_app_gmail_credentials( @router.get("/admin/connector/google-drive/app-credential") def check_google_app_credentials_exist( - _: User = Depends(current_admin_user), + _: User = Depends(current_curator_or_admin_user), ) -> dict[str, str]: try: return {"client_id": get_google_app_cred().web.client_id} @@ -177,7 +183,7 @@ def delete_google_app_credentials( @router.get("/admin/connector/gmail/service-account-key") def check_google_service_gmail_account_key_exist( - _: User = Depends(current_admin_user), + _: User = Depends(current_curator_or_admin_user), ) -> dict[str, str]: try: return {"service_account_email": get_gmail_service_account_key().client_email} @@ -217,7 +223,7 @@ def delete_google_service_gmail_account_key( @router.get("/admin/connector/google-drive/service-account-key") def check_google_service_account_key_exist( - _: User = Depends(current_admin_user), + _: User = Depends(current_curator_or_admin_user), ) -> dict[str, str]: try: return {"service_account_email": get_service_account_key().client_email} @@ -258,7 +264,7 @@ def delete_google_service_account_key( @router.put("/admin/connector/google-drive/service-account-credential") def upsert_service_account_credential( service_account_credential_request: GoogleServiceAccountCredentialRequest, - user: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: """Special API which allows the creation of a credential for a service account. @@ -284,7 +290,7 @@ def upsert_service_account_credential( @router.put("/admin/connector/gmail/service-account-credential") def upsert_gmail_service_account_credential( service_account_credential_request: GoogleServiceAccountCredentialRequest, - user: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: """Special API which allows the creation of a credential for a service account. @@ -345,7 +351,7 @@ def admin_google_drive_auth( @router.post("/admin/connector/file/upload") def upload_files( files: list[UploadFile], - _: User = Depends(current_admin_user), + _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> FileUploadResponse: for file in files: @@ -372,13 +378,21 @@ def upload_files( @router.get("/admin/connector/indexing-status") def get_connector_indexing_status( secondary_index: bool = False, - _: User = Depends(current_admin_user), + user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + get_editable: bool = Query( + False, description="If true, return editable document sets" + ), ) -> list[ConnectorIndexingStatus]: indexing_statuses: list[ConnectorIndexingStatus] = [] # TODO: make this one query - cc_pairs = get_connector_credential_pairs(db_session) + cc_pairs = get_connector_credential_pairs( + db_session=db_session, + user=user, + get_editable=get_editable, + ) + cc_pair_identifiers = [ ConnectorCredentialPairIdentifier( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id @@ -488,28 +502,74 @@ def _validate_connector_allowed(source: DocumentSource) -> None: ) +def _check_connector_permissions( + connector_data: ConnectorUpdateRequest, user: User | None +) -> ConnectorBase: + """ + This is not a proper permission check, but this should prevent curators creating bad situations + until a long-term solution is implemented (Replacing CC pairs/Connectors with Connections) + """ + if user and user.role != UserRole.ADMIN: + if connector_data.is_public: + raise HTTPException( + status_code=400, + detail="Public connectors can only be created by admins", + ) + if not connector_data.groups: + raise HTTPException( + status_code=400, + detail="Connectors created by curators must have groups", + ) + return ConnectorBase( + name=connector_data.name, + source=connector_data.source, + input_type=connector_data.input_type, + connector_specific_config=connector_data.connector_specific_config, + refresh_freq=connector_data.refresh_freq, + prune_freq=connector_data.prune_freq, + indexing_start=connector_data.indexing_start, + ) + + @router.post("/admin/connector") def create_connector_from_model( - connector_data: ConnectorBase, - _: User = Depends(current_admin_user), + connector_data: ConnectorUpdateRequest, + user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: try: _validate_connector_allowed(connector_data.source) - return create_connector(connector_data, db_session) + connector_base = _check_connector_permissions(connector_data, user) + return create_connector( + db_session=db_session, + connector_data=connector_base, + ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @router.post("/admin/connector-with-mock-credential") def create_connector_with_mock_credential( - connector_data: ConnectorCredentialBase, - user: User = Depends(current_admin_user), + connector_data: ConnectorUpdateRequest, + user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: + if user and user.role != UserRole.ADMIN: + if connector_data.is_public: + raise HTTPException( + status_code=401, + detail="User does not have permission to create public credentials", + ) + if not connector_data.groups: + raise HTTPException( + status_code=401, + detail="Curators must specify 1+ groups", + ) try: _validate_connector_allowed(connector_data.source) - connector_response = create_connector(connector_data, db_session) + connector_response = create_connector( + db_session=db_session, connector_data=connector_data + ) mock_credential = CredentialBase( credential_json={}, admin_public=True, source=connector_data.source ) @@ -517,12 +577,13 @@ def create_connector_with_mock_credential( mock_credential, user=user, db_session=db_session ) response = add_credential_to_connector( + db_session=db_session, + user=user, connector_id=cast(int, connector_response.id), # will aways be an int credential_id=credential.id, - is_public=connector_data.is_public, - user=user, - db_session=db_session, + is_public=connector_data.is_public or False, cc_pair_name=connector_data.name, + groups=connector_data.groups, ) return response @@ -533,16 +594,17 @@ def create_connector_with_mock_credential( @router.patch("/admin/connector/{connector_id}") def update_connector_from_model( connector_id: int, - connector_data: ConnectorBase, - _: User = Depends(current_admin_user), + connector_data: ConnectorUpdateRequest, + user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> ConnectorSnapshot | StatusResponse[int]: try: _validate_connector_allowed(connector_data.source) + connector_base = _check_connector_permissions(connector_data, user) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) - updated_connector = update_connector(connector_id, connector_data, db_session) + updated_connector = update_connector(connector_id, connector_base, db_session) if updated_connector is None: raise HTTPException( status_code=404, detail=f"Connector {connector_id} does not exist" @@ -573,7 +635,10 @@ def delete_connector_by_id( ) -> StatusResponse[int]: try: with db_session.begin(): - return delete_connector(db_session=db_session, connector_id=connector_id) + return delete_connector( + db_session=db_session, + connector_id=connector_id, + ) except AssertionError: raise HTTPException(status_code=400, detail="Connector is not deletable") @@ -581,7 +646,7 @@ def delete_connector_by_id( @router.post("/admin/connector/run-once") def connector_run_once( run_info: RunConnectorRequest, - _: User = Depends(current_admin_user), + _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[list[int]]: connector_id = run_info.connector_id diff --git a/backend/danswer/server/documents/credential.py b/backend/danswer/server/documents/credential.py index 8525aeb93..ba30b65f2 100644 --- a/backend/danswer/server/documents/credential.py +++ b/backend/danswer/server/documents/credential.py @@ -1,13 +1,16 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from fastapi import Query from sqlalchemy.orm import Session -from danswer.auth.schemas import UserRole from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user +from danswer.auth.users import validate_curator_request from danswer.db.credentials import alter_credential from danswer.db.credentials import create_credential +from danswer.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE from danswer.db.credentials import delete_credential from danswer.db.credentials import fetch_credential_by_id from danswer.db.credentials import fetch_credentials @@ -17,27 +20,39 @@ from danswer.db.credentials import update_credential from danswer.db.engine import get_session from danswer.db.models import DocumentSource from danswer.db.models import User +from danswer.db.models import UserRole from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import CredentialDataUpdateRequest from danswer.server.documents.models import CredentialSnapshot from danswer.server.documents.models import CredentialSwapRequest from danswer.server.documents.models import ObjectCreationIdResponse from danswer.server.models import StatusResponse +from danswer.utils.logger import setup_logger + +logger = setup_logger() router = APIRouter(prefix="/manage") +def _ignore_credential_permissions(source: DocumentSource) -> bool: + return source in CREDENTIAL_PERMISSIONS_TO_IGNORE + + """Admin-only endpoints""" @router.get("/admin/credential") def list_credentials_admin( - user: User = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[CredentialSnapshot]: """Lists all public credentials""" - credentials = fetch_credentials(db_session=db_session, user=user) + credentials = fetch_credentials( + db_session=db_session, + user=user, + get_editable=False, + ) return [ CredentialSnapshot.from_credential_db_model(credential) for credential in credentials @@ -47,13 +62,18 @@ def list_credentials_admin( @router.get("/admin/similar-credentials/{source_type}") def get_cc_source_full_info( source_type: DocumentSource, - user: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + get_editable: bool = Query( + False, description="If true, return editable credentials" + ), ) -> list[CredentialSnapshot]: credentials = fetch_credentials_by_source( - db_session=db_session, user=user, document_source=source_type + db_session=db_session, + user=user, + document_source=source_type, + get_editable=get_editable, ) - return [ CredentialSnapshot.from_credential_db_model(credential) for credential in credentials @@ -87,13 +107,13 @@ def delete_credential_by_id_admin( @router.put("/admin/credentials/swap") def swap_credentials_for_connector( - credentail_swap_req: CredentialSwapRequest, + credential_swap_req: CredentialSwapRequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse: connector_credential_pair = swap_credentials_connector( - new_credential_id=credentail_swap_req.new_credential_id, - connector_id=credentail_swap_req.connector_id, + new_credential_id=credential_swap_req.new_credential_id, + connector_id=credential_swap_req.connector_id, db_session=db_session, user=user, ) @@ -105,6 +125,29 @@ def swap_credentials_for_connector( ) +@router.post("/credential") +def create_credential_from_model( + credential_info: CredentialBase, + user: User | None = Depends(current_curator_or_admin_user), + db_session: Session = Depends(get_session), +) -> ObjectCreationIdResponse: + if ( + user + and user.role != UserRole.ADMIN + and not _ignore_credential_permissions(credential_info.source) + ): + validate_curator_request( + groups=credential_info.groups, + is_public=credential_info.curator_public, + ) + + credential = create_credential(credential_info, user, db_session) + return ObjectCreationIdResponse( + id=credential.id, + credential=CredentialSnapshot.from_credential_db_model(credential), + ) + + """Endpoints for all""" @@ -120,26 +163,6 @@ def list_credentials( ] -@router.post("/credential") -def create_credential_from_model( - credential_info: CredentialBase, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> ObjectCreationIdResponse: - if user and user.role != UserRole.ADMIN and credential_info.admin_public: - raise HTTPException( - status_code=400, - detail="Non-admin cannot create admin credential", - ) - - credential = create_credential(credential_info, user, db_session) - - return ObjectCreationIdResponse( - id=credential.id, - credential=CredentialSnapshot.from_credential_db_model(credential), - ) - - @router.get("/credential/{credential_id}") def get_credential_by_id( credential_id: int, @@ -195,9 +218,11 @@ def update_credential_from_model( id=updated_credential.id, credential_json=updated_credential.credential_json, user_id=updated_credential.user_id, + name=updated_credential.name, admin_public=updated_credential.admin_public, time_created=updated_credential.time_created, time_updated=updated_credential.time_updated, + curator_public=updated_credential.curator_public, ) diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index ed23b79d6..32a882979 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -45,8 +45,9 @@ class ConnectorBase(BaseModel): indexing_start: datetime | None -class ConnectorCredentialBase(ConnectorBase): - is_public: bool +class ConnectorUpdateRequest(ConnectorBase): + is_public: bool | None = None + groups: list[int] | None = None class ConnectorSnapshot(ConnectorBase): @@ -91,6 +92,8 @@ class CredentialBase(BaseModel): admin_public: bool source: DocumentSource name: str | None = None + curator_public: bool = False + groups: list[int] = [] class CredentialSnapshot(CredentialBase): @@ -98,6 +101,11 @@ class CredentialSnapshot(CredentialBase): user_id: UUID | None time_created: datetime time_updated: datetime + name: str | None + source: DocumentSource + credential_json: dict[str, Any] + admin_public: bool + curator_public: bool @classmethod def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot": @@ -105,7 +113,7 @@ class CredentialSnapshot(CredentialBase): id=credential.id, credential_json=( mask_credential_dict(credential.credential_json) - if MASK_CREDENTIAL_PREFIX + if MASK_CREDENTIAL_PREFIX and credential.credential_json else credential.credential_json ), user_id=credential.user_id, @@ -114,6 +122,7 @@ class CredentialSnapshot(CredentialBase): time_updated=credential.time_updated, source=credential.source or DocumentSource.NOT_APPLICABLE, name=credential.name, + curator_public=credential.curator_public, ) @@ -185,6 +194,8 @@ class CCPairFullInfo(BaseModel): credential: CredentialSnapshot index_attempts: list[IndexAttemptSnapshot] latest_deletion_attempt: DeletionAttemptSnapshot | None + is_public: bool + is_editable_for_current_user: bool @classmethod def from_models( @@ -193,6 +204,7 @@ class CCPairFullInfo(BaseModel): index_attempt_models: list[IndexAttempt], latest_deletion_attempt: DeletionAttemptSnapshot | None, num_docs_indexed: int, # not ideal, but this must be computed separately + is_editable_for_current_user: bool, ) -> "CCPairFullInfo": return cls( id=cc_pair_model.id, @@ -210,6 +222,8 @@ class CCPairFullInfo(BaseModel): for index_attempt_model in index_attempt_models ], latest_deletion_attempt=latest_deletion_attempt, + is_public=cc_pair_model.is_public, + is_editable_for_current_user=is_editable_for_current_user, ) @@ -241,6 +255,7 @@ class ConnectorCredentialPairIdentifier(BaseModel): class ConnectorCredentialPairMetadata(BaseModel): name: str | None is_public: bool + groups: list[int] | None class ConnectorCredentialPairDescriptor(BaseModel): diff --git a/backend/danswer/server/features/document_set/api.py b/backend/danswer/server/features/document_set/api.py index f939329bf..cbce90997 100644 --- a/backend/danswer/server/features/document_set/api.py +++ b/backend/danswer/server/features/document_set/api.py @@ -1,18 +1,22 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from fastapi import Query from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user +from danswer.auth.users import validate_curator_request from danswer.db.document_set import check_document_sets_are_public -from danswer.db.document_set import fetch_all_document_sets +from danswer.db.document_set import fetch_all_document_sets_for_user from danswer.db.document_set import fetch_user_document_sets from danswer.db.document_set import insert_document_set from danswer.db.document_set import mark_document_set_as_to_be_deleted from danswer.db.document_set import update_document_set from danswer.db.engine import get_session from danswer.db.models import User +from danswer.db.models import UserRole from danswer.server.documents.models import ConnectorCredentialPairDescriptor from danswer.server.documents.models import ConnectorSnapshot from danswer.server.documents.models import CredentialSnapshot @@ -29,9 +33,14 @@ router = APIRouter(prefix="/manage") @router.post("/admin/document-set") def create_document_set( document_set_creation_request: DocumentSetCreationRequest, - user: User = Depends(current_admin_user), + user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> int: + if user and user.role != UserRole.ADMIN: + validate_curator_request( + groups=document_set_creation_request.groups, + is_public=document_set_creation_request.is_public, + ) try: document_set_db_model, _ = insert_document_set( document_set_creation_request=document_set_creation_request, @@ -74,12 +83,17 @@ def delete_document_set( @router.get("/admin/document-set") def list_document_sets_admin( - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + get_editable: bool = Query( + False, description="If true, return editable document sets" + ), ) -> list[DocumentSet]: return [ DocumentSet.from_model(ds) - for ds in fetch_all_document_sets(db_session=db_session) + for ds in fetch_all_document_sets_for_user( + db_session=db_session, user=user, get_editable=get_editable + ) ] diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index 2ea68f581..72b16d719 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -3,11 +3,13 @@ from uuid import UUID from fastapi import APIRouter from fastapi import Depends +from fastapi import Query from fastapi import UploadFile from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user from danswer.configs.constants import FileOrigin from danswer.db.engine import get_session @@ -45,13 +47,14 @@ class IsVisibleRequest(BaseModel): def patch_persona_visibility( persona_id: int, is_visible_request: IsVisibleRequest, - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> None: update_persona_visibility( persona_id=persona_id, is_visible=is_visible_request.is_visible, db_session=db_session, + user=user, ) @@ -69,15 +72,17 @@ def patch_persona_display_priority( @admin_router.get("") def list_personas_admin( - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), include_deleted: bool = False, + get_editable: bool = Query(False, description="If true, return editable personas"), ) -> list[PersonaSnapshot]: return [ PersonaSnapshot.from_model(persona) for persona in get_personas( db_session=db_session, - user_id=None, # user_id = None -> give back all personas + user=user, + get_editable=get_editable, include_deleted=include_deleted, joinedload_all=True, ) @@ -187,13 +192,13 @@ def list_personas( db_session: Session = Depends(get_session), include_deleted: bool = False, ) -> list[PersonaSnapshot]: - user_id = user.id if user is not None else None return [ PersonaSnapshot.from_model(persona) for persona in get_personas( - user_id=user_id, + user=user, include_deleted=include_deleted, db_session=db_session, + get_editable=False, joinedload_all=True, ) ] diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index b931db6df..0ac90ba8d 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -9,6 +9,7 @@ from fastapi import HTTPException from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.constants import DocumentSource from danswer.configs.constants import KV_GEN_AI_KEY_CHECK_TIME @@ -35,6 +36,7 @@ from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.manage.models import BoostDoc from danswer.server.manage.models import BoostUpdateRequest from danswer.server.manage.models import HiddenUpdateRequest +from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger router = APIRouter(prefix="/manage") @@ -47,11 +49,14 @@ logger = setup_logger() def get_most_boosted_docs( ascending: bool, limit: int, - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[BoostDoc]: boost_docs = fetch_docs_ranked_by_boost( - ascending=ascending, limit=limit, db_session=db_session + ascending=ascending, + limit=limit, + db_session=db_session, + user=user, ) return [ BoostDoc( @@ -69,45 +74,43 @@ def get_most_boosted_docs( @router.post("/admin/doc-boosts") def document_boost_update( boost_update: BoostUpdateRequest, - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), -) -> None: +) -> StatusResponse: curr_ind_name, sec_ind_name = get_both_index_names(db_session) document_index = get_default_document_index( primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name ) - try: - update_document_boost( - db_session=db_session, - document_id=boost_update.document_id, - boost=boost_update.boost, - document_index=document_index, - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) + update_document_boost( + db_session=db_session, + document_id=boost_update.document_id, + boost=boost_update.boost, + document_index=document_index, + user=user, + ) + return StatusResponse(success=True, message="Updated document boost") @router.post("/admin/doc-hidden") def document_hidden_update( hidden_update: HiddenUpdateRequest, - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), -) -> None: +) -> StatusResponse: curr_ind_name, sec_ind_name = get_both_index_names(db_session) document_index = get_default_document_index( primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name ) - try: - update_document_hidden( - db_session=db_session, - document_id=hidden_update.document_id, - hidden=hidden_update.hidden, - document_index=document_index, - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) + update_document_hidden( + db_session=db_session, + document_id=hidden_update.document_id, + hidden=hidden_update.hidden, + document_index=document_index, + user=user, + ) + return StatusResponse(success=True, message="Updated document boost") @router.get("/admin/genai-api-key/validate") @@ -145,7 +148,7 @@ def validate_existing_genai_api_key( @router.post("/admin/deletion-attempt") def create_deletion_attempt_for_connector_id( connector_credential_pair_identifier: ConnectorCredentialPairIdentifier, - _: User = Depends(current_admin_user), + user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> None: from danswer.background.celery.celery_app import ( @@ -159,6 +162,8 @@ def create_deletion_attempt_for_connector_id( db_session=db_session, connector_id=connector_id, credential_id=credential_id, + user=user, + get_editable=True, ) if cc_pair is None: raise HTTPException( @@ -196,5 +201,5 @@ def create_deletion_attempt_for_connector_id( if cc_pair.connector.source == DocumentSource.FILE: connector = cc_pair.connector file_store = get_default_file_store(db_session) - for file_name in connector.connector_specific_config["file_locations"]: + for file_name in connector.connector_specific_config.get("file_locations", []): file_store.delete_file(file_name) diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index 70f4a5ada..c080e4275 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -89,6 +89,11 @@ class UserByEmail(BaseModel): user_email: str +class UserRoleUpdateRequest(BaseModel): + user_email: str + new_role: UserRole + + class UserRoleResponse(BaseModel): role: str diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index dac134946..16ebc3e2f 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -22,6 +22,7 @@ from danswer.auth.noauth_user import set_no_auth_user_preferences from danswer.auth.schemas import UserRole from danswer.auth.schemas import UserStatus from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user from danswer.auth.users import optional_user from danswer.configs.app_configs import AUTH_TYPE @@ -38,11 +39,13 @@ from danswer.server.manage.models import AllUsersResponse from danswer.server.manage.models import UserByEmail from danswer.server.manage.models import UserInfo from danswer.server.manage.models import UserRoleResponse +from danswer.server.manage.models import UserRoleUpdateRequest from danswer.server.models import FullUserSnapshot from danswer.server.models import InvitedUserSnapshot from danswer.server.models import MinimalUserSnapshot from danswer.utils.logger import setup_logger from ee.danswer.db.api_key import is_api_key_email_address +from ee.danswer.db.user_group import remove_curator_status__no_commit logger = setup_logger() @@ -52,42 +55,38 @@ router = APIRouter() USERS_PAGE_SIZE = 10 -@router.patch("/manage/promote-user-to-admin") -def promote_admin( - user_email: UserByEmail, - _: User = Depends(current_admin_user), +@router.patch("/manage/set-user-role") +def set_user_role( + user_role_update_request: UserRoleUpdateRequest, + current_user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> None: - user_to_promote = get_user_by_email( - email=user_email.user_email, db_session=db_session + user_to_update = get_user_by_email( + email=user_role_update_request.user_email, db_session=db_session ) - if not user_to_promote: + if not user_to_update: raise HTTPException(status_code=404, detail="User not found") - user_to_promote.role = UserRole.ADMIN - db_session.add(user_to_promote) - db_session.commit() - - -@router.patch("/manage/demote-admin-to-basic") -async def demote_admin( - user_email: UserByEmail, - user: User = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> None: - user_to_demote = get_user_by_email( - email=user_email.user_email, db_session=db_session - ) - if not user_to_demote: - raise HTTPException(status_code=404, detail="User not found") - - if user_to_demote.id == user.id: + if user_role_update_request.new_role == UserRole.CURATOR: raise HTTPException( - status_code=400, detail="Cannot demote yourself from admin role!" + status_code=400, + detail="Curator role must be set via the User Group Menu", ) - user_to_demote.role = UserRole.BASIC - db_session.add(user_to_demote) + if user_to_update.role == user_role_update_request.new_role: + return + + if current_user.id == user_to_update.id: + raise HTTPException( + status_code=400, + detail="An admin cannot demote themselves from admin role!", + ) + + if user_to_update.role == UserRole.CURATOR: + remove_curator_status__no_commit(db_session, user_to_update) + + user_to_update.role = user_role_update_request.new_role.value + db_session.commit() @@ -96,7 +95,7 @@ def list_all_users( q: str | None = None, accepted_page: int | None = None, invited_page: int | None = None, - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> AllUsersResponse: if not q: @@ -104,7 +103,7 @@ def list_all_users( users = [ user - for user in list_users(db_session, q=q) + for user in list_users(db_session, email_filter_string=q, user=user) if not is_api_key_email_address(user.email) ] accepted_emails = {user.email for user in users} diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 71cf3092c..ec170013c 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -4,7 +4,7 @@ from fastapi import HTTPException from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user from danswer.configs.constants import DocumentSource from danswer.configs.constants import MessageType @@ -50,7 +50,7 @@ basic_router = APIRouter(prefix="/query") @admin_router.post("/search") def admin_search( question: AdminSearchRequest, - user: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> AdminSearchResponse: query = question.query diff --git a/backend/ee/danswer/db/token_limit.py b/backend/ee/danswer/db/token_limit.py index 9b1538116..95dd00118 100644 --- a/backend/ee/danswer/db/token_limit.py +++ b/backend/ee/danswer/db/token_limit.py @@ -1,16 +1,70 @@ from collections.abc import Sequence +from sqlalchemy import exists from sqlalchemy import Row +from sqlalchemy import Select from sqlalchemy import select +from sqlalchemy.orm import aliased from sqlalchemy.orm import Session from danswer.configs.constants import TokenRateLimitScope from danswer.db.models import TokenRateLimit from danswer.db.models import TokenRateLimit__UserGroup +from danswer.db.models import User +from danswer.db.models import User__UserGroup from danswer.db.models import UserGroup +from danswer.db.models import UserRole from danswer.server.token_rate_limits.models import TokenRateLimitArgs +def _add_user_filters( + stmt: Select, user: User | None, get_editable: bool = True +) -> Select: + # If user is None, assume the user is an admin or auth is disabled + if user is None or user.role == UserRole.ADMIN: + return stmt + + TRLimit_UG = aliased(TokenRateLimit__UserGroup) + User__UG = aliased(User__UserGroup) + + """ + Here we select token_rate_limits by relation: + User -> User__UserGroup -> TokenRateLimit__UserGroup -> + TokenRateLimit + """ + stmt = stmt.outerjoin(TRLimit_UG).outerjoin( + User__UG, + User__UG.user_group_id == TRLimit_UG.user_group_id, + ) + + """ + Filter token_rate_limits by: + - if the user is in the user_group that owns the token_rate_limit + - if the user is not a global_curator, they must also have a curator relationship + to the user_group + - if editing is being done, we also filter out token_rate_limits that are owned by groups + that the user isn't a curator for + - if we are not editing, we show all token_rate_limits in the groups the user curates + """ + where_clause = User__UG.user_id == user.id + if user.role == UserRole.CURATOR and get_editable: + where_clause &= User__UG.is_curator == True # noqa: E712 + if get_editable: + user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id) + if user.role == UserRole.CURATOR: + user_groups = user_groups.where( + User__UserGroup.is_curator == True # noqa: E712 + ) + where_clause &= ( + ~exists() + .where(TRLimit_UG.rate_limit_id == TokenRateLimit.id) + .where(~TRLimit_UG.user_group_id.in_(user_groups)) + .correlate(TokenRateLimit) + ) + + return stmt.where(where_clause) + + def fetch_all_user_token_rate_limits( db_session: Session, enabled_only: bool = False, @@ -48,29 +102,25 @@ def fetch_all_global_token_rate_limits( return token_rate_limits -def fetch_all_user_group_token_rate_limits( - db_session: Session, group_id: int, enabled_only: bool = False, ordered: bool = True +def fetch_user_group_token_rate_limits( + db_session: Session, + group_id: int, + user: User | None = None, + enabled_only: bool = False, + ordered: bool = True, + get_editable: bool = True, ) -> Sequence[TokenRateLimit]: - query = ( - select(TokenRateLimit) - .join( - TokenRateLimit__UserGroup, - TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id, - ) - .where( - TokenRateLimit__UserGroup.user_group_id == group_id, - TokenRateLimit.scope == TokenRateLimitScope.USER_GROUP, - ) - ) + stmt = select(TokenRateLimit) + stmt = stmt.where(User__UserGroup.user_group_id == group_id) + stmt = _add_user_filters(stmt, user, get_editable) if enabled_only: - query = query.where(TokenRateLimit.enabled.is_(True)) + stmt = stmt.where(TokenRateLimit.enabled.is_(True)) if ordered: - query = query.order_by(TokenRateLimit.created_at.desc()) + stmt = stmt.order_by(TokenRateLimit.created_at.desc()) - token_rate_limits = db_session.scalars(query).all() - return token_rate_limits + return db_session.scalars(stmt).all() def fetch_all_user_group_token_rate_limits_by_group( diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index bdcb296fa..93cfff36e 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -5,11 +5,13 @@ from uuid import UUID from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import select +from sqlalchemy import update from sqlalchemy.orm import Session from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair +from danswer.db.models import Credential__UserGroup from danswer.db.models import Document from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.models import LLMProvider__UserGroup @@ -18,9 +20,15 @@ from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.db.models import UserGroup from danswer.db.models import UserGroup__ConnectorCredentialPair +from danswer.db.models import UserRole +from danswer.db.users import fetch_user_by_id +from danswer.utils.logger import setup_logger +from ee.danswer.server.user_group.models import SetCuratorRequest from ee.danswer.server.user_group.models import UserGroupCreate from ee.danswer.server.user_group.models import UserGroupUpdate +logger = setup_logger() + def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None: stmt = select(UserGroup).where(UserGroup.id == user_group_id) @@ -37,7 +45,7 @@ def fetch_user_groups( def fetch_user_groups_for_user( - db_session: Session, user_id: UUID + db_session: Session, user_id: UUID, only_curator_groups: bool = False ) -> Sequence[UserGroup]: stmt = ( select(UserGroup) @@ -45,6 +53,8 @@ def fetch_user_groups_for_user( .join(User, User.id == User__UserGroup.user_id) # type: ignore .where(User.id == user_id) # type: ignore ) + if only_curator_groups: + stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712 return db_session.scalars(stmt).all() @@ -179,16 +189,32 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG def _cleanup_user__user_group_relationships__no_commit( - db_session: Session, user_group_id: int + db_session: Session, + user_group_id: int, + user_ids: list[UUID] | None = None, ) -> None: """NOTE: does not commit the transaction.""" + where_clause = User__UserGroup.user_group_id == user_group_id + if user_ids: + where_clause &= User__UserGroup.user_id.in_(user_ids) + user__user_group_relationships = db_session.scalars( - select(User__UserGroup).where(User__UserGroup.user_group_id == user_group_id) + select(User__UserGroup).where(where_clause) ).all() for user__user_group_relationship in user__user_group_relationships: db_session.delete(user__user_group_relationship) +def _cleanup_credential__user_group_relationships__no_commit( + db_session: Session, + user_group_id: int, +) -> None: + """NOTE: does not commit the transaction.""" + db_session.query(Credential__UserGroup).filter( + Credential__UserGroup.user_group_id == user_group_id + ).delete(synchronize_session=False) + + def _cleanup_llm_provider__user_group_relationships__no_commit( db_session: Session, user_group_id: int ) -> None: @@ -211,8 +237,84 @@ def _mark_user_group__cc_pair_relationships_outdated__no_commit( user_group__cc_pair_relationship.is_current = False +def _validate_curator_status__no_commit( + db_session: Session, + users: list[User], +) -> None: + for user in users: + # Check if the user is a curator in any of their groups + curator_relationships = ( + db_session.query(User__UserGroup) + .filter( + User__UserGroup.user_id == user.id, + User__UserGroup.is_curator == True, # noqa: E712 + ) + .all() + ) + + if curator_relationships: + user.role = UserRole.CURATOR + elif user.role == UserRole.CURATOR: + user.role = UserRole.BASIC + db_session.add(user) + + +def remove_curator_status__no_commit(db_session: Session, user: User) -> None: + stmt = ( + update(User__UserGroup) + .where(User__UserGroup.user_id == user.id) + .values(is_curator=False) + ) + db_session.execute(stmt) + _validate_curator_status__no_commit(db_session, [user]) + + +def update_user_curator_relationship( + db_session: Session, + user_group_id: int, + set_curator_request: SetCuratorRequest, +) -> None: + user = fetch_user_by_id(db_session, set_curator_request.user_id) + if not user: + raise ValueError(f"User with id '{set_curator_request.user_id}' not found") + requested_user_groups = fetch_user_groups_for_user( + db_session=db_session, + user_id=set_curator_request.user_id, + only_curator_groups=False, + ) + + group_ids = [group.id for group in requested_user_groups] + if user_group_id not in group_ids: + raise ValueError(f"user is not in group '{user_group_id}'") + + relationship_to_update = ( + db_session.query(User__UserGroup) + .filter( + User__UserGroup.user_group_id == user_group_id, + User__UserGroup.user_id == set_curator_request.user_id, + ) + .first() + ) + + if relationship_to_update: + relationship_to_update.is_curator = set_curator_request.is_curator + else: + relationship_to_update = User__UserGroup( + user_group_id=user_group_id, + user_id=set_curator_request.user_id, + is_curator=True, + ) + db_session.add(relationship_to_update) + + _validate_curator_status__no_commit(db_session, [user]) + db_session.commit() + + def update_user_group( - db_session: Session, user_group_id: int, user_group: UserGroupUpdate + db_session: Session, + user: User | None, + user_group_id: int, + user_group_update: UserGroupUpdate, ) -> UserGroup: stmt = select(UserGroup).where(UserGroup.id == user_group_id) db_user_group = db_session.scalar(stmt) @@ -221,23 +323,33 @@ def update_user_group( _check_user_group_is_modifiable(db_user_group) - existing_cc_pairs = db_user_group.cc_pairs - cc_pairs_updated = set([cc_pair.id for cc_pair in existing_cc_pairs]) != set( - user_group.cc_pair_ids - ) - users_updated = set([user.id for user in db_user_group.users]) != set( - user_group.user_ids - ) + current_user_ids = set([user.id for user in db_user_group.users]) + updated_user_ids = set(user_group_update.user_ids) + added_user_ids = list(updated_user_ids - current_user_ids) + removed_user_ids = list(current_user_ids - updated_user_ids) - if users_updated: + if (removed_user_ids or added_user_ids) and ( + not user or user.role != UserRole.ADMIN + ): + raise ValueError("Only admins can add or remove users from user groups") + + if removed_user_ids: _cleanup_user__user_group_relationships__no_commit( - db_session=db_session, user_group_id=user_group_id + db_session=db_session, + user_group_id=user_group_id, + user_ids=removed_user_ids, ) + + if added_user_ids: _add_user__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id, - user_ids=user_group.user_ids, + user_ids=added_user_ids, ) + + cc_pairs_updated = set([cc_pair.id for cc_pair in db_user_group.cc_pairs]) != set( + user_group_update.cc_pair_ids + ) if cc_pairs_updated: _mark_user_group__cc_pair_relationships_outdated__no_commit( db_session=db_session, user_group_id=user_group_id @@ -245,13 +357,17 @@ def update_user_group( _add_user_group__cc_pair_relationships__no_commit( db_session=db_session, user_group_id=db_user_group.id, - cc_pair_ids=user_group.cc_pair_ids, + cc_pair_ids=user_group_update.cc_pair_ids, ) # only needs to sync with Vespa if the cc_pairs have been updated if cc_pairs_updated: db_user_group.is_up_to_date = False + removed_users = db_session.scalars( + select(User).where(User.id.in_(removed_user_ids)) # type: ignore + ).unique() + _validate_curator_status__no_commit(db_session, list(removed_users)) db_session.commit() return db_user_group @@ -279,6 +395,9 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> _check_user_group_is_modifiable(db_user_group) + _cleanup_credential__user_group_relationships__no_commit( + db_session=db_session, user_group_id=user_group_id + ) _cleanup_user__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) diff --git a/backend/ee/danswer/server/token_rate_limits/api.py b/backend/ee/danswer/server/token_rate_limits/api.py index aac3ebb16..97f1f15fa 100644 --- a/backend/ee/danswer/server/token_rate_limits/api.py +++ b/backend/ee/danswer/server/token_rate_limits/api.py @@ -5,14 +5,15 @@ from fastapi import Depends from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.db.engine import get_session from danswer.db.models import User from danswer.server.query_and_chat.token_limit import any_rate_limit_exists from danswer.server.token_rate_limits.models import TokenRateLimitArgs from danswer.server.token_rate_limits.models import TokenRateLimitDisplay -from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits_by_group from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits +from ee.danswer.db.token_limit import fetch_user_group_token_rate_limits from ee.danswer.db.token_limit import insert_user_group_token_rate_limit from ee.danswer.db.token_limit import insert_user_token_rate_limit @@ -45,13 +46,13 @@ def get_all_group_token_limit_settings( @router.get("/user-group/{group_id}") def get_group_token_limit_settings( group_id: int, - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[TokenRateLimitDisplay]: return [ TokenRateLimitDisplay.from_db(token_rate_limit) - for token_rate_limit in fetch_all_user_group_token_rate_limits( - db_session, group_id + for token_rate_limit in fetch_user_group_token_rate_limits( + db_session, group_id, user ) ] diff --git a/backend/ee/danswer/server/user_group/api.py b/backend/ee/danswer/server/user_group/api.py index 36e101001..e18487d54 100644 --- a/backend/ee/danswer/server/user_group/api.py +++ b/backend/ee/danswer/server/user_group/api.py @@ -5,12 +5,17 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.auth.users import current_curator_or_admin_user from danswer.db.engine import get_session from danswer.db.models import User +from danswer.db.models import UserRole from ee.danswer.db.user_group import fetch_user_groups +from ee.danswer.db.user_group import fetch_user_groups_for_user from ee.danswer.db.user_group import insert_user_group from ee.danswer.db.user_group import prepare_user_group_for_deletion +from ee.danswer.db.user_group import update_user_curator_relationship from ee.danswer.db.user_group import update_user_group +from ee.danswer.server.user_group.models import SetCuratorRequest from ee.danswer.server.user_group.models import UserGroup from ee.danswer.server.user_group.models import UserGroupCreate from ee.danswer.server.user_group.models import UserGroupUpdate @@ -20,10 +25,17 @@ router = APIRouter(prefix="/manage") @router.get("/admin/user-group") def list_user_groups( - _: User | None = Depends(current_admin_user), + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[UserGroup]: - user_groups = fetch_user_groups(db_session, only_current=False) + if user is None or user.role == UserRole.ADMIN: + user_groups = fetch_user_groups(db_session, only_current=False) + else: + user_groups = fetch_user_groups_for_user( + db_session=db_session, + user_id=user.id, + only_curator_groups=user.role == UserRole.CURATOR, + ) return [UserGroup.from_model(user_group) for user_group in user_groups] @@ -47,13 +59,35 @@ def create_user_group( @router.patch("/admin/user-group/{user_group_id}") def patch_user_group( user_group_id: int, - user_group: UserGroupUpdate, - _: User | None = Depends(current_admin_user), + user_group_update: UserGroupUpdate, + user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> UserGroup: try: return UserGroup.from_model( - update_user_group(db_session, user_group_id, user_group) + update_user_group( + db_session=db_session, + user=user, + user_group_id=user_group_id, + user_group_update=user_group_update, + ) + ) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.post("/admin/user-group/{user_group_id}/set-curator") +def set_user_curator( + user_group_id: int, + set_curator_request: SetCuratorRequest, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + try: + update_user_curator_relationship( + db_session=db_session, + user_group_id=user_group_id, + set_curator_request=set_curator_request, ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) diff --git a/backend/ee/danswer/server/user_group/models.py b/backend/ee/danswer/server/user_group/models.py index 22a6d55f5..077a217e9 100644 --- a/backend/ee/danswer/server/user_group/models.py +++ b/backend/ee/danswer/server/user_group/models.py @@ -16,6 +16,7 @@ class UserGroup(BaseModel): id: int name: str users: list[UserInfo] + curator_ids: list[UUID] cc_pairs: list[ConnectorCredentialPairDescriptor] document_sets: list[DocumentSet] personas: list[PersonaSnapshot] @@ -42,6 +43,11 @@ class UserGroup(BaseModel): ) for user in user_group_model.users ], + curator_ids=[ + user.user_id + for user in user_group_model.user_group_relationships + if user.is_curator and user.user_id is not None + ], cc_pairs=[ ConnectorCredentialPairDescriptor( id=cc_pair_relationship.cc_pair.id, @@ -78,3 +84,8 @@ class UserGroupCreate(BaseModel): class UserGroupUpdate(BaseModel): user_ids: list[UUID] cc_pair_ids: list[int] + + +class SetCuratorRequest(BaseModel): + user_id: UUID + is_curator: bool diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index c9e4e2be2..8d6e81006 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -4,6 +4,7 @@ import { generateRandomIconShape, createSVG } from "@/lib/assistantIconUtils"; import { CCPairBasicInfo, DocumentSet, User } from "@/lib/types"; import { Button, Divider, Italic, Text } from "@tremor/react"; +import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; import { ArrayHelpers, ErrorMessage, @@ -11,6 +12,7 @@ import { FieldArray, Form, Formik, + FormikProps, } from "formik"; import { @@ -21,10 +23,8 @@ import { } from "@/components/admin/connectors/Field"; import { usePopup } from "@/components/admin/connectors/Popup"; import { getDisplayNameForModel } from "@/lib/hooks"; -import { Bubble } from "@/components/Bubble"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; import { Option } from "@/components/Dropdown"; -import { GroupsIcon } from "@/components/icons/icons"; import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences"; import { useUserGroups } from "@/lib/hooks"; @@ -232,6 +232,9 @@ export function AssistantEditor({ const [existingPersonaImageId, setExistingPersonaImageId] = useState< string | null >(existingPersona?.uploaded_image_id || null); + + const [isRequestSuccessful, setIsRequestSuccessful] = useState(false); + return (
{popup} @@ -414,10 +417,16 @@ export function AssistantEditor({ ? `/admin/assistants?u=${Date.now()}` : `/chat?assistantId=${assistantId}` ); + setIsRequestSuccessful(true); } }} > - {({ isSubmitting, values, setFieldValue }) => { + {({ + isSubmitting, + values, + setFieldValue, + ...formikProps + }: FormikProps) => { function toggleToolInValues(toolId: number) { const updatedEnabledToolsMap = { ...values.enabled_tools_map, @@ -891,24 +900,28 @@ export function AssistantEditor({
{values.starter_messages && values.starter_messages.length > 0 && - values.starter_messages.map((_, index) => { - return ( -
-
-
-
- - - Shows up as the "title" for - this Starter Message. For example, - "Write an email". - - { + return ( +
+
+
+
+ + + Shows up as the "title" + for this Starter Message. For + example, "Write an email". + + - -
+ autoComplete="off" + /> + +
-
- - - A description which tells the user - what they might want to use this - Starter Message for. For example - "to a client about a new - feature" - - + + + A description which tells the user + what they might want to use this + Starter Message for. For example + "to a client about a new + feature" + + - -
+ autoComplete="off" + /> + +
-
- - - The actual message to be sent as the - initial user message if a user selects - this starter prompt. For example, - "Write me an email to a client - about a new billing feature we just - released." - - + + + The actual message to be sent as the + initial user message if a user + selects this starter prompt. For + example, "Write me an email to + a client about a new billing feature + we just released." + + - + +
+
+
+ + arrayHelpers.remove(index) + } />
-
- - arrayHelpers.remove(index) - } - /> -
-
- ); - })} + ); + } + )} diff --git a/web/src/app/admin/assistants/PersonaTable.tsx b/web/src/app/admin/assistants/PersonaTable.tsx index b5360e50e..9c41f100f 100644 --- a/web/src/app/admin/assistants/PersonaTable.tsx +++ b/web/src/app/admin/assistants/PersonaTable.tsx @@ -5,12 +5,14 @@ import { Persona } from "./interfaces"; import { useRouter } from "next/navigation"; import { CustomCheckbox } from "@/components/CustomCheckbox"; import { usePopup } from "@/components/admin/connectors/Popup"; -import { useState } from "react"; +import { useState, useMemo, useEffect } from "react"; import { UniqueIdentifier } from "@dnd-kit/core"; import { DraggableTable } from "@/components/table/DraggableTable"; import { deletePersona, personaComparator } from "./lib"; import { FiEdit2 } from "react-icons/fi"; import { TrashIcon } from "@/components/icons/icons"; +import { getCurrentUser } from "@/lib/user"; +import { UserRole, User } from "@/lib/types"; function PersonaTypeDisplay({ persona }: { persona: Persona }) { if (persona.default_persona) { @@ -28,21 +30,67 @@ function PersonaTypeDisplay({ persona }: { persona: Persona }) { return Personal {persona.owner && <>({persona.owner.email})}; } -export function PersonasTable({ personas }: { personas: Persona[] }) { +const togglePersonaVisibility = async ( + personaId: number, + isVisible: boolean +) => { + const response = await fetch(`/api/admin/persona/${personaId}/visible`, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + is_visible: !isVisible, + }), + }); + return response; +}; + +export function PersonasTable({ + allPersonas, + editablePersonas, +}: { + allPersonas: Persona[]; + editablePersonas: Persona[]; +}) { const router = useRouter(); const { popup, setPopup } = usePopup(); - const availablePersonaIds = new Set( - personas.map((persona) => persona.id.toString()) + const [currentUser, setCurrentUser] = useState(null); + const isAdmin = currentUser?.role === UserRole.ADMIN; + useEffect(() => { + const fetchCurrentUser = async () => { + try { + const user = await getCurrentUser(); + if (user) { + setCurrentUser(user); + } else { + console.error("Failed to fetch current user"); + } + } catch (error) { + console.error("Error fetching current user:", error); + } + }; + fetchCurrentUser(); + }, []); + + const editablePersonaIds = new Set( + editablePersonas.map((p) => p.id.toString()) ); - const sortedPersonas = [...personas]; - sortedPersonas.sort(personaComparator); + + const sortedPersonas = useMemo(() => { + const editable = editablePersonas.sort(personaComparator); + const nonEditable = allPersonas + .filter((p) => !editablePersonaIds.has(p.id.toString())) + .sort(personaComparator); + return [...editable, ...nonEditable]; + }, [allPersonas, editablePersonas]); const [finalPersonas, setFinalPersonas] = useState( sortedPersonas.map((persona) => persona.id.toString()) ); const finalPersonaValues = finalPersonas - .filter((id) => availablePersonaIds.has(id)) + .filter((id) => new Set(allPersonas.map((p) => p.id.toString())).has(id)) .map((id) => { return sortedPersonas.find( (persona) => persona.id.toString() === id @@ -82,12 +130,14 @@ export function PersonasTable({ personas }: { personas: Persona[] }) { Assistants will be displayed as options on the Chat / Search interfaces in the order they are displayed below. Assistants marked as hidden will - not be displayed. + not be displayed. Editable assistants are shown at the top. { + const isEditable = editablePersonaIds.has(persona.id.toString()); return { id: persona.id.toString(), cells: [ @@ -116,28 +166,22 @@ export function PersonasTable({ personas }: { personas: Persona[] }) {
{ - const response = await fetch( - `/api/admin/persona/${persona.id}/visible`, - { - method: "PATCH", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - is_visible: !persona.is_visible, - }), + if (isEditable) { + const response = await togglePersonaVisibility( + persona.id, + persona.is_visible + ); + if (response.ok) { + router.refresh(); + } else { + setPopup({ + type: "error", + message: `Failed to update persona - ${await response.text()}`, + }); } - ); - if (response.ok) { - router.refresh(); - } else { - setPopup({ - type: "error", - message: `Failed to update persona - ${await response.text()}`, - }); } }} - className="px-1 py-0.5 hover:bg-hover-light rounded flex cursor-pointer select-none w-fit" + className={`px-1 py-0.5 rounded flex ${isEditable ? "hover:bg-hover cursor-pointer" : ""} select-none w-fit`} >
{!persona.is_visible ? ( @@ -152,7 +196,7 @@ export function PersonasTable({ personas }: { personas: Persona[] }) {
,
- {!persona.default_persona ? ( + {!persona.default_persona && isEditable ? (
{ diff --git a/web/src/app/admin/assistants/page.tsx b/web/src/app/admin/assistants/page.tsx index 206d8da5a..159094705 100644 --- a/web/src/app/admin/assistants/page.tsx +++ b/web/src/app/admin/assistants/page.tsx @@ -9,18 +9,25 @@ import { AssistantsIcon, RobotIcon } from "@/components/icons/icons"; import { AdminPageTitle } from "@/components/admin/Title"; export default async function Page() { - const personaResponse = await fetchSS("/admin/persona"); + const allPersonaResponse = await fetchSS("/admin/persona"); + const editablePersonaResponse = await fetchSS( + "/admin/persona?get_editable=true" + ); - if (!personaResponse.ok) { + if (!allPersonaResponse.ok || !editablePersonaResponse.ok) { return ( ); } - const personas = (await personaResponse.json()) as Persona[]; + const allPersonas = (await allPersonaResponse.json()) as Persona[]; + const editablePersonas = (await editablePersonaResponse.json()) as Persona[]; return (
@@ -57,7 +64,10 @@ export default async function Page() { Existing Assistants - +
); diff --git a/web/src/app/admin/connector/[ccPairId]/lib.ts b/web/src/app/admin/connector/[ccPairId]/lib.ts index a7e7fb264..c2d02b23d 100644 --- a/web/src/app/admin/connector/[ccPairId]/lib.ts +++ b/web/src/app/admin/connector/[ccPairId]/lib.ts @@ -4,6 +4,10 @@ export function buildCCPairInfoUrl(ccPairId: string | number) { return `/api/manage/admin/cc-pair/${ccPairId}`; } -export function buildSimilarCredentialInfoURL(source_type: ValidSources) { - return `/api/manage/admin/similar-credentials/${source_type}`; +export function buildSimilarCredentialInfoURL( + source_type: ValidSources, + get_editable: boolean = false +) { + const base = `/api/manage/admin/similar-credentials/${source_type}`; + return get_editable ? `${base}?get_editable=True` : base; } diff --git a/web/src/app/admin/connector/[ccPairId]/page.tsx b/web/src/app/admin/connector/[ccPairId]/page.tsx index 7cff858e5..130237edb 100644 --- a/web/src/app/admin/connector/[ccPairId]/page.tsx +++ b/web/src/app/admin/connector/[ccPairId]/page.tsx @@ -132,7 +132,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
- {isEditing ? ( + {ccPair.is_editable_for_current_user && isEditing ? (
) : (

startEditing()} - className="group flex cursor-pointer text-3xl text-emphasis gap-x-2 items-center font-bold" + onClick={() => + ccPair.is_editable_for_current_user && startEditing() + } + className={`group flex ${ccPair.is_editable_for_current_user ? "cursor-pointer" : ""} text-3xl text-emphasis gap-x-2 items-center font-bold`} > {ccPair.name} - + {ccPair.is_editable_for_current_user && ( + + )}

)} -
- {!CONNECTOR_TYPES_THAT_CANT_REINDEX.includes( - ccPair.connector.source - ) && ( - - )} - {!isDeleting && } -
+ {ccPair.is_editable_for_current_user && ( +
+ {!CONNECTOR_TYPES_THAT_CANT_REINDEX.includes( + ccPair.connector.source + ) && ( + + )} + {!isDeleting && } +
+ )}
{totalDocsIndexed}
- {credentialTemplates[ccPair.connector.source] && ( - <> - - - Credentials - - refresh()} - /> - + {!ccPair.is_editable_for_current_user && ( +
+ {ccPair.is_public + ? "Public connectors are not editable by curators." + : "This connector belongs to groups where you don't have curator permissions, so it's not editable."} +
)} + {credentialTemplates[ccPair.connector.source] && + ccPair.is_editable_for_current_user && ( + <> + + + Credentials + + refresh()} + /> + + )}
- + {ccPair.is_editable_for_current_user && ( + + )}
diff --git a/web/src/app/admin/connector/[ccPairId]/types.ts b/web/src/app/admin/connector/[ccPairId]/types.ts index 2fa1af6c9..1cc43311e 100644 --- a/web/src/app/admin/connector/[ccPairId]/types.ts +++ b/web/src/app/admin/connector/[ccPairId]/types.ts @@ -17,4 +17,6 @@ export interface CCPairFullInfo { credential: Credential; index_attempts: IndexAttemptSnapshot[]; latest_deletion_attempt: DeletionAttemptSnapshot | null; + is_public: boolean; + is_editable_for_current_user: boolean; } diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index d0bc07da1..b13a5af96 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -12,7 +12,7 @@ import { usePopup } from "@/components/admin/connectors/Popup"; import { useFormContext } from "@/components/context/FormContext"; import { getSourceDisplayName } from "@/lib/sources"; import { SourceIcon } from "@/components/SourceIcon"; -import { useRef, useState } from "react"; +import { useRef, useState, useEffect } from "react"; import { submitConnector } from "@/components/admin/connectors/ConnectorForm"; import { deleteCredential, linkCredential } from "@/lib/credential"; import { submitFiles } from "./pages/utils/files"; @@ -59,6 +59,11 @@ export default function AddConnector({ errorHandlingFetcher, { refreshInterval: 5000 } ); + const { data: editableCredentials } = useSWR[]>( + buildSimilarCredentialInfoURL(connector, true), + errorHandlingFetcher, + { refreshInterval: 5000 } + ); const [selectedFiles, setSelectedFiles] = useState([]); const credentialTemplate = credentialTemplates[connector]; @@ -95,6 +100,7 @@ export default function AddConnector({ const [pruneFreq, setPruneFreq] = useState(defaultPrune); const [indexingStart, setIndexingStart] = useState(null); const [isPublic, setIsPublic] = useState(true); + const [groups, setGroups] = useState([]); const [createConnectorToggle, setCreateConnectorToggle] = useState(false); const formRef = useRef>(null); const [advancedFormPageState, setAdvancedFormPageState] = useState(true); @@ -110,7 +116,9 @@ export default function AddConnector({ const { liveGmailCredential } = useGmailCredentials(); const credentialActivated = - liveGDriveCredential || liveGmailCredential || currentCredential; + (connector === "google_drive" && liveGDriveCredential) || + (connector === "gmail" && liveGmailCredential) || + currentCredential; const noCredentials = credentialTemplate == null; if (noCredentials && 1 != formStep) { @@ -170,7 +178,8 @@ export default function AddConnector({ setSelectedFiles, name, AdvancedConfig, - isPublic + isPublic, + groups ); if (response) { setTimeout(() => { @@ -189,6 +198,8 @@ export default function AddConnector({ refresh_freq: refreshFreq * 60 || null, prune_freq: pruneFreq * 60 * 60 * 24 || null, indexing_start: indexingStart, + is_public: isPublic, + groups: groups, }, undefined, credentialActivated ? false : true, @@ -218,7 +229,8 @@ export default function AddConnector({ response.id, credential?.id!, name, - isPublic + isPublic, + groups ); if (linkCredentialResponse.ok) { setPopup({ @@ -247,7 +259,7 @@ export default function AddConnector({ }; const displayName = getSourceDisplayName(connector) || connector; - if (!credentials) { + if (!credentials || !editableCredentials) { return <>; } @@ -350,6 +362,7 @@ export default function AddConnector({ source={connector} defaultedCredential={currentCredential!} credentials={credentials} + editableCredentials={editableCredentials} onDeleteCredential={onDeleteCredential} onSwitch={onSwap} /> @@ -411,6 +424,8 @@ export default function AddConnector({ setName={setName} config={configuration} isPublic={isPublic} + groups={groups} + setGroups={setGroups} defaultValues={values} initialName={name} onFormStatusChange={handleFormStatusChange} @@ -430,7 +445,10 @@ export default function AddConnector({ )}
-
- If you want to update these credentials, delete the existing - credentials through the button below, and then upload a new - credentials JSON. -
- + {isAdmin ? ( + <> +
+ If you want to update these credentials, delete the existing + credentials through the button below, and then upload a new + credentials JSON. +
+ + + ) : ( + <> +
+ To change these credentials, please contact an administrator. +
+ + )}
); } @@ -242,6 +254,17 @@ export const DriveJsonUploadSection = ({ ); } + if (!isAdmin) { + return ( +
+

+ Curators are unable to set up the google drive credentials. To add a + Google Drive connector, please contact an administrator. +

+
+ ); + } + return (

diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx index 8632f1c40..8f84105e2 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/GoogleDrivePage.tsx @@ -1,12 +1,15 @@ "use client"; import React from "react"; +import { useState, useEffect } from "react"; import useSWR from "swr"; import { FetchError, errorHandlingFetcher } from "@/lib/fetcher"; import { ErrorCallout } from "@/components/ErrorCallout"; import { LoadingAnimation } from "@/components/Loading"; import { usePopup } from "@/components/admin/connectors/Popup"; import { ConnectorIndexingStatus } from "@/lib/types"; +import { getCurrentUser } from "@/lib/user"; +import { User, UserRole } from "@/lib/types"; import { usePublicCredentials } from "@/lib/hooks"; import { Title } from "@tremor/react"; import { DriveJsonUploadSection, DriveOAuthSection } from "./Credential"; @@ -18,6 +21,24 @@ import { import { GoogleDriveConfig } from "@/lib/connectors/connectors"; const GDriveMain = ({}: {}) => { + const [currentUser, setCurrentUser] = useState(null); + const isAdmin = currentUser?.role === UserRole.ADMIN; + + useEffect(() => { + const fetchCurrentUser = async () => { + try { + const user = await getCurrentUser(); + if (user) { + setCurrentUser(user); + } else { + console.error("Failed to fetch current user"); + } + } catch (error) { + console.error("Error fetching current user:", error); + } + }; + fetchCurrentUser(); + }, []); const { data: appCredentialData, isLoading: isAppCredentialLoading, @@ -119,22 +140,27 @@ const GDriveMain = ({}: {}) => { setPopup={setPopup} appCredentialData={appCredentialData} serviceAccountCredentialData={serviceAccountKeyData} + isAdmin={isAdmin} /> - - Step 2: Authenticate with Danswer - - 0} - /> + {isAdmin && ( + <> + + Step 2: Authenticate with Danswer + + 0} + /> + + )} ); }; diff --git a/web/src/app/admin/connectors/[connector]/pages/gmail/Credential.tsx b/web/src/app/admin/connectors/[connector]/pages/gmail/Credential.tsx index ac699574d..8b456884f 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gmail/Credential.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gmail/Credential.tsx @@ -145,12 +145,14 @@ interface DriveJsonUploadSectionProps { setPopup: (popupSpec: PopupSpec | null) => void; appCredentialData?: { client_id: string }; serviceAccountCredentialData?: { service_account_email: string }; + isAdmin: boolean; } export const GmailJsonUploadSection = ({ setPopup, appCredentialData, serviceAccountCredentialData, + isAdmin, }: DriveJsonUploadSectionProps) => { const { mutate } = useSWRConfig(); @@ -163,36 +165,48 @@ export const GmailJsonUploadSection = ({ {serviceAccountCredentialData.service_account_email}

-
- If you want to update these credentials, delete the existing - credentials through the button below, and then upload a new - credentials JSON. -
- + {isAdmin ? ( + <> +
+ If you want to update these credentials, delete the existing + credentials through the button below, and then upload a new + credentials JSON. +
+ + + ) : ( + <> +
+ To change these credentials, please contact an administrator. +
+ + )}
); } @@ -238,6 +252,17 @@ export const GmailJsonUploadSection = ({ ); } + if (!isAdmin) { + return ( +
+

+ Curators are unable to set up the Gmail credentials. To add a Gmail + connector, please contact an administrator. +

+
+ ); + } + return (

diff --git a/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx b/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx index 1be7c6f22..ebfdbe832 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx @@ -5,6 +5,8 @@ import { errorHandlingFetcher } from "@/lib/fetcher"; import { LoadingAnimation } from "@/components/Loading"; import { usePopup } from "@/components/admin/connectors/Popup"; import { ConnectorIndexingStatus } from "@/lib/types"; +import { getCurrentUser } from "@/lib/user"; +import { User, UserRole } from "@/lib/types"; import { Credential, GmailCredentialJson, @@ -14,8 +16,27 @@ import { GmailOAuthSection, GmailJsonUploadSection } from "./Credential"; import { usePublicCredentials } from "@/lib/hooks"; import { Title } from "@tremor/react"; import { GmailConfig } from "@/lib/connectors/connectors"; +import { useState, useEffect } from "react"; export const GmailMain = () => { + const [currentUser, setCurrentUser] = useState(null); + const isAdmin = currentUser?.role === UserRole.ADMIN; + + useEffect(() => { + const fetchCurrentUser = async () => { + try { + const user = await getCurrentUser(); + if (user) { + setCurrentUser(user); + } else { + console.error("Failed to fetch current user"); + } + } catch (error) { + console.error("Error fetching current user:", error); + } + }; + fetchCurrentUser(); + }, []); const { data: appCredentialData, isLoading: isAppCredentialLoading, @@ -126,20 +147,25 @@ export const GmailMain = () => { setPopup={setPopup} appCredentialData={appCredentialData} serviceAccountCredentialData={serviceAccountKeyData} + isAdmin={isAdmin} /> - - Step 2: Authenticate with Danswer - - 0} - /> + {isAdmin && ( + <> + + Step 2: Authenticate with Danswer + + 0} + /> + + )} ); }; diff --git a/web/src/app/admin/connectors/[connector]/pages/utils/files.ts b/web/src/app/admin/connectors/[connector]/pages/utils/files.ts index 7ef335643..bd7ee8bd5 100644 --- a/web/src/app/admin/connectors/[connector]/pages/utils/files.ts +++ b/web/src/app/admin/connectors/[connector]/pages/utils/files.ts @@ -10,7 +10,8 @@ export const submitFiles = async ( setSelectedFiles: (files: File[]) => void, name: string, advancedConfig: AdvancedConfig, - isPublic: boolean + isPublic: boolean, + groups?: number[] ) => { const formData = new FormData(); @@ -43,6 +44,8 @@ export const submitFiles = async ( refresh_freq: null, prune_freq: null, indexing_start: null, + is_public: isPublic, + groups: groups, }); if (connectorErrorMsg || !connector) { setPopup({ @@ -60,6 +63,8 @@ export const submitFiles = async ( credential_json: {}, admin_public: true, source: "file", + curator_public: isPublic, + groups: groups, name, }); if (!createCredentialResponse.ok) { @@ -77,7 +82,8 @@ export const submitFiles = async ( connector.id, credentialId, name, - isPublic + isPublic, + groups ); if (!credentialResponse.ok) { const credentialResponseJson = await credentialResponse.json(); diff --git a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx index 89a73bf3c..615b2cc0a 100644 --- a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx +++ b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx @@ -3,16 +3,17 @@ import { ArrayHelpers, FieldArray, Form, Formik } from "formik"; import * as Yup from "yup"; import { PopupSpec } from "@/components/admin/connectors/Popup"; -import { createDocumentSet, updateDocumentSet } from "./lib"; -import { ConnectorIndexingStatus, DocumentSet, UserGroup } from "@/lib/types"; import { - BooleanFormField, - TextFormField, -} from "@/components/admin/connectors/Field"; + createDocumentSet, + updateDocumentSet, + DocumentSetCreationRequest, +} from "./lib"; +import { ConnectorIndexingStatus, DocumentSet, UserGroup } from "@/lib/types"; +import { TextFormField } from "@/components/admin/connectors/Field"; import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle"; import { Button, Divider, Text } from "@tremor/react"; -import { FiUsers } from "react-icons/fi"; import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; +import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; interface SetCreationPopupProps { ccPairs: ConnectorIndexingStatus[]; @@ -35,22 +36,17 @@ export const DocumentSetCreationForm = ({ return (

- initialValues={{ - name: existingDocumentSet ? existingDocumentSet.name : "", - description: existingDocumentSet - ? existingDocumentSet.description - : "", - cc_pair_ids: existingDocumentSet - ? existingDocumentSet.cc_pair_descriptors.map( - (ccPairDescriptor) => { - return ccPairDescriptor.id; - } - ) - : ([] as number[]), - is_public: existingDocumentSet ? existingDocumentSet.is_public : true, - users: existingDocumentSet ? existingDocumentSet.users : [], - groups: existingDocumentSet ? existingDocumentSet.groups : [], + name: existingDocumentSet?.name ?? "", + description: existingDocumentSet?.description ?? "", + cc_pair_ids: + existingDocumentSet?.cc_pair_descriptors.map( + (ccPairDescriptor) => ccPairDescriptor.id + ) ?? [], + is_public: existingDocumentSet?.is_public ?? true, + users: existingDocumentSet?.users ?? [], + groups: existingDocumentSet?.groups ?? [], }} validationSchema={Yup.object().shape({ name: Yup.string().required("Please enter a name for the set"), @@ -74,6 +70,7 @@ export const DocumentSetCreationForm = ({ response = await updateDocumentSet({ id: existingDocumentSet.id, ...processedValues, + users: processedValues.users, }); } else { response = await createDocumentSet(processedValues); @@ -98,7 +95,7 @@ export const DocumentSetCreationForm = ({ } }} > - {({ isSubmitting, values }) => ( + {(props) => (
(
{ccPairs.map((ccPair) => { - const ind = values.cc_pair_ids.indexOf(ccPair.cc_pair_id); + const ind = props.values.cc_pair_ids.indexOf( + ccPair.cc_pair_id + ); let isSelected = ind !== -1; return (
0 && ( -
- - - - If the document set is public, then it will be visible - to all users. If it is not public, then only - users in the specified groups will be able to see it. - - } - /> - - -

- Groups with Access -

- {!values.is_public ? ( - <> - - If any groups are specified, then this Document Set will - only be visible to the specified groups. If no groups - are specified, then the Document Set will be visible to - all users. - - ( -
- {userGroups.map((userGroup) => { - const ind = values.groups.indexOf(userGroup.id); - let isSelected = ind !== -1; - return ( -
{ - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push(userGroup.id); - } - }} - > -
- {" "} - {userGroup.name} -
-
- ); - })} -
- )} - /> - - ) : ( - - This Document Set is public, so this does not apply. If - you want to control which user groups see this Document - Set, mark it as non-public! - - )} -
+ )}