mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-21 05:20:55 +02:00
DAN-5 OAuth Backend (#17)
Also added in an exception handler for logging
This commit is contained in:
parent
4f4c65acac
commit
e20179048d
55
backend/alembic/versions/2666d766cb9b_google_oauth2.py
Normal file
55
backend/alembic/versions/2666d766cb9b_google_oauth2.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""Google OAuth2
|
||||||
|
|
||||||
|
Revision ID: 2666d766cb9b
|
||||||
|
Revises: 6d387b3196c2
|
||||||
|
Create Date: 2023-05-05 15:49:35.716016
|
||||||
|
|
||||||
|
"""
|
||||||
|
import fastapi_users_db_sqlalchemy
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "2666d766cb9b"
|
||||||
|
down_revision = "6d387b3196c2"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"oauth_account",
|
||||||
|
sa.Column("id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("oauth_name", sa.String(length=100), nullable=False),
|
||||||
|
sa.Column("access_token", sa.String(length=1024), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("refresh_token", sa.String(length=1024), nullable=True),
|
||||||
|
sa.Column("account_id", sa.String(length=320), nullable=False),
|
||||||
|
sa.Column("account_email", sa.String(length=320), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="cascade"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_account_account_id"),
|
||||||
|
"oauth_account",
|
||||||
|
["account_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_oauth_account_oauth_name"),
|
||||||
|
"oauth_account",
|
||||||
|
["oauth_name"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(op.f("ix_oauth_account_oauth_name"), table_name="oauth_account")
|
||||||
|
op.drop_index(op.f("ix_oauth_account_account_id"), table_name="oauth_account")
|
||||||
|
op.drop_table("oauth_account")
|
@ -1,7 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
RESET_PASSWORD_TOKEN_SECRET = os.environ["RESET_PASSWORD_TOKEN_SECRET"]
|
SECRET = os.environ.get("SECRET", "")
|
||||||
RESET_PASSWORD_VERIFICATION_TOKEN_SECRET = os.environ[
|
|
||||||
"RESET_PASSWORD_VERIFICATION_TOKEN_SECRET"
|
|
||||||
]
|
|
||||||
SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600))
|
SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600))
|
||||||
|
|
||||||
|
ENABLE_OAUTH = os.environ.get("ENABLE_OAUTH", "False").lower() != "false"
|
||||||
|
GOOGLE_OAUTH_CLIENT_ID = os.environ.get("GOOGLE_OAUTH_CLIENT_ID", "")
|
||||||
|
GOOGLE_OAUTH_CLIENT_SECRET = os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET", "")
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from danswer.auth.configs import RESET_PASSWORD_TOKEN_SECRET
|
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_ID
|
||||||
from danswer.auth.configs import RESET_PASSWORD_VERIFICATION_TOKEN_SECRET
|
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_SECRET
|
||||||
|
from danswer.auth.configs import SECRET
|
||||||
from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS
|
from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS
|
||||||
from danswer.db.auth import get_access_token_db
|
from danswer.db.auth import get_access_token_db
|
||||||
from danswer.db.auth import get_user_db
|
from danswer.db.auth import get_user_db
|
||||||
@ -18,11 +19,12 @@ from fastapi_users.authentication import CookieTransport
|
|||||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
|
from httpx_oauth.clients.google import GoogleOAuth2
|
||||||
|
|
||||||
|
|
||||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||||
reset_password_token_secret = RESET_PASSWORD_TOKEN_SECRET
|
reset_password_token_secret = SECRET
|
||||||
verification_token_secret = RESET_PASSWORD_VERIFICATION_TOKEN_SECRET
|
verification_token_secret = SECRET
|
||||||
|
|
||||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||||
print(f"User {user.id} has registered.")
|
print(f"User {user.id} has registered.")
|
||||||
@ -59,6 +61,8 @@ auth_backend = AuthenticationBackend(
|
|||||||
get_strategy=get_database_strategy,
|
get_strategy=get_database_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_SECRET)
|
||||||
|
|
||||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||||
|
|
||||||
current_active_user = fastapi_users.current_user(active=True)
|
current_active_user = fastapi_users.current_user(active=True)
|
||||||
|
@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator
|
|||||||
|
|
||||||
from danswer.db.engine import build_async_engine
|
from danswer.db.engine import build_async_engine
|
||||||
from danswer.db.models import AccessToken
|
from danswer.db.models import AccessToken
|
||||||
|
from danswer.db.models import OAuthAccount
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
@ -17,7 +18,7 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
|
|
||||||
|
|
||||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||||
yield SQLAlchemyUserDatabase(session, User)
|
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
|
||||||
|
|
||||||
|
|
||||||
async def get_access_token_db(
|
async def get_access_token_db(
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from enum import Enum as PyEnum
|
from enum import Enum as PyEnum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.models import InputType
|
from danswer.connectors.models import InputType
|
||||||
|
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
|
||||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
|
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
|
||||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
|
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
|
||||||
from sqlalchemy import DateTime
|
from sqlalchemy import DateTime
|
||||||
@ -14,13 +16,21 @@ from sqlalchemy.dialects import postgresql
|
|||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
from sqlalchemy.orm import Mapped
|
from sqlalchemy.orm import Mapped
|
||||||
from sqlalchemy.orm import mapped_column
|
from sqlalchemy.orm import mapped_column
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||||
|
oauth_accounts: Mapped[List[OAuthAccount]] = relationship(
|
||||||
|
"OAuthAccount", lazy="joined"
|
||||||
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
|
from danswer.auth.configs import ENABLE_OAUTH
|
||||||
|
from danswer.auth.configs import SECRET
|
||||||
from danswer.auth.schemas import UserCreate
|
from danswer.auth.schemas import UserCreate
|
||||||
from danswer.auth.schemas import UserRead
|
from danswer.auth.schemas import UserRead
|
||||||
from danswer.auth.schemas import UserUpdate
|
from danswer.auth.schemas import UserUpdate
|
||||||
from danswer.auth.users import auth_backend
|
from danswer.auth.users import auth_backend
|
||||||
from danswer.auth.users import fastapi_users
|
from danswer.auth.users import fastapi_users
|
||||||
|
from danswer.auth.users import google_oauth_client
|
||||||
from danswer.configs.app_configs import APP_HOST
|
from danswer.configs.app_configs import APP_HOST
|
||||||
from danswer.configs.app_configs import APP_PORT
|
from danswer.configs.app_configs import APP_PORT
|
||||||
from danswer.server.admin import router as admin_router
|
from danswer.server.admin import router as admin_router
|
||||||
@ -11,12 +14,22 @@ from danswer.server.event_loading import router as event_processing_router
|
|||||||
from danswer.server.search_backend import router as backend_router
|
from danswer.server.search_backend import router as backend_router
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
|
||||||
|
logger.exception(f"{request}: {exc_str}")
|
||||||
|
content = {"status_code": 422, "message": exc_str, "data": None}
|
||||||
|
return JSONResponse(content=content, status_code=422)
|
||||||
|
|
||||||
|
|
||||||
def get_application() -> FastAPI:
|
def get_application() -> FastAPI:
|
||||||
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
|
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
|
||||||
application.include_router(backend_router)
|
application.include_router(backend_router)
|
||||||
@ -48,6 +61,31 @@ def get_application() -> FastAPI:
|
|||||||
prefix="/users",
|
prefix="/users",
|
||||||
tags=["users"],
|
tags=["users"],
|
||||||
)
|
)
|
||||||
|
if ENABLE_OAUTH:
|
||||||
|
application.include_router(
|
||||||
|
fastapi_users.get_oauth_router(
|
||||||
|
google_oauth_client,
|
||||||
|
auth_backend,
|
||||||
|
SECRET,
|
||||||
|
associate_by_email=True,
|
||||||
|
is_verified_by_default=True,
|
||||||
|
redirect_url="http://localhost:8080/test", # TODO DAN-39 set this to frontend redirect
|
||||||
|
),
|
||||||
|
prefix="/auth/google",
|
||||||
|
tags=["auth"],
|
||||||
|
)
|
||||||
|
application.include_router(
|
||||||
|
fastapi_users.get_oauth_associate_router(
|
||||||
|
google_oauth_client, UserRead, SECRET
|
||||||
|
),
|
||||||
|
prefix="/auth/associate/google",
|
||||||
|
tags=["auth"],
|
||||||
|
)
|
||||||
|
|
||||||
|
application.add_exception_handler(
|
||||||
|
RequestValidationError, validation_exception_handler
|
||||||
|
)
|
||||||
|
|
||||||
return application
|
return application
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ from danswer.utils.clients import TSClient
|
|||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -30,6 +31,12 @@ async def authenticated_route(user: User = Depends(current_active_user)):
|
|||||||
return {"message": f"Hello {user.email}!"}
|
return {"message": f"Hello {user.email}!"}
|
||||||
|
|
||||||
|
|
||||||
|
# TODO DAN-39 delete this once oauth is built out and tested
|
||||||
|
@router.api_route("/test", methods=["GET", "POST", "PUT", "PATCH", "DELETE"])
|
||||||
|
def test_endpoint(request: Request):
|
||||||
|
print(request)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=ServerStatus)
|
@router.get("/", response_model=ServerStatus)
|
||||||
@router.get("/status", response_model=ServerStatus)
|
@router.get("/status", response_model=ServerStatus)
|
||||||
def read_server_status():
|
def read_server_status():
|
||||||
|
@ -8,6 +8,9 @@ filelock==3.12.0
|
|||||||
google-api-python-client==2.86.0
|
google-api-python-client==2.86.0
|
||||||
google-auth-httplib2==0.1.0
|
google-auth-httplib2==0.1.0
|
||||||
google-auth-oauthlib==1.0.0
|
google-auth-oauthlib==1.0.0
|
||||||
|
httpcore==0.16.3
|
||||||
|
httpx==0.23.3
|
||||||
|
httpx-oauth==0.11.2
|
||||||
Mako==1.2.4
|
Mako==1.2.4
|
||||||
openai==0.27.6
|
openai==0.27.6
|
||||||
playwright==1.32.1
|
playwright==1.32.1
|
||||||
@ -16,6 +19,7 @@ PyPDF2==3.0.1
|
|||||||
pytest-playwright==0.3.2
|
pytest-playwright==0.3.2
|
||||||
qdrant-client==1.1.0
|
qdrant-client==1.1.0
|
||||||
requests==2.28.2
|
requests==2.28.2
|
||||||
|
rfc3986==1.5.0
|
||||||
sentence-transformers==2.2.2
|
sentence-transformers==2.2.2
|
||||||
slack-sdk==3.20.2
|
slack-sdk==3.20.2
|
||||||
SQLAlchemy[mypy]==2.0.12
|
SQLAlchemy[mypy]==2.0.12
|
||||||
|
Loading…
x
Reference in New Issue
Block a user