mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 21:32:36 +01:00
DAN-45 Fix OAuth User Creation Flow (#22)
This commit is contained in:
parent
e896d0786e
commit
312366eae1
@ -8,6 +8,7 @@ from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import build_async_engine
|
||||
from danswer.db.models import AccessToken
|
||||
@ -27,9 +28,7 @@ from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
@ -42,18 +41,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> models.UP:
|
||||
if safe:
|
||||
if hasattr(user_create, "role"):
|
||||
async with AsyncSession(build_async_engine()) as asession:
|
||||
stmt = select(func.count(User.id))
|
||||
result = await asession.execute(stmt)
|
||||
user_count = result.scalar()
|
||||
if user_count is None:
|
||||
raise RuntimeError("Was not able to fetch the user count.")
|
||||
if user_count == 0:
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0:
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
return await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
|
||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||
|
@ -1,15 +1,44 @@
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.engine import build_async_engine
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from fastapi_users.models import UP
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
|
||||
async def get_user_count() -> int:
|
||||
async with AsyncSession(build_async_engine()) as asession:
|
||||
stmt = select(func.count(User.id))
|
||||
result = await asession.execute(stmt)
|
||||
user_count = result.scalar()
|
||||
if user_count is None:
|
||||
raise RuntimeError("Was not able to fetch the user count.")
|
||||
return user_count
|
||||
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
|
||||
async def create(self, create_dict: Dict[str, Any]) -> UP:
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0:
|
||||
create_dict["role"] = UserRole.ADMIN
|
||||
else:
|
||||
create_dict["role"] = UserRole.BASIC
|
||||
return await super().create(create_dict)
|
||||
|
||||
|
||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||
yield SQLAlchemyUserDatabase(session, User, OAuthAccount) # type: ignore
|
||||
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) # type: ignore
|
||||
|
||||
|
||||
async def get_access_token_db(
|
||||
|
Loading…
x
Reference in New Issue
Block a user