DAN-45 Fix OAuth User Creation Flow (#22)

This commit is contained in:
Yuhong Sun 2023-05-09 20:45:43 -07:00 committed by GitHub
parent e896d0786e
commit 312366eae1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 15 deletions

View File

@ -8,6 +8,7 @@ from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import build_async_engine
from danswer.db.models import AccessToken
@ -27,9 +28,7 @@ from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.db import SQLAlchemyUserDatabase
from httpx_oauth.clients.google import GoogleOAuth2
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
@ -42,18 +41,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> models.UP:
if safe:
if hasattr(user_create, "role"):
async with AsyncSession(build_async_engine()) as asession:
stmt = select(func.count(User.id))
result = await asession.execute(stmt)
user_count = result.scalar()
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
if user_count == 0:
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0:
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
return await super().create(user_create, safe=safe, request=request) # type: ignore
async def on_after_register(self, user: User, request: Optional[Request] = None):

View File

@ -1,15 +1,44 @@
from typing import Any
from typing import Dict
from danswer.auth.schemas import UserRole
from danswer.db.engine import build_async_engine
from danswer.db.engine import get_async_session
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from fastapi import Depends
from fastapi_users.db import SQLAlchemyUserDatabase
from fastapi_users.models import UP
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
async def get_user_count() -> int:
async with AsyncSession(build_async_engine()) as asession:
stmt = select(func.count(User.id))
result = await asession.execute(stmt)
user_count = result.scalar()
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")
return user_count
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
async def create(self, create_dict: Dict[str, Any]) -> UP:
user_count = await get_user_count()
if user_count == 0:
create_dict["role"] = UserRole.ADMIN
else:
create_dict["role"] = UserRole.BASIC
return await super().create(create_dict)
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User, OAuthAccount) # type: ignore
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) # type: ignore
async def get_access_token_db(