From 312366eae1624165d9b8a63f5dedab39315705d5 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Tue, 9 May 2023 20:45:43 -0700 Subject: [PATCH] DAN-45 Fix OAuth User Creation Flow (#22) --- backend/danswer/auth/users.py | 21 +++++++-------------- backend/danswer/db/auth.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 9d46d6d05..9513d972b 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -8,6 +8,7 @@ from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRole from danswer.db.auth import get_access_token_db +from danswer.db.auth import get_user_count from danswer.db.auth import get_user_db from danswer.db.engine import build_async_engine from danswer.db.models import AccessToken @@ -27,9 +28,7 @@ from fastapi_users.authentication.strategy.db import AccessTokenDatabase from fastapi_users.authentication.strategy.db import DatabaseStrategy from fastapi_users.db import SQLAlchemyUserDatabase from httpx_oauth.clients.google import GoogleOAuth2 -from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): @@ -42,18 +41,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): safe: bool = False, request: Optional[Request] = None, ) -> models.UP: - if safe: - if hasattr(user_create, "role"): - async with AsyncSession(build_async_engine()) as asession: - stmt = select(func.count(User.id)) - result = await asession.execute(stmt) - user_count = result.scalar() - if user_count is None: - raise RuntimeError("Was not able to fetch the user count.") - if user_count == 0: - user_create.role = UserRole.ADMIN - else: - user_create.role = UserRole.BASIC + if hasattr(user_create, "role"): + user_count = await get_user_count() + if user_count == 0: + user_create.role = UserRole.ADMIN + else: + user_create.role = UserRole.BASIC return await super().create(user_create, safe=safe, request=request) # type: ignore async def on_after_register(self, user: User, request: Optional[Request] = None): diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py index 429a83743..f3cc44804 100644 --- a/backend/danswer/db/auth.py +++ b/backend/danswer/db/auth.py @@ -1,15 +1,44 @@ +from typing import Any +from typing import Dict + +from danswer.auth.schemas import UserRole +from danswer.db.engine import build_async_engine from danswer.db.engine import get_async_session from danswer.db.models import AccessToken from danswer.db.models import OAuthAccount from danswer.db.models import User from fastapi import Depends from fastapi_users.db import SQLAlchemyUserDatabase +from fastapi_users.models import UP from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase +from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + + +async def get_user_count() -> int: + async with AsyncSession(build_async_engine()) as asession: + stmt = select(func.count(User.id)) + result = await asession.execute(stmt) + user_count = result.scalar() + if user_count is None: + raise RuntimeError("Was not able to fetch the user count.") + return user_count + + +# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow +class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase): + async def create(self, create_dict: Dict[str, Any]) -> UP: + user_count = await get_user_count() + if user_count == 0: + create_dict["role"] = UserRole.ADMIN + else: + create_dict["role"] = UserRole.BASIC + return await super().create(create_dict) async def get_user_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyUserDatabase(session, User, OAuthAccount) # type: ignore + yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) # type: ignore async def get_access_token_db(