mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Add User Groups (a.k.a. RBAC) (#4)
This commit is contained in:
75
backend/ee/danswer/access/access.py
Normal file
75
backend/ee/danswer/access/access.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from danswer.access.access import (
|
||||
get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from ee.danswer.db.user_group import fetch_user_groups_for_documents
|
||||
from ee.danswer.db.user_group import fetch_user_groups_for_user
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
access_dict = get_access_for_documents_without_groups(
|
||||
document_ids=document_ids,
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
db_session=db_session,
|
||||
)
|
||||
user_group_info = {
|
||||
document_id: group_names
|
||||
for document_id, group_names in fetch_user_groups_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
)
|
||||
}
|
||||
|
||||
# overload user_ids a bit - use it for both actual User IDs + group IDs
|
||||
return {
|
||||
document_id: DocumentAccess(
|
||||
user_ids=access.user_ids.union(user_group_info.get(document_id, [])), # type: ignore
|
||||
is_public=access.is_public,
|
||||
)
|
||||
for document_id, access in access_dict.items()
|
||||
}
|
||||
|
||||
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
db_session: Session | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
if db_session is None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
return _get_access_for_documents(
|
||||
document_ids, cc_pair_to_delete, db_session
|
||||
)
|
||||
|
||||
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
|
||||
NOTE: is imported in danswer.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
user_groups = fetch_user_groups_for_user(db_session, user.id) if user else []
|
||||
return set(
|
||||
[prefix_user_group(user_group.name) for user_group in user_groups]
|
||||
).union(get_acl_for_user_without_groups(user, db_session))
|
7
backend/ee/danswer/background/celery/celery.py
Normal file
7
backend/ee/danswer/background/celery/celery.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from danswer.background.celery.celery import celery_app
|
||||
from ee.danswer.user_groups.sync import sync_user_groups
|
||||
|
||||
|
||||
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
|
||||
def sync_user_group_task(user_group_id: int) -> None:
|
||||
sync_user_groups(user_group_id=user_group_id)
|
54
backend/ee/danswer/background/user_group_sync_script.py
Normal file
54
backend/ee/danswer/background/user_group_sync_script.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import time
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.celery.celery import sync_user_group_task
|
||||
from ee.danswer.db.user_group import fetch_user_groups
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_ExistingTaskCache: dict[int, AsyncResult] = {}
|
||||
|
||||
|
||||
def _user_group_sync_loop() -> None:
|
||||
# cleanup tasks
|
||||
existing_tasks = list(_ExistingTaskCache.items())
|
||||
for user_group_id, task in existing_tasks:
|
||||
if task.ready():
|
||||
logger.info(
|
||||
f"User Group '{user_group_id}' is complete with status "
|
||||
f"{task.status}. Cleaning up."
|
||||
)
|
||||
del _ExistingTaskCache[user_group_id]
|
||||
|
||||
# kick off new tasks
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
user_groups = fetch_user_groups(db_session=db_session, only_current=False)
|
||||
for user_group in user_groups:
|
||||
if not user_group.is_up_to_date:
|
||||
if user_group.id in _ExistingTaskCache:
|
||||
logger.info(
|
||||
f"User Group '{user_group.id}' is already syncing. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"User Group {user_group.id} is not synced. Syncing now!")
|
||||
task = sync_user_group_task.apply_async(
|
||||
kwargs=dict(user_group_id=user_group.id),
|
||||
)
|
||||
_ExistingTaskCache[user_group.id] = task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
while True:
|
||||
start = time.monotonic()
|
||||
|
||||
_user_group_sync_loop()
|
||||
|
||||
sleep_time = 30 - (time.monotonic() - start)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
@@ -1,14 +1,18 @@
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from danswer.db.models import Base
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
@@ -24,3 +28,69 @@ class SamlAccount(Base):
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User")
|
||||
|
||||
|
||||
"""Tables related to RBAC"""
|
||||
|
||||
|
||||
class User__UserGroup(Base):
|
||||
__tablename__ = "user__user_group"
|
||||
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
|
||||
|
||||
class UserGroup__ConnectorCredentialPair(Base):
|
||||
__tablename__ = "user_group__connector_credential_pair"
|
||||
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
cc_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), primary_key=True
|
||||
)
|
||||
# if `True`, then is part of the current state of the UserGroup
|
||||
# if `False`, then is a part of the prior state of the UserGroup
|
||||
# rows with `is_current=False` should be deleted when the UserGroup
|
||||
# is updated and should not exist for a given UserGroup if
|
||||
# `UserGroup.is_up_to_date == True`
|
||||
is_current: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=True,
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
cc_pair: Mapped[ConnectorCredentialPair] = relationship(
|
||||
"ConnectorCredentialPair",
|
||||
)
|
||||
|
||||
|
||||
class UserGroup(Base):
|
||||
__tablename__ = "user_group"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
# whether or not changes to the UserGroup have been propogated to Vespa
|
||||
is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
# tell the sync job to clean up the group
|
||||
is_up_for_deletion: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
||||
users: Mapped[list[User]] = relationship(
|
||||
"User",
|
||||
secondary=User__UserGroup.__table__,
|
||||
)
|
||||
cc_pairs: Mapped[list[ConnectorCredentialPair]] = relationship(
|
||||
"ConnectorCredentialPair",
|
||||
secondary=UserGroup__ConnectorCredentialPair.__table__,
|
||||
viewonly=True,
|
||||
)
|
||||
cc_pair_relationships: Mapped[
|
||||
list[UserGroup__ConnectorCredentialPair]
|
||||
] = relationship(
|
||||
"UserGroup__ConnectorCredentialPair",
|
||||
viewonly=True,
|
||||
)
|
||||
|
298
backend/ee/danswer/db/user_group.py
Normal file
298
backend/ee/danswer/db/user_group.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from collections.abc import Sequence
|
||||
from operator import and_
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from ee.danswer.db.models import User__UserGroup
|
||||
from ee.danswer.db.models import UserGroup
|
||||
from ee.danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from ee.danswer.server.user_group.models import UserGroupCreate
|
||||
from ee.danswer.server.user_group.models import UserGroupUpdate
|
||||
|
||||
|
||||
def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def fetch_user_groups(
|
||||
db_session: Session, only_current: bool = True
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = select(UserGroup)
|
||||
if only_current:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_user_groups_for_user(
|
||||
db_session: Session, user_id: UUID
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
.join(User__UserGroup, User__UserGroup.user_group_id == UserGroup.id)
|
||||
.join(User, User.id == User__UserGroup.user_id) # type: ignore
|
||||
.where(User.id == user_id) # type: ignore
|
||||
)
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_documents_for_user_group(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> Sequence[Document]:
|
||||
stmt = (
|
||||
select(Document)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
UserGroup__ConnectorCredentialPair,
|
||||
UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
UserGroup,
|
||||
UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id,
|
||||
)
|
||||
.where(UserGroup.id == user_group_id)
|
||||
)
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_user_groups_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
) -> Sequence[tuple[int, list[str]]]:
|
||||
stmt = (
|
||||
select(Document.id, func.array_agg(UserGroup.name))
|
||||
.join(
|
||||
UserGroup__ConnectorCredentialPair,
|
||||
UserGroup.id == UserGroup__ConnectorCredentialPair.user_group_id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id,
|
||||
)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(Document, Document.id == DocumentByConnectorCredentialPair.id)
|
||||
.where(Document.id.in_(document_ids))
|
||||
.where(UserGroup__ConnectorCredentialPair.is_current == True) # noqa: E712
|
||||
.group_by(Document.id)
|
||||
)
|
||||
|
||||
# pretend that the specified cc pair doesn't exist
|
||||
if cc_pair_to_delete is not None:
|
||||
stmt = stmt.where(
|
||||
and_(
|
||||
ConnectorCredentialPair.connector_id != cc_pair_to_delete.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
!= cc_pair_to_delete.credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def _check_user_group_is_modifiable(user_group: UserGroup) -> None:
|
||||
if not user_group.is_up_to_date:
|
||||
raise ValueError(
|
||||
"Specified user group is currently syncing. Wait until the current "
|
||||
"sync has finished before editing."
|
||||
)
|
||||
|
||||
|
||||
def _add_user__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, user_ids: list[UUID]
|
||||
) -> list[User__UserGroup]:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
relationships = [
|
||||
User__UserGroup(user_id=user_id, user_group_id=user_group_id)
|
||||
for user_id in user_ids
|
||||
]
|
||||
db_session.add_all(relationships)
|
||||
return relationships
|
||||
|
||||
|
||||
def _add_user_group__cc_pair_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, cc_pair_ids: list[int]
|
||||
) -> list[UserGroup__ConnectorCredentialPair]:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
relationships = [
|
||||
UserGroup__ConnectorCredentialPair(
|
||||
user_group_id=user_group_id, cc_pair_id=cc_pair_id
|
||||
)
|
||||
for cc_pair_id in cc_pair_ids
|
||||
]
|
||||
db_session.add_all(relationships)
|
||||
return relationships
|
||||
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(name=user_group.name)
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
_add_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
user_ids=user_group.user_ids,
|
||||
)
|
||||
_add_user_group__cc_pair_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
cc_pair_ids=user_group.cc_pair_ids,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
|
||||
def _cleanup_user__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
user__user_group_relationships = db_session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_group_id == user_group_id)
|
||||
).all()
|
||||
for user__user_group_relationship in user__user_group_relationships:
|
||||
db_session.delete(user__user_group_relationship)
|
||||
|
||||
|
||||
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
user_group__cc_pair_relationships = db_session.scalars(
|
||||
select(UserGroup__ConnectorCredentialPair).where(
|
||||
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
|
||||
user_group__cc_pair_relationship.is_current = False
|
||||
|
||||
|
||||
def update_user_group(
|
||||
db_session: Session, user_group_id: int, user_group: UserGroupUpdate
|
||||
) -> UserGroup:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
if db_user_group is None:
|
||||
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
existing_cc_pairs = db_user_group.cc_pairs
|
||||
cc_pairs_updated = set([cc_pair.id for cc_pair in existing_cc_pairs]) != set(
|
||||
user_group.cc_pair_ids
|
||||
)
|
||||
users_updated = set([user.id for user in db_user_group.users]) != set(
|
||||
user_group.user_ids
|
||||
)
|
||||
|
||||
if users_updated:
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_add_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
user_ids=user_group.user_ids,
|
||||
)
|
||||
if cc_pairs_updated:
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_add_user_group__cc_pair_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
cc_pair_ids=user_group.cc_pair_ids,
|
||||
)
|
||||
|
||||
# only needs to sync with Vespa if the cc_pairs have been updated
|
||||
if cc_pairs_updated:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
|
||||
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
if db_user_group is None:
|
||||
raise ValueError(f"UserGroup with id '{user_group_id}' not found")
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
db_user_group.is_up_to_date = False
|
||||
db_user_group.is_up_for_deletion = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, outdated_only: bool
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
stmt = select(UserGroup__ConnectorCredentialPair).where(
|
||||
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
|
||||
)
|
||||
if outdated_only:
|
||||
stmt = stmt.where(
|
||||
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
|
||||
)
|
||||
user_group__cc_pair_relationships = db_session.scalars(stmt)
|
||||
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
|
||||
db_session.delete(user_group__cc_pair_relationship)
|
||||
|
||||
|
||||
def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None:
|
||||
# cleanup outdated relationships
|
||||
_cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id, outdated_only=True
|
||||
)
|
||||
user_group.is_up_to_date = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
_cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group.id,
|
||||
outdated_only=False,
|
||||
)
|
||||
db_session.delete(user_group)
|
||||
db_session.commit()
|
@@ -17,6 +17,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
|
||||
from ee.danswer.server.saml import router as saml_router
|
||||
from ee.danswer.server.user_group.api import router as user_group_router
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -51,6 +52,9 @@ def get_ee_application() -> FastAPI:
|
||||
elif AUTH_TYPE == AuthType.SAML:
|
||||
application.include_router(saml_router)
|
||||
|
||||
# RBAC / group access control
|
||||
application.include_router(user_group_router)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
|
71
backend/ee/danswer/server/user_group/api.py
Normal file
71
backend/ee/danswer/server/user_group/api.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import danswer.db.models as db_models
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from ee.danswer.db.user_group import fetch_user_groups
|
||||
from ee.danswer.db.user_group import insert_user_group
|
||||
from ee.danswer.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.danswer.db.user_group import update_user_group
|
||||
from ee.danswer.server.user_group.models import UserGroup
|
||||
from ee.danswer.server.user_group.models import UserGroupCreate
|
||||
from ee.danswer.server.user_group.models import UserGroupUpdate
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
_: db_models.User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
user_groups = fetch_user_groups(db_session, only_current=False)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
_: db_models.User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
db_user_group = insert_user_group(db_session, user_group)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"User group with name '{user_group.name}' already exists. Please "
|
||||
+ "choose a different name.",
|
||||
)
|
||||
return UserGroup.from_model(db_user_group)
|
||||
|
||||
|
||||
@router.patch("/admin/user-group/{user_group_id}")
|
||||
def patch_user_group(
|
||||
user_group_id: int,
|
||||
user_group: UserGroupUpdate,
|
||||
_: db_models.User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
update_user_group(db_session, user_group_id, user_group)
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/admin/user-group/{user_group_id}")
|
||||
def delete_user_group(
|
||||
user_group_id: int,
|
||||
_: db_models.User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
65
backend/ee/danswer/server/user_group/models.py
Normal file
65
backend/ee/danswer/server/user_group/models.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from danswer.server.documents.models import (
|
||||
ConnectorCredentialPairDescriptor,
|
||||
ConnectorSnapshot,
|
||||
CredentialSnapshot,
|
||||
)
|
||||
from danswer.server.manage.models import UserInfo
|
||||
|
||||
from ee.danswer.db.models import UserGroup as UserGroupModel
|
||||
|
||||
|
||||
class UserGroup(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
users: list[UserInfo]
|
||||
cc_pairs: list[ConnectorCredentialPairDescriptor]
|
||||
is_up_to_date: bool
|
||||
is_up_for_deletion: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, document_set_model: UserGroupModel) -> "UserGroup":
|
||||
return cls(
|
||||
id=document_set_model.id,
|
||||
name=document_set_model.name,
|
||||
users=[
|
||||
UserInfo(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
is_active=user.is_active,
|
||||
is_superuser=user.is_superuser,
|
||||
is_verified=user.is_verified,
|
||||
role=user.role,
|
||||
)
|
||||
for user in document_set_model.users
|
||||
],
|
||||
cc_pairs=[
|
||||
ConnectorCredentialPairDescriptor(
|
||||
id=cc_pair_relationship.cc_pair.id,
|
||||
name=cc_pair_relationship.cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_relationship.cc_pair.connector
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
),
|
||||
)
|
||||
for cc_pair_relationship in document_set_model.cc_pair_relationships
|
||||
if cc_pair_relationship.is_current
|
||||
],
|
||||
is_up_to_date=document_set_model.is_up_to_date,
|
||||
is_up_for_deletion=document_set_model.is_up_for_deletion,
|
||||
)
|
||||
|
||||
|
||||
class UserGroupCreate(BaseModel):
|
||||
name: str
|
||||
user_ids: list[UUID]
|
||||
cc_pair_ids: list[int]
|
||||
|
||||
|
||||
class UserGroupUpdate(BaseModel):
|
||||
user_ids: list[UUID]
|
||||
cc_pair_ids: list[int]
|
70
backend/ee/danswer/user_groups/sync.py
Normal file
70
backend/ee/danswer/user_groups/sync.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.access.access import get_access_for_documents
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_documents_for_user_group
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
from ee.danswer.db.user_group import mark_user_group_as_synced
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SYNC_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def _sync_user_group_batch(
|
||||
document_ids: list[str], document_index: DocumentIndex
|
||||
) -> None:
|
||||
logger.debug(f"Syncing document sets for: {document_ids}")
|
||||
# begin a transaction, release lock at the end
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquires a lock on the documents so that no other process can modify them
|
||||
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
|
||||
|
||||
# get current state of document sets for these documents
|
||||
document_id_to_access = get_access_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
document_index.update(
|
||||
update_requests=[
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=document_id_to_access[document_id],
|
||||
)
|
||||
for document_id in document_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def sync_user_groups(user_group_id: int) -> None:
|
||||
"""Sync the status of Postgres for the specified user group"""
|
||||
document_index = get_default_document_index()
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
user_group = fetch_user_group(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
if user_group is None:
|
||||
raise ValueError(f"User group '{user_group_id}' does not exist")
|
||||
|
||||
documents_to_update = fetch_documents_for_user_group(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
)
|
||||
for document_batch in batch_generator(documents_to_update, _SYNC_BATCH_SIZE):
|
||||
_sync_user_group_batch(
|
||||
document_ids=[document.id for document in document_batch],
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
if user_group.is_up_for_deletion:
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
Reference in New Issue
Block a user