Curator role (#2166)

* Added backend support for curator role

* modal refactor

* finalized first 2 commits

same as before

finally

what was it for

* added credential, cc_pair, and cleanup

mypy is super helpful hahahahahahahahahahahaha

* curator support for personas

* added connector management permission checks

* fixed the connector creation flow

* added document access to curator

* small cleanup added comments and started ui

* groups and assistant editor

* Persona frontend

* Document set frontend

* cleaned up the entire frontend

* alembic fix

* Minor fixes

* credentials section

* some credential updates

* removed logging statements

* fixed try catch

* fixed model name

* made everything happen in one db commit

* Final cleanup

* cleaned up fast code

* mypy/build fixes

* polish

* more token rate limit polish

* fixed weird credential permissions

* Addressed chris feedback

* addressed pablo feedback

* fixed alembic

* removed deduping and caching

* polish!!!!
This commit is contained in:
hagen-danswer
2024-08-22 18:39:37 -07:00
committed by GitHub
parent 5409777e0b
commit c042a19c00
82 changed files with 3141 additions and 1205 deletions

View File

@@ -1,16 +1,70 @@
from collections.abc import Sequence
from sqlalchemy import exists
from sqlalchemy import Row
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from danswer.configs.constants import TokenRateLimitScope
from danswer.db.models import TokenRateLimit
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.db.models import UserRole
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
return stmt
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
User__UG = aliased(User__UserGroup)
"""
Here we select token_rate_limits by relation:
User -> User__UserGroup -> TokenRateLimit__UserGroup ->
TokenRateLimit
"""
stmt = stmt.outerjoin(TRLimit_UG).outerjoin(
User__UG,
User__UG.user_group_id == TRLimit_UG.user_group_id,
)
"""
Filter token_rate_limits by:
- if the user is in the user_group that owns the token_rate_limit
- if the user is not a global_curator, they must also have a curator relationship
to the user_group
- if editing is being done, we also filter out token_rate_limits that are owned by groups
that the user isn't a curator for
- if we are not editing, we show all token_rate_limits in the groups the user curates
"""
where_clause = User__UG.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UG.is_curator == True # noqa: E712
if get_editable:
user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id)
if user.role == UserRole.CURATOR:
user_groups = user_groups.where(
User__UserGroup.is_curator == True # noqa: E712
)
where_clause &= (
~exists()
.where(TRLimit_UG.rate_limit_id == TokenRateLimit.id)
.where(~TRLimit_UG.user_group_id.in_(user_groups))
.correlate(TokenRateLimit)
)
return stmt.where(where_clause)
def fetch_all_user_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
@@ -48,29 +102,25 @@ def fetch_all_global_token_rate_limits(
return token_rate_limits
def fetch_all_user_group_token_rate_limits(
db_session: Session, group_id: int, enabled_only: bool = False, ordered: bool = True
def fetch_user_group_token_rate_limits(
db_session: Session,
group_id: int,
user: User | None = None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,
) -> Sequence[TokenRateLimit]:
query = (
select(TokenRateLimit)
.join(
TokenRateLimit__UserGroup,
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
)
.where(
TokenRateLimit__UserGroup.user_group_id == group_id,
TokenRateLimit.scope == TokenRateLimitScope.USER_GROUP,
)
)
stmt = select(TokenRateLimit)
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
stmt = _add_user_filters(stmt, user, get_editable)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
stmt = stmt.order_by(TokenRateLimit.created_at.desc())
token_rate_limits = db_session.scalars(query).all()
return token_rate_limits
return db_session.scalars(stmt).all()
def fetch_all_user_group_token_rate_limits_by_group(

View File

@@ -5,11 +5,13 @@ from uuid import UUID
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential__UserGroup
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import LLMProvider__UserGroup
@@ -18,9 +20,15 @@ from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.db.users import fetch_user_by_id
from danswer.utils.logger import setup_logger
from ee.danswer.server.user_group.models import SetCuratorRequest
from ee.danswer.server.user_group.models import UserGroupCreate
from ee.danswer.server.user_group.models import UserGroupUpdate
logger = setup_logger()
def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
@@ -37,7 +45,7 @@ def fetch_user_groups(
def fetch_user_groups_for_user(
db_session: Session, user_id: UUID
db_session: Session, user_id: UUID, only_curator_groups: bool = False
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -45,6 +53,8 @@ def fetch_user_groups_for_user(
.join(User, User.id == User__UserGroup.user_id) # type: ignore
.where(User.id == user_id) # type: ignore
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
return db_session.scalars(stmt).all()
@@ -179,16 +189,32 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
def _cleanup_user__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
db_session: Session,
user_group_id: int,
user_ids: list[UUID] | None = None,
) -> None:
"""NOTE: does not commit the transaction."""
where_clause = User__UserGroup.user_group_id == user_group_id
if user_ids:
where_clause &= User__UserGroup.user_id.in_(user_ids)
user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(User__UserGroup.user_group_id == user_group_id)
select(User__UserGroup).where(where_clause)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)
def _cleanup_credential__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Credential__UserGroup).filter(
Credential__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
@@ -211,8 +237,84 @@ def _mark_user_group__cc_pair_relationships_outdated__no_commit(
user_group__cc_pair_relationship.is_current = False
def _validate_curator_status__no_commit(
db_session: Session,
users: list[User],
) -> None:
for user in users:
# Check if the user is a curator in any of their groups
curator_relationships = (
db_session.query(User__UserGroup)
.filter(
User__UserGroup.user_id == user.id,
User__UserGroup.is_curator == True, # noqa: E712
)
.all()
)
if curator_relationships:
user.role = UserRole.CURATOR
elif user.role == UserRole.CURATOR:
user.role = UserRole.BASIC
db_session.add(user)
def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
stmt = (
update(User__UserGroup)
.where(User__UserGroup.user_id == user.id)
.values(is_curator=False)
)
db_session.execute(stmt)
_validate_curator_status__no_commit(db_session, [user])
def update_user_curator_relationship(
db_session: Session,
user_group_id: int,
set_curator_request: SetCuratorRequest,
) -> None:
user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
requested_user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=set_curator_request.user_id,
only_curator_groups=False,
)
group_ids = [group.id for group in requested_user_groups]
if user_group_id not in group_ids:
raise ValueError(f"user is not in group '{user_group_id}'")
relationship_to_update = (
db_session.query(User__UserGroup)
.filter(
User__UserGroup.user_group_id == user_group_id,
User__UserGroup.user_id == set_curator_request.user_id,
)
.first()
)
if relationship_to_update:
relationship_to_update.is_curator = set_curator_request.is_curator
else:
relationship_to_update = User__UserGroup(
user_group_id=user_group_id,
user_id=set_curator_request.user_id,
is_curator=True,
)
db_session.add(relationship_to_update)
_validate_curator_status__no_commit(db_session, [user])
db_session.commit()
def update_user_group(
db_session: Session, user_group_id: int, user_group: UserGroupUpdate
db_session: Session,
user: User | None,
user_group_id: int,
user_group_update: UserGroupUpdate,
) -> UserGroup:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
@@ -221,23 +323,33 @@ def update_user_group(
_check_user_group_is_modifiable(db_user_group)
existing_cc_pairs = db_user_group.cc_pairs
cc_pairs_updated = set([cc_pair.id for cc_pair in existing_cc_pairs]) != set(
user_group.cc_pair_ids
)
users_updated = set([user.id for user in db_user_group.users]) != set(
user_group.user_ids
)
current_user_ids = set([user.id for user in db_user_group.users])
updated_user_ids = set(user_group_update.user_ids)
added_user_ids = list(updated_user_ids - current_user_ids)
removed_user_ids = list(current_user_ids - updated_user_ids)
if users_updated:
if (removed_user_ids or added_user_ids) and (
not user or user.role != UserRole.ADMIN
):
raise ValueError("Only admins can add or remove users from user groups")
if removed_user_ids:
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
db_session=db_session,
user_group_id=user_group_id,
user_ids=removed_user_ids,
)
if added_user_ids:
_add_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=user_group_id,
user_ids=user_group.user_ids,
user_ids=added_user_ids,
)
cc_pairs_updated = set([cc_pair.id for cc_pair in db_user_group.cc_pairs]) != set(
user_group_update.cc_pair_ids
)
if cc_pairs_updated:
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
@@ -245,13 +357,17 @@ def update_user_group(
_add_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
cc_pair_ids=user_group.cc_pair_ids,
cc_pair_ids=user_group_update.cc_pair_ids,
)
# only needs to sync with Vespa if the cc_pairs have been updated
if cc_pairs_updated:
db_user_group.is_up_to_date = False
removed_users = db_session.scalars(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
_validate_curator_status__no_commit(db_session, list(removed_users))
db_session.commit()
return db_user_group
@@ -279,6 +395,9 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
_check_user_group_is_modifiable(db_user_group)
_cleanup_credential__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)

View File

@@ -5,14 +5,15 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits
from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
from ee.danswer.db.token_limit import fetch_user_group_token_rate_limits
from ee.danswer.db.token_limit import insert_user_group_token_rate_limit
from ee.danswer.db.token_limit import insert_user_token_rate_limit
@@ -45,13 +46,13 @@ def get_all_group_token_limit_settings(
@router.get("/user-group/{group_id}")
def get_group_token_limit_settings(
group_id: int,
_: User | None = Depends(current_admin_user),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> list[TokenRateLimitDisplay]:
return [
TokenRateLimitDisplay.from_db(token_rate_limit)
for token_rate_limit in fetch_all_user_group_token_rate_limits(
db_session, group_id
for token_rate_limit in fetch_user_group_token_rate_limits(
db_session, group_id, user
)
]

View File

@@ -5,12 +5,17 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.models import UserRole
from ee.danswer.db.user_group import fetch_user_groups
from ee.danswer.db.user_group import fetch_user_groups_for_user
from ee.danswer.db.user_group import insert_user_group
from ee.danswer.db.user_group import prepare_user_group_for_deletion
from ee.danswer.db.user_group import update_user_curator_relationship
from ee.danswer.db.user_group import update_user_group
from ee.danswer.server.user_group.models import SetCuratorRequest
from ee.danswer.server.user_group.models import UserGroup
from ee.danswer.server.user_group.models import UserGroupCreate
from ee.danswer.server.user_group.models import UserGroupUpdate
@@ -20,10 +25,17 @@ router = APIRouter(prefix="/manage")
@router.get("/admin/user-group")
def list_user_groups(
_: User | None = Depends(current_admin_user),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
user_groups = fetch_user_groups(db_session, only_current=False)
if user is None or user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_current=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]
@@ -47,13 +59,35 @@ def create_user_group(
@router.patch("/admin/user-group/{user_group_id}")
def patch_user_group(
user_group_id: int,
user_group: UserGroupUpdate,
_: User | None = Depends(current_admin_user),
user_group_update: UserGroupUpdate,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
try:
return UserGroup.from_model(
update_user_group(db_session, user_group_id, user_group)
update_user_group(
db_session=db_session,
user=user,
user_group_id=user_group_id,
user_group_update=user_group_update,
)
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/admin/user-group/{user_group_id}/set-curator")
def set_user_curator(
user_group_id: int,
set_curator_request: SetCuratorRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
update_user_curator_relationship(
db_session=db_session,
user_group_id=user_group_id,
set_curator_request=set_curator_request,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -16,6 +16,7 @@ class UserGroup(BaseModel):
id: int
name: str
users: list[UserInfo]
curator_ids: list[UUID]
cc_pairs: list[ConnectorCredentialPairDescriptor]
document_sets: list[DocumentSet]
personas: list[PersonaSnapshot]
@@ -42,6 +43,11 @@ class UserGroup(BaseModel):
)
for user in user_group_model.users
],
curator_ids=[
user.user_id
for user in user_group_model.user_group_relationships
if user.is_curator and user.user_id is not None
],
cc_pairs=[
ConnectorCredentialPairDescriptor(
id=cc_pair_relationship.cc_pair.id,
@@ -78,3 +84,8 @@ class UserGroupCreate(BaseModel):
class UserGroupUpdate(BaseModel):
user_ids: list[UUID]
cc_pair_ids: list[int]
class SetCuratorRequest(BaseModel):
user_id: UUID
is_curator: bool