mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 11:28:09 +02:00
DAN-5 OAuth Backend (#17)
Also added in an exception handler for logging
This commit is contained in:
parent
4f4c65acac
commit
e20179048d
55
backend/alembic/versions/2666d766cb9b_google_oauth2.py
Normal file
55
backend/alembic/versions/2666d766cb9b_google_oauth2.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Google OAuth2
|
||||
|
||||
Revision ID: 2666d766cb9b
|
||||
Revises: 6d387b3196c2
|
||||
Create Date: 2023-05-05 15:49:35.716016
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2666d766cb9b"
|
||||
down_revision = "6d387b3196c2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"oauth_account",
|
||||
sa.Column("id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("oauth_name", sa.String(length=100), nullable=False),
|
||||
sa.Column("access_token", sa.String(length=1024), nullable=False),
|
||||
sa.Column("expires_at", sa.Integer(), nullable=True),
|
||||
sa.Column("refresh_token", sa.String(length=1024), nullable=True),
|
||||
sa.Column("account_id", sa.String(length=320), nullable=False),
|
||||
sa.Column("account_email", sa.String(length=320), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="cascade"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_account_account_id"),
|
||||
"oauth_account",
|
||||
["account_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_oauth_account_oauth_name"),
|
||||
"oauth_account",
|
||||
["oauth_name"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_oauth_account_oauth_name"), table_name="oauth_account")
|
||||
op.drop_index(op.f("ix_oauth_account_account_id"), table_name="oauth_account")
|
||||
op.drop_table("oauth_account")
|
@ -1,7 +1,8 @@
|
||||
import os
|
||||
|
||||
RESET_PASSWORD_TOKEN_SECRET = os.environ["RESET_PASSWORD_TOKEN_SECRET"]
|
||||
RESET_PASSWORD_VERIFICATION_TOKEN_SECRET = os.environ[
|
||||
"RESET_PASSWORD_VERIFICATION_TOKEN_SECRET"
|
||||
]
|
||||
SECRET = os.environ.get("SECRET", "")
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600))
|
||||
|
||||
ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "False").lower() != "false"
|
||||
GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")
|
||||
GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")
|
||||
|
@ -1,8 +1,9 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from danswer.auth.configs import RESET_PASSWORD_TOKEN_SECRET
|
||||
from danswer.auth.configs import RESET_PASSWORD_VERIFICATION_TOKEN_SECRET
|
||||
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_ID
|
||||
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_SECRET
|
||||
from danswer.auth.configs import SECRET
|
||||
from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_user_db
|
||||
@ -18,11 +19,12 @@ from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = RESET_PASSWORD_TOKEN_SECRET
|
||||
verification_token_secret = RESET_PASSWORD_VERIFICATION_TOKEN_SECRET
|
||||
reset_password_token_secret = SECRET
|
||||
verification_token_secret = SECRET
|
||||
|
||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||
print(f"User {user.id} has registered.")
|
||||
@ -59,6 +61,8 @@ auth_backend = AuthenticationBackend(
|
||||
get_strategy=get_database_strategy,
|
||||
)
|
||||
|
||||
google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET)
|
||||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
|
@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator
|
||||
|
||||
from danswer.db.engine import build_async_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
@ -17,7 +18,7 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
|
||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||
yield SQLAlchemyUserDatabase(session, User)
|
||||
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
|
||||
|
||||
|
||||
async def get_access_token_db(
|
||||
|
@ -1,9 +1,11 @@
|
||||
import datetime
|
||||
from enum import Enum as PyEnum
|
||||
from typing import Any
|
||||
from typing import List
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
|
||||
from sqlalchemy import DateTime
|
||||
@ -14,13 +16,21 @@ from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
pass
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
oauth_accounts: Mapped[List[OAuthAccount]] = relationship(
|
||||
"OAuthAccount", lazy="joined"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
import uvicorn
|
||||
from danswer.auth.configs import ENABLE_OAUTH
|
||||
from danswer.auth.configs import SECRET
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.auth.users import google_oauth_client
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.server.admin import router as admin_router
|
||||
@ -11,12 +14,22 @@ from danswer.server.event_loading import router as event_processing_router
|
||||
from danswer.server.search_backend import router as backend_router
|
||||
from danswer.utils.logging import setup_logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
|
||||
logger.exception(f"{request}: {exc_str}")
|
||||
content = {"status_code": 422, "message": exc_str, "data": None}
|
||||
return JSONResponse(content=content, status_code=422)
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
|
||||
application.include_router(backend_router)
|
||||
@ -48,6 +61,31 @@ def get_application() -> FastAPI:
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
if ENABLE_OAUTH:
|
||||
application.include_router(
|
||||
fastapi_users.get_oauth_router(
|
||||
google_oauth_client,
|
||||
auth_backend,
|
||||
SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
redirect_url="http://localhost:8080/test", # TODO DAN-39 set this to frontend redirect
|
||||
),
|
||||
prefix="/auth/google",
|
||||
tags=["auth"],
|
||||
)
|
||||
application.include_router(
|
||||
fastapi_users.get_oauth_associate_router(
|
||||
google_oauth_client, UserRead, SECRET
|
||||
),
|
||||
prefix="/auth/associate/google",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
application.add_exception_handler(
|
||||
RequestValidationError, validation_exception_handler
|
||||
)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
|
@ -17,6 +17,7 @@ from danswer.utils.clients import TSClient
|
||||
from danswer.utils.logging import setup_logger
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -30,6 +31,12 @@ async def authenticated_route(user: User = Depends(current_active_user)):
|
||||
return {"message": f"Hello {user.email}!"}
|
||||
|
||||
|
||||
# TODO DAN-39 delete this once oauth is built out and tested
|
||||
@router.api_route("/test", methods=["GET", "POST", "PUT", "PATCH", "DELETE"])
|
||||
def test_endpoint(request: Request):
|
||||
print(request)
|
||||
|
||||
|
||||
@router.get("/", response_model=ServerStatus)
|
||||
@router.get("/status", response_model=ServerStatus)
|
||||
def read_server_status():
|
||||
|
@ -8,6 +8,9 @@ filelock==3.12.0
|
||||
google-api-python-client==2.86.0
|
||||
google-auth-httplib2==0.1.0
|
||||
google-auth-oauthlib==1.0.0
|
||||
httpcore==0.16.3
|
||||
httpx==0.23.3
|
||||
httpx-oauth==0.11.2
|
||||
Mako==1.2.4
|
||||
openai==0.27.6
|
||||
playwright==1.32.1
|
||||
@ -16,6 +19,7 @@ PyPDF2==3.0.1
|
||||
pytest-playwright==0.3.2
|
||||
qdrant-client==1.1.0
|
||||
requests==2.28.2
|
||||
rfc3986==1.5.0
|
||||
sentence-transformers==2.2.2
|
||||
slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.12
|
||||
|
Loading…
x
Reference in New Issue
Block a user