From 5af35cf07c16bde59d0302712a66e360e3fa7922 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 12 May 2023 20:34:30 -0700 Subject: [PATCH] DAN-57 Make qa / admin endpoints permissioned optionally (#37) --- backend/danswer/auth/configs.py | 2 + backend/danswer/auth/users.py | 11 ++--- backend/danswer/server/models.py | 4 -- backend/danswer/server/search_backend.py | 51 ++++++++---------------- 4 files changed, 25 insertions(+), 43 deletions(-) diff --git a/backend/danswer/auth/configs.py b/backend/danswer/auth/configs.py index 565c098b6..276cff5e1 100644 --- a/backend/danswer/auth/configs.py +++ b/backend/danswer/auth/configs.py @@ -1,5 +1,7 @@ import os +DISABLE_AUTH = os.environ.get("DISABLE_AUTH", "").lower() == "true" + SECRET = os.environ.get("SECRET", "") SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600)) diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 9513d972b..e535c7ad6 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -1,6 +1,7 @@ import uuid from typing import Optional +from danswer.auth.configs import DISABLE_AUTH from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_ID from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_SECRET from danswer.auth.configs import SECRET @@ -10,7 +11,6 @@ from danswer.auth.schemas import UserRole from danswer.db.auth import get_access_token_db from danswer.db.auth import get_user_count from danswer.db.auth import get_user_db -from danswer.db.engine import build_async_engine from danswer.db.models import AccessToken from danswer.db.models import User from fastapi import Depends @@ -28,7 +28,6 @@ from fastapi_users.authentication.strategy.db import AccessTokenDatabase from fastapi_users.authentication.strategy.db import DatabaseStrategy from fastapi_users.db import SQLAlchemyUserDatabase from httpx_oauth.clients.google import GoogleOAuth2 -from sqlalchemy.ext.asyncio import AsyncSession class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): @@ -88,11 +87,13 @@ google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_S fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend]) -current_active_user = fastapi_users.current_user(active=True) +current_active_user = fastapi_users.current_user(active=True, optional=DISABLE_AUTH) -def current_admin_user(user: User = Depends(current_active_user)) -> User: - if not hasattr(user, "role") or user.role != UserRole.ADMIN: +def current_admin_user(user: User = Depends(current_active_user)) -> User | None: + if DISABLE_AUTH: + return None + if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Access denied. User is not an admin.", diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 704c3a3c4..ceebdfefc 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -2,10 +2,6 @@ from danswer.datastores.interfaces import DatastoreFilter from pydantic import BaseModel -class ServerStatus(BaseModel): - status: str - - class UserRoleResponse(BaseModel): role: str diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 35678de8c..b40909685 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -1,5 +1,5 @@ import time -from http import HTTPStatus +from collections.abc import Generator from danswer.auth.schemas import UserRole from danswer.auth.users import current_active_user @@ -18,7 +18,6 @@ from danswer.server.models import KeywordResponse from danswer.server.models import QAQuestion from danswer.server.models import QAResponse from danswer.server.models import SearchDoc -from danswer.server.models import ServerStatus from danswer.server.models import UserByEmail from danswer.server.models import UserRoleResponse from danswer.utils.clients import TSClient @@ -26,7 +25,6 @@ from danswer.utils.logging import setup_logger from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException -from fastapi import Request from fastapi.responses import StreamingResponse from fastapi_users.db import SQLAlchemyUserDatabase from sqlalchemy.ext.asyncio import AsyncSession @@ -36,39 +34,17 @@ logger = setup_logger() router = APIRouter() -# TODO delete this useless endpoint once frontend is integrated with auth -@router.get("/test-auth") -async def authenticated_route(user: User = Depends(current_active_user)): - return {"message": f"Hello {user.email} who is a {user.role}!"} - - -# TODO delete this useless endpoint once frontend is integrated with auth -@router.get("/test-admin") -async def admin_route(user: User = Depends(current_admin_user)): - return {"message": f"Hello {user.email} who is a {user.role}!"} - - -# TODO DAN-39 delete this once oauth is built out and tested -@router.api_route("/test", methods=["GET", "POST", "PUT", "PATCH", "DELETE"]) -def test_endpoint(request: Request): - print(request) - - @router.get("/get-user-role", response_model=UserRoleResponse) -async def get_user_role(user: User = Depends(current_active_user)): +async def get_user_role(user: User = Depends(current_active_user)) -> UserRoleResponse: + if user is None: + raise ValueError("Invalid or missing user.") return UserRoleResponse(role=user.role) -@router.get("/", response_model=ServerStatus) -@router.get("/status", response_model=ServerStatus) -def read_server_status(): - return ServerStatus(status=HTTPStatus.OK.value) - - @router.patch("/promote-user-to-admin", response_model=None) async def promote_admin( - user_email: UserByEmail, user: User = Depends(current_active_user) -): + user_email: UserByEmail, user: User = Depends(current_admin_user) +) -> None: if user.role != UserRole.ADMIN: raise HTTPException(status_code=401, detail="Unauthorized") async with AsyncSession(build_async_engine()) as asession: @@ -83,7 +59,9 @@ async def promote_admin( @router.get("/direct-qa", response_model=QAResponse) -def direct_qa(question: QAQuestion = Depends()) -> QAResponse: +def direct_qa( + question: QAQuestion = Depends(), _: User = Depends(current_active_user) +) -> QAResponse: start_time = time.time() query = question.query @@ -118,10 +96,12 @@ def direct_qa(question: QAQuestion = Depends()) -> QAResponse: @router.get("/stream-direct-qa") -def stream_direct_qa(question: QAQuestion = Depends()): +def stream_direct_qa( + question: QAQuestion = Depends(), _: User = Depends(current_active_user) +) -> StreamingResponse: top_documents_key = "top_documents" - def stream_qa_portions(): + def stream_qa_portions() -> Generator[str, None, None]: query = question.query collection = question.collection filters = question.filters @@ -153,12 +133,15 @@ def stream_direct_qa(question: QAQuestion = Depends()): ): logger.debug(response_dict) yield yield_json_line(response_dict) + return return StreamingResponse(stream_qa_portions(), media_type="application/json") @router.get("/keyword-search", response_model=KeywordResponse) -def keyword_search(question: QAQuestion = Depends()): +def keyword_search( + question: QAQuestion = Depends(), _: User = Depends(current_active_user) +) -> KeywordResponse: ts_client = TSClient.get_instance() query = question.query collection = question.collection