mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 19:38:19 +02:00
plan out
This commit is contained in:
parent
3838908e70
commit
7b895008d3
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
@ -8,6 +9,7 @@ from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
@ -54,6 +56,7 @@ from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
@ -192,10 +195,87 @@ def send_user_verification_email(
|
||||
s.send_message(msg)
|
||||
|
||||
|
||||
def verify_sso_token(token: str) -> dict:
|
||||
print("VERIFYING")
|
||||
try:
|
||||
print(token)
|
||||
print("DECODING")
|
||||
# payload = jwt.decode(token, settings.SSO_SECRET_KEY, algorithms=["HS256"])
|
||||
payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"])
|
||||
print(payload)
|
||||
if datetime.now(timezone.utc) > datetime.fromtimestamp(
|
||||
payload["exp"], timezone.utc
|
||||
):
|
||||
print("EXPIRED")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired"
|
||||
)
|
||||
return payload
|
||||
except jwt.PyJWTError as e:
|
||||
print(e)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
|
||||
|
||||
async def get_or_create_user(email: str, user_id: str, tenant_id: str) -> User:
|
||||
get_async_session_context = contextlib.asynccontextmanager(get_async_session)
|
||||
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
|
||||
|
||||
async with get_async_session_context() as session:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
existing_user = await user_db.get_by_email(email)
|
||||
if existing_user:
|
||||
return existing_user
|
||||
|
||||
# Generate a random password
|
||||
uuid.uuid4().hex
|
||||
# hashed_password = get_password_hash(random_password)
|
||||
|
||||
new_user = {
|
||||
"email": email,
|
||||
"id": uuid.UUID(user_id),
|
||||
"role": UserRole.BASIC,
|
||||
"oidc_expiry": None,
|
||||
"default_model": None,
|
||||
"chosen_assistants": None,
|
||||
"hashed_password": "p",
|
||||
"is_active": True,
|
||||
"is_superuser": False,
|
||||
"is_verified": True,
|
||||
}
|
||||
created_user = await user_db.create(new_user)
|
||||
return created_user
|
||||
|
||||
|
||||
async def create_user_session(user: User, strategy: Strategy) -> str:
|
||||
token = await strategy.write_token(user)
|
||||
return token
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
|
||||
async def sso_authenticate(
|
||||
self,
|
||||
email: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
) -> models.UP:
|
||||
user = await self.get_by_email(email)
|
||||
if not user:
|
||||
# user_create = UserCreate(email=email, password=secrets.token_urlsafe(32))
|
||||
user_create = UserCreate(role=UserRole.BASIC)
|
||||
user = await self.create(user_create)
|
||||
|
||||
# Update user with tenant information if needed
|
||||
if user.tenant_id != tenant_id:
|
||||
await self.user_db.update(user, {"tenant_id": tenant_id})
|
||||
|
||||
return user
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_create: schemas.UC | UserCreate,
|
||||
@ -280,6 +360,30 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
send_user_verification_email(user.email, token)
|
||||
|
||||
|
||||
async def sso_authenticate(
|
||||
self,
|
||||
email: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
) -> models.UP:
|
||||
user = await self.get_by_email(email)
|
||||
if not user:
|
||||
user_create = UserCreate(UserRole.BASIC)
|
||||
user = await self.create(user_create)
|
||||
|
||||
# Update user with tenant information if needed
|
||||
if user.tenant_id != tenant_id:
|
||||
await self.user_db.update(user, {"tenant_id": tenant_id})
|
||||
|
||||
return user
|
||||
|
||||
|
||||
# async def get_user_manager(
|
||||
# user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
# ) -> AsyncGenerator[UserManager, None]:
|
||||
# yield UserManager(user_db)
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
) -> AsyncGenerator[UserManager, None]:
|
||||
@ -375,7 +479,8 @@ async def optional_user(
|
||||
versioned_fetch_user = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "optional_user_"
|
||||
)
|
||||
return await versioned_fetch_user(request, user, db_session)
|
||||
val = await versioned_fetch_user(request, user, db_session)
|
||||
return val
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
|
@ -11,7 +11,6 @@ from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_sqlalchemy_async_engine
|
||||
from danswer.db.models import AccessToken
|
||||
@ -46,11 +45,11 @@ async def get_user_count() -> int:
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
|
||||
async def create(self, create_dict: Dict[str, Any]) -> UP:
|
||||
user_count = await get_user_count()
|
||||
if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
|
||||
create_dict["role"] = UserRole.ADMIN
|
||||
else:
|
||||
create_dict["role"] = UserRole.BASIC
|
||||
await get_user_count()
|
||||
# if user_count == 0 or create_dict["email"] in get_default_admin_user_emails():
|
||||
# create_dict["role"] = UserRole.ADMIN
|
||||
# else:
|
||||
# create_dict["role"] = UserRole.BASIC
|
||||
return await super().create(create_dict)
|
||||
|
||||
|
||||
|
@ -27,7 +27,7 @@ class RerankingDetails(BaseModel):
|
||||
# If model is None (or num_rerank is 0), then reranking is turned off
|
||||
rerank_model_name: str | None
|
||||
rerank_provider_type: RerankerProvider | None
|
||||
rerank_api_key: str | None
|
||||
rerank_api_key: str | None = None
|
||||
|
||||
num_rerank: int
|
||||
|
||||
|
@ -34,6 +34,7 @@ PUBLIC_ENDPOINT_SPECS = [
|
||||
("/auth/reset-password", {"POST"}),
|
||||
("/auth/request-verify-token", {"POST"}),
|
||||
("/auth/verify", {"POST"}),
|
||||
("/settings/auth/sso-callback", {"POST"}),
|
||||
("/users/me", {"GET"}),
|
||||
("/users/me", {"PATCH"}),
|
||||
("/users/{id}", {"GET"}),
|
||||
|
@ -3,12 +3,24 @@ from typing import cast
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi_users.authentication import Strategy
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import create_user_session
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import get_database_strategy
|
||||
from danswer.auth.users import get_or_create_user
|
||||
from danswer.auth.users import get_user_manager
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.auth.users import UserManager
|
||||
from danswer.auth.users import verify_sso_token
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.engine import get_session
|
||||
@ -28,12 +40,65 @@ from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/settings")
|
||||
basic_router = APIRouter(prefix="/settings")
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/settings")
|
||||
basic_router = APIRouter(prefix="/settings")
|
||||
@basic_router.post("/auth/sso-callback")
|
||||
async def sso_callback(
|
||||
sso_token: str = Query(
|
||||
..., alias="sso_token"
|
||||
), # Get SSO token from query parameters
|
||||
strategy: Strategy = Depends(get_database_strategy),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
print("SSO callback reached")
|
||||
|
||||
payload = verify_sso_token(sso_token)
|
||||
print("hi")
|
||||
user = await get_or_create_user(
|
||||
payload["email"], payload["user_id"], payload["tenant_id"]
|
||||
)
|
||||
print(user)
|
||||
session_token = await create_user_session(user, strategy)
|
||||
print("HIII")
|
||||
response = RedirectResponse(url="/")
|
||||
response.set_cookie(
|
||||
key="session",
|
||||
value=session_token,
|
||||
httponly=True,
|
||||
max_age=SESSION_EXPIRE_TIME_SECONDS,
|
||||
secure=WEB_DOMAIN.startswith("https"),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# @basic_router.post("/auth/sso-callback")
|
||||
# async def sso_callback(
|
||||
# user = Depends(current_user),
|
||||
# token: str = Depends(oauth2_scheme),
|
||||
# strategy: Strategy = Depends(get_database_strategy),
|
||||
# user_manager: UserManager = Depends(get_user_manager),
|
||||
# ):
|
||||
# print('SSO callback reached')
|
||||
|
||||
# payload = verify_sso_token(token)
|
||||
# user = await get_or_create_user(payload["email"], payload["user_id"], payload["tenant_id"])
|
||||
# session_token = await create_user_session(user, strategy)
|
||||
|
||||
# response = RedirectResponse(url="/")
|
||||
# response.set_cookie(
|
||||
# key="session",
|
||||
# value=session_token,
|
||||
# httponly=True,
|
||||
# max_age=SESSION_EXPIRE_TIME_SECONDS,
|
||||
# secure=WEB_DOMAIN.startswith("https"),
|
||||
# )
|
||||
# return response
|
||||
|
||||
|
||||
@admin_router.put("")
|
||||
|
@ -3,6 +3,9 @@ import { buildUrl } from "@/lib/utilsSS";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
export const GET = async (request: NextRequest) => {
|
||||
console.log("request", request);
|
||||
console.log("HELLO");
|
||||
console.log("request.nextUrl.search", request.nextUrl.search);
|
||||
// Wrapper around the FastAPI endpoint /auth/oauth/callback,
|
||||
// which adds back a redirect to the main app.
|
||||
const url = new URL(buildUrl("/auth/oauth/callback"));
|
||||
|
@ -15,9 +15,7 @@ import { buildClientUrl } from "@/lib/utilsSS";
|
||||
import { Inter } from "next/font/google";
|
||||
import Head from "next/head";
|
||||
import { EnterpriseSettings } from "./admin/settings/interfaces";
|
||||
import { redirect } from "next/navigation";
|
||||
import { Button, Card } from "@tremor/react";
|
||||
import LogoType from "@/components/header/LogoType";
|
||||
import { Card } from "@tremor/react";
|
||||
import { HeaderTitle } from "@/components/header/HeaderTitle";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { UserProvider } from "@/components/user/UserProvider";
|
||||
|
68
web/src/app/sso-callback/page.tsx
Normal file
68
web/src/app/sso-callback/page.tsx
Normal file
@ -0,0 +1,68 @@
|
||||
"use client";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { Card, Text } from "@tremor/react";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
|
||||
export default function SSOCallback() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [authStatus, setAuthStatus] = useState<string>("Authenticating...");
|
||||
|
||||
useEffect(() => {
|
||||
const verifyToken = async () => {
|
||||
const ssoToken = searchParams.get("sso_token");
|
||||
if (!ssoToken) {
|
||||
setError("No SSO token found");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setAuthStatus("Verifying SSO token...");
|
||||
const response = await fetch(
|
||||
`/api/settings/auth/sso-callback?sso_token=${ssoToken}`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
setAuthStatus("Authentication successful!");
|
||||
setTimeout(() => {
|
||||
setAuthStatus("Redirecting to dashboard...");
|
||||
setTimeout(() => {
|
||||
router.push("/dashboard");
|
||||
}, 1000);
|
||||
}, 1000);
|
||||
} else {
|
||||
const errorData = await response.json();
|
||||
setError(errorData.detail || "Authentication failed");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error verifying token:", error);
|
||||
setError("An unexpected error occurred");
|
||||
}
|
||||
};
|
||||
|
||||
verifyToken();
|
||||
}, [router, searchParams]);
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-center min-h-screen bg-gray-100">
|
||||
<Card className="max-w-lg p-8 text-center">
|
||||
{error ? (
|
||||
<Text className="text-red-500">{error}</Text>
|
||||
) : (
|
||||
<>
|
||||
<Spinner />
|
||||
<Text className="text-lg font-semibold">{authStatus}</Text>
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user