mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-13 06:19:33 +02:00
functional multi tenancy (excluding accurate provisioning)
This commit is contained in:
parent
127526d080
commit
e9906c37fe
@ -194,7 +194,6 @@ def send_user_verification_email(
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
|
||||
|
||||
def verify_sso_token(token: str) -> dict:
|
||||
try:
|
||||
payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"])
|
||||
@ -207,6 +206,7 @@ def verify_sso_token(token: str) -> dict:
|
||||
)
|
||||
return payload
|
||||
except jwt.PyJWTError:
|
||||
print("ATTEMPING TO DECODE")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
@ -240,12 +240,25 @@ async def get_or_create_user(email: str, user_id: str, tenant_id: str) -> User:
|
||||
created_user = await user_db.create(new_user)
|
||||
return created_user
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
async def create_user_session(user: User, strategy: Strategy) -> str:
|
||||
token = await strategy.write_token(user)
|
||||
async def create_user_session(user: User, tenant_id: str) -> str:
|
||||
# Create a payload with user information and tenant_id
|
||||
payload = {
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
"tenant_id": tenant_id,
|
||||
"exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS)
|
||||
}
|
||||
|
||||
# Encode the token
|
||||
token = jwt.encode(payload, "JWT_SECRET_KEY", algorithm="HS256")
|
||||
|
||||
|
||||
return token
|
||||
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
|
@ -177,15 +177,61 @@ def get_session_context_manager() -> ContextManager[Session]:
|
||||
return contextlib.contextmanager(get_session)()
|
||||
|
||||
|
||||
def get_session(schema: str = DEFAULT_SCHEMA) -> Generator[Session, None, None]:
|
||||
# The line below was added to monitor the latency caused by Postgres connections
|
||||
# during API calls.
|
||||
# with tracer.trace("db.get_session"):
|
||||
from typing import Generator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
import jwt
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
DEFAULT_SCHEMA = "public"
|
||||
|
||||
from fastapi import Request, Depends, HTTPException
|
||||
|
||||
|
||||
def get_current_tenant_id(request: Request):
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, "JWT_SECRET_KEY", algorithms=["HS256"])
|
||||
tenant_id = payload.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing")
|
||||
return tenant_id
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=401, detail="Invalid token") from e
|
||||
|
||||
|
||||
|
||||
def get_session(tenant_id: str = Depends(get_current_tenant_id)) -> Generator[Session, None, None]:
|
||||
|
||||
print('\n\n\n\n\ntenant_id', tenant_id)
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
|
||||
session.execute(text(f"SET search_path TO {schema}"))
|
||||
|
||||
tenant_id = "01fb5963-9ab3-4585-900a-438480857427"
|
||||
|
||||
session.execute(text(f'SET search_path TO "{tenant_id}"'))
|
||||
yield session
|
||||
session.execute(text("SET search_path TO public"))
|
||||
session.execute(text('SET search_path TO "public"'))
|
||||
|
||||
|
||||
# Logic to create or retrieve a database session for the given tenant_id
|
||||
|
||||
|
||||
|
||||
# def get_session(schema: str = DEFAULT_SCHEMA) -> Generator[Session, None, None]:
|
||||
# # The line below was added to monitor the latency caused by Postgres connections
|
||||
# # during API calls.
|
||||
# # with tracer.trace("db.get_session"):
|
||||
|
||||
# with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
|
||||
# session.execute(text(f"SET search_path TO {schema}"))
|
||||
# yield session
|
||||
# session.execute(text("SET search_path TO public"))
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
@ -10,6 +10,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import create_user_session
|
||||
from danswer.auth.users import optional_user
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import get_database_strategy
|
||||
@ -37,35 +38,148 @@ from danswer.server.settings.models import UserSettings
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import Response
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/settings")
|
||||
basic_router = APIRouter(prefix="/settings")
|
||||
|
||||
from danswer.db.engine import get_async_session
|
||||
import subprocess
|
||||
logger = setup_logger()
|
||||
from sqlalchemy import text
|
||||
|
||||
import contextlib
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
def run_alembic_migrations(schema_name: str):
|
||||
# alembic -x "schema=tenant1,create_schema=true" upgrade head
|
||||
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
command = [
|
||||
"alembic",
|
||||
"-x",
|
||||
f"schema={schema_name},create_schema=true",
|
||||
"upgrade",
|
||||
"head"
|
||||
]
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Alembic migration failed for schema {schema_name}: {result.stderr}")
|
||||
raise Exception(f"Migration failed for schema {schema_name}")
|
||||
logger.info(f"Alembic migrations completed successfully for schema: {schema_name}")
|
||||
|
||||
|
||||
async def check_schema_exists(tenant_id: str) -> bool:
|
||||
logger.info(f"Checking if schema exists for tenant: {tenant_id}")
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
)
|
||||
async with get_async_session_context() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"),
|
||||
{"schema_name": tenant_id}
|
||||
)
|
||||
schema = result.scalar()
|
||||
exists = schema is not None
|
||||
logger.info(f"Schema for tenant {tenant_id} exists: {exists}")
|
||||
return exists
|
||||
|
||||
async def create_tenant_schema(tenant_id: str):
|
||||
logger.info(f"Creating schema for tenant: {tenant_id}")
|
||||
# Create the schema
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
)
|
||||
async with get_async_session_context() as session:
|
||||
|
||||
|
||||
await session.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{tenant_id}"'))
|
||||
await session.commit()
|
||||
logger.info(f"Schema created for tenant: {tenant_id}")
|
||||
|
||||
|
||||
# Run migrations for the new schema
|
||||
logger.info(f"Running migrations for tenant: {tenant_id}")
|
||||
run_alembic_migrations(tenant_id)
|
||||
logger.info(f"Migrations completed for tenant: {tenant_id}")
|
||||
|
||||
|
||||
@basic_router.post("/auth/sso-callback")
|
||||
async def sso_callback(
|
||||
response: Response,
|
||||
sso_token: str = Query(..., alias="sso_token"),
|
||||
strategy: Strategy = Depends(get_database_strategy),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
):
|
||||
logger.info("SSO callback initiated")
|
||||
payload = verify_sso_token(sso_token)
|
||||
logger.info(f"SSO token verified for email: {payload['email']}")
|
||||
|
||||
user = await user_manager.sso_authenticate(
|
||||
payload["email"], payload["user_id"], payload["tenant_id"]
|
||||
)
|
||||
logger.info(f"User authenticated: {user.email}")
|
||||
|
||||
session_token = await create_user_session(user, strategy)
|
||||
logger.info(f"Session token created: {session_token[:10]}...")
|
||||
tenant_id = payload["tenant_id"]
|
||||
logger.info(f"Checking schema for tenant: {tenant_id}")
|
||||
# Check if tenant schema exists
|
||||
|
||||
return {
|
||||
"session_token": session_token,
|
||||
"max_age": SESSION_EXPIRE_TIME_SECONDS,
|
||||
"domain": WEB_DOMAIN.split("://")[-1],
|
||||
}
|
||||
|
||||
schema_exists = await check_schema_exists(tenant_id)
|
||||
if not schema_exists:
|
||||
logger.info(f"Schema does not exist for tenant: {tenant_id}. Creating...")
|
||||
# Create schema and run migrations
|
||||
await create_tenant_schema(tenant_id)
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant: {tenant_id}")
|
||||
|
||||
|
||||
|
||||
session_token = await create_user_session(user, payload["tenant_id"])
|
||||
logger.info(f"Session token created for user: {user.email}")
|
||||
|
||||
# Set the session cookie with proper flags
|
||||
response.set_cookie(
|
||||
key="fastapiusersauth",
|
||||
value=session_token,
|
||||
max_age=SESSION_EXPIRE_TIME_SECONDS,
|
||||
expires=SESSION_EXPIRE_TIME_SECONDS,
|
||||
path="/",
|
||||
domain=WEB_DOMAIN.split("://")[-1],
|
||||
secure=True,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
)
|
||||
logger.info("Session cookie set")
|
||||
|
||||
logger.info("SSO callback completed successfully")
|
||||
return JSONResponse(
|
||||
content={"message": "Authentication successful"},
|
||||
status_code=200
|
||||
)
|
||||
|
||||
|
||||
|
||||
# @basic_router.post("/auth/sso-callback")
|
||||
# async def sso_callback(
|
||||
# sso_token: str = Query(..., alias="sso_token"),
|
||||
# strategy: Strategy = Depends(get_database_strategy),
|
||||
# user_manager: UserManager = Depends(get_user_manager),
|
||||
# ):
|
||||
# payload = verify_sso_token(sso_token)
|
||||
|
||||
# user = await user_manager.sso_authenticate(
|
||||
# payload["email"], payload["user_id"], payload["tenant_id"]
|
||||
# )
|
||||
|
||||
# session_token = await create_user_session(user, payload["tenant_id"], strategy)
|
||||
# logger.info(f"Session token created: {session_token[:10]}...")
|
||||
|
||||
# return {
|
||||
# "session_token": session_token,
|
||||
# "max_age": SESSION_EXPIRE_TIME_SECONDS,
|
||||
# "domain": WEB_DOMAIN.split("://")[-1],
|
||||
# }
|
||||
|
||||
|
||||
@admin_router.put("")
|
||||
@ -123,6 +237,7 @@ def dismiss_notification_endpoint(
|
||||
def get_user_notifications(
|
||||
user: User | None, db_session: Session
|
||||
) -> list[Notification]:
|
||||
return []
|
||||
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
|
||||
is_admin = is_user_admin(user)
|
||||
if not is_admin:
|
||||
|
25
web/src/app/auth/sso-callback/layout.tsx
Normal file
25
web/src/app/auth/sso-callback/layout.tsx
Normal file
@ -0,0 +1,25 @@
|
||||
// app/auth/sso-callback/layout.tsx
|
||||
import React from "react";
|
||||
|
||||
export const metadata = {
|
||||
title: "SSO Callback",
|
||||
};
|
||||
|
||||
export default function SSOCallbackLayout({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>SSO Callback</title>
|
||||
{/* Include any meta tags or scripts specific to this page */}
|
||||
</head>
|
||||
<body>
|
||||
{/* Minimal styling or components */}
|
||||
{children}
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
@ -27,21 +27,17 @@ export default function SSOCallback() {
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
credentials: "include",
|
||||
credentials: "include", // Ensure cookies are included in requests
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
setAuthStatus("Authentication successful!");
|
||||
|
||||
// Set the session cookie manually TODO validate safety
|
||||
document.cookie = `fastapiusersauth=${data.session_token}; max-age=${data.max_age}; path=/; secure; samesite=lax`;
|
||||
|
||||
|
||||
// Redirect to the dashboard
|
||||
router.replace("/admin/plan");
|
||||
} else {
|
||||
const errorData = await response.json();
|
||||
console.error("Authentication failed:", errorData);
|
||||
console.error(errorData);
|
||||
setError(errorData.detail || "Authentication failed");
|
||||
}
|
||||
} catch (error) {
|
||||
|
@ -53,6 +53,16 @@ export default async function RootLayout({
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
// 00a89749-beab-489a-8b72-88aa3d646274
|
||||
// 01fb5963-9ab3-4585-900a-438480857427
|
||||
// return <>{children}</>
|
||||
|
||||
// SELECT table_name, column_name, data_type, character_maximum_length
|
||||
// FROM information_schema.columns
|
||||
// WHERE table_schema = '00a89749-beab-489a-8b72-88aa3d646274'
|
||||
// ORDER BY table_name, ordinal_position;
|
||||
|
||||
|
||||
const combinedSettings = await fetchSettingsSS();
|
||||
if (!combinedSettings) {
|
||||
// Just display a simple full page error if fetching fails.
|
||||
|
Loading…
x
Reference in New Issue
Block a user