danswer/backend/onyx/db/users.py
2025-02-13 15:59:44 -08:00

344 lines
12 KiB
Python

from collections.abc import Sequence
from typing import Any
from uuid import UUID
from fastapi import HTTPException
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlalchemy.sql import expression
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import KeyedColumnElement
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona__User
from onyx.db.models import SamlAccount
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
"""
Validate that a user role update is valid.
Assumed only admins can hit this endpoint.
raise if:
- requested role is a curator
- requested role is a slack user
- requested role is an external permissioned user
- requested role is a limited user
- current role is a slack user
- current role is an external permissioned user
- current role is a limited user
"""
if current_role == UserRole.SLACK_USER:
raise HTTPException(
status_code=400,
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
)
if current_role == UserRole.EXT_PERM_USER:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
)
if current_role == UserRole.LIMITED:
raise HTTPException(
status_code=400,
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
)
if requested_role == UserRole.CURATOR:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail="Curator role must be set via the User Group Menu",
)
if requested_role == UserRole.LIMITED:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail=(
"A user cannot be set to a Limited User role. "
"This role is automatically assigned to users through certain endpoints in the API."
),
)
if requested_role == UserRole.SLACK_USER:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail=(
"A user cannot be set to a Slack User role. "
"This role is automatically assigned to users who only use Onyx via Slack."
),
)
if requested_role == UserRole.EXT_PERM_USER:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail=(
"A user cannot be set to an External Permissioned User role. "
"This role is automatically assigned to users who have been "
"pulled in to the system via an external permissions system."
),
)
def get_all_users(
db_session: Session,
email_filter_string: str | None = None,
include_external: bool = False,
) -> Sequence[User]:
"""List all users. No pagination as of now, as the # of users
is assumed to be relatively small (<< 1 million)"""
stmt = select(User)
where_clause = []
if not include_external:
where_clause.append(User.role != UserRole.EXT_PERM_USER)
if email_filter_string is not None:
where_clause.append(User.email.ilike(f"%{email_filter_string}%")) # type: ignore
stmt = stmt.where(*where_clause)
return db_session.scalars(stmt).unique().all()
def _get_accepted_user_where_clause(
email_filter_string: str | None = None,
roles_filter: list[UserRole] = [],
include_external: bool = False,
is_active_filter: bool | None = None,
) -> list[ColumnElement[bool]]:
"""
Generates a SQLAlchemy where clause for filtering users based on the provided parameters.
This is used to build the filters for the function that retrieves the users for the users table in the admin panel.
Parameters:
- email_filter_string: A substring to filter user emails. Only users whose emails contain this substring will be included.
- is_active_filter: When True, only active users will be included. When False, only inactive users will be included.
- roles_filter: A list of user roles to filter by. Only users with roles in this list will be included.
- include_external: If False, external permissioned users will be excluded.
Returns:
- list: A list of conditions to be used in a SQLAlchemy query to filter users.
"""
# Access table columns directly via __table__.c to get proper SQLAlchemy column types
# This ensures type checking works correctly for SQL operations like ilike, endswith, and is_
email_col: KeyedColumnElement[Any] = User.__table__.c.email
is_active_col: KeyedColumnElement[Any] = User.__table__.c.is_active
where_clause: list[ColumnElement[bool]] = [
expression.not_(email_col.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN))
]
if not include_external:
where_clause.append(User.role != UserRole.EXT_PERM_USER)
if email_filter_string is not None:
where_clause.append(email_col.ilike(f"%{email_filter_string}%"))
if roles_filter:
where_clause.append(User.role.in_(roles_filter))
if is_active_filter is not None:
where_clause.append(is_active_col.is_(is_active_filter))
return where_clause
def get_page_of_filtered_users(
db_session: Session,
page_size: int,
page_num: int,
email_filter_string: str | None = None,
is_active_filter: bool | None = None,
roles_filter: list[UserRole] = [],
include_external: bool = False,
) -> Sequence[User]:
users_stmt = select(User)
where_clause = _get_accepted_user_where_clause(
email_filter_string=email_filter_string,
roles_filter=roles_filter,
include_external=include_external,
is_active_filter=is_active_filter,
)
# Apply pagination
users_stmt = users_stmt.offset((page_num) * page_size).limit(page_size)
# Apply filtering
users_stmt = users_stmt.where(*where_clause)
return db_session.scalars(users_stmt).unique().all()
def get_total_filtered_users_count(
db_session: Session,
email_filter_string: str | None = None,
is_active_filter: bool | None = None,
roles_filter: list[UserRole] = [],
include_external: bool = False,
) -> int:
where_clause = _get_accepted_user_where_clause(
email_filter_string=email_filter_string,
roles_filter=roles_filter,
include_external=include_external,
is_active_filter=is_active_filter,
)
total_count_stmt = select(func.count()).select_from(User)
# Apply filtering
total_count_stmt = total_count_stmt.where(*where_clause)
return db_session.scalar(total_count_stmt) or 0
def get_user_by_email(email: str, db_session: Session) -> User | None:
user = (
db_session.query(User)
.filter(func.lower(User.email) == func.lower(email))
.first()
)
return user
def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
return db_session.query(User).filter(User.id == user_id).first() # type: ignore
def _generate_slack_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
return User(
email=email,
hashed_password=hashed_pass,
role=UserRole.SLACK_USER,
)
def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
email = email.lower()
user = get_user_by_email(email, db_session)
if user is not None:
# If the user is an external permissioned user, we update it to a slack user
if user.role == UserRole.EXT_PERM_USER:
user.role = UserRole.SLACK_USER
db_session.commit()
return user
user = _generate_slack_user(email=email)
db_session.add(user)
db_session.commit()
return user
def _get_users_by_emails(
db_session: Session, lower_emails: list[str]
) -> tuple[list[User], list[str]]:
stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
# Extract found emails and convert to lowercase to avoid case sensitivity issues
found_users_emails = [user.email.lower() for user in found_users]
# Separate emails for users that were not found
missing_user_emails = [
email for email in lower_emails if email not in found_users_emails
]
return found_users, missing_user_emails
def _generate_ext_permissioned_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
return User(
email=email,
hashed_password=hashed_pass,
role=UserRole.EXT_PERM_USER,
)
def batch_add_ext_perm_user_if_not_exists(
db_session: Session, emails: list[str], continue_on_error: bool = False
) -> list[User]:
lower_emails = [email.lower() for email in emails]
found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails)
new_users: list[User] = []
for email in missing_lower_emails:
new_users.append(_generate_ext_permissioned_user(email=email))
try:
db_session.add_all(new_users)
db_session.commit()
except IntegrityError:
db_session.rollback()
if not continue_on_error:
raise
for user in new_users:
try:
db_session.add(user)
db_session.commit()
except IntegrityError:
db_session.rollback()
continue
# Fetch all users again to ensure we have the most up-to-date list
all_users, _ = _get_users_by_emails(db_session, lower_emails)
return all_users
def delete_user_from_db(
user_to_delete: User,
db_session: Session,
) -> None:
for oauth_account in user_to_delete.oauth_accounts:
db_session.delete(oauth_account)
fetch_ee_implementation_or_noop(
"onyx.db.external_perm",
"delete_user__ext_group_for_user__no_commit",
)(
db_session=db_session,
user_id=user_to_delete.id,
)
db_session.query(SamlAccount).filter(
SamlAccount.user_id == user_to_delete.id
).delete()
db_session.query(DocumentSet__User).filter(
DocumentSet__User.user_id == user_to_delete.id
).delete()
db_session.query(Persona__User).filter(
Persona__User.user_id == user_to_delete.id
).delete()
db_session.query(User__UserGroup).filter(
User__UserGroup.user_id == user_to_delete.id
).delete()
db_session.delete(user_to_delete)
db_session.commit()
# NOTE: edge case may exist with race conditions
# with this `invited user` scheme generally.
user_emails = get_invited_users()
remaining_users = [
remaining_user_email
for remaining_user_email in user_emails
if remaining_user_email != user_to_delete.email
]
write_invited_users(remaining_users)