pablonyx 4e2c90f4af
Proper user deletion / organization leaving (#3460)
* Proper user deletion / organization leaving

* minor nit

* update

* udpate provisioning

* minor cleanup

* typing

* post rebase
2024-12-20 21:01:03 +00:00

93 lines
3.3 KiB
Python

from collections.abc import AsyncGenerator
from collections.abc import Callable
from typing import Any
from typing import Dict
from fastapi import Depends
from fastapi_users.models import ID
from fastapi_users.models import UP
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.api_key import get_api_key_email_pattern
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
def get_default_admin_user_emails() -> list[str]:
"""Returns a list of emails who should default to Admin role.
Only used in the EE version. For MIT, just return empty list."""
get_default_admin_user_emails_fn: Callable[
[], list[str]
] = fetch_versioned_implementation_with_fallback(
"onyx.auth.users", "get_default_admin_user_emails_", lambda: list[str]()
)
return get_default_admin_user_emails_fn()
def get_total_users_count(db_session: Session) -> int:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
"""
user_count = (
db_session.query(User)
.filter(
~User.email.endswith(get_api_key_email_pattern()), # type: ignore
User.role != UserRole.EXT_PERM_USER,
)
.count()
)
invited_users = len(get_invited_users())
return user_count + invited_users
async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_with_tenant() as session:
stmt = select(func.count(User.id))
if only_admin_users:
stmt = stmt.where(User.role == UserRole.ADMIN)
result = await session.execute(stmt)
user_count = result.scalar()
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
return user_count
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def create(
self,
create_dict: Dict[str, Any],
) -> UP:
user_count = await get_user_count()
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
create_dict["role"] = UserRole.ADMIN
else:
create_dict["role"] = UserRole.BASIC
return await super().create(create_dict)
async def get_user_db(
session: AsyncSession = Depends(get_async_session),
) -> AsyncGenerator[SQLAlchemyUserAdminDB, None]:
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) # type: ignore
async def get_access_token_db(
session: AsyncSession = Depends(get_async_session),
) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase, None]:
yield SQLAlchemyAccessTokenDatabase(session, AccessToken) # type: ignore