mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-30 04:31:49 +02:00
DAN-57 Make qa / admin endpoints permissioned optionally (#37)
This commit is contained in:
parent
090578f1f3
commit
5af35cf07c
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
DISABLE_AUTH = os.environ.get("DISABLE_AUTH", "").lower() == "true"
|
||||||
|
|
||||||
SECRET = os.environ.get("SECRET", "")
|
SECRET = os.environ.get("SECRET", "")
|
||||||
SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600))
|
SESSION_EXPIRE_TIME_SECONDS = int(os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 3600))
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from danswer.auth.configs import DISABLE_AUTH
|
||||||
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_ID
|
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_ID
|
||||||
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_SECRET
|
from danswer.auth.configs import GOOGLE_OAUTH_CLIENT_SECRET
|
||||||
from danswer.auth.configs import SECRET
|
from danswer.auth.configs import SECRET
|
||||||
@ -10,7 +11,6 @@ from danswer.auth.schemas import UserRole
|
|||||||
from danswer.db.auth import get_access_token_db
|
from danswer.db.auth import get_access_token_db
|
||||||
from danswer.db.auth import get_user_count
|
from danswer.db.auth import get_user_count
|
||||||
from danswer.db.auth import get_user_db
|
from danswer.db.auth import get_user_db
|
||||||
from danswer.db.engine import build_async_engine
|
|
||||||
from danswer.db.models import AccessToken
|
from danswer.db.models import AccessToken
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
@ -28,7 +28,6 @@ from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
|||||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
from httpx_oauth.clients.google import GoogleOAuth2
|
from httpx_oauth.clients.google import GoogleOAuth2
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
|
|
||||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||||
@ -88,11 +87,13 @@ google_oauth_client = GoogleOAuth2(GOOGLE_OAUTH_CLIENT_ID, GOOGLE_OAUTH_CLIENT_S
|
|||||||
|
|
||||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||||
|
|
||||||
current_active_user = fastapi_users.current_user(active=True)
|
current_active_user = fastapi_users.current_user(active=True, optional=DISABLE_AUTH)
|
||||||
|
|
||||||
|
|
||||||
def current_admin_user(user: User = Depends(current_active_user)) -> User:
|
def current_admin_user(user: User = Depends(current_active_user)) -> User | None:
|
||||||
if not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
if DISABLE_AUTH:
|
||||||
|
return None
|
||||||
|
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access denied. User is not an admin.",
|
detail="Access denied. User is not an admin.",
|
||||||
|
@ -2,10 +2,6 @@ from danswer.datastores.interfaces import DatastoreFilter
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class ServerStatus(BaseModel):
|
|
||||||
status: str
|
|
||||||
|
|
||||||
|
|
||||||
class UserRoleResponse(BaseModel):
|
class UserRoleResponse(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from collections.abc import Generator
|
||||||
|
|
||||||
from danswer.auth.schemas import UserRole
|
from danswer.auth.schemas import UserRole
|
||||||
from danswer.auth.users import current_active_user
|
from danswer.auth.users import current_active_user
|
||||||
@ -18,7 +18,6 @@ from danswer.server.models import KeywordResponse
|
|||||||
from danswer.server.models import QAQuestion
|
from danswer.server.models import QAQuestion
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
from danswer.server.models import SearchDoc
|
from danswer.server.models import SearchDoc
|
||||||
from danswer.server.models import ServerStatus
|
|
||||||
from danswer.server.models import UserByEmail
|
from danswer.server.models import UserByEmail
|
||||||
from danswer.server.models import UserRoleResponse
|
from danswer.server.models import UserRoleResponse
|
||||||
from danswer.utils.clients import TSClient
|
from danswer.utils.clients import TSClient
|
||||||
@ -26,7 +25,6 @@ from danswer.utils.logging import setup_logger
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi import Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@ -36,39 +34,17 @@ logger = setup_logger()
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
# TODO delete this useless endpoint once frontend is integrated with auth
|
|
||||||
@router.get("/test-auth")
|
|
||||||
async def authenticated_route(user: User = Depends(current_active_user)):
|
|
||||||
return {"message": f"Hello {user.email} who is a {user.role}!"}
|
|
||||||
|
|
||||||
|
|
||||||
# TODO delete this useless endpoint once frontend is integrated with auth
|
|
||||||
@router.get("/test-admin")
|
|
||||||
async def admin_route(user: User = Depends(current_admin_user)):
|
|
||||||
return {"message": f"Hello {user.email} who is a {user.role}!"}
|
|
||||||
|
|
||||||
|
|
||||||
# TODO DAN-39 delete this once oauth is built out and tested
|
|
||||||
@router.api_route("/test", methods=["GET", "POST", "PUT", "PATCH", "DELETE"])
|
|
||||||
def test_endpoint(request: Request):
|
|
||||||
print(request)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/get-user-role", response_model=UserRoleResponse)
|
@router.get("/get-user-role", response_model=UserRoleResponse)
|
||||||
async def get_user_role(user: User = Depends(current_active_user)):
|
async def get_user_role(user: User = Depends(current_active_user)) -> UserRoleResponse:
|
||||||
|
if user is None:
|
||||||
|
raise ValueError("Invalid or missing user.")
|
||||||
return UserRoleResponse(role=user.role)
|
return UserRoleResponse(role=user.role)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=ServerStatus)
|
|
||||||
@router.get("/status", response_model=ServerStatus)
|
|
||||||
def read_server_status():
|
|
||||||
return ServerStatus(status=HTTPStatus.OK.value)
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/promote-user-to-admin", response_model=None)
|
@router.patch("/promote-user-to-admin", response_model=None)
|
||||||
async def promote_admin(
|
async def promote_admin(
|
||||||
user_email: UserByEmail, user: User = Depends(current_active_user)
|
user_email: UserByEmail, user: User = Depends(current_admin_user)
|
||||||
):
|
) -> None:
|
||||||
if user.role != UserRole.ADMIN:
|
if user.role != UserRole.ADMIN:
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
async with AsyncSession(build_async_engine()) as asession:
|
async with AsyncSession(build_async_engine()) as asession:
|
||||||
@ -83,7 +59,9 @@ async def promote_admin(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/direct-qa", response_model=QAResponse)
|
@router.get("/direct-qa", response_model=QAResponse)
|
||||||
def direct_qa(question: QAQuestion = Depends()) -> QAResponse:
|
def direct_qa(
|
||||||
|
question: QAQuestion = Depends(), _: User = Depends(current_active_user)
|
||||||
|
) -> QAResponse:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
query = question.query
|
query = question.query
|
||||||
@ -118,10 +96,12 @@ def direct_qa(question: QAQuestion = Depends()) -> QAResponse:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/stream-direct-qa")
|
@router.get("/stream-direct-qa")
|
||||||
def stream_direct_qa(question: QAQuestion = Depends()):
|
def stream_direct_qa(
|
||||||
|
question: QAQuestion = Depends(), _: User = Depends(current_active_user)
|
||||||
|
) -> StreamingResponse:
|
||||||
top_documents_key = "top_documents"
|
top_documents_key = "top_documents"
|
||||||
|
|
||||||
def stream_qa_portions():
|
def stream_qa_portions() -> Generator[str, None, None]:
|
||||||
query = question.query
|
query = question.query
|
||||||
collection = question.collection
|
collection = question.collection
|
||||||
filters = question.filters
|
filters = question.filters
|
||||||
@ -153,12 +133,15 @@ def stream_direct_qa(question: QAQuestion = Depends()):
|
|||||||
):
|
):
|
||||||
logger.debug(response_dict)
|
logger.debug(response_dict)
|
||||||
yield yield_json_line(response_dict)
|
yield yield_json_line(response_dict)
|
||||||
|
return
|
||||||
|
|
||||||
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/keyword-search", response_model=KeywordResponse)
|
@router.get("/keyword-search", response_model=KeywordResponse)
|
||||||
def keyword_search(question: QAQuestion = Depends()):
|
def keyword_search(
|
||||||
|
question: QAQuestion = Depends(), _: User = Depends(current_active_user)
|
||||||
|
) -> KeywordResponse:
|
||||||
ts_client = TSClient.get_instance()
|
ts_client = TSClient.get_instance()
|
||||||
query = question.query
|
query = question.query
|
||||||
collection = question.collection
|
collection = question.collection
|
||||||
|
Loading…
x
Reference in New Issue
Block a user