diff --git a/backend/alembic/versions/2666d766cb9b_google_oauth2.py b/backend/alembic/versions/2666d766cb9b_google_oauth2.py new file mode 100644 index 000000000..b163fe38b --- /dev/null +++ b/backend/alembic/versions/2666d766cb9b_google_oauth2.py @@ -0,0 +1,55 @@ +"""Google OAuth2 + +Revision ID: 2666d766cb9b +Revises: 6d387b3196c2 +Create Date: 2023-05-05 15:49:35.716016 + +""" +import fastapi_users_db_sqlalchemy +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "2666d766cb9b" +down_revision = "6d387b3196c2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "oauth_account", + sa.Column("id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=False, + ), + sa.Column("oauth_name", sa.String(length=100), nullable=False), + sa.Column("access_token", sa.String(length=1024), nullable=False), + sa.Column("expires_at", sa.Integer(), nullable=True), + sa.Column("refresh_token", sa.String(length=1024), nullable=True), + sa.Column("account_id", sa.String(length=320), nullable=False), + sa.Column("account_email", sa.String(length=320), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="cascade"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_oauth_account_account_id"), + "oauth_account", + ["account_id"], + unique=False, + ) + op.create_index( + op.f("ix_oauth_account_oauth_name"), + "oauth_account", + ["oauth_name"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index(op.f("ix_oauth_account_oauth_name"), table_name="oauth_account") + op.drop_index(op.f("ix_oauth_account_account_id"), table_name="oauth_account") + op.drop_table("oauth_account") diff --git a/backend/danswer/auth/configs.py b/backend/danswer/auth/configs.py index 16b356d03..565c098b6 100644 --- a/backend/danswer/auth/configs.py +++ b/backend/danswer/auth/configs.py @@ -1,7 +1,8 @@ import os -RESET_PASSWORD_TOKEN_SECRET = os.environ["RESET_PASSWORD_TOKEN_SECRET"] -RESET_PASSWORD_VERIFICATION_TOKEN_SECRET = os.environ[ - "RESET_PASSWORD_VERIFICATION_TOKEN_SECRET" -] +SECRET = os.environ.get("SECRET", "") SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600)) + +ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "False").lower() != "false" +GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "") +GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "") diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 1abce3062..38c2fe5e6 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -1,8 +1,9 @@ import uuid from typing import Optional -from danswer.auth.configs import RESET_PASSWORD_TOKEN_SECRET -from danswer.auth.configs import RESET_PASSWORD_VERIFICATION_TOKEN_SECRET +from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_ID +from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_SECRET +from danswer.auth.configs import SECRET from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS from danswer.db.auth import get_access_token_db from danswer.db.auth import get_user_db @@ -18,11 +19,12 @@ from fastapi_users.authentication import CookieTransport from fastapi_users.authentication.strategy.db import AccessTokenDatabase from fastapi_users.authentication.strategy.db import DatabaseStrategy from fastapi_users.db import SQLAlchemyUserDatabase +from httpx_oauth.clients.google import GoogleOAuth2 class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): - reset_password_token_secret = RESET_PASSWORD_TOKEN_SECRET - verification_token_secret = RESET_PASSWORD_VERIFICATION_TOKEN_SECRET + reset_password_token_secret = SECRET + verification_token_secret = SECRET async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") @@ -59,6 +61,8 @@ auth_backend = AuthenticationBackend( get_strategy=get_database_strategy, ) +google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET) + fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) current_active_user = fastapi_users.current_user(active=True) diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py index a2146d585..d14f5c68a 100644 --- a/backend/danswer/db/auth.py +++ b/backend/danswer/db/auth.py @@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator from danswer.db.engine import build_async_engine from danswer.db.models import AccessToken +from danswer.db.models import OAuthAccount from danswer.db.models import User from fastapi import Depends from fastapi_users.db import SQLAlchemyUserDatabase @@ -17,7 +18,7 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async def get_user_db(session: AsyncSession = Depends(get_async_session)): - yield SQLAlchemyUserDatabase(session, User) + yield SQLAlchemyUserDatabase(session, User, OAuthAccount) async def get_access_token_db( diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 79de9667b..2d942ef31 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1,9 +1,11 @@ import datetime from enum import Enum as PyEnum from typing import Any +from typing import List from danswer.configs.constants import DocumentSource from danswer.connectors.models import InputType +from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID from fastapi_users.db import SQLAlchemyBaseUserTableUUID from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID from sqlalchemy import DateTime @@ -14,13 +16,21 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship class Base(DeclarativeBase): pass +class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): + pass + + class User(SQLAlchemyBaseUserTableUUID, Base): + oauth_accounts: Mapped[List[OAuthAccount]] = relationship( + "OAuthAccount", lazy="joined" + ) pass diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 8f360dd8f..220bb627e 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -1,9 +1,12 @@ import uvicorn +from danswer.auth.configs import ENABLE_OAUTH +from danswer.auth.configs import SECRET from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRead from danswer.auth.schemas import UserUpdate from danswer.auth.users import auth_backend from danswer.auth.users import fastapi_users +from danswer.auth.users import google_oauth_client from danswer.configs.app_configs import APP_HOST from danswer.configs.app_configs import APP_PORT from danswer.server.admin import router as admin_router @@ -11,12 +14,22 @@ from danswer.server.event_loading import router as event_processing_router from danswer.server.search_backend import router as backend_router from danswer.utils.logging import setup_logger from fastapi import FastAPI +from fastapi import Request +from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse logger = setup_logger() +def validation_exception_handler(request: Request, exc: RequestValidationError): + exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") + logger.exception(f"{request}: {exc_str}") + content = {"status_code": 422, "message": exc_str, "data": None} + return JSONResponse(content=content, status_code=422) + + def get_application() -> FastAPI: application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1") application.include_router(backend_router) @@ -48,6 +61,31 @@ def get_application() -> FastAPI: prefix="/users", tags=["users"], ) + if ENABLE_OAUTH: + application.include_router( + fastapi_users.get_oauth_router( + google_oauth_client, + auth_backend, + SECRET, + associate_by_email=True, + is_verified_by_default=True, + redirect_url="http://localhost:8080/test", # TODO DAN-39 set this to frontend redirect + ), + prefix="/auth/google", + tags=["auth"], + ) + application.include_router( + fastapi_users.get_oauth_associate_router( + google_oauth_client, UserRead, SECRET + ), + prefix="/auth/associate/google", + tags=["auth"], + ) + + application.add_exception_handler( + RequestValidationError, validation_exception_handler + ) + return application diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index fce7c8042..84b8dd288 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -17,6 +17,7 @@ from danswer.utils.clients import TSClient from danswer.utils.logging import setup_logger from fastapi import APIRouter from fastapi import Depends +from fastapi import Request logger = setup_logger() @@ -30,6 +31,12 @@ async def authenticated_route(user: User = Depends(current_active_user)): return {"message": f"Hello {user.email}!"} +# TODO DAN-39 delete this once oauth is built out and tested +@router.api_route("/test", methods=["GET", "POST", "PUT", "PATCH", "DELETE"]) +def test_endpoint(request: Request): + print(request) + + @router.get("/", response_model=ServerStatus) @router.get("/status", response_model=ServerStatus) def read_server_status(): diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 947e692b8..dd7adc91d 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -8,6 +8,9 @@ filelock==3.12.0 google-api-python-client==2.86.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==1.0.0 +httpcore==0.16.3 +httpx==0.23.3 +httpx-oauth==0.11.2 Mako==1.2.4 openai==0.27.6 playwright==1.32.1 @@ -16,6 +19,7 @@ PyPDF2==3.0.1 pytest-playwright==0.3.2 qdrant-client==1.1.0 requests==2.28.2 +rfc3986==1.5.0 sentence-transformers==2.2.2 slack-sdk==3.20.2 SQLAlchemy[mypy]==2.0.12