diff --git a/backend/alembic/README.md b/backend/alembic/README.md index 22cbffa16..ef9b50f56 100644 --- a/backend/alembic/README.md +++ b/backend/alembic/README.md @@ -1,6 +1,7 @@ Generic single-database configuration with an async dbapi. ## To generate new migrations: +run from danswer/backend: `alembic revision --autogenerate -m ` More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html diff --git a/backend/alembic/versions/6d387b3196c2_basic_auth.py b/backend/alembic/versions/6d387b3196c2_basic_auth.py new file mode 100644 index 000000000..09748bcf3 --- /dev/null +++ b/backend/alembic/versions/6d387b3196c2_basic_auth.py @@ -0,0 +1,86 @@ +"""Basic Auth + +Revision ID: 6d387b3196c2 +Revises: 47433d30de82 +Create Date: 2023-05-05 14:40:10.242502 + +""" +import fastapi_users_db_sqlalchemy +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "6d387b3196c2" +down_revision = "47433d30de82" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "user", + sa.Column("id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False), + sa.Column("email", sa.String(length=320), nullable=False), + sa.Column("hashed_password", sa.String(length=1024), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("is_superuser", sa.Boolean(), nullable=False), + sa.Column("is_verified", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_user_email"), "user", ["email"], unique=True) + op.create_table( + "accesstoken", + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=False, + ), + sa.Column("token", sa.String(length=43), nullable=False), + sa.Column( + "created_at", + fastapi_users_db_sqlalchemy.generics.TIMESTAMPAware(timezone=True), + nullable=False, + ), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="cascade"), + sa.PrimaryKeyConstraint("token"), + ) + op.create_index( + op.f("ix_accesstoken_created_at"), + "accesstoken", + ["created_at"], + unique=False, + ) + op.alter_column( + "index_attempt", + "time_created", + existing_type=postgresql.TIMESTAMP(timezone=True), + nullable=False, + existing_server_default=sa.text("now()"), # type: ignore + ) + op.alter_column( + "index_attempt", + "time_updated", + existing_type=postgresql.TIMESTAMP(timezone=True), + nullable=False, + ) + + +def downgrade() -> None: + op.alter_column( + "index_attempt", + "time_updated", + existing_type=postgresql.TIMESTAMP(timezone=True), + nullable=True, + ) + op.alter_column( + "index_attempt", + "time_created", + existing_type=postgresql.TIMESTAMP(timezone=True), + nullable=True, + existing_server_default=sa.text("now()"), # type: ignore + ) + op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken") + op.drop_table("accesstoken") + op.drop_index(op.f("ix_user_email"), table_name="user") + op.drop_table("user") diff --git a/backend/danswer/auth/__init__.py b/backend/danswer/auth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/danswer/auth/configs.py b/backend/danswer/auth/configs.py new file mode 100644 index 000000000..16b356d03 --- /dev/null +++ b/backend/danswer/auth/configs.py @@ -0,0 +1,7 @@ +import os + +RESET_PASSWORD_TOKEN_SECRET = os.environ["RESET_PASSWORD_TOKEN_SECRET"] +RESET_PASSWORD_VERIFICATION_TOKEN_SECRET = os.environ[ + "RESET_PASSWORD_VERIFICATION_TOKEN_SECRET" +] +SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600)) diff --git a/backend/danswer/auth/schemas.py b/backend/danswer/auth/schemas.py new file mode 100644 index 000000000..de1169e4c --- /dev/null +++ b/backend/danswer/auth/schemas.py @@ -0,0 +1,15 @@ +import uuid + +from fastapi_users import schemas + + +class UserRead(schemas.BaseUser[uuid.UUID]): + pass + + +class UserCreate(schemas.BaseUserCreate): + pass + + +class UserUpdate(schemas.BaseUserUpdate): + pass diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py new file mode 100644 index 000000000..1abce3062 --- /dev/null +++ b/backend/danswer/auth/users.py @@ -0,0 +1,64 @@ +import uuid +from typing import Optional + +from danswer.auth.configs import RESET_PASSWORD_TOKEN_SECRET +from danswer.auth.configs import RESET_PASSWORD_VERIFICATION_TOKEN_SECRET +from danswer.auth.configs import SESSION_EXPIRE_TIME_SECONDS +from danswer.db.auth import get_access_token_db +from danswer.db.auth import get_user_db +from danswer.db.models import AccessToken +from danswer.db.models import User +from fastapi import Depends +from fastapi import Request +from fastapi_users import BaseUserManager +from fastapi_users import FastAPIUsers +from fastapi_users import UUIDIDMixin +from fastapi_users.authentication import AuthenticationBackend +from fastapi_users.authentication import CookieTransport +from fastapi_users.authentication.strategy.db import AccessTokenDatabase +from fastapi_users.authentication.strategy.db import DatabaseStrategy +from fastapi_users.db import SQLAlchemyUserDatabase + + +class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): + reset_password_token_secret = RESET_PASSWORD_TOKEN_SECRET + verification_token_secret = RESET_PASSWORD_VERIFICATION_TOKEN_SECRET + + async def on_after_register(self, user: User, request: Optional[Request] = None): + print(f"User {user.id} has registered.") + + async def on_after_forgot_password( + self, user: User, token: str, request: Optional[Request] = None + ): + print(f"User {user.id} has forgot their password. Reset token: {token}") + + async def on_after_request_verify( + self, user: User, token: str, request: Optional[Request] = None + ): + print(f"Verification requested for user {user.id}. Verification token: {token}") + + +async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): + yield UserManager(user_db) + + +cookie_transport = CookieTransport(cookie_max_age=SESSION_EXPIRE_TIME_SECONDS) + + +def get_database_strategy( + access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db), +) -> DatabaseStrategy: + return DatabaseStrategy( + access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS + ) + + +auth_backend = AuthenticationBackend( + name="database", + transport=cookie_transport, + get_strategy=get_database_strategy, +) + +fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) + +current_active_user = fastapi_users.current_user(active=True) diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py new file mode 100644 index 000000000..a2146d585 --- /dev/null +++ b/backend/danswer/db/auth.py @@ -0,0 +1,26 @@ +from collections.abc import AsyncGenerator + +from danswer.db.engine import build_async_engine +from danswer.db.models import AccessToken +from danswer.db.models import User +from fastapi import Depends +from fastapi_users.db import SQLAlchemyUserDatabase +from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase +from sqlalchemy.ext.asyncio import AsyncSession + + +async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + async with AsyncSession( + build_async_engine(), future=True, expire_on_commit=False + ) as async_session: + yield async_session + + +async def get_user_db(session: AsyncSession = Depends(get_async_session)): + yield SQLAlchemyUserDatabase(session, User) + + +async def get_access_token_db( + session: AsyncSession = Depends(get_async_session), +): + yield SQLAlchemyAccessTokenDatabase(session, AccessToken) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index d679981b6..79de9667b 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -4,7 +4,8 @@ from typing import Any from danswer.configs.constants import DocumentSource from danswer.connectors.models import InputType -from fastapi.encoders import jsonable_encoder +from fastapi_users.db import SQLAlchemyBaseUserTableUUID +from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID from sqlalchemy import DateTime from sqlalchemy import Enum from sqlalchemy import func @@ -19,6 +20,14 @@ class Base(DeclarativeBase): pass +class User(SQLAlchemyBaseUserTableUUID, Base): + pass + + +class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): + pass + + class IndexingStatus(str, PyEnum): NOT_STARTED = "not_started" IN_PROGRESS = "in_progress" diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 8149e8b63..8f360dd8f 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -1,4 +1,9 @@ import uvicorn +from danswer.auth.schemas import UserCreate +from danswer.auth.schemas import UserRead +from danswer.auth.schemas import UserUpdate +from danswer.auth.users import auth_backend +from danswer.auth.users import fastapi_users from danswer.configs.app_configs import APP_HOST from danswer.configs.app_configs import APP_PORT from danswer.server.admin import router as admin_router @@ -17,6 +22,32 @@ def get_application() -> FastAPI: application.include_router(backend_router) application.include_router(event_processing_router) application.include_router(admin_router) + + application.include_router( + fastapi_users.get_auth_router(auth_backend), + prefix="/auth/database", + tags=["auth"], + ) + application.include_router( + fastapi_users.get_register_router(UserRead, UserCreate), + prefix="/auth", + tags=["auth"], + ) + application.include_router( + fastapi_users.get_reset_password_router(), + prefix="/auth", + tags=["auth"], + ) + application.include_router( + fastapi_users.get_verify_router(UserRead), + prefix="/auth", + tags=["auth"], + ) + application.include_router( + fastapi_users.get_users_router(UserRead, UserUpdate), + prefix="/users", + tags=["users"], + ) return application diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index f5e777791..fce7c8042 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -1,10 +1,12 @@ import time from http import HTTPStatus +from danswer.auth.users import current_active_user from danswer.configs.app_configs import KEYWORD_MAX_HITS from danswer.configs.constants import CONTENT from danswer.configs.constants import SOURCE_LINKS from danswer.datastores import create_datastore +from danswer.db.models import User from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa.semantic_search import semantic_search from danswer.server.models import KeywordResponse @@ -14,6 +16,7 @@ from danswer.server.models import ServerStatus from danswer.utils.clients import TSClient from danswer.utils.logging import setup_logger from fastapi import APIRouter +from fastapi import Depends logger = setup_logger() @@ -21,6 +24,12 @@ logger = setup_logger() router = APIRouter() +# TODO delete this useless endpoint once frontend is integrated with auth +@router.get("/test-auth") +async def authenticated_route(user: User = Depends(current_active_user)): + return {"message": f"Hello {user.email}!"} + + @router.get("/", response_model=ServerStatus) @router.get("/status", response_model=ServerStatus) def read_server_status(): diff --git a/backend/mypy.ini b/backend/mypy.ini new file mode 100644 index 000000000..9b2161c57 --- /dev/null +++ b/backend/mypy.ini @@ -0,0 +1,4 @@ +[mypy] +mypy_path = . +explicit-package-bases = True +no-site-packages = True diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 538cd207e..947e692b8 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -2,6 +2,8 @@ alembic==1.10.4 asyncpg==0.27.0 beautifulsoup4==4.12.0 fastapi==0.95.0 +fastapi-users==11.0.0 +fastapi-users-db-sqlalchemy==5.0.0 filelock==3.12.0 google-api-python-client==2.86.0 google-auth-httplib2==0.1.0