Multi tenant vespa (#2762)

* add vespa multi tenancy

* k

* formatting

* Billing (#2667)

* k

* data -> control

* nit

* nit: error handling

* auth + app

* nit: color standardization

* nit

* nit: typing

* k

* k

* feat: functional upgrading

* feat: add block for downgrading to seats < active users

* add auth

* remove accomplished todo + prints

* nit

* tiny nit

* nit: centralize security

* add tenant expulsion/gating + invite user -> increment billing seat no.

* add cloud configs

* k

* k

* nit: update

* k

* k

* k

* k

* nit
This commit is contained in:
pablodanswer
2024-10-12 16:53:11 -07:00
committed by GitHub
parent 7eafdae17f
commit 20df20ae51
44 changed files with 1458 additions and 602 deletions

View File

@@ -21,3 +21,7 @@ API_KEY_HASH_ROUNDS = (
# Auto Permission Sync
#####
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")

View File

@@ -85,8 +85,6 @@ def get_application() -> FastAPI:
# RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router)
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)
# Analytics endpoints
include_router_with_global_prefix_prepended(application, analytics_router)
include_router_with_global_prefix_prepended(application, query_history_router)
@@ -107,6 +105,10 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
include_router_with_global_prefix_prepended(application, usage_export_router)
if MULTI_TENANT:
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)

View File

@@ -0,0 +1,53 @@
from datetime import datetime
from datetime import timedelta
import jwt
from fastapi import HTTPException
from fastapi import Request
from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.configs.app_configs import JWT_ALGORITHM
from danswer.utils.logger import setup_logger
logger = setup_logger()
def generate_data_plane_token() -> str:
if DATA_PLANE_SECRET is None:
raise ValueError("DATA_PLANE_SECRET is not set")
payload = {
"iss": "data_plane",
"exp": datetime.utcnow() + timedelta(minutes=5),
"iat": datetime.utcnow(),
"scope": "api_access",
}
token = jwt.encode(payload, DATA_PLANE_SECRET, algorithm=JWT_ALGORITHM)
return token
async def control_plane_dep(request: Request) -> None:
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=[JWT_ALGORITHM])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")

View File

@@ -1,19 +1,33 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from danswer.auth.users import control_plane_dep
from danswer.auth.users import current_admin_user
from danswer.auth.users import User
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.engine import get_session_with_tenant
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
from ee.danswer.server.tenants.access import control_plane_dep
from ee.danswer.server.tenants.billing import fetch_billing_information
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
from ee.danswer.server.tenants.models import BillingInformation
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.models import ProductGatingRequest
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@@ -22,30 +36,30 @@ router = APIRouter(prefix="/tenants")
def create_tenant(
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
) -> dict[str, str]:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
tenant_id = create_tenant_request.tenant_id
email = create_tenant_request.initial_admin_email
token = None
if user_owns_a_tenant(email):
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
)
try:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
if not ensure_schema_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")
run_alembic_migrations(tenant_id)
token = current_tenant_id.set(tenant_id)
print("getting session", tenant_id)
run_alembic_migrations(tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session)
logger.info(f"Tenant {tenant_id} created successfully")
add_users_to_tenant([email], tenant_id)
return {
@@ -60,3 +74,53 @@ def create_tenant(
finally:
if token is not None:
current_tenant_id.reset(token)
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> None:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when
1) User has ended free trial without adding payment method
2) User's card has declined
"""
token = current_tenant_id.set(current_tenant_id.get())
settings = load_settings()
settings.product_gating = product_gating_request.product_gating
store_settings(settings)
if token is not None:
current_tenant_id.reset(token)
@router.get("/billing-information", response_model=BillingInformation)
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation:
logger.info("Fetching billing information")
return BillingInformation(**fetch_billing_information(current_tenant_id.get()))
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
try:
# Fetch tenant_id and current tenant's information
tenant_id = current_tenant_id.get()
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,69 @@
from typing import cast
import requests
import stripe
from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import STRIPE_PRICE_ID
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
from ee.danswer.server.tenants.access import generate_data_plane_token
from shared_configs.configs import current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
def fetch_tenant_stripe_information(tenant_id: str) -> dict:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/tenant-stripe-information"
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
return response.json()
def fetch_billing_information(tenant_id: str) -> dict:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/billing-information"
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
billing_info = response.json()
return billing_info
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
"""
Send a request to the control service to register the number of users for a tenant.
"""
if not STRIPE_PRICE_ID:
raise Exception("STRIPE_PRICE_ID is not set")
tenant_id = current_tenant_id.get()
response = fetch_tenant_stripe_information(tenant_id)
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
updated_subscription = stripe.Subscription.modify(
stripe_subscription_id,
items=[
{
"id": subscription["items"]["data"][0].id,
"price": STRIPE_PRICE_ID,
"quantity": number_of_users,
}
],
metadata={"tenant_id": str(tenant_id)},
)
return updated_subscription

View File

@@ -1,6 +1,29 @@
from pydantic import BaseModel
from danswer.server.settings.models import GatingType
class CheckoutSessionCreationRequest(BaseModel):
quantity: int
class CreateTenantRequest(BaseModel):
tenant_id: str
initial_admin_email: str
class ProductGatingRequest(BaseModel):
tenant_id: str
product_gating: GatingType
class BillingInformation(BaseModel):
seats: int
subscription_status: str
billing_start: str
billing_end: str
payment_method_enabled: bool
class CheckoutSessionCreationResponse(BaseModel):
id: str