functional multi tenancy (excluding accurate provisioning)

This commit is contained in:
pablodanswer 2024-09-21 21:31:05 -07:00
parent 127526d080
commit e9906c37fe
6 changed files with 230 additions and 25 deletions

View File

@ -194,7 +194,6 @@ def send_user_verification_email(
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
def verify_sso_token(token: str) -> dict:
try:
payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"])
@ -207,6 +206,7 @@ def verify_sso_token(token: str) -> dict:
)
return payload
except jwt.PyJWTError:
print("ATTEMPING TO DECODE")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
@ -240,12 +240,25 @@ async def get_or_create_user(email: str, user_id: str, tenant_id: str) -> User:
created_user = await user_db.create(new_user)
return created_user
from datetime import timedelta
async def create_user_session(user: User, strategy: Strategy) -> str:
token = await strategy.write_token(user)
async def create_user_session(user: User, tenant_id: str) -> str:
# Create a payload with user information and tenant_id
payload = {
"sub": str(user.id),
"email": user.email,
"tenant_id": tenant_id,
"exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS)
}
# Encode the token
token = jwt.encode(payload, "JWT_SECRET_KEY", algorithm="HS256")
return token
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET

View File

@ -177,15 +177,61 @@ def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()
def get_session(schema: str = DEFAULT_SCHEMA) -> Generator[Session, None, None]:
# The line below was added to monitor the latency caused by Postgres connections
# during API calls.
# with tracer.trace("db.get_session"):
from typing import Generator
from sqlalchemy.orm import Session
from sqlalchemy import text
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
import jwt
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
DEFAULT_SCHEMA = "public"
from fastapi import Request, Depends, HTTPException
def get_current_tenant_id(request: Request):
token = request.cookies.get("fastapiusersauth")
if not token:
raise HTTPException(status_code=401, detail="Authentication required")
try:
payload = jwt.decode(token, "JWT_SECRET_KEY", algorithms=["HS256"])
tenant_id = payload.get("tenant_id")
if not tenant_id:
raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing")
return tenant_id
except Exception as e:
raise HTTPException(status_code=401, detail="Invalid token") from e
def get_session(tenant_id: str = Depends(get_current_tenant_id)) -> Generator[Session, None, None]:
print('\n\n\n\n\ntenant_id', tenant_id)
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
session.execute(text(f"SET search_path TO {schema}"))
tenant_id = "01fb5963-9ab3-4585-900a-438480857427"
session.execute(text(f'SET search_path TO "{tenant_id}"'))
yield session
session.execute(text("SET search_path TO public"))
session.execute(text('SET search_path TO "public"'))
# Logic to create or retrieve a database session for the given tenant_id
# def get_session(schema: str = DEFAULT_SCHEMA) -> Generator[Session, None, None]:
# # The line below was added to monitor the latency caused by Postgres connections
# # during API calls.
# # with tracer.trace("db.get_session"):
# with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
# session.execute(text(f"SET search_path TO {schema}"))
# yield session
# session.execute(text("SET search_path TO public"))
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:

View File

@ -10,6 +10,7 @@ from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from danswer.auth.users import create_user_session
from danswer.auth.users import optional_user
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import get_database_strategy
@ -37,35 +38,148 @@ from danswer.server.settings.models import UserSettings
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.utils.logger import setup_logger
from fastapi.responses import JSONResponse
from fastapi.responses import Response
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
admin_router = APIRouter(prefix="/admin/settings")
basic_router = APIRouter(prefix="/settings")
from danswer.db.engine import get_async_session
import subprocess
logger = setup_logger()
from sqlalchemy import text
import contextlib
from sqlalchemy.ext.asyncio import AsyncSession
def run_alembic_migrations(schema_name: str):
# alembic -x "schema=tenant1,create_schema=true" upgrade head
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
command = [
"alembic",
"-x",
f"schema={schema_name},create_schema=true",
"upgrade",
"head"
]
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
logger.error(f"Alembic migration failed for schema {schema_name}: {result.stderr}")
raise Exception(f"Migration failed for schema {schema_name}")
logger.info(f"Alembic migrations completed successfully for schema: {schema_name}")
async def check_schema_exists(tenant_id: str) -> bool:
logger.info(f"Checking if schema exists for tenant: {tenant_id}")
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
)
async with get_async_session_context() as session:
result = await session.execute(
text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"),
{"schema_name": tenant_id}
)
schema = result.scalar()
exists = schema is not None
logger.info(f"Schema for tenant {tenant_id} exists: {exists}")
return exists
async def create_tenant_schema(tenant_id: str):
logger.info(f"Creating schema for tenant: {tenant_id}")
# Create the schema
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
)
async with get_async_session_context() as session:
await session.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{tenant_id}"'))
await session.commit()
logger.info(f"Schema created for tenant: {tenant_id}")
# Run migrations for the new schema
logger.info(f"Running migrations for tenant: {tenant_id}")
run_alembic_migrations(tenant_id)
logger.info(f"Migrations completed for tenant: {tenant_id}")
@basic_router.post("/auth/sso-callback")
async def sso_callback(
response: Response,
sso_token: str = Query(..., alias="sso_token"),
strategy: Strategy = Depends(get_database_strategy),
user_manager: UserManager = Depends(get_user_manager),
):
logger.info("SSO callback initiated")
payload = verify_sso_token(sso_token)
logger.info(f"SSO token verified for email: {payload['email']}")
user = await user_manager.sso_authenticate(
payload["email"], payload["user_id"], payload["tenant_id"]
)
logger.info(f"User authenticated: {user.email}")
session_token = await create_user_session(user, strategy)
logger.info(f"Session token created: {session_token[:10]}...")
tenant_id = payload["tenant_id"]
logger.info(f"Checking schema for tenant: {tenant_id}")
# Check if tenant schema exists
return {
"session_token": session_token,
"max_age": SESSION_EXPIRE_TIME_SECONDS,
"domain": WEB_DOMAIN.split("://")[-1],
}
schema_exists = await check_schema_exists(tenant_id)
if not schema_exists:
logger.info(f"Schema does not exist for tenant: {tenant_id}. Creating...")
# Create schema and run migrations
await create_tenant_schema(tenant_id)
else:
logger.info(f"Schema already exists for tenant: {tenant_id}")
session_token = await create_user_session(user, payload["tenant_id"])
logger.info(f"Session token created for user: {user.email}")
# Set the session cookie with proper flags
response.set_cookie(
key="fastapiusersauth",
value=session_token,
max_age=SESSION_EXPIRE_TIME_SECONDS,
expires=SESSION_EXPIRE_TIME_SECONDS,
path="/",
domain=WEB_DOMAIN.split("://")[-1],
secure=True,
httponly=True,
samesite="lax",
)
logger.info("Session cookie set")
logger.info("SSO callback completed successfully")
return JSONResponse(
content={"message": "Authentication successful"},
status_code=200
)
# @basic_router.post("/auth/sso-callback")
# async def sso_callback(
# sso_token: str = Query(..., alias="sso_token"),
# strategy: Strategy = Depends(get_database_strategy),
# user_manager: UserManager = Depends(get_user_manager),
# ):
# payload = verify_sso_token(sso_token)
# user = await user_manager.sso_authenticate(
# payload["email"], payload["user_id"], payload["tenant_id"]
# )
# session_token = await create_user_session(user, payload["tenant_id"], strategy)
# logger.info(f"Session token created: {session_token[:10]}...")
# return {
# "session_token": session_token,
# "max_age": SESSION_EXPIRE_TIME_SECONDS,
# "domain": WEB_DOMAIN.split("://")[-1],
# }
@admin_router.put("")
@ -123,6 +237,7 @@ def dismiss_notification_endpoint(
def get_user_notifications(
user: User | None, db_session: Session
) -> list[Notification]:
return []
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
is_admin = is_user_admin(user)
if not is_admin:

View File

@ -0,0 +1,25 @@
// app/auth/sso-callback/layout.tsx
import React from "react";
export const metadata = {
title: "SSO Callback",
};
export default function SSOCallbackLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<html lang="en">
<head>
<title>SSO Callback</title>
{/* Include any meta tags or scripts specific to this page */}
</head>
<body>
{/* Minimal styling or components */}
{children}
</body>
</html>
);
}

View File

@ -27,21 +27,17 @@ export default function SSOCallback() {
headers: {
"Content-Type": "application/json",
},
credentials: "include",
credentials: "include", // Ensure cookies are included in requests
}
);
if (response.ok) {
const data = await response.json();
setAuthStatus("Authentication successful!");
// Set the session cookie manually TODO validate safety
document.cookie = `fastapiusersauth=${data.session_token}; max-age=${data.max_age}; path=/; secure; samesite=lax`;
// Redirect to the dashboard
router.replace("/admin/plan");
} else {
const errorData = await response.json();
console.error("Authentication failed:", errorData);
console.error(errorData);
setError(errorData.detail || "Authentication failed");
}
} catch (error) {

View File

@ -53,6 +53,16 @@ export default async function RootLayout({
}: {
children: React.ReactNode;
}) {
// 00a89749-beab-489a-8b72-88aa3d646274
// 01fb5963-9ab3-4585-900a-438480857427
// return <>{children}</>
// SELECT table_name, column_name, data_type, character_maximum_length
// FROM information_schema.columns
// WHERE table_schema = '00a89749-beab-489a-8b72-88aa3d646274'
// ORDER BY table_name, ordinal_position;
const combinedSettings = await fetchSettingsSS();
if (!combinedSettings) {
// Just display a simple full page error if fetching fails.