DAN-45 Fix OAuth User Creation Flow (#22)

This commit is contained in:
Yuhong Sun 2023-05-09 20:45:43 -07:00 committed by GitHub
parent e896d0786e
commit 312366eae1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 15 deletions

View File

@ -8,6 +8,7 @@ from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole from danswer.auth.schemas import UserRole
from danswer.db.auth import get_access_token_db from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db from danswer.db.auth import get_user_db
from danswer.db.engine import build_async_engine from danswer.db.engine import build_async_engine
from danswer.db.models import AccessToken from danswer.db.models import AccessToken
@ -27,9 +28,7 @@ from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyUserDatabase
from httpx_oauth.clients.google import GoogleOAuth2 from httpx_oauth.clients.google import GoogleOAuth2
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
@ -42,18 +41,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False, safe: bool = False,
request: Optional[Request] = None, request: Optional[Request] = None,
) -> models.UP: ) -> models.UP:
if safe: if hasattr(user_create, "role"):
if hasattr(user_create, "role"): user_count = await get_user_count()
async with AsyncSession(build_async_engine()) as asession: if user_count == 0:
stmt = select(func.count(User.id)) user_create.role = UserRole.ADMIN
result = await asession.execute(stmt) else:
user_count = result.scalar() user_create.role = UserRole.BASIC
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
if user_count == 0:
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
return await super().create(user_create, safe=safe, request=request) # type: ignore return await super().create(user_create, safe=safe, request=request) # type: ignore
async def on_after_register(self, user: User, request: Optional[Request] = None): async def on_after_register(self, user: User, request: Optional[Request] = None):

View File

@ -1,15 +1,44 @@
from typing import Any
from typing import Dict
from danswer.auth.schemas import UserRole
from danswer.db.engine import build_async_engine
from danswer.db.engine import get_async_session from danswer.db.engine import get_async_session
from danswer.db.models import AccessToken from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount from danswer.db.models import OAuthAccount
from danswer.db.models import User from danswer.db.models import User
from fastapi import Depends from fastapi import Depends
from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyUserDatabase
from fastapi_users.models import UP
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
async def get_user_count() -> int:
async with AsyncSession(build_async_engine()) as asession:
stmt = select(func.count(User.id))
result = await asession.execute(stmt)
user_count = result.scalar()
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
return user_count
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
async def create(self, create_dict: Dict[str, Any]) -> UP:
user_count = await get_user_count()
if user_count == 0:
create_dict["role"] = UserRole.ADMIN
else:
create_dict["role"] = UserRole.BASIC
return await super().create(create_dict)
async def get_user_db(session: AsyncSession = Depends(get_async_session)): async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User, OAuthAccount) # type: ignore yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) # type: ignore
async def get_access_token_db( async def get_access_token_db(