diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index dff6a6036..667c515d8 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -1,3 +1,4 @@ +import contextlib import smtplib import uuid from collections.abc import AsyncGenerator @@ -8,6 +9,7 @@ from email.mime.text import MIMEText from typing import Optional from typing import Tuple +import jwt from email_validator import EmailNotValidError from email_validator import validate_email from fastapi import APIRouter @@ -54,6 +56,7 @@ from danswer.db.auth import get_access_token_db from danswer.db.auth import get_default_admin_user_emails from danswer.db.auth import get_user_count from danswer.db.auth import get_user_db +from danswer.db.engine import get_async_session from danswer.db.engine import get_session from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import AccessToken @@ -192,10 +195,87 @@ def send_user_verification_email( s.send_message(msg) +def verify_sso_token(token: str) -> dict: + print("VERIFYING") + try: + print(token) + print("DECODING") + # payload = jwt.decode(token, settings.SSO_SECRET_KEY, algorithms=["HS256"]) + payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"]) + print(payload) + if datetime.now(timezone.utc) > datetime.fromtimestamp( + payload["exp"], timezone.utc + ): + print("EXPIRED") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired" + ) + return payload + except jwt.PyJWTError as e: + print(e) + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" + ) + + +async def get_or_create_user(email: str, user_id: str, tenant_id: str) -> User: + get_async_session_context = contextlib.asynccontextmanager(get_async_session) + get_user_db_context = contextlib.asynccontextmanager(get_user_db) + + async with get_async_session_context() as session: + async with get_user_db_context(session) as user_db: + existing_user = await user_db.get_by_email(email) + if existing_user: + return existing_user + + # Generate a random password + uuid.uuid4().hex + # hashed_password = get_password_hash(random_password) + + new_user = { + "email": email, + "id": uuid.UUID(user_id), + "role": UserRole.BASIC, + "oidc_expiry": None, + "default_model": None, + "chosen_assistants": None, + "hashed_password": "p", + "is_active": True, + "is_superuser": False, + "is_verified": True, + } + created_user = await user_db.create(new_user) + return created_user + + +async def create_user_session(user: User, strategy: Strategy) -> str: + token = await strategy.write_token(user) + return token + + class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = USER_AUTH_SECRET verification_token_secret = USER_AUTH_SECRET + async def sso_authenticate( + self, + email: str, + user_id: str, + tenant_id: str, + ) -> models.UP: + user = await self.get_by_email(email) + if not user: + # user_create = UserCreate(email=email, password=secrets.token_urlsafe(32)) + user_create = UserCreate(role=UserRole.BASIC) + user = await self.create(user_create) + + # Update user with tenant information if needed + if user.tenant_id != tenant_id: + await self.user_db.update(user, {"tenant_id": tenant_id}) + + return user + async def create( self, user_create: schemas.UC | UserCreate, @@ -280,6 +360,30 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): send_user_verification_email(user.email, token) +async def sso_authenticate( + self, + email: str, + user_id: str, + tenant_id: str, +) -> models.UP: + user = await self.get_by_email(email) + if not user: + user_create = UserCreate(UserRole.BASIC) + user = await self.create(user_create) + + # Update user with tenant information if needed + if user.tenant_id != tenant_id: + await self.user_db.update(user, {"tenant_id": tenant_id}) + + return user + + +# async def get_user_manager( +# user_db: SQLAlchemyUserDatabase = Depends(get_user_db), +# ) -> AsyncGenerator[UserManager, None]: +# yield UserManager(user_db) + + async def get_user_manager( user_db: SQLAlchemyUserDatabase = Depends(get_user_db), ) -> AsyncGenerator[UserManager, None]: @@ -375,7 +479,8 @@ async def optional_user( versioned_fetch_user = fetch_versioned_implementation( "danswer.auth.users", "optional_user_" ) - return await versioned_fetch_user(request, user, db_session) + val = await versioned_fetch_user(request, user, db_session) + return val async def double_check_user( diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py index 161fdc8f1..49c7f822f 100644 --- a/backend/danswer/db/auth.py +++ b/backend/danswer/db/auth.py @@ -11,7 +11,6 @@ from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from danswer.auth.schemas import UserRole from danswer.db.engine import get_async_session from danswer.db.engine import get_sqlalchemy_async_engine from danswer.db.models import AccessToken @@ -46,11 +45,11 @@ async def get_user_count() -> int: # Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase): async def create(self, create_dict: Dict[str, Any]) -> UP: - user_count = await get_user_count() - if user_count == 0 or create_dict["email"] in get_default_admin_user_emails(): - create_dict["role"] = UserRole.ADMIN - else: - create_dict["role"] = UserRole.BASIC + await get_user_count() + # if user_count == 0 or create_dict["email"] in get_default_admin_user_emails(): + # create_dict["role"] = UserRole.ADMIN + # else: + # create_dict["role"] = UserRole.BASIC return await super().create(create_dict) diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 576d1503b..15387e6c6 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -27,7 +27,7 @@ class RerankingDetails(BaseModel): # If model is None (or num_rerank is 0), then reranking is turned off rerank_model_name: str | None rerank_provider_type: RerankerProvider | None - rerank_api_key: str | None + rerank_api_key: str | None = None num_rerank: int diff --git a/backend/danswer/server/auth_check.py b/backend/danswer/server/auth_check.py index 12258eba2..f85856009 100644 --- a/backend/danswer/server/auth_check.py +++ b/backend/danswer/server/auth_check.py @@ -34,6 +34,7 @@ PUBLIC_ENDPOINT_SPECS = [ ("/auth/reset-password", {"POST"}), ("/auth/request-verify-token", {"POST"}), ("/auth/verify", {"POST"}), + ("/settings/auth/sso-callback", {"POST"}), ("/users/me", {"GET"}), ("/users/me", {"PATCH"}), ("/users/{id}", {"GET"}), diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index 3330f6cc5..a889ce116 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -3,12 +3,24 @@ from typing import cast from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from fastapi import Query +from fastapi.responses import RedirectResponse +from fastapi.security import OAuth2PasswordBearer +from fastapi_users.authentication import Strategy from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session +from danswer.auth.users import create_user_session from danswer.auth.users import current_admin_user from danswer.auth.users import current_user +from danswer.auth.users import get_database_strategy +from danswer.auth.users import get_or_create_user +from danswer.auth.users import get_user_manager from danswer.auth.users import is_user_admin +from danswer.auth.users import UserManager +from danswer.auth.users import verify_sso_token +from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS +from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import KV_REINDEX_KEY from danswer.configs.constants import NotificationType from danswer.db.engine import get_session @@ -28,12 +40,65 @@ from danswer.server.settings.store import load_settings from danswer.server.settings.store import store_settings from danswer.utils.logger import setup_logger +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +admin_router = APIRouter(prefix="/admin/settings") +basic_router = APIRouter(prefix="/settings") logger = setup_logger() -admin_router = APIRouter(prefix="/admin/settings") -basic_router = APIRouter(prefix="/settings") +@basic_router.post("/auth/sso-callback") +async def sso_callback( + sso_token: str = Query( + ..., alias="sso_token" + ), # Get SSO token from query parameters + strategy: Strategy = Depends(get_database_strategy), + user_manager: UserManager = Depends(get_user_manager), +): + print("SSO callback reached") + + payload = verify_sso_token(sso_token) + print("hi") + user = await get_or_create_user( + payload["email"], payload["user_id"], payload["tenant_id"] + ) + print(user) + session_token = await create_user_session(user, strategy) + print("HIII") + response = RedirectResponse(url="/") + response.set_cookie( + key="session", + value=session_token, + httponly=True, + max_age=SESSION_EXPIRE_TIME_SECONDS, + secure=WEB_DOMAIN.startswith("https"), + ) + return response + + +# @basic_router.post("/auth/sso-callback") +# async def sso_callback( +# user = Depends(current_user), +# token: str = Depends(oauth2_scheme), +# strategy: Strategy = Depends(get_database_strategy), +# user_manager: UserManager = Depends(get_user_manager), +# ): +# print('SSO callback reached') + +# payload = verify_sso_token(token) +# user = await get_or_create_user(payload["email"], payload["user_id"], payload["tenant_id"]) +# session_token = await create_user_session(user, strategy) + +# response = RedirectResponse(url="/") +# response.set_cookie( +# key="session", +# value=session_token, +# httponly=True, +# max_age=SESSION_EXPIRE_TIME_SECONDS, +# secure=WEB_DOMAIN.startswith("https"), +# ) +# return response @admin_router.put("") diff --git a/web/src/app/auth/oauth/callback/route.ts b/web/src/app/auth/oauth/callback/route.ts index 0b4157731..d01daca6a 100644 --- a/web/src/app/auth/oauth/callback/route.ts +++ b/web/src/app/auth/oauth/callback/route.ts @@ -3,6 +3,9 @@ import { buildUrl } from "@/lib/utilsSS"; import { NextRequest, NextResponse } from "next/server"; export const GET = async (request: NextRequest) => { + console.log("request", request); + console.log("HELLO"); + console.log("request.nextUrl.search", request.nextUrl.search); // Wrapper around the FastAPI endpoint /auth/oauth/callback, // which adds back a redirect to the main app. const url = new URL(buildUrl("/auth/oauth/callback")); diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index 8219dbb9a..27d1bea55 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -15,9 +15,7 @@ import { buildClientUrl } from "@/lib/utilsSS"; import { Inter } from "next/font/google"; import Head from "next/head"; import { EnterpriseSettings } from "./admin/settings/interfaces"; -import { redirect } from "next/navigation"; -import { Button, Card } from "@tremor/react"; -import LogoType from "@/components/header/LogoType"; +import { Card } from "@tremor/react"; import { HeaderTitle } from "@/components/header/HeaderTitle"; import { Logo } from "@/components/Logo"; import { UserProvider } from "@/components/user/UserProvider"; diff --git a/web/src/app/sso-callback/page.tsx b/web/src/app/sso-callback/page.tsx new file mode 100644 index 000000000..289ef237c --- /dev/null +++ b/web/src/app/sso-callback/page.tsx @@ -0,0 +1,68 @@ +"use client"; +import { useEffect, useState } from "react"; +import { useRouter, useSearchParams } from "next/navigation"; +import { Card, Text } from "@tremor/react"; +import { Spinner } from "@/components/Spinner"; + +export default function SSOCallback() { + const router = useRouter(); + const searchParams = useSearchParams(); + const [error, setError] = useState(null); + const [authStatus, setAuthStatus] = useState("Authenticating..."); + + useEffect(() => { + const verifyToken = async () => { + const ssoToken = searchParams.get("sso_token"); + if (!ssoToken) { + setError("No SSO token found"); + return; + } + + try { + setAuthStatus("Verifying SSO token..."); + const response = await fetch( + `/api/settings/auth/sso-callback?sso_token=${ssoToken}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + } + ); + + if (response.ok) { + setAuthStatus("Authentication successful!"); + setTimeout(() => { + setAuthStatus("Redirecting to dashboard..."); + setTimeout(() => { + router.push("/dashboard"); + }, 1000); + }, 1000); + } else { + const errorData = await response.json(); + setError(errorData.detail || "Authentication failed"); + } + } catch (error) { + console.error("Error verifying token:", error); + setError("An unexpected error occurred"); + } + }; + + verifyToken(); + }, [router, searchParams]); + + return ( +
+ + {error ? ( + {error} + ) : ( + <> + + {authStatus} + + )} + +
+ ); +}