mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-25 23:40:58 +02:00
DAN-45 Fix OAuth User Creation Flow (#22)
This commit is contained in:
parent
e896d0786e
commit
312366eae1
@ -8,6 +8,7 @@ from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS
|
|||||||
from danswer.auth.schemas import UserCreate
|
from danswer.auth.schemas import UserCreate
|
||||||
from danswer.auth.schemas import UserRole
|
from danswer.auth.schemas import UserRole
|
||||||
from danswer.db.auth import get_access_token_db
|
from danswer.db.auth import get_access_token_db
|
||||||
|
from danswer.db.auth import get_user_count
|
||||||
from danswer.db.auth import get_user_db
|
from danswer.db.auth import get_user_db
|
||||||
from danswer.db.engine import build_async_engine
|
from danswer.db.engine import build_async_engine
|
||||||
from danswer.db.models import AccessToken
|
from danswer.db.models import AccessToken
|
||||||
@ -27,9 +28,7 @@ from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
|||||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
from httpx_oauth.clients.google import GoogleOAuth2
|
from httpx_oauth.clients.google import GoogleOAuth2
|
||||||
from sqlalchemy import func
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||||
@ -42,18 +41,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
safe: bool = False,
|
safe: bool = False,
|
||||||
request: Optional[Request] = None,
|
request: Optional[Request] = None,
|
||||||
) -> models.UP:
|
) -> models.UP:
|
||||||
if safe:
|
if hasattr(user_create, "role"):
|
||||||
if hasattr(user_create, "role"):
|
user_count = await get_user_count()
|
||||||
async with AsyncSession(build_async_engine()) as asession:
|
if user_count == 0:
|
||||||
stmt = select(func.count(User.id))
|
user_create.role = UserRole.ADMIN
|
||||||
result = await asession.execute(stmt)
|
else:
|
||||||
user_count = result.scalar()
|
user_create.role = UserRole.BASIC
|
||||||
if user_count is None:
|
|
||||||
raise RuntimeError("Was not able to fetch the user count.")
|
|
||||||
if user_count == 0:
|
|
||||||
user_create.role = UserRole.ADMIN
|
|
||||||
else:
|
|
||||||
user_create.role = UserRole.BASIC
|
|
||||||
return await super().create(user_create, safe=safe, request=request) # type: ignore
|
return await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||||
|
|
||||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||||
|
@ -1,15 +1,44 @@
|
|||||||
|
from typing import Any
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from danswer.auth.schemas import UserRole
|
||||||
|
from danswer.db.engine import build_async_engine
|
||||||
from danswer.db.engine import get_async_session
|
from danswer.db.engine import get_async_session
|
||||||
from danswer.db.models import AccessToken
|
from danswer.db.models import AccessToken
|
||||||
from danswer.db.models import OAuthAccount
|
from danswer.db.models import OAuthAccount
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
|
from fastapi_users.models import UP
|
||||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
|
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
|
||||||
|
from sqlalchemy import func
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_count() -> int:
|
||||||
|
async with AsyncSession(build_async_engine()) as asession:
|
||||||
|
stmt = select(func.count(User.id))
|
||||||
|
result = await asession.execute(stmt)
|
||||||
|
user_count = result.scalar()
|
||||||
|
if user_count is None:
|
||||||
|
raise RuntimeError("Was not able to fetch the user count.")
|
||||||
|
return user_count
|
||||||
|
|
||||||
|
|
||||||
|
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||||
|
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
|
||||||
|
async def create(self, create_dict: Dict[str, Any]) -> UP:
|
||||||
|
user_count = await get_user_count()
|
||||||
|
if user_count == 0:
|
||||||
|
create_dict["role"] = UserRole.ADMIN
|
||||||
|
else:
|
||||||
|
create_dict["role"] = UserRole.BASIC
|
||||||
|
return await super().create(create_dict)
|
||||||
|
|
||||||
|
|
||||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||||
yield SQLAlchemyUserDatabase(session, User, OAuthAccount) # type: ignore
|
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
async def get_access_token_db(
|
async def get_access_token_db(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user