This commit is contained in:
pablodanswer 2024-08-28 14:43:16 -07:00
parent 3838908e70
commit 7b895008d3
8 changed files with 252 additions and 13 deletions

View File

@ -1,3 +1,4 @@
import contextlib
import smtplib
import uuid
from collections.abc import AsyncGenerator
@ -8,6 +9,7 @@ from email.mime.text import MIMEText
from typing import Optional
from typing import Tuple
import jwt
from email_validator import EmailNotValidError
from email_validator import validate_email
from fastapi import APIRouter
@ -54,6 +56,7 @@ from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_async_session
from danswer.db.engine import get_session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
@ -192,10 +195,87 @@ def send_user_verification_email(
s.send_message(msg)
def verify_sso_token(token: str) -> dict:
print("VERIFYING")
try:
print(token)
print("DECODING")
# payload = jwt.decode(token, settings.SSO_SECRET_KEY, algorithms=["HS256"])
payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"])
print(payload)
if datetime.now(timezone.utc) > datetime.fromtimestamp(
payload["exp"], timezone.utc
):
print("EXPIRED")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired"
)
return payload
except jwt.PyJWTError as e:
print(e)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
async def get_or_create_user(email: str, user_id: str, tenant_id: str) -> User:
get_async_session_context = contextlib.asynccontextmanager(get_async_session)
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
async with get_async_session_context() as session:
async with get_user_db_context(session) as user_db:
existing_user = await user_db.get_by_email(email)
if existing_user:
return existing_user
# Generate a random password
uuid.uuid4().hex
# hashed_password = get_password_hash(random_password)
new_user = {
"email": email,
"id": uuid.UUID(user_id),
"role": UserRole.BASIC,
"oidc_expiry": None,
"default_model": None,
"chosen_assistants": None,
"hashed_password": "p",
"is_active": True,
"is_superuser": False,
"is_verified": True,
}
created_user = await user_db.create(new_user)
return created_user
async def create_user_session(user: User, strategy: Strategy) -> str:
token = await strategy.write_token(user)
return token
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
async def sso_authenticate(
self,
email: str,
user_id: str,
tenant_id: str,
) -> models.UP:
user = await self.get_by_email(email)
if not user:
# user_create = UserCreate(email=email, password=secrets.token_urlsafe(32))
user_create = UserCreate(role=UserRole.BASIC)
user = await self.create(user_create)
# Update user with tenant information if needed
if user.tenant_id != tenant_id:
await self.user_db.update(user, {"tenant_id": tenant_id})
return user
async def create(
self,
user_create: schemas.UC | UserCreate,
@ -280,6 +360,30 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
send_user_verification_email(user.email, token)
async def sso_authenticate(
self,
email: str,
user_id: str,
tenant_id: str,
) -> models.UP:
user = await self.get_by_email(email)
if not user:
user_create = UserCreate(UserRole.BASIC)
user = await self.create(user_create)
# Update user with tenant information if needed
if user.tenant_id != tenant_id:
await self.user_db.update(user, {"tenant_id": tenant_id})
return user
# async def get_user_manager(
# user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
# ) -> AsyncGenerator[UserManager, None]:
# yield UserManager(user_db)
async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
) -> AsyncGenerator[UserManager, None]:
@ -375,7 +479,8 @@ async def optional_user(
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
return await versioned_fetch_user(request, user, db_session)
val = await versioned_fetch_user(request, user, db_session)
return val
async def double_check_user(

View File

@ -11,7 +11,6 @@ from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from danswer.auth.schemas import UserRole
from danswer.db.engine import get_async_session
from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.models import AccessToken
@ -46,11 +45,11 @@ async def get_user_count() -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
async def create(self, create_dict: Dict[str, Any]) -> UP:
user_count = await get_user_count()
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
create_dict["role"] = UserRole.ADMIN
else:
create_dict["role"] = UserRole.BASIC
await get_user_count()
# if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
# create_dict["role"] = UserRole.ADMIN
# else:
# create_dict["role"] = UserRole.BASIC
return await super().create(create_dict)

View File

@ -27,7 +27,7 @@ class RerankingDetails(BaseModel):
# If model is None (or num_rerank is 0), then reranking is turned off
rerank_model_name: str | None
rerank_provider_type: RerankerProvider | None
rerank_api_key: str | None
rerank_api_key: str | None = None
num_rerank: int

View File

@ -34,6 +34,7 @@ PUBLIC_ENDPOINT_SPECS = [
("/auth/reset-password", {"POST"}),
("/auth/request-verify-token", {"POST"}),
("/auth/verify", {"POST"}),
("/settings/auth/sso-callback", {"POST"}),
("/users/me", {"GET"}),
("/users/me", {"PATCH"}),
("/users/{id}", {"GET"}),

View File

@ -3,12 +3,24 @@ from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordBearer
from fastapi_users.authentication import Strategy
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from danswer.auth.users import create_user_session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import get_database_strategy
from danswer.auth.users import get_or_create_user
from danswer.auth.users import get_user_manager
from danswer.auth.users import is_user_admin
from danswer.auth.users import UserManager
from danswer.auth.users import verify_sso_token
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session
@ -28,12 +40,65 @@ from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.utils.logger import setup_logger
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
admin_router = APIRouter(prefix="/admin/settings")
basic_router = APIRouter(prefix="/settings")
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/settings")
basic_router = APIRouter(prefix="/settings")
@basic_router.post("/auth/sso-callback")
async def sso_callback(
sso_token: str = Query(
..., alias="sso_token"
), # Get SSO token from query parameters
strategy: Strategy = Depends(get_database_strategy),
user_manager: UserManager = Depends(get_user_manager),
):
print("SSO callback reached")
payload = verify_sso_token(sso_token)
print("hi")
user = await get_or_create_user(
payload["email"], payload["user_id"], payload["tenant_id"]
)
print(user)
session_token = await create_user_session(user, strategy)
print("HIII")
response = RedirectResponse(url="/")
response.set_cookie(
key="session",
value=session_token,
httponly=True,
max_age=SESSION_EXPIRE_TIME_SECONDS,
secure=WEB_DOMAIN.startswith("https"),
)
return response
# @basic_router.post("/auth/sso-callback")
# async def sso_callback(
# user = Depends(current_user),
# token: str = Depends(oauth2_scheme),
# strategy: Strategy = Depends(get_database_strategy),
# user_manager: UserManager = Depends(get_user_manager),
# ):
# print('SSO callback reached')
# payload = verify_sso_token(token)
# user = await get_or_create_user(payload["email"], payload["user_id"], payload["tenant_id"])
# session_token = await create_user_session(user, strategy)
# response = RedirectResponse(url="/")
# response.set_cookie(
# key="session",
# value=session_token,
# httponly=True,
# max_age=SESSION_EXPIRE_TIME_SECONDS,
# secure=WEB_DOMAIN.startswith("https"),
# )
# return response
@admin_router.put("")

View File

@ -3,6 +3,9 @@ import { buildUrl } from "@/lib/utilsSS";
import { NextRequest, NextResponse } from "next/server";
export const GET = async (request: NextRequest) => {
console.log("request", request);
console.log("HELLO");
console.log("request.nextUrl.search", request.nextUrl.search);
// Wrapper around the FastAPI endpoint /auth/oauth/callback,
// which adds back a redirect to the main app.
const url = new URL(buildUrl("/auth/oauth/callback"));

View File

@ -15,9 +15,7 @@ import { buildClientUrl } from "@/lib/utilsSS";
import { Inter } from "next/font/google";
import Head from "next/head";
import { EnterpriseSettings } from "./admin/settings/interfaces";
import { redirect } from "next/navigation";
import { Button, Card } from "@tremor/react";
import LogoType from "@/components/header/LogoType";
import { Card } from "@tremor/react";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { Logo } from "@/components/Logo";
import { UserProvider } from "@/components/user/UserProvider";

View File

@ -0,0 +1,68 @@
"use client";
import { useEffect, useState } from "react";
import { useRouter, useSearchParams } from "next/navigation";
import { Card, Text } from "@tremor/react";
import { Spinner } from "@/components/Spinner";
export default function SSOCallback() {
const router = useRouter();
const searchParams = useSearchParams();
const [error, setError] = useState<string | null>(null);
const [authStatus, setAuthStatus] = useState<string>("Authenticating...");
useEffect(() => {
const verifyToken = async () => {
const ssoToken = searchParams.get("sso_token");
if (!ssoToken) {
setError("No SSO token found");
return;
}
try {
setAuthStatus("Verifying SSO token...");
const response = await fetch(
`/api/settings/auth/sso-callback?sso_token=${ssoToken}`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
}
);
if (response.ok) {
setAuthStatus("Authentication successful!");
setTimeout(() => {
setAuthStatus("Redirecting to dashboard...");
setTimeout(() => {
router.push("/dashboard");
}, 1000);
}, 1000);
} else {
const errorData = await response.json();
setError(errorData.detail || "Authentication failed");
}
} catch (error) {
console.error("Error verifying token:", error);
setError("An unexpected error occurred");
}
};
verifyToken();
}, [router, searchParams]);
return (
<div className="flex items-center justify-center min-h-screen bg-gray-100">
<Card className="max-w-lg p-8 text-center">
{error ? (
<Text className="text-red-500">{error}</Text>
) : (
<>
<Spinner />
<Text className="text-lg font-semibold">{authStatus}</Text>
</>
)}
</Card>
</div>
);
}