Billing fixes (#3976)

This commit is contained in:
pablonyx
2025-02-13 15:59:10 -08:00
committed by GitHub
parent 1a7aca06b9
commit 3260d793d1
21 changed files with 546 additions and 233 deletions

View File

@@ -77,3 +77,5 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
GATED_TENANTS_KEY = "gated_tenants"

View File

@@ -18,11 +18,16 @@ from ee.onyx.server.tenants.anonymous_user_path import (
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
@@ -39,12 +44,9 @@ from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.notification import create_notification
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -126,37 +128,29 @@ async def login_as_anonymous_user(
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> None:
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when
1) User has ended free trial without adding payment method
2) User's card has declined
We gate the product when their subscription has ended.
"""
tenant_id = product_gating_request.tenant_id
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
settings = load_settings()
settings.product_gating = product_gating_request.product_gating
store_settings(settings)
if product_gating_request.notification:
with get_session_with_tenant(tenant_id) as db_session:
create_notification(None, product_gating_request.notification, db_session)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information", response_model=BillingInformation)
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation:
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
return BillingInformation(
**fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
)
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
@router.post("/create-customer-portal-session")
@@ -169,9 +163,10 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
@@ -180,6 +175,20 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,

View File

@@ -6,6 +6,7 @@ import stripe
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
@@ -14,6 +15,19 @@ stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
def fetch_stripe_checkout_session(tenant_id: str) -> str:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
params = {"tenant_id": tenant_id}
response = requests.post(url, headers=headers, params=params)
response.raise_for_status()
return response.json()["sessionId"]
def fetch_tenant_stripe_information(tenant_id: str) -> dict:
token = generate_data_plane_token()
headers = {
@@ -27,7 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()
def fetch_billing_information(tenant_id: str) -> dict:
def fetch_billing_information(tenant_id: str) -> BillingInformation:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
@@ -38,7 +52,7 @@ def fetch_billing_information(tenant_id: str) -> dict:
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
billing_info = response.json()
billing_info = BillingInformation(**response.json())
return billing_info

View File

@@ -1,7 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
from onyx.configs.constants import NotificationType
from onyx.server.settings.models import GatingType
from onyx.server.settings.models import ApplicationStatus
class CheckoutSessionCreationRequest(BaseModel):
@@ -15,15 +16,24 @@ class CreateTenantRequest(BaseModel):
class ProductGatingRequest(BaseModel):
tenant_id: str
product_gating: GatingType
notification: NotificationType | None = None
application_status: ApplicationStatus
class SubscriptionStatusResponse(BaseModel):
subscribed: bool
class BillingInformation(BaseModel):
stripe_subscription_id: str
status: str
current_period_start: datetime
current_period_end: datetime
number_of_seats: int
cancel_at_period_end: bool
canceled_at: datetime | None
trial_start: datetime | None
trial_end: datetime | None
seats: int
subscription_status: str
billing_start: str
billing_end: str
payment_method_enabled: bool
@@ -48,3 +58,12 @@ class TenantDeletionPayload(BaseModel):
class AnonymousUserPath(BaseModel):
anonymous_user_path: str | None
class ProductGatingResponse(BaseModel):
updated: bool
error: str | None
class SubscriptionSessionResponse(BaseModel):
sessionId: str

View File

@@ -0,0 +1,50 @@
from typing import cast
from ee.onyx.configs.app_configs import GATED_TENANTS_KEY
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.setup import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
# Store the full status
status_key = f"tenant:{tenant_id}:status"
redis_client.set(status_key, status.value)
# Maintain the GATED_ACCESS set
if status == ApplicationStatus.GATED_ACCESS:
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
else:
redis_client.srem(GATED_TENANTS_KEY, tenant_id)
def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
settings = load_settings()
settings.application_status = application_status
store_settings(settings)
# Store gated tenant information in Redis
update_tenant_gating(tenant_id, application_status)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
except Exception:
logger.exception("Failed to gate product")
raise
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))

View File

@@ -8,6 +8,7 @@ from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from tenacity import RetryError
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
@@ -252,7 +253,11 @@ def cloud_beat_task_generator(
try:
tenant_ids = get_all_tenant_ids()
gated_tenants = get_gated_tenants()
for tenant_id in tenant_ids:
if tenant_id in gated_tenants:
continue
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()

View File

@@ -12,10 +12,10 @@ class PageType(str, Enum):
SEARCH = "search"
class GatingType(str, Enum):
FULL = "full" # Complete restriction of access to the product or service
PARTIAL = "partial" # Full access but warning (no credit card on file)
NONE = "none" # No restrictions, full access to all features
class ApplicationStatus(str, Enum):
PAYMENT_REMINDER = "payment_reminder"
GATED_ACCESS = "gated_access"
ACTIVE = "active"
class Notification(BaseModel):
@@ -43,7 +43,7 @@ class Settings(BaseModel):
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
product_gating: GatingType = GatingType.NONE
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
anonymous_user_enabled: bool | None = None
pro_search_disabled: bool | None = None
auto_scroll: bool | None = None