mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-06 21:19:54 +02:00
* fix group sync name capitalization * everything is lowercased now * comments * Added test for be2ab2aa50ee migration * polish
133 lines
4.0 KiB
Python
133 lines
4.0 KiB
Python
from collections.abc import Sequence
|
|
from uuid import UUID
|
|
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import delete
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.access.utils import build_ext_group_name_for_onyx
|
|
from onyx.configs.constants import DocumentSource
|
|
from onyx.db.models import User
|
|
from onyx.db.models import User__ExternalUserGroupId
|
|
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
|
from onyx.db.users import get_user_by_email
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
class ExternalUserGroup(BaseModel):
|
|
id: str
|
|
user_emails: list[str]
|
|
|
|
|
|
def delete_user__ext_group_for_user__no_commit(
|
|
db_session: Session,
|
|
user_id: UUID,
|
|
) -> None:
|
|
db_session.execute(
|
|
delete(User__ExternalUserGroupId).where(
|
|
User__ExternalUserGroupId.user_id == user_id
|
|
)
|
|
)
|
|
|
|
|
|
def delete_user__ext_group_for_cc_pair__no_commit(
|
|
db_session: Session,
|
|
cc_pair_id: int,
|
|
) -> None:
|
|
db_session.execute(
|
|
delete(User__ExternalUserGroupId).where(
|
|
User__ExternalUserGroupId.cc_pair_id == cc_pair_id
|
|
)
|
|
)
|
|
|
|
|
|
def replace_user__ext_group_for_cc_pair(
|
|
db_session: Session,
|
|
cc_pair_id: int,
|
|
group_defs: list[ExternalUserGroup],
|
|
source: DocumentSource,
|
|
) -> None:
|
|
"""
|
|
This function clears all existing external user group relations for a given cc_pair_id
|
|
and replaces them with the new group definitions and commits the changes.
|
|
"""
|
|
|
|
# collect all emails from all groups to batch add all users at once for efficiency
|
|
all_group_member_emails = set()
|
|
for external_group in group_defs:
|
|
for user_email in external_group.user_emails:
|
|
all_group_member_emails.add(user_email)
|
|
|
|
# batch add users if they don't exist and get their ids
|
|
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
|
|
db_session=db_session,
|
|
# NOTE: this function handles case sensitivity for emails
|
|
emails=list(all_group_member_emails),
|
|
)
|
|
|
|
delete_user__ext_group_for_cc_pair__no_commit(
|
|
db_session=db_session,
|
|
cc_pair_id=cc_pair_id,
|
|
)
|
|
|
|
# map emails to ids
|
|
email_id_map = {user.email: user.id for user in all_group_members}
|
|
|
|
# use these ids to create new external user group relations relating group_id to user_ids
|
|
new_external_permissions = []
|
|
for external_group in group_defs:
|
|
for user_email in external_group.user_emails:
|
|
user_id = email_id_map.get(user_email.lower())
|
|
if user_id is None:
|
|
logger.warning(
|
|
f"User in group {external_group.id}"
|
|
f" with email {user_email} not found"
|
|
)
|
|
continue
|
|
external_group_id = build_ext_group_name_for_onyx(
|
|
ext_group_name=external_group.id,
|
|
source=source,
|
|
)
|
|
new_external_permissions.append(
|
|
User__ExternalUserGroupId(
|
|
user_id=user_id,
|
|
external_user_group_id=external_group_id,
|
|
cc_pair_id=cc_pair_id,
|
|
)
|
|
)
|
|
|
|
db_session.add_all(new_external_permissions)
|
|
db_session.commit()
|
|
|
|
|
|
def fetch_external_groups_for_user(
|
|
db_session: Session,
|
|
user_id: UUID,
|
|
) -> Sequence[User__ExternalUserGroupId]:
|
|
return db_session.scalars(
|
|
select(User__ExternalUserGroupId).where(
|
|
User__ExternalUserGroupId.user_id == user_id
|
|
)
|
|
).all()
|
|
|
|
|
|
def fetch_external_groups_for_user_email_and_group_ids(
|
|
db_session: Session,
|
|
user_email: str,
|
|
group_ids: list[str],
|
|
) -> list[User__ExternalUserGroupId]:
|
|
user = get_user_by_email(db_session=db_session, email=user_email)
|
|
if user is None:
|
|
return []
|
|
user_id = user.id
|
|
user_ext_groups = db_session.scalars(
|
|
select(User__ExternalUserGroupId).where(
|
|
User__ExternalUserGroupId.user_id == user_id,
|
|
User__ExternalUserGroupId.external_user_group_id.in_(group_ids),
|
|
)
|
|
).all()
|
|
return list(user_ext_groups)
|