DAN-5 OAuth Backend (#17)

Also added in an exception handler for logging
This commit is contained in:
Yuhong Sun 2023-05-06 23:47:21 -07:00 committed by GitHub
parent 4f4c65acac
commit e20179048d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 129 additions and 9 deletions

View File

@ -0,0 +1,55 @@
"""Google OAuth2
Revision ID: 2666d766cb9b
Revises: 6d387b3196c2
Create Date: 2023-05-05 15:49:35.716016
"""
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "2666d766cb9b"
down_revision = "6d387b3196c2"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"oauth_account",
sa.Column("id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column("oauth_name", sa.String(length=100), nullable=False),
sa.Column("access_token", sa.String(length=1024), nullable=False),
sa.Column("expires_at", sa.Integer(), nullable=True),
sa.Column("refresh_token", sa.String(length=1024), nullable=True),
sa.Column("account_id", sa.String(length=320), nullable=False),
sa.Column("account_email", sa.String(length=320), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="cascade"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_oauth_account_account_id"),
"oauth_account",
["account_id"],
unique=False,
)
op.create_index(
op.f("ix_oauth_account_oauth_name"),
"oauth_account",
["oauth_name"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_oauth_account_oauth_name"), table_name="oauth_account")
op.drop_index(op.f("ix_oauth_account_account_id"), table_name="oauth_account")
op.drop_table("oauth_account")

View File

@ -1,7 +1,8 @@
import os import os
RESET_PASSWORD_TOKEN_SECRET = os.environ["RESET_PASSWORD_TOKEN_SECRET"] SECRET = os.environ.get("SECRET", "")
RESET_PASSWORD_VERIFICATION_TOKEN_SECRET = os.environ[
"RESET_PASSWORD_VERIFICATION_TOKEN_SECRET"
]
SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600)) SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600))
ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "False").lower() != "false"
GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")
GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")

View File

@ -1,8 +1,9 @@
import uuid import uuid
from typing import Optional from typing import Optional
from danswer.auth.configs import RESET_PASSWORD_TOKEN_SECRET from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_ID
from danswer.auth.configs import RESET_PASSWORD_VERIFICATION_TOKEN_SECRET from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_SECRET
from danswer.auth.configs import SECRET
from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.db.auth import get_access_token_db from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_user_db from danswer.db.auth import get_user_db
@ -18,11 +19,12 @@ from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication.strategy.db import AccessTokenDatabase from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyUserDatabase
from httpx_oauth.clients.google import GoogleOAuth2
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = RESET_PASSWORD_TOKEN_SECRET reset_password_token_secret = SECRET
verification_token_secret = RESET_PASSWORD_VERIFICATION_TOKEN_SECRET verification_token_secret = SECRET
async def on_after_register(self, user: User, request: Optional[Request] = None): async def on_after_register(self, user: User, request: Optional[Request] = None):
print(f"User {user.id} has registered.") print(f"User {user.id} has registered.")
@ -59,6 +61,8 @@ auth_backend = AuthenticationBackend(
get_strategy=get_database_strategy, get_strategy=get_database_strategy,
) )
google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET)
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
current_active_user = fastapi_users.current_user(active=True) current_active_user = fastapi_users.current_user(active=True)

View File

@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator
from danswer.db.engine import build_async_engine from danswer.db.engine import build_async_engine
from danswer.db.models import AccessToken from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User from danswer.db.models import User
from fastapi import Depends from fastapi import Depends
from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyUserDatabase
@ -17,7 +18,7 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async def get_user_db(session: AsyncSession = Depends(get_async_session)): async def get_user_db(session: AsyncSession = Depends(get_async_session)):
yield SQLAlchemyUserDatabase(session, User) yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
async def get_access_token_db( async def get_access_token_db(

View File

@ -1,9 +1,11 @@
import datetime import datetime
from enum import Enum as PyEnum from enum import Enum as PyEnum
from typing import Any from typing import Any
from typing import List
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType from danswer.connectors.models import InputType
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users.db import SQLAlchemyBaseUserTableUUID from fastapi_users.db import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
from sqlalchemy import DateTime from sqlalchemy import DateTime
@ -14,13 +16,21 @@ from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass pass
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
pass
class User(SQLAlchemyBaseUserTableUUID, Base): class User(SQLAlchemyBaseUserTableUUID, Base):
oauth_accounts: Mapped[List[OAuthAccount]] = relationship(
"OAuthAccount", lazy="joined"
)
pass pass

View File

@ -1,9 +1,12 @@
import uvicorn import uvicorn
from danswer.auth.configs import ENABLE_OAUTH
from danswer.auth.configs import SECRET
from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users from danswer.auth.users import fastapi_users
from danswer.auth.users import google_oauth_client
from danswer.configs.app_configs import APP_HOST from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT from danswer.configs.app_configs import APP_PORT
from danswer.server.admin import router as admin_router from danswer.server.admin import router as admin_router
@ -11,12 +14,22 @@ from danswer.server.event_loading import router as event_processing_router
from danswer.server.search_backend import router as backend_router from danswer.server.search_backend import router as backend_router
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
from fastapi import FastAPI from fastapi import FastAPI
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
logger = setup_logger() logger = setup_logger()
def validation_exception_handler(request: Request, exc: RequestValidationError):
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
logger.exception(f"{request}: {exc_str}")
content = {"status_code": 422, "message": exc_str, "data": None}
return JSONResponse(content=content, status_code=422)
def get_application() -> FastAPI: def get_application() -> FastAPI:
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1") application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
application.include_router(backend_router) application.include_router(backend_router)
@ -48,6 +61,31 @@ def get_application() -> FastAPI:
prefix="/users", prefix="/users",
tags=["users"], tags=["users"],
) )
if ENABLE_OAUTH:
application.include_router(
fastapi_users.get_oauth_router(
google_oauth_client,
auth_backend,
SECRET,
associate_by_email=True,
is_verified_by_default=True,
redirect_url="http://localhost:8080/test", # TODO DAN-39 set this to frontend redirect
),
prefix="/auth/google",
tags=["auth"],
)
application.include_router(
fastapi_users.get_oauth_associate_router(
google_oauth_client, UserRead, SECRET
),
prefix="/auth/associate/google",
tags=["auth"],
)
application.add_exception_handler(
RequestValidationError, validation_exception_handler
)
return application return application

View File

@ -17,6 +17,7 @@ from danswer.utils.clients import TSClient
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import Depends from fastapi import Depends
from fastapi import Request
logger = setup_logger() logger = setup_logger()
@ -30,6 +31,12 @@ async def authenticated_route(user: User = Depends(current_active_user)):
return {"message": f"Hello {user.email}!"} return {"message": f"Hello {user.email}!"}
# TODO DAN-39 delete this once oauth is built out and tested
@router.api_route("/test", methods=["GET", "POST", "PUT", "PATCH", "DELETE"])
def test_endpoint(request: Request):
print(request)
@router.get("/", response_model=ServerStatus) @router.get("/", response_model=ServerStatus)
@router.get("/status", response_model=ServerStatus) @router.get("/status", response_model=ServerStatus)
def read_server_status(): def read_server_status():

View File

@ -8,6 +8,9 @@ filelock==3.12.0
google-api-python-client==2.86.0 google-api-python-client==2.86.0
google-auth-httplib2==0.1.0 google-auth-httplib2==0.1.0
google-auth-oauthlib==1.0.0 google-auth-oauthlib==1.0.0
httpcore==0.16.3
httpx==0.23.3
httpx-oauth==0.11.2
Mako==1.2.4 Mako==1.2.4
openai==0.27.6 openai==0.27.6
playwright==1.32.1 playwright==1.32.1
@ -16,6 +19,7 @@ PyPDF2==3.0.1
pytest-playwright==0.3.2 pytest-playwright==0.3.2
qdrant-client==1.1.0 qdrant-client==1.1.0
requests==2.28.2 requests==2.28.2
rfc3986==1.5.0
sentence-transformers==2.2.2 sentence-transformers==2.2.2
slack-sdk==3.20.2 slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.12 SQLAlchemy[mypy]==2.0.12