From 7736c351e42e802ba2e11ac4e4c53469ba08ca6e Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Wed, 26 Feb 2025 09:11:23 -0800 Subject: [PATCH] addressed CW comments --- backend/ee/onyx/db/user_group.py | 21 +++++++++++++++++---- backend/ee/onyx/server/user_group/api.py | 5 ++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/backend/ee/onyx/db/user_group.py b/backend/ee/onyx/db/user_group.py index e70255934..9de2121ae 100644 --- a/backend/ee/onyx/db/user_group.py +++ b/backend/ee/onyx/db/user_group.py @@ -182,8 +182,10 @@ def validate_object_creation_for_user( def eager_usergroup_options(stmt: Select[tuple[UserGroup]]) -> Select[tuple[UserGroup]]: return stmt.options( + # Which users are in this group selectinload(UserGroup.users), selectinload(UserGroup.user_group_relationships), + # Which CC pairs this group has access to selectinload(UserGroup.cc_pair_relationships) .selectinload(UserGroup__ConnectorCredentialPair.cc_pair) .joinedload(ConnectorCredentialPair.credential), @@ -191,6 +193,7 @@ def eager_usergroup_options(stmt: Select[tuple[UserGroup]]) -> Select[tuple[User .selectinload(UserGroup__ConnectorCredentialPair.cc_pair) .joinedload(ConnectorCredentialPair.connector) .contains_eager(Connector.credentials), + # Which document sets this group has access to selectinload(UserGroup.document_sets) .selectinload(DocumentSet.connector_credential_pairs) .selectinload(ConnectorCredentialPair.credential), @@ -198,6 +201,9 @@ def eager_usergroup_options(stmt: Select[tuple[UserGroup]]) -> Select[tuple[User .selectinload(DocumentSet.connector_credential_pairs) .joinedload(ConnectorCredentialPair.connector) .contains_eager(Connector.credentials), + # Which personas this group has access to. Each persona has + # its own set of associated data similar to the above per-user-group + # associations; TODO: do we really need to load all of this? selectinload(UserGroup.personas).selectinload(Persona.user), selectinload(UserGroup.personas).selectinload(Persona.prompts), selectinload(UserGroup.personas).selectinload(Persona.tools), @@ -222,7 +228,7 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non def fetch_user_groups( - db_session: Session, only_up_to_date: bool = True + db_session: Session, only_up_to_date: bool = True, eager_load_all: bool = False ) -> Sequence[UserGroup]: """ Fetches user groups from the database. @@ -243,12 +249,17 @@ def fetch_user_groups( if only_up_to_date: stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712 - stmt = eager_usergroup_options(stmt) + if eager_load_all: + stmt = eager_usergroup_options(stmt) + return db_session.scalars(stmt).all() def fetch_user_groups_for_user( - db_session: Session, user_id: UUID, only_curator_groups: bool = False + db_session: Session, + user_id: UUID, + only_curator_groups: bool = False, + eager_load_all: bool = False, ) -> Sequence[UserGroup]: stmt = ( select(UserGroup) @@ -259,7 +270,9 @@ def fetch_user_groups_for_user( if only_curator_groups: stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712 - stmt = eager_usergroup_options(stmt) + if eager_load_all: + stmt = eager_usergroup_options(stmt) + stmt = stmt.options(contains_eager(UserGroup.users)) return db_session.scalars(stmt).all() diff --git a/backend/ee/onyx/server/user_group/api.py b/backend/ee/onyx/server/user_group/api.py index 4f01a21ad..0167241d3 100644 --- a/backend/ee/onyx/server/user_group/api.py +++ b/backend/ee/onyx/server/user_group/api.py @@ -32,12 +32,15 @@ def list_user_groups( db_session: Session = Depends(get_session), ) -> list[UserGroup]: if user is None or user.role == UserRole.ADMIN: - user_groups = fetch_user_groups(db_session, only_up_to_date=False) + user_groups = fetch_user_groups( + db_session, only_up_to_date=False, eager_load_all=True + ) else: user_groups = fetch_user_groups_for_user( db_session=db_session, user_id=user.id, only_curator_groups=user.role == UserRole.CURATOR, + eager_load_all=True, ) return [UserGroup.from_model(user_group) for user_group in user_groups]