addressed CW comments

This commit is contained in:
Evan Lohn 2025-02-26 09:11:23 -08:00 committed by pablonyx
parent a5831ae375
commit 7736c351e4
2 changed files with 21 additions and 5 deletions

View File

@ -182,8 +182,10 @@ def validate_object_creation_for_user(
def eager_usergroup_options(stmt: Select[tuple[UserGroup]]) -> Select[tuple[UserGroup]]:
return stmt.options(
# Which users are in this group
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
# Which CC pairs this group has access to
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.joinedload(ConnectorCredentialPair.credential),
@ -191,6 +193,7 @@ def eager_usergroup_options(stmt: Select[tuple[UserGroup]]) -> Select[tuple[User
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.joinedload(ConnectorCredentialPair.connector)
.contains_eager(Connector.credentials),
# Which document sets this group has access to
selectinload(UserGroup.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.credential),
@ -198,6 +201,9 @@ def eager_usergroup_options(stmt: Select[tuple[UserGroup]]) -> Select[tuple[User
.selectinload(DocumentSet.connector_credential_pairs)
.joinedload(ConnectorCredentialPair.connector)
.contains_eager(Connector.credentials),
# Which personas this group has access to. Each persona has
# its own set of associated data similar to the above per-user-group
# associations; TODO: do we really need to load all of this?
selectinload(UserGroup.personas).selectinload(Persona.user),
selectinload(UserGroup.personas).selectinload(Persona.prompts),
selectinload(UserGroup.personas).selectinload(Persona.tools),
@ -222,7 +228,7 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
def fetch_user_groups(
db_session: Session, only_up_to_date: bool = True
db_session: Session, only_up_to_date: bool = True, eager_load_all: bool = False
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@ -243,12 +249,17 @@ def fetch_user_groups(
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
stmt = eager_usergroup_options(stmt)
if eager_load_all:
stmt = eager_usergroup_options(stmt)
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session, user_id: UUID, only_curator_groups: bool = False
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_all: bool = False,
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@ -259,7 +270,9 @@ def fetch_user_groups_for_user(
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
stmt = eager_usergroup_options(stmt)
if eager_load_all:
stmt = eager_usergroup_options(stmt)
stmt = stmt.options(contains_eager(UserGroup.users))
return db_session.scalars(stmt).all()

View File

@ -32,12 +32,15 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user is None or user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_all=True
)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_all=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]