Add User Groups (a.k.a. RBAC) (#4)

This commit is contained in:
Chris Weaver
2023-10-09 09:45:07 -07:00
parent 92de6acc6f
commit 7503f8f37b
43 changed files with 2121 additions and 23 deletions

View File

@@ -0,0 +1,75 @@
from sqlalchemy.orm import Session
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
from danswer.access.access import (
get_access_for_documents as get_access_for_documents_without_groups,
)
from danswer.access.models import DocumentAccess
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import User
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from ee.danswer.db.user_group import fetch_user_groups_for_documents
from ee.danswer.db.user_group import fetch_user_groups_for_user
def _get_access_for_documents(
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
db_session: Session,
) -> dict[str, DocumentAccess]:
access_dict = get_access_for_documents_without_groups(
document_ids=document_ids,
cc_pair_to_delete=cc_pair_to_delete,
db_session=db_session,
)
user_group_info = {
document_id: group_names
for document_id, group_names in fetch_user_groups_for_documents(
db_session=db_session,
document_ids=document_ids,
cc_pair_to_delete=cc_pair_to_delete,
)
}
# overload user_ids a bit - use it for both actual User IDs + group IDs
return {
document_id: DocumentAccess(
user_ids=access.user_ids.union(user_group_info.get(document_id, [])), # type: ignore
is_public=access.is_public,
)
for document_id, access in access_dict.items()
}
def get_access_for_documents(
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
db_session: Session | None = None,
) -> dict[str, DocumentAccess]:
if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session:
return _get_access_for_documents(
document_ids, cc_pair_to_delete, db_session
)
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user IDs.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The
user should have access to a document if at least one entry in the document's ACL
matches one entry in the returned set.
NOTE: is imported in danswer.access.access by `fetch_versioned_implementation`
DO NOT REMOVE."""
user_groups = fetch_user_groups_for_user(db_session, user.id) if user else []
return set(
[prefix_user_group(user_group.name) for user_group in user_groups]
).union(get_acl_for_user_without_groups(user, db_session))

View File

@@ -0,0 +1,7 @@
from danswer.background.celery.celery import celery_app
from ee.danswer.user_groups.sync import sync_user_groups
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
def sync_user_group_task(user_group_id: int) -> None:
sync_user_groups(user_group_id=user_group_id)

View File

@@ -0,0 +1,54 @@
import time
from celery.result import AsyncResult
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.utils.logger import setup_logger
from ee.danswer.background.celery.celery import sync_user_group_task
from ee.danswer.db.user_group import fetch_user_groups
logger = setup_logger()
_ExistingTaskCache: dict[int, AsyncResult] = {}
def _user_group_sync_loop() -> None:
# cleanup tasks
existing_tasks = list(_ExistingTaskCache.items())
for user_group_id, task in existing_tasks:
if task.ready():
logger.info(
f"User Group '{user_group_id}' is complete with status "
f"{task.status}. Cleaning up."
)
del _ExistingTaskCache[user_group_id]
# kick off new tasks
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
user_groups = fetch_user_groups(db_session=db_session, only_current=False)
for user_group in user_groups:
if not user_group.is_up_to_date:
if user_group.id in _ExistingTaskCache:
logger.info(
f"User Group '{user_group.id}' is already syncing. Skipping."
)
continue
logger.info(f"User Group {user_group.id} is not synced. Syncing now!")
task = sync_user_group_task.apply_async(
kwargs=dict(user_group_id=user_group.id),
)
_ExistingTaskCache[user_group.id] = task
if __name__ == "__main__":
while True:
start = time.monotonic()
_user_group_sync_loop()
sleep_time = 30 - (time.monotonic() - start)
if sleep_time > 0:
time.sleep(sleep_time)

View File

@@ -1,14 +1,18 @@
import datetime
from uuid import UUID
from sqlalchemy import Boolean
from sqlalchemy import DateTime
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import String
from sqlalchemy import Text
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from danswer.db.models import Base
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import User
@@ -24,3 +28,69 @@ class SamlAccount(Base):
)
user: Mapped[User] = relationship("User")
"""Tables related to RBAC"""
class User__UserGroup(Base):
__tablename__ = "user__user_group"
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
class UserGroup__ConnectorCredentialPair(Base):
__tablename__ = "user_group__connector_credential_pair"
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
cc_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"), primary_key=True
)
# if `True`, then is part of the current state of the UserGroup
# if `False`, then is a part of the prior state of the UserGroup
# rows with `is_current=False` should be deleted when the UserGroup
# is updated and should not exist for a given UserGroup if
# `UserGroup.is_up_to_date == True`
is_current: Mapped[bool] = mapped_column(
Boolean,
default=True,
primary_key=True,
)
cc_pair: Mapped[ConnectorCredentialPair] = relationship(
"ConnectorCredentialPair",
)
class UserGroup(Base):
__tablename__ = "user_group"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
# whether or not changes to the UserGroup have been propogated to Vespa
is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# tell the sync job to clean up the group
is_up_for_deletion: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
users: Mapped[list[User]] = relationship(
"User",
secondary=User__UserGroup.__table__,
)
cc_pairs: Mapped[list[ConnectorCredentialPair]] = relationship(
"ConnectorCredentialPair",
secondary=UserGroup__ConnectorCredentialPair.__table__,
viewonly=True,
)
cc_pair_relationships: Mapped[
list[UserGroup__ConnectorCredentialPair]
] = relationship(
"UserGroup__ConnectorCredentialPair",
viewonly=True,
)

View File

@@ -0,0 +1,298 @@
from collections.abc import Sequence
from operator import and_
from uuid import UUID
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import User
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from ee.danswer.db.models import User__UserGroup
from ee.danswer.db.models import UserGroup
from ee.danswer.db.models import UserGroup__ConnectorCredentialPair
from ee.danswer.server.user_group.models import UserGroupCreate
from ee.danswer.server.user_group.models import UserGroupUpdate
def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
return db_session.scalar(stmt)
def fetch_user_groups(
db_session: Session, only_current: bool = True
) -> Sequence[UserGroup]:
stmt = select(UserGroup)
if only_current:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session, user_id: UUID
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
.join(User__UserGroup, User__UserGroup.user_group_id == UserGroup.id)
.join(User, User.id == User__UserGroup.user_id) # type: ignore
.where(User.id == user_id) # type: ignore
)
return db_session.scalars(stmt).all()
def fetch_documents_for_user_group(
db_session: Session, user_group_id: int
) -> Sequence[Document]:
stmt = (
select(Document)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.join(
UserGroup__ConnectorCredentialPair,
UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id,
)
.join(
UserGroup,
UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id,
)
.where(UserGroup.id == user_group_id)
)
return db_session.scalars(stmt).all()
def fetch_user_groups_for_documents(
db_session: Session,
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> Sequence[tuple[int, list[str]]]:
stmt = (
select(Document.id, func.array_agg(UserGroup.name))
.join(
UserGroup__ConnectorCredentialPair,
UserGroup.id == UserGroup__ConnectorCredentialPair.user_group_id,
)
.join(
ConnectorCredentialPair,
ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id,
)
.join(
DocumentByConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.join(Document, Document.id == DocumentByConnectorCredentialPair.id)
.where(Document.id.in_(document_ids))
.where(UserGroup__ConnectorCredentialPair.is_current == True) # noqa: E712
.group_by(Document.id)
)
# pretend that the specified cc pair doesn't exist
if cc_pair_to_delete is not None:
stmt = stmt.where(
and_(
ConnectorCredentialPair.connector_id != cc_pair_to_delete.connector_id,
ConnectorCredentialPair.credential_id
!= cc_pair_to_delete.credential_id,
)
)
return db_session.execute(stmt).all() # type: ignore
def _check_user_group_is_modifiable(user_group: UserGroup) -> None:
if not user_group.is_up_to_date:
raise ValueError(
"Specified user group is currently syncing. Wait until the current "
"sync has finished before editing."
)
def _add_user__user_group_relationships__no_commit(
db_session: Session, user_group_id: int, user_ids: list[UUID]
) -> list[User__UserGroup]:
"""NOTE: does not commit the transaction."""
relationships = [
User__UserGroup(user_id=user_id, user_group_id=user_group_id)
for user_id in user_ids
]
db_session.add_all(relationships)
return relationships
def _add_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, cc_pair_ids: list[int]
) -> list[UserGroup__ConnectorCredentialPair]:
"""NOTE: does not commit the transaction."""
relationships = [
UserGroup__ConnectorCredentialPair(
user_group_id=user_group_id, cc_pair_id=cc_pair_id
)
for cc_pair_id in cc_pair_ids
]
db_session.add_all(relationships)
return relationships
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(name=user_group.name)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
_add_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
user_ids=user_group.user_ids,
)
_add_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
cc_pair_ids=user_group.cc_pair_ids,
)
db_session.commit()
return db_user_group
def _cleanup_user__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(User__UserGroup.user_group_id == user_group_id)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
user_group__cc_pair_relationships = db_session.scalars(
select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
user_group__cc_pair_relationship.is_current = False
def update_user_group(
db_session: Session, user_group_id: int, user_group: UserGroupUpdate
) -> UserGroup:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
if db_user_group is None:
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
_check_user_group_is_modifiable(db_user_group)
existing_cc_pairs = db_user_group.cc_pairs
cc_pairs_updated = set([cc_pair.id for cc_pair in existing_cc_pairs]) != set(
user_group.cc_pair_ids
)
users_updated = set([user.id for user in db_user_group.users]) != set(
user_group.user_ids
)
if users_updated:
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_add_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=user_group_id,
user_ids=user_group.user_ids,
)
if cc_pairs_updated:
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_add_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
cc_pair_ids=user_group.cc_pair_ids,
)
# only needs to sync with Vespa if the cc_pairs have been updated
if cc_pairs_updated:
db_user_group.is_up_to_date = False
db_session.commit()
return db_user_group
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
if db_user_group is None:
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
_check_user_group_is_modifiable(db_user_group)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
db_user_group.is_up_to_date = False
db_user_group.is_up_for_deletion = True
db_session.commit()
def _cleanup_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, outdated_only: bool
) -> None:
"""NOTE: does not commit the transaction."""
stmt = select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
if outdated_only:
stmt = stmt.where(
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
)
user_group__cc_pair_relationships = db_session.scalars(stmt)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
db_session.delete(user_group__cc_pair_relationship)
def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None:
# cleanup outdated relationships
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id, outdated_only=True
)
user_group.is_up_to_date = True
db_session.commit()
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=user_group.id,
outdated_only=False,
)
db_session.delete(user_group)
db_session.commit()

View File

@@ -17,6 +17,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
from ee.danswer.server.saml import router as saml_router
from ee.danswer.server.user_group.api import router as user_group_router
logger = setup_logger()
@@ -51,6 +52,9 @@ def get_ee_application() -> FastAPI:
elif AUTH_TYPE == AuthType.SAML:
application.include_router(saml_router)
# RBAC / group access control
application.include_router(user_group_router)
return application

View File

@@ -0,0 +1,71 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
import danswer.db.models as db_models
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from ee.danswer.db.user_group import fetch_user_groups
from ee.danswer.db.user_group import insert_user_group
from ee.danswer.db.user_group import prepare_user_group_for_deletion
from ee.danswer.db.user_group import update_user_group
from ee.danswer.server.user_group.models import UserGroup
from ee.danswer.server.user_group.models import UserGroupCreate
from ee.danswer.server.user_group.models import UserGroupUpdate
router = APIRouter(prefix="/manage")
@router.get("/admin/user-group")
def list_user_groups(
_: db_models.User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
user_groups = fetch_user_groups(db_session, only_current=False)
return [UserGroup.from_model(user_group) for user_group in user_groups]
@router.post("/admin/user-group")
def create_user_group(
user_group: UserGroupCreate,
_: db_models.User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
try:
db_user_group = insert_user_group(db_session, user_group)
except IntegrityError:
raise HTTPException(
400,
f"User group with name '{user_group.name}' already exists. Please "
+ "choose a different name.",
)
return UserGroup.from_model(db_user_group)
@router.patch("/admin/user-group/{user_group_id}")
def patch_user_group(
user_group_id: int,
user_group: UserGroupUpdate,
_: db_models.User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
try:
return UserGroup.from_model(
update_user_group(db_session, user_group_id, user_group)
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.delete("/admin/user-group/{user_group_id}")
def delete_user_group(
user_group_id: int,
_: db_models.User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -0,0 +1,65 @@
from uuid import UUID
from pydantic import BaseModel
from danswer.server.documents.models import (
ConnectorCredentialPairDescriptor,
ConnectorSnapshot,
CredentialSnapshot,
)
from danswer.server.manage.models import UserInfo
from ee.danswer.db.models import UserGroup as UserGroupModel
class UserGroup(BaseModel):
id: int
name: str
users: list[UserInfo]
cc_pairs: list[ConnectorCredentialPairDescriptor]
is_up_to_date: bool
is_up_for_deletion: bool
@classmethod
def from_model(cls, document_set_model: UserGroupModel) -> "UserGroup":
return cls(
id=document_set_model.id,
name=document_set_model.name,
users=[
UserInfo(
id=str(user.id),
email=user.email,
is_active=user.is_active,
is_superuser=user.is_superuser,
is_verified=user.is_verified,
role=user.role,
)
for user in document_set_model.users
],
cc_pairs=[
ConnectorCredentialPairDescriptor(
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential
),
)
for cc_pair_relationship in document_set_model.cc_pair_relationships
if cc_pair_relationship.is_current
],
is_up_to_date=document_set_model.is_up_to_date,
is_up_for_deletion=document_set_model.is_up_for_deletion,
)
class UserGroupCreate(BaseModel):
name: str
user_ids: list[UUID]
cc_pair_ids: list[int]
class UserGroupUpdate(BaseModel):
user_ids: list[UUID]
cc_pair_ids: list[int]

View File

@@ -0,0 +1,70 @@
from sqlalchemy.orm import Session
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.db.document import prepare_to_modify_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from ee.danswer.access.access import get_access_for_documents
from ee.danswer.db.user_group import delete_user_group
from ee.danswer.db.user_group import fetch_documents_for_user_group
from ee.danswer.db.user_group import fetch_user_group
from ee.danswer.db.user_group import mark_user_group_as_synced
logger = setup_logger()
_SYNC_BATCH_SIZE = 1000
def _sync_user_group_batch(
document_ids: list[str], document_index: DocumentIndex
) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# begin a transaction, release lock at the end
with Session(get_sqlalchemy_engine()) as db_session:
# acquires a lock on the documents so that no other process can modify them
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
# get current state of document sets for these documents
document_id_to_access = get_access_for_documents(
document_ids=document_ids, db_session=db_session
)
# update Vespa
document_index.update(
update_requests=[
UpdateRequest(
document_ids=[document_id],
access=document_id_to_access[document_id],
)
for document_id in document_ids
]
)
def sync_user_groups(user_group_id: int) -> None:
"""Sync the status of Postgres for the specified user group"""
document_index = get_default_document_index()
with Session(get_sqlalchemy_engine()) as db_session:
user_group = fetch_user_group(
db_session=db_session, user_group_id=user_group_id
)
if user_group is None:
raise ValueError(f"User group '{user_group_id}' does not exist")
documents_to_update = fetch_documents_for_user_group(
db_session=db_session,
user_group_id=user_group_id,
)
for document_batch in batch_generator(documents_to_update, _SYNC_BATCH_SIZE):
_sync_user_group_batch(
document_ids=[document.id for document in document_batch],
document_index=document_index,
)
if user_group.is_up_for_deletion:
delete_user_group(db_session=db_session, user_group=user_group)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)