mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
Multi tenant vespa (#2762)
* add vespa multi tenancy * k * formatting * Billing (#2667) * k * data -> control * nit * nit: error handling * auth + app * nit: color standardization * nit * nit: typing * k * k * feat: functional upgrading * feat: add block for downgrading to seats < active users * add auth * remove accomplished todo + prints * nit * tiny nit * nit: centralize security * add tenant expulsion/gating + invite user -> increment billing seat no. * add cloud configs * k * k * nit: update * k * k * k * k * nit
This commit is contained in:
@@ -21,3 +21,7 @@ API_KEY_HASH_ROUNDS = (
|
||||
# Auto Permission Sync
|
||||
#####
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
@@ -85,8 +85,6 @@ def get_application() -> FastAPI:
|
||||
|
||||
# RBAC / group access control
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
# Analytics endpoints
|
||||
include_router_with_global_prefix_prepended(application, analytics_router)
|
||||
include_router_with_global_prefix_prepended(application, query_history_router)
|
||||
@@ -107,6 +105,10 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
|
||||
include_router_with_global_prefix_prepended(application, usage_export_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
|
@@ -0,0 +1,53 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
|
||||
from danswer.configs.app_configs import DATA_PLANE_SECRET
|
||||
from danswer.configs.app_configs import EXPECTED_API_KEY
|
||||
from danswer.configs.app_configs import JWT_ALGORITHM
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_data_plane_token() -> str:
|
||||
if DATA_PLANE_SECRET is None:
|
||||
raise ValueError("DATA_PLANE_SECRET is not set")
|
||||
|
||||
payload = {
|
||||
"iss": "data_plane",
|
||||
"exp": datetime.utcnow() + timedelta(minutes=5),
|
||||
"iat": datetime.utcnow(),
|
||||
"scope": "api_access",
|
||||
}
|
||||
|
||||
token = jwt.encode(payload, DATA_PLANE_SECRET, algorithm=JWT_ALGORITHM)
|
||||
return token
|
||||
|
||||
|
||||
async def control_plane_dep(request: Request) -> None:
|
||||
api_key = request.headers.get("X-API-KEY")
|
||||
if api_key != EXPECTED_API_KEY:
|
||||
logger.warning("Invalid API key")
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
logger.warning("Invalid authorization header")
|
||||
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||
|
||||
token = auth_header.split(" ")[1]
|
||||
try:
|
||||
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=[JWT_ALGORITHM])
|
||||
if payload.get("scope") != "tenant:create":
|
||||
logger.warning("Insufficient permissions")
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
raise HTTPException(status_code=401, detail="Token has expired")
|
||||
except jwt.InvalidTokenError:
|
||||
logger.warning("Invalid token")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
@@ -1,19 +1,33 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from danswer.auth.users import control_plane_dep
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import User
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.danswer.server.tenants.access import control_plane_dep
|
||||
from ee.danswer.server.tenants.billing import fetch_billing_information
|
||||
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.danswer.server.tenants.models import BillingInformation
|
||||
from ee.danswer.server.tenants.models import CreateTenantRequest
|
||||
from ee.danswer.server.tenants.models import ProductGatingRequest
|
||||
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
|
||||
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
|
||||
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
@@ -22,30 +36,30 @@ router = APIRouter(prefix="/tenants")
|
||||
def create_tenant(
|
||||
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
|
||||
) -> dict[str, str]:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
tenant_id = create_tenant_request.tenant_id
|
||||
email = create_tenant_request.initial_admin_email
|
||||
token = None
|
||||
|
||||
if user_owns_a_tenant(email):
|
||||
raise HTTPException(
|
||||
status_code=409, detail="User already belongs to an organization"
|
||||
)
|
||||
|
||||
try:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
if not ensure_schema_exists(tenant_id):
|
||||
logger.info(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
run_alembic_migrations(tenant_id)
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
print("getting session", tenant_id)
|
||||
run_alembic_migrations(tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session)
|
||||
|
||||
logger.info(f"Tenant {tenant_id} created successfully")
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
return {
|
||||
@@ -60,3 +74,53 @@ def create_tenant(
|
||||
finally:
|
||||
if token is not None:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> None:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when
|
||||
1) User has ended free trial without adding payment method
|
||||
2) User's card has declined
|
||||
"""
|
||||
token = current_tenant_id.set(current_tenant_id.get())
|
||||
|
||||
settings = load_settings()
|
||||
settings.product_gating = product_gating_request.product_gating
|
||||
store_settings(settings)
|
||||
|
||||
if token is not None:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
@router.get("/billing-information", response_model=BillingInformation)
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation:
|
||||
logger.info("Fetching billing information")
|
||||
return BillingInformation(**fetch_billing_information(current_tenant_id.get()))
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
|
||||
try:
|
||||
# Fetch tenant_id and current tenant's information
|
||||
tenant_id = current_tenant_id.get()
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
69
backend/ee/danswer/server/tenants/billing.py
Normal file
69
backend/ee/danswer/server/tenants/billing.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
import stripe
|
||||
|
||||
from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.danswer.server.tenants.access import generate_data_plane_token
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/tenant-stripe-information"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(tenant_id: str) -> dict:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/billing-information"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
billing_info = response.json()
|
||||
return billing_info
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
"""
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
tenant_id = current_tenant_id.get()
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
stripe_subscription_id,
|
||||
items=[
|
||||
{
|
||||
"id": subscription["items"]["data"][0].id,
|
||||
"price": STRIPE_PRICE_ID,
|
||||
"quantity": number_of_users,
|
||||
}
|
||||
],
|
||||
metadata={"tenant_id": str(tenant_id)},
|
||||
)
|
||||
return updated_subscription
|
@@ -1,6 +1,29 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.server.settings.models import GatingType
|
||||
|
||||
|
||||
class CheckoutSessionCreationRequest(BaseModel):
|
||||
quantity: int
|
||||
|
||||
|
||||
class CreateTenantRequest(BaseModel):
|
||||
tenant_id: str
|
||||
initial_admin_email: str
|
||||
|
||||
|
||||
class ProductGatingRequest(BaseModel):
|
||||
tenant_id: str
|
||||
product_gating: GatingType
|
||||
|
||||
|
||||
class BillingInformation(BaseModel):
|
||||
seats: int
|
||||
subscription_status: str
|
||||
billing_start: str
|
||||
billing_end: str
|
||||
payment_method_enabled: bool
|
||||
|
||||
|
||||
class CheckoutSessionCreationResponse(BaseModel):
|
||||
id: str
|
||||
|
Reference in New Issue
Block a user