DAN-57 Make qa / admin endpoints permissioned optionally (#37)

This commit is contained in:
Yuhong Sun 2023-05-12 20:34:30 -07:00 committed by GitHub
parent 090578f1f3
commit 5af35cf07c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 43 deletions

View File

@ -1,5 +1,7 @@
import os
DISABLE_AUTH = os.environ.get("DISABLE_AUTH", "").lower() == "true"
SECRET = os.environ.get("SECRET", "")
SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600))

View File

@ -1,6 +1,7 @@
import uuid
from typing import Optional
from danswer.auth.configs import DISABLE_AUTH
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_ID
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_SECRET
from danswer.auth.configs import SECRET
@ -10,7 +11,6 @@ from danswer.auth.schemas import UserRole
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import build_async_engine
from danswer.db.models import AccessToken
from danswer.db.models import User
from fastapi import Depends
@ -28,7 +28,6 @@ from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.db import SQLAlchemyUserDatabase
from httpx_oauth.clients.google import GoogleOAuth2
from sqlalchemy.ext.asyncio import AsyncSession
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
@ -88,11 +87,13 @@ google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_S
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
current_active_user = fastapi_users.current_user(active=True)
current_active_user = fastapi_users.current_user(active=True, optional=DISABLE_AUTH)
def current_admin_user(user: User = Depends(current_active_user)) -> User:
if not hasattr(user, "role") or user.role != UserRole.ADMIN:
def current_admin_user(user: User = Depends(current_active_user)) -> User | None:
if DISABLE_AUTH:
return None
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not an admin.",

View File

@ -2,10 +2,6 @@ from danswer.datastores.interfaces import DatastoreFilter
from pydantic import BaseModel
class ServerStatus(BaseModel):
status: str
class UserRoleResponse(BaseModel):
role: str

View File

@ -1,5 +1,5 @@
import time
from http import HTTPStatus
from collections.abc import Generator
from danswer.auth.schemas import UserRole
from danswer.auth.users import current_active_user
@ -18,7 +18,6 @@ from danswer.server.models import KeywordResponse
from danswer.server.models import QAQuestion
from danswer.server.models import QAResponse
from danswer.server.models import SearchDoc
from danswer.server.models import ServerStatus
from danswer.server.models import UserByEmail
from danswer.server.models import UserRoleResponse
from danswer.utils.clients import TSClient
@ -26,7 +25,6 @@ from danswer.utils.logging import setup_logger
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi.responses import StreamingResponse
from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession
@ -36,39 +34,17 @@ logger = setup_logger()
router = APIRouter()
# TODO delete this useless endpoint once frontend is integrated with auth
@router.get("/test-auth")
async def authenticated_route(user: User = Depends(current_active_user)):
return {"message": f"Hello {user.email} who is a {user.role}!"}
# TODO delete this useless endpoint once frontend is integrated with auth
@router.get("/test-admin")
async def admin_route(user: User = Depends(current_admin_user)):
return {"message": f"Hello {user.email} who is a {user.role}!"}
# TODO DAN-39 delete this once oauth is built out and tested
@router.api_route("/test", methods=["GET", "POST", "PUT", "PATCH", "DELETE"])
def test_endpoint(request: Request):
print(request)
@router.get("/get-user-role", response_model=UserRoleResponse)
async def get_user_role(user: User = Depends(current_active_user)):
async def get_user_role(user: User = Depends(current_active_user)) -> UserRoleResponse:
if user is None:
raise ValueError("Invalid or missing user.")
return UserRoleResponse(role=user.role)
@router.get("/", response_model=ServerStatus)
@router.get("/status", response_model=ServerStatus)
def read_server_status():
return ServerStatus(status=HTTPStatus.OK.value)
@router.patch("/promote-user-to-admin", response_model=None)
async def promote_admin(
user_email: UserByEmail, user: User = Depends(current_active_user)
):
user_email: UserByEmail, user: User = Depends(current_admin_user)
) -> None:
if user.role != UserRole.ADMIN:
raise HTTPException(status_code=401, detail="Unauthorized")
async with AsyncSession(build_async_engine()) as asession:
@ -83,7 +59,9 @@ async def promote_admin(
@router.get("/direct-qa", response_model=QAResponse)
def direct_qa(question: QAQuestion = Depends()) -> QAResponse:
def direct_qa(
question: QAQuestion = Depends(), _: User = Depends(current_active_user)
) -> QAResponse:
start_time = time.time()
query = question.query
@ -118,10 +96,12 @@ def direct_qa(question: QAQuestion = Depends()) -> QAResponse:
@router.get("/stream-direct-qa")
def stream_direct_qa(question: QAQuestion = Depends()):
def stream_direct_qa(
question: QAQuestion = Depends(), _: User = Depends(current_active_user)
) -> StreamingResponse:
top_documents_key = "top_documents"
def stream_qa_portions():
def stream_qa_portions() -> Generator[str, None, None]:
query = question.query
collection = question.collection
filters = question.filters
@ -153,12 +133,15 @@ def stream_direct_qa(question: QAQuestion = Depends()):
):
logger.debug(response_dict)
yield yield_json_line(response_dict)
return
return StreamingResponse(stream_qa_portions(), media_type="application/json")
@router.get("/keyword-search", response_model=KeywordResponse)
def keyword_search(question: QAQuestion = Depends()):
def keyword_search(
question: QAQuestion = Depends(), _: User = Depends(current_active_user)
) -> KeywordResponse:
ts_client = TSClient.get_instance()
query = question.query
collection = question.collection