DAN-5 OAuth Backend (#17)

Also added in an exception handler for logging
This commit is contained in:
Yuhong Sun 2023-05-06 23:47:21 -07:00 committed by GitHub
parent 4f4c65acac
commit e20179048d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 129 additions and 9 deletions

View File

@ -0,0 +1,55 @@
"""Google OAuth2
Revision ID: 2666d766cb9b
Revises: 6d387b3196c2
Create Date: 2023-05-05 15:49:35.716016
"""
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "2666d766cb9b"
down_revision = "6d387b3196c2"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"oauth_account",
sa.Column("id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column("oauth_name", sa.String(length=100), nullable=False),
sa.Column("access_token", sa.String(length=1024), nullable=False),
sa.Column("expires_at", sa.Integer(), nullable=True),
sa.Column("refresh_token", sa.String(length=1024), nullable=True),
sa.Column("account_id", sa.String(length=320), nullable=False),
sa.Column("account_email", sa.String(length=320), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="cascade"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_oauth_account_account_id"),
"oauth_account",
["account_id"],
unique=False,
)
op.create_index(
op.f("ix_oauth_account_oauth_name"),
"oauth_account",
["oauth_name"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_oauth_account_oauth_name"), table_name="oauth_account")
op.drop_index(op.f("ix_oauth_account_account_id"), table_name="oauth_account")
op.drop_table("oauth_account")

View File

@ -1,7 +1,8 @@
import os
RESET_PASSWORD_TOKEN_SECRET = os.environ["RESET_PASSWORD_TOKEN_SECRET"]
RESET_PASSWORD_VERIFICATION_TOKEN_SECRET = os.environ[
"RESET_PASSWORD_VERIFICATION_TOKEN_SECRET"
]
SECRET = os.environ.get("SECRET", "")
SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600))
ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "False").lower() != "false"
GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")
GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")

View File

@ -1,8 +1,9 @@
import uuid
from typing import Optional
from danswer.auth.configs import RESET_PASSWORD_TOKEN_SECRET
from danswer.auth.configs import RESET_PASSWORD_VERIFICATION_TOKEN_SECRET
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_ID
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_SECRET
from danswer.auth.configs import SECRET
from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_user_db
@ -18,11 +19,12 @@ from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.db import SQLAlchemyUserDatabase
from httpx_oauth.clients.google import GoogleOAuth2
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = RESET_PASSWORD_TOKEN_SECRET
verification_token_secret = RESET_PASSWORD_VERIFICATION_TOKEN_SECRET
reset_password_token_secret = SECRET
verification_token_secret = SECRET
async def on_after_register(self, user: User, request: Optional[Request] = None):
print(f"User {user.id} has registered.")
@ -59,6 +61,8 @@ auth_backend = AuthenticationBackend(
get_strategy=get_database_strategy,
)
google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET)
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
current_active_user = fastapi_users.current_user(active=True)

View File

@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator
from danswer.db.engine import build_async_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from fastapi import Depends
from fastapi_users.db import SQLAlchemyUserDatabase
@ -17,7 +18,7 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User)
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
async def get_access_token_db(

View File

@ -1,9 +1,11 @@
import datetime
from enum import Enum as PyEnum
from typing import Any
from typing import List
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
from sqlalchemy import DateTime
@ -14,13 +16,21 @@ from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
class Base(DeclarativeBase):
pass
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
pass
class User(SQLAlchemyBaseUserTableUUID, Base):
oauth_accounts: Mapped[List[OAuthAccount]] = relationship(
"OAuthAccount", lazy="joined"
)
pass

View File

@ -1,9 +1,12 @@
import uvicorn
from danswer.auth.configs import ENABLE_OAUTH
from danswer.auth.configs import SECRET
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.auth.users import google_oauth_client
from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.server.admin import router as admin_router
@ -11,12 +14,22 @@ from danswer.server.event_loading import router as event_processing_router
from danswer.server.search_backend import router as backend_router
from danswer.utils.logging import setup_logger
from fastapi import FastAPI
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
logger = setup_logger()
def validation_exception_handler(request: Request, exc: RequestValidationError):
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
logger.exception(f"{request}: {exc_str}")
content = {"status_code": 422, "message": exc_str, "data": None}
return JSONResponse(content=content, status_code=422)
def get_application() -> FastAPI:
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
application.include_router(backend_router)
@ -48,6 +61,31 @@ def get_application() -> FastAPI:
prefix="/users",
tags=["users"],
)
if ENABLE_OAUTH:
application.include_router(
fastapi_users.get_oauth_router(
google_oauth_client,
auth_backend,
SECRET,
associate_by_email=True,
is_verified_by_default=True,
redirect_url="http://localhost:8080/test", # TODO DAN-39 set this to frontend redirect
),
prefix="/auth/google",
tags=["auth"],
)
application.include_router(
fastapi_users.get_oauth_associate_router(
google_oauth_client, UserRead, SECRET
),
prefix="/auth/associate/google",
tags=["auth"],
)
application.add_exception_handler(
RequestValidationError, validation_exception_handler
)
return application

View File

@ -17,6 +17,7 @@ from danswer.utils.clients import TSClient
from danswer.utils.logging import setup_logger
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Request
logger = setup_logger()
@ -30,6 +31,12 @@ async def authenticated_route(user: User = Depends(current_active_user)):
return {"message": f"Hello {user.email}!"}
# TODO DAN-39 delete this once oauth is built out and tested
@router.api_route("/test", methods=["GET", "POST", "PUT", "PATCH", "DELETE"])
def test_endpoint(request: Request):
print(request)
@router.get("/", response_model=ServerStatus)
@router.get("/status", response_model=ServerStatus)
def read_server_status():

View File

@ -8,6 +8,9 @@ filelock==3.12.0
google-api-python-client==2.86.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==1.0.0
httpcore==0.16.3
httpx==0.23.3
httpx-oauth==0.11.2
Mako==1.2.4
openai==0.27.6
playwright==1.32.1
@ -16,6 +19,7 @@ PyPDF2==3.0.1
pytest-playwright==0.3.2
qdrant-client==1.1.0
requests==2.28.2
rfc3986==1.5.0
sentence-transformers==2.2.2
slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.12