danswer/backend/ee/onyx/db/user_group.py
pablonyx 7490250e91
Fix user group edge case (#4159)
* fix user group

* k
2025-02-28 23:55:21 +00:00

733 lines
26 KiB
Python

from collections.abc import Sequence
from operator import and_
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.models import UserRole
from onyx.db.users import fetch_user_by_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _cleanup_user__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
user_ids: list[UUID] | None = None,
) -> None:
"""NOTE: does not commit the transaction."""
where_clause = User__UserGroup.user_group_id == user_group_id
if user_ids:
where_clause &= User__UserGroup.user_id.in_(user_ids)
user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(where_clause)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)
def _cleanup_credential__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Credential__UserGroup).filter(
Credential__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_persona__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
token_rate_limit__user_group_relationships = db_session.scalars(
select(TokenRateLimit__UserGroup).where(
TokenRateLimit__UserGroup.user_group_id == user_group_id
)
).all()
for (
token_rate_limit__user_group_relationship
) in token_rate_limit__user_group_relationships:
db_session.delete(token_rate_limit__user_group_relationship)
def _cleanup_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, outdated_only: bool
) -> None:
"""NOTE: does not commit the transaction."""
stmt = select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
if outdated_only:
stmt = stmt.where(
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
)
user_group__cc_pair_relationships = db_session.scalars(stmt)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
db_session.delete(user_group__cc_pair_relationship)
def _cleanup_document_set__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.execute(
delete(DocumentSet__UserGroup).where(
DocumentSet__UserGroup.user_group_id == user_group_id
)
)
def validate_object_creation_for_user(
db_session: Session,
user: User | None,
target_group_ids: list[int] | None = None,
object_is_public: bool | None = None,
object_is_perm_sync: bool | None = None,
) -> None:
"""
All users can create/edit permission synced objects if they don't specify a group
All admin actions are allowed.
Prevents non-admins from creating/editing:
- public objects
- objects with no groups
- objects that belong to a group they don't curate
"""
if object_is_perm_sync and not target_group_ids:
return
if not user or user.role == UserRole.ADMIN:
return
if object_is_public:
detail = "User does not have permission to create public credentials"
logger.error(detail)
raise HTTPException(
status_code=400,
detail=detail,
)
if not target_group_ids:
detail = "Curators must specify 1+ groups"
logger.error(detail)
raise HTTPException(
status_code=400,
detail=detail,
)
user_curated_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
# Global curators can curate all groups they are member of
only_curator_groups=user.role != UserRole.GLOBAL_CURATOR,
)
user_curated_group_ids = set([group.id for group in user_curated_groups])
target_group_ids_set = set(target_group_ids)
if not target_group_ids_set.issubset(user_curated_group_ids):
detail = "Curators cannot control groups they don't curate"
logger.error(detail)
raise HTTPException(
status_code=400,
detail=detail,
)
def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
return db_session.scalar(stmt)
def fetch_user_groups(
db_session: Session, only_up_to_date: bool = True
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
This function retrieves a sequence of `UserGroup` objects from the database.
If `only_up_to_date` is set to `True`, it filters the user groups to return only those
that are marked as up-to-date (`is_up_to_date` is `True`).
Args:
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
"""
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session, user_id: UUID, only_curator_groups: bool = False
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
.join(User__UserGroup, User__UserGroup.user_group_id == UserGroup.id)
.join(User, User.id == User__UserGroup.user_id) # type: ignore
.where(User.id == user_id) # type: ignore
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
return db_session.scalars(stmt).all()
def construct_document_id_select_by_usergroup(
user_group_id: int,
) -> Select:
"""This returns a statement that should be executed using
.yield_per() to minimize overhead. The primary consumers of this function
are background processing task generators."""
stmt = (
select(Document.id)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.join(
UserGroup__ConnectorCredentialPair,
UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id,
)
.join(
UserGroup,
UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id,
)
.where(UserGroup.id == user_group_id)
.order_by(Document.id)
)
stmt = stmt.distinct()
return stmt
def fetch_documents_for_user_group_paginated(
db_session: Session,
user_group_id: int,
last_document_id: str | None = None,
limit: int = 100,
) -> tuple[Sequence[Document], str | None]:
stmt = (
select(Document)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.join(
UserGroup__ConnectorCredentialPair,
UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id,
)
.join(
UserGroup,
UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id,
)
.where(UserGroup.id == user_group_id)
.order_by(Document.id)
.limit(limit)
)
if last_document_id is not None:
stmt = stmt.where(Document.id > last_document_id)
stmt = stmt.distinct()
documents = db_session.scalars(stmt).all()
return documents, documents[-1].id if documents else None
def fetch_user_groups_for_documents(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, list[str]]]:
"""
Fetches all user groups that have access to the given documents.
NOTE: this doesn't include groups if the cc_pair is access type SYNC
"""
stmt = (
select(Document.id, func.array_agg(UserGroup.name))
.join(
UserGroup__ConnectorCredentialPair,
UserGroup.id == UserGroup__ConnectorCredentialPair.user_group_id,
)
.join(
ConnectorCredentialPair,
and_(
ConnectorCredentialPair.id
== UserGroup__ConnectorCredentialPair.cc_pair_id,
ConnectorCredentialPair.access_type != AccessType.SYNC,
),
)
.join(
DocumentByConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.join(Document, Document.id == DocumentByConnectorCredentialPair.id)
.where(Document.id.in_(document_ids))
.where(UserGroup__ConnectorCredentialPair.is_current == True) # noqa: E712
# don't include CC pairs that are being deleted
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
.group_by(Document.id)
)
return db_session.execute(stmt).all() # type: ignore
def _check_user_group_is_modifiable(user_group: UserGroup) -> None:
if not user_group.is_up_to_date:
raise ValueError(
"Specified user group is currently syncing. Wait until the current "
"sync has finished before editing."
)
def _add_user__user_group_relationships__no_commit(
db_session: Session, user_group_id: int, user_ids: list[UUID]
) -> list[User__UserGroup]:
"""NOTE: does not commit the transaction."""
relationships = [
User__UserGroup(user_id=user_id, user_group_id=user_group_id)
for user_id in user_ids
]
db_session.add_all(relationships)
return relationships
def _add_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, cc_pair_ids: list[int]
) -> list[UserGroup__ConnectorCredentialPair]:
"""NOTE: does not commit the transaction."""
relationships = [
UserGroup__ConnectorCredentialPair(
user_group_id=user_group_id, cc_pair_id=cc_pair_id
)
for cc_pair_id in cc_pair_ids
]
db_session.add_all(relationships)
return relationships
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(
name=user_group.name, time_last_modified_by_user=func.now()
)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
_add_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
user_ids=user_group.user_ids,
)
_add_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
cc_pair_ids=user_group.cc_pair_ids,
)
db_session.commit()
return db_user_group
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
user_group__cc_pair_relationships = db_session.scalars(
select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
user_group__cc_pair_relationship.is_current = False
def _validate_curator_status__no_commit(
db_session: Session,
users: list[User],
) -> None:
for user in users:
# Check if the user is a curator in any of their groups
curator_relationships = (
db_session.query(User__UserGroup)
.filter(
User__UserGroup.user_id == user.id,
User__UserGroup.is_curator == True, # noqa: E712
)
.all()
)
# if the user is a curator in any of their groups, set their role to CURATOR
# otherwise, set their role to BASIC only if they were previously a CURATOR
if curator_relationships:
user.role = UserRole.CURATOR
elif user.role == UserRole.CURATOR:
user.role = UserRole.BASIC
db_session.add(user)
def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
stmt = (
update(User__UserGroup)
.where(User__UserGroup.user_id == user.id)
.values(is_curator=False)
)
db_session.execute(stmt)
_validate_curator_status__no_commit(db_session, [user])
def _validate_curator_relationship_update_requester(
db_session: Session,
user_group_id: int,
user_making_change: User | None = None,
) -> None:
"""
This function validates that the user making the change has the necessary permissions
to update the curator relationship for the target user in the given user group.
"""
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
return
# check if the user making the change is a curator in the group they are changing the curator relationship for
user_making_change_curator_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user_making_change.id,
# only check if the user making the change is a curator if they are a curator
# otherwise, they are a global_curator and can update the curator relationship
# for any group they are a member of
only_curator_groups=user_making_change.role == UserRole.CURATOR,
)
requestor_curator_group_ids = [
group.id for group in user_making_change_curator_groups
]
if user_group_id not in requestor_curator_group_ids:
raise ValueError(
f"user making change {user_making_change.email} is not a curator,"
f" admin, or global_curator for group '{user_group_id}'"
)
def _validate_curator_relationship_update_request(
db_session: Session,
user_group_id: int,
target_user: User,
) -> None:
"""
This function validates that the curator_relationship_update request itself is valid.
"""
if target_user.role == UserRole.ADMIN:
raise ValueError(
f"User '{target_user.email}' is an admin and therefore has all permissions "
"of a curator. If you'd like this user to only have curator permissions, "
"you must update their role to BASIC then assign them to be CURATOR in the "
"appropriate groups."
)
elif target_user.role == UserRole.GLOBAL_CURATOR:
raise ValueError(
f"User '{target_user.email}' is a global_curator and therefore has all "
"permissions of a curator for all groups. If you'd like this user to only "
"have curator permissions for a specific group, you must update their role "
"to BASIC then assign them to be CURATOR in the appropriate groups."
)
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
raise ValueError(
f"This endpoint can only be used to update the curator relationship for "
"users with the CURATOR or BASIC role. \n"
f"Target user: {target_user.email} \n"
f"Target user role: {target_user.role} \n"
)
# check if the target user is in the group they are changing the curator relationship for
requested_user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=target_user.id,
only_curator_groups=False,
)
group_ids = [group.id for group in requested_user_groups]
if user_group_id not in group_ids:
raise ValueError(
f"target user {target_user.email} is not in group '{user_group_id}'"
)
def update_user_curator_relationship(
db_session: Session,
user_group_id: int,
set_curator_request: SetCuratorRequest,
user_making_change: User | None = None,
) -> None:
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not target_user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
_validate_curator_relationship_update_request(
db_session=db_session,
user_group_id=user_group_id,
target_user=target_user,
)
_validate_curator_relationship_update_requester(
db_session=db_session,
user_group_id=user_group_id,
user_making_change=user_making_change,
)
logger.info(
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
f"updating the curator relationship for user={target_user.email} "
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
)
relationship_to_update = (
db_session.query(User__UserGroup)
.filter(
User__UserGroup.user_group_id == user_group_id,
User__UserGroup.user_id == set_curator_request.user_id,
)
.first()
)
if relationship_to_update:
relationship_to_update.is_curator = set_curator_request.is_curator
else:
relationship_to_update = User__UserGroup(
user_group_id=user_group_id,
user_id=set_curator_request.user_id,
is_curator=True,
)
db_session.add(relationship_to_update)
_validate_curator_status__no_commit(db_session, [target_user])
db_session.commit()
def update_user_group(
db_session: Session,
user: User | None,
user_group_id: int,
user_group_update: UserGroupUpdate,
) -> UserGroup:
"""If successful, this can set db_user_group.is_up_to_date = False.
That will be processed by check_for_vespa_user_groups_sync_task and trigger
a long running background sync to Vespa.
"""
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
if db_user_group is None:
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
_check_user_group_is_modifiable(db_user_group)
current_user_ids = set([user.id for user in db_user_group.users])
updated_user_ids = set(user_group_update.user_ids)
added_user_ids = list(updated_user_ids - current_user_ids)
removed_user_ids = list(current_user_ids - updated_user_ids)
# LEAVING THIS HERE FOR NOW FOR GIVING DIFFERENT ROLES
# ACCESS TO DIFFERENT PERMISSIONS
# if (removed_user_ids or added_user_ids) and (
# not user or user.role != UserRole.ADMIN
# ):
# raise ValueError("Only admins can add or remove users from user groups")
if removed_user_ids:
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=user_group_id,
user_ids=removed_user_ids,
)
if added_user_ids:
_add_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=user_group_id,
user_ids=added_user_ids,
)
cc_pairs_updated = set([cc_pair.id for cc_pair in db_user_group.cc_pairs]) != set(
user_group_update.cc_pair_ids
)
if cc_pairs_updated:
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_add_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
cc_pair_ids=user_group_update.cc_pair_ids,
)
# only needs to sync with Vespa if the cc_pairs have been updated
if cc_pairs_updated:
db_user_group.is_up_to_date = False
removed_users = db_session.scalars(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
# Filter out admin and global curator users before validating curator status
users_to_validate = [
user
for user in removed_users
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
]
if users_to_validate:
_validate_curator_status__no_commit(db_session, users_to_validate)
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
db_session.commit()
return db_user_group
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
if db_user_group is None:
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
_check_user_group_is_modifiable(db_user_group)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_credential__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_document_set__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_persona__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=user_group_id,
outdated_only=False,
)
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
db_user_group.is_up_to_date = False
db_user_group.is_up_for_deletion = True
db_session.commit()
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
"""
This assumes that all the fk cleanup has already been done.
"""
db_session.delete(user_group)
db_session.commit()
def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None:
# cleanup outdated relationships
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id, outdated_only=True
)
user_group.is_up_to_date = True
db_session.commit()
def delete_user_group_cc_pair_relationship__no_commit(
cc_pair_id: int, db_session: Session
) -> None:
"""Deletes all rows from UserGroup__ConnectorCredentialPair where the
connector_credential_pair_id matches the given cc_pair_id.
Should be used very carefully (only for connectors that are being deleted)."""
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist")
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
raise ValueError(
f"Connector Credential Pair '{cc_pair_id}' is not in the DELETING state. status={cc_pair.status}"
)
delete_stmt = delete(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.cc_pair_id == cc_pair_id,
)
db_session.execute(delete_stmt)