From e9906c37fe92b74bc067a80da741811712f34355 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sat, 21 Sep 2024 21:31:05 -0700 Subject: [PATCH] functional multi tenancy (excluding accurate provisioning) --- backend/danswer/auth/users.py | 19 +++- backend/danswer/db/engine.py | 58 +++++++++- backend/danswer/server/settings/api.py | 133 +++++++++++++++++++++-- web/src/app/auth/sso-callback/layout.tsx | 25 +++++ web/src/app/auth/sso-callback/page.tsx | 10 +- web/src/app/layout.tsx | 10 ++ 6 files changed, 230 insertions(+), 25 deletions(-) create mode 100644 web/src/app/auth/sso-callback/layout.tsx diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 51d4c5b08..8d5255f0b 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -194,7 +194,6 @@ def send_user_verification_email( s.login(SMTP_USER, SMTP_PASS) s.send_message(msg) - def verify_sso_token(token: str) -> dict: try: payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"]) @@ -207,6 +206,7 @@ def verify_sso_token(token: str) -> dict: ) return payload except jwt.PyJWTError: + print("ATTEMPING TO DECODE") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" ) @@ -240,12 +240,25 @@ async def get_or_create_user(email: str, user_id: str, tenant_id: str) -> User: created_user = await user_db.create(new_user) return created_user +from datetime import timedelta -async def create_user_session(user: User, strategy: Strategy) -> str: - token = await strategy.write_token(user) +async def create_user_session(user: User, tenant_id: str) -> str: + # Create a payload with user information and tenant_id + payload = { + "sub": str(user.id), + "email": user.email, + "tenant_id": tenant_id, + "exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS) + } + + # Encode the token + token = jwt.encode(payload, "JWT_SECRET_KEY", algorithm="HS256") + + return token + class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = USER_AUTH_SECRET verification_token_secret = USER_AUTH_SECRET diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 7b1a0bd88..d0d483774 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -177,15 +177,61 @@ def get_session_context_manager() -> ContextManager[Session]: return contextlib.contextmanager(get_session)() -def get_session(schema: str = DEFAULT_SCHEMA) -> Generator[Session, None, None]: - # The line below was added to monitor the latency caused by Postgres connections - # during API calls. - # with tracer.trace("db.get_session"): +from typing import Generator +from sqlalchemy.orm import Session +from sqlalchemy import text +from fastapi import Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +import jwt +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +DEFAULT_SCHEMA = "public" + +from fastapi import Request, Depends, HTTPException + + +def get_current_tenant_id(request: Request): + token = request.cookies.get("fastapiusersauth") + if not token: + raise HTTPException(status_code=401, detail="Authentication required") + + try: + payload = jwt.decode(token, "JWT_SECRET_KEY", algorithms=["HS256"]) + tenant_id = payload.get("tenant_id") + if not tenant_id: + raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing") + return tenant_id + except Exception as e: + raise HTTPException(status_code=401, detail="Invalid token") from e + + + +def get_session(tenant_id: str = Depends(get_current_tenant_id)) -> Generator[Session, None, None]: + + print('\n\n\n\n\ntenant_id', tenant_id) with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: - session.execute(text(f"SET search_path TO {schema}")) + + tenant_id = "01fb5963-9ab3-4585-900a-438480857427" + + session.execute(text(f'SET search_path TO "{tenant_id}"')) yield session - session.execute(text("SET search_path TO public")) + session.execute(text('SET search_path TO "public"')) + + + # Logic to create or retrieve a database session for the given tenant_id + + + +# def get_session(schema: str = DEFAULT_SCHEMA) -> Generator[Session, None, None]: +# # The line below was added to monitor the latency caused by Postgres connections +# # during API calls. +# # with tracer.trace("db.get_session"): + +# with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: +# session.execute(text(f"SET search_path TO {schema}")) +# yield session +# session.execute(text("SET search_path TO public")) async def get_async_session() -> AsyncGenerator[AsyncSession, None]: diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index 2d28f1b4d..aac5b1dcc 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -10,6 +10,7 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session from danswer.auth.users import create_user_session +from danswer.auth.users import optional_user from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.auth.users import get_database_strategy @@ -37,35 +38,148 @@ from danswer.server.settings.models import UserSettings from danswer.server.settings.store import load_settings from danswer.server.settings.store import store_settings from danswer.utils.logger import setup_logger +from fastapi.responses import JSONResponse +from fastapi.responses import Response oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") admin_router = APIRouter(prefix="/admin/settings") basic_router = APIRouter(prefix="/settings") - +from danswer.db.engine import get_async_session +import subprocess logger = setup_logger() +from sqlalchemy import text + +import contextlib +from sqlalchemy.ext.asyncio import AsyncSession +def run_alembic_migrations(schema_name: str): + # alembic -x "schema=tenant1,create_schema=true" upgrade head + + logger.info(f"Starting Alembic migrations for schema: {schema_name}") + command = [ + "alembic", + "-x", + f"schema={schema_name},create_schema=true", + "upgrade", + "head" + ] + result = subprocess.run(command, capture_output=True, text=True) + if result.returncode != 0: + logger.error(f"Alembic migration failed for schema {schema_name}: {result.stderr}") + raise Exception(f"Migration failed for schema {schema_name}") + logger.info(f"Alembic migrations completed successfully for schema: {schema_name}") + + +async def check_schema_exists(tenant_id: str) -> bool: + logger.info(f"Checking if schema exists for tenant: {tenant_id}") + get_async_session_context = contextlib.asynccontextmanager( + get_async_session + ) + async with get_async_session_context() as session: + result = await session.execute( + text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"), + {"schema_name": tenant_id} + ) + schema = result.scalar() + exists = schema is not None + logger.info(f"Schema for tenant {tenant_id} exists: {exists}") + return exists + +async def create_tenant_schema(tenant_id: str): + logger.info(f"Creating schema for tenant: {tenant_id}") + # Create the schema + get_async_session_context = contextlib.asynccontextmanager( + get_async_session + ) + async with get_async_session_context() as session: + + + await session.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{tenant_id}"')) + await session.commit() + logger.info(f"Schema created for tenant: {tenant_id}") + + + # Run migrations for the new schema + logger.info(f"Running migrations for tenant: {tenant_id}") + run_alembic_migrations(tenant_id) + logger.info(f"Migrations completed for tenant: {tenant_id}") @basic_router.post("/auth/sso-callback") async def sso_callback( + response: Response, sso_token: str = Query(..., alias="sso_token"), - strategy: Strategy = Depends(get_database_strategy), user_manager: UserManager = Depends(get_user_manager), ): + logger.info("SSO callback initiated") payload = verify_sso_token(sso_token) + logger.info(f"SSO token verified for email: {payload['email']}") user = await user_manager.sso_authenticate( payload["email"], payload["user_id"], payload["tenant_id"] ) + logger.info(f"User authenticated: {user.email}") - session_token = await create_user_session(user, strategy) - logger.info(f"Session token created: {session_token[:10]}...") + tenant_id = payload["tenant_id"] + logger.info(f"Checking schema for tenant: {tenant_id}") + # Check if tenant schema exists - return { - "session_token": session_token, - "max_age": SESSION_EXPIRE_TIME_SECONDS, - "domain": WEB_DOMAIN.split("://")[-1], - } + + schema_exists = await check_schema_exists(tenant_id) + if not schema_exists: + logger.info(f"Schema does not exist for tenant: {tenant_id}. Creating...") + # Create schema and run migrations + await create_tenant_schema(tenant_id) + else: + logger.info(f"Schema already exists for tenant: {tenant_id}") + + + + session_token = await create_user_session(user, payload["tenant_id"]) + logger.info(f"Session token created for user: {user.email}") + + # Set the session cookie with proper flags + response.set_cookie( + key="fastapiusersauth", + value=session_token, + max_age=SESSION_EXPIRE_TIME_SECONDS, + expires=SESSION_EXPIRE_TIME_SECONDS, + path="/", + domain=WEB_DOMAIN.split("://")[-1], + secure=True, + httponly=True, + samesite="lax", + ) + logger.info("Session cookie set") + + logger.info("SSO callback completed successfully") + return JSONResponse( + content={"message": "Authentication successful"}, + status_code=200 + ) + + + +# @basic_router.post("/auth/sso-callback") +# async def sso_callback( +# sso_token: str = Query(..., alias="sso_token"), +# strategy: Strategy = Depends(get_database_strategy), +# user_manager: UserManager = Depends(get_user_manager), +# ): +# payload = verify_sso_token(sso_token) + +# user = await user_manager.sso_authenticate( +# payload["email"], payload["user_id"], payload["tenant_id"] +# ) + +# session_token = await create_user_session(user, payload["tenant_id"], strategy) +# logger.info(f"Session token created: {session_token[:10]}...") + +# return { +# "session_token": session_token, +# "max_age": SESSION_EXPIRE_TIME_SECONDS, +# "domain": WEB_DOMAIN.split("://")[-1], +# } @admin_router.put("") @@ -123,6 +237,7 @@ def dismiss_notification_endpoint( def get_user_notifications( user: User | None, db_session: Session ) -> list[Notification]: + return [] """Get notifications for the user, currently the logic is very specific to the reindexing flag""" is_admin = is_user_admin(user) if not is_admin: diff --git a/web/src/app/auth/sso-callback/layout.tsx b/web/src/app/auth/sso-callback/layout.tsx new file mode 100644 index 000000000..ba79fa30f --- /dev/null +++ b/web/src/app/auth/sso-callback/layout.tsx @@ -0,0 +1,25 @@ +// app/auth/sso-callback/layout.tsx +import React from "react"; + +export const metadata = { + title: "SSO Callback", +}; + +export default function SSOCallbackLayout({ + children, +}: { + children: React.ReactNode; +}) { + return ( + + + SSO Callback + {/* Include any meta tags or scripts specific to this page */} + + + {/* Minimal styling or components */} + {children} + + + ); +} diff --git a/web/src/app/auth/sso-callback/page.tsx b/web/src/app/auth/sso-callback/page.tsx index ea6a14cad..9e34ce8da 100644 --- a/web/src/app/auth/sso-callback/page.tsx +++ b/web/src/app/auth/sso-callback/page.tsx @@ -27,21 +27,17 @@ export default function SSOCallback() { headers: { "Content-Type": "application/json", }, - credentials: "include", + credentials: "include", // Ensure cookies are included in requests } ); if (response.ok) { - const data = await response.json(); setAuthStatus("Authentication successful!"); - - // Set the session cookie manually TODO validate safety - document.cookie = `fastapiusersauth=${data.session_token}; max-age=${data.max_age}; path=/; secure; samesite=lax`; - + // Redirect to the dashboard router.replace("/admin/plan"); } else { const errorData = await response.json(); - console.error("Authentication failed:", errorData); + console.error(errorData); setError(errorData.detail || "Authentication failed"); } } catch (error) { diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index 27d1bea55..8b157bc51 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -53,6 +53,16 @@ export default async function RootLayout({ }: { children: React.ReactNode; }) { +// 00a89749-beab-489a-8b72-88aa3d646274 +// 01fb5963-9ab3-4585-900a-438480857427 + // return <>{children} + +// SELECT table_name, column_name, data_type, character_maximum_length +// FROM information_schema.columns +// WHERE table_schema = '00a89749-beab-489a-8b72-88aa3d646274' +// ORDER BY table_name, ordinal_position; + + const combinedSettings = await fetchSettingsSS(); if (!combinedSettings) { // Just display a simple full page error if fetching fails.