mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
DAN-57 Make qa / admin endpoints permissioned optionally (#37)
This commit is contained in:
parent
090578f1f3
commit
5af35cf07c
@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
DISABLE_AUTH = os.environ.get("DISABLE_AUTH", "").lower() == "true"
|
||||
|
||||
SECRET = os.environ.get("SECRET", "")
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600))
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from danswer.auth.configs import DISABLE_AUTH
|
||||
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_ID
|
||||
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_SECRET
|
||||
from danswer.auth.configs import SECRET
|
||||
@ -10,7 +11,6 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import build_async_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from fastapi import Depends
|
||||
@ -28,7 +28,6 @@ from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
@ -88,11 +87,13 @@ google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_S
|
||||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
current_active_user = fastapi_users.current_user(active=True, optional=DISABLE_AUTH)
|
||||
|
||||
|
||||
def current_admin_user(user: User = Depends(current_active_user)) -> User:
|
||||
if not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
def current_admin_user(user: User = Depends(current_active_user)) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not an admin.",
|
||||
|
@ -2,10 +2,6 @@ from danswer.datastores.interfaces import DatastoreFilter
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ServerStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class UserRoleResponse(BaseModel):
|
||||
role: str
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.users import current_active_user
|
||||
@ -18,7 +18,6 @@ from danswer.server.models import KeywordResponse
|
||||
from danswer.server.models import QAQuestion
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import SearchDoc
|
||||
from danswer.server.models import ServerStatus
|
||||
from danswer.server.models import UserByEmail
|
||||
from danswer.server.models import UserRoleResponse
|
||||
from danswer.utils.clients import TSClient
|
||||
@ -26,7 +25,6 @@ from danswer.utils.logging import setup_logger
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@ -36,39 +34,17 @@ logger = setup_logger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# TODO delete this useless endpoint once frontend is integrated with auth
|
||||
@router.get("/test-auth")
|
||||
async def authenticated_route(user: User = Depends(current_active_user)):
|
||||
return {"message": f"Hello {user.email} who is a {user.role}!"}
|
||||
|
||||
|
||||
# TODO delete this useless endpoint once frontend is integrated with auth
|
||||
@router.get("/test-admin")
|
||||
async def admin_route(user: User = Depends(current_admin_user)):
|
||||
return {"message": f"Hello {user.email} who is a {user.role}!"}
|
||||
|
||||
|
||||
# TODO DAN-39 delete this once oauth is built out and tested
|
||||
@router.api_route("/test", methods=["GET", "POST", "PUT", "PATCH", "DELETE"])
|
||||
def test_endpoint(request: Request):
|
||||
print(request)
|
||||
|
||||
|
||||
@router.get("/get-user-role", response_model=UserRoleResponse)
|
||||
async def get_user_role(user: User = Depends(current_active_user)):
|
||||
async def get_user_role(user: User = Depends(current_active_user)) -> UserRoleResponse:
|
||||
if user is None:
|
||||
raise ValueError("Invalid or missing user.")
|
||||
return UserRoleResponse(role=user.role)
|
||||
|
||||
|
||||
@router.get("/", response_model=ServerStatus)
|
||||
@router.get("/status", response_model=ServerStatus)
|
||||
def read_server_status():
|
||||
return ServerStatus(status=HTTPStatus.OK.value)
|
||||
|
||||
|
||||
@router.patch("/promote-user-to-admin", response_model=None)
|
||||
async def promote_admin(
|
||||
user_email: UserByEmail, user: User = Depends(current_active_user)
|
||||
):
|
||||
user_email: UserByEmail, user: User = Depends(current_admin_user)
|
||||
) -> None:
|
||||
if user.role != UserRole.ADMIN:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
async with AsyncSession(build_async_engine()) as asession:
|
||||
@ -83,7 +59,9 @@ async def promote_admin(
|
||||
|
||||
|
||||
@router.get("/direct-qa", response_model=QAResponse)
|
||||
def direct_qa(question: QAQuestion = Depends()) -> QAResponse:
|
||||
def direct_qa(
|
||||
question: QAQuestion = Depends(), _: User = Depends(current_active_user)
|
||||
) -> QAResponse:
|
||||
start_time = time.time()
|
||||
|
||||
query = question.query
|
||||
@ -118,10 +96,12 @@ def direct_qa(question: QAQuestion = Depends()) -> QAResponse:
|
||||
|
||||
|
||||
@router.get("/stream-direct-qa")
|
||||
def stream_direct_qa(question: QAQuestion = Depends()):
|
||||
def stream_direct_qa(
|
||||
question: QAQuestion = Depends(), _: User = Depends(current_active_user)
|
||||
) -> StreamingResponse:
|
||||
top_documents_key = "top_documents"
|
||||
|
||||
def stream_qa_portions():
|
||||
def stream_qa_portions() -> Generator[str, None, None]:
|
||||
query = question.query
|
||||
collection = question.collection
|
||||
filters = question.filters
|
||||
@ -153,12 +133,15 @@ def stream_direct_qa(question: QAQuestion = Depends()):
|
||||
):
|
||||
logger.debug(response_dict)
|
||||
yield yield_json_line(response_dict)
|
||||
return
|
||||
|
||||
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
||||
|
||||
|
||||
@router.get("/keyword-search", response_model=KeywordResponse)
|
||||
def keyword_search(question: QAQuestion = Depends()):
|
||||
def keyword_search(
|
||||
question: QAQuestion = Depends(), _: User = Depends(current_active_user)
|
||||
) -> KeywordResponse:
|
||||
ts_client = TSClient.get_instance()
|
||||
query = question.query
|
||||
collection = question.collection
|
||||
|
Loading…
x
Reference in New Issue
Block a user