Billing fixes (#3976)

This commit is contained in:
pablonyx 2025-02-13 15:59:10 -08:00 committed by GitHub
parent 1a7aca06b9
commit 3260d793d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 546 additions and 233 deletions

View File

@ -77,3 +77,5 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
GATED_TENANTS_KEY = "gated_tenants"

View File

@ -18,11 +18,16 @@ from ee.onyx.server.tenants.anonymous_user_path import (
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
@ -39,12 +44,9 @@ from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.notification import create_notification
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@ -126,37 +128,29 @@ async def login_as_anonymous_user(
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> None:
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when
1) User has ended free trial without adding payment method
2) User's card has declined
We gate the product when their subscription has ended.
"""
tenant_id = product_gating_request.tenant_id
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
settings = load_settings()
settings.product_gating = product_gating_request.product_gating
store_settings(settings)
if product_gating_request.notification:
with get_session_with_tenant(tenant_id) as db_session:
create_notification(None, product_gating_request.notification, db_session)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information", response_model=BillingInformation)
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation:
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
return BillingInformation(
**fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
)
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
@router.post("/create-customer-portal-session")
@ -169,9 +163,10 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
@ -180,6 +175,20 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,

View File

@ -6,6 +6,7 @@ import stripe
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
@ -14,6 +15,19 @@ stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
def fetch_stripe_checkout_session(tenant_id: str) -> str:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
params = {"tenant_id": tenant_id}
response = requests.post(url, headers=headers, params=params)
response.raise_for_status()
return response.json()["sessionId"]
def fetch_tenant_stripe_information(tenant_id: str) -> dict:
token = generate_data_plane_token()
headers = {
@ -27,7 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()
def fetch_billing_information(tenant_id: str) -> dict:
def fetch_billing_information(tenant_id: str) -> BillingInformation:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
@ -38,7 +52,7 @@ def fetch_billing_information(tenant_id: str) -> dict:
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
billing_info = response.json()
billing_info = BillingInformation(**response.json())
return billing_info

View File

@ -1,7 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
from onyx.configs.constants import NotificationType
from onyx.server.settings.models import GatingType
from onyx.server.settings.models import ApplicationStatus
class CheckoutSessionCreationRequest(BaseModel):
@ -15,15 +16,24 @@ class CreateTenantRequest(BaseModel):
class ProductGatingRequest(BaseModel):
tenant_id: str
product_gating: GatingType
notification: NotificationType | None = None
application_status: ApplicationStatus
class SubscriptionStatusResponse(BaseModel):
subscribed: bool
class BillingInformation(BaseModel):
stripe_subscription_id: str
status: str
current_period_start: datetime
current_period_end: datetime
number_of_seats: int
cancel_at_period_end: bool
canceled_at: datetime | None
trial_start: datetime | None
trial_end: datetime | None
seats: int
subscription_status: str
billing_start: str
billing_end: str
payment_method_enabled: bool
@ -48,3 +58,12 @@ class TenantDeletionPayload(BaseModel):
class AnonymousUserPath(BaseModel):
anonymous_user_path: str | None
class ProductGatingResponse(BaseModel):
updated: bool
error: str | None
class SubscriptionSessionResponse(BaseModel):
sessionId: str

View File

@ -0,0 +1,50 @@
from typing import cast
from ee.onyx.configs.app_configs import GATED_TENANTS_KEY
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.setup import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
# Store the full status
status_key = f"tenant:{tenant_id}:status"
redis_client.set(status_key, status.value)
# Maintain the GATED_ACCESS set
if status == ApplicationStatus.GATED_ACCESS:
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
else:
redis_client.srem(GATED_TENANTS_KEY, tenant_id)
def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
settings = load_settings()
settings.application_status = application_status
store_settings(settings)
# Store gated tenant information in Redis
update_tenant_gating(tenant_id, application_status)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
except Exception:
logger.exception("Failed to gate product")
raise
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))

View File

@ -8,6 +8,7 @@ from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from tenacity import RetryError
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
@ -252,7 +253,11 @@ def cloud_beat_task_generator(
try:
tenant_ids = get_all_tenant_ids()
gated_tenants = get_gated_tenants()
for tenant_id in tenant_ids:
if tenant_id in gated_tenants:
continue
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()

View File

@ -12,10 +12,10 @@ class PageType(str, Enum):
SEARCH = "search"
class GatingType(str, Enum):
FULL = "full" # Complete restriction of access to the product or service
PARTIAL = "partial" # Full access but warning (no credit card on file)
NONE = "none" # No restrictions, full access to all features
class ApplicationStatus(str, Enum):
PAYMENT_REMINDER = "payment_reminder"
GATED_ACCESS = "gated_access"
ACTIVE = "active"
class Notification(BaseModel):
@ -43,7 +43,7 @@ class Settings(BaseModel):
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
product_gating: GatingType = GatingType.NONE
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
anonymous_user_enabled: bool | None = None
pro_search_disabled: bool | None = None
auto_scroll: bool | None = None

View File

@ -1,7 +1,7 @@
export enum GatingType {
FULL = "full",
PARTIAL = "partial",
NONE = "none",
export enum ApplicationStatus {
PAYMENT_REMINDER = "payment_reminder",
GATED_ACCESS = "gated_access",
ACTIVE = "active",
}
export interface Settings {
@ -11,7 +11,7 @@ export interface Settings {
needs_reindexing: boolean;
gpu_enabled: boolean;
pro_search_disabled: boolean | null;
product_gating: GatingType;
application_status: ApplicationStatus;
auto_scroll: boolean;
}

View File

@ -2291,8 +2291,6 @@ export function ChatPage({
bg-opacity-80
duration-300
ease-in-out
${
!untoggled && (showHistorySidebar || sidebarVisible)
? "opacity-100 w-[250px] translate-x-0"

View File

@ -0,0 +1,73 @@
import React from "react";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
import { CircleAlert, Info } from "lucide-react";
import { BillingInformation, BillingStatus } from "./interfaces";
export function BillingAlerts({
billingInformation,
}: {
billingInformation: BillingInformation;
}) {
const isTrialing = billingInformation.status === BillingStatus.TRIALING;
const isCancelled = billingInformation.cancel_at_period_end;
const isExpired =
new Date(billingInformation.current_period_end) < new Date();
const noPaymentMethod = !billingInformation.payment_method_enabled;
const messages: string[] = [];
if (isExpired) {
messages.push(
"Your subscription has expired. Please resubscribe to continue using the service."
);
}
if (isCancelled && !isExpired) {
messages.push(
`Your subscription will cancel on ${new Date(
billingInformation.current_period_end
).toLocaleDateString()}. You can resubscribe before this date to remain uninterrupted.`
);
}
if (isTrialing) {
messages.push(
`You're currently on a trial. Your trial ends on ${
billingInformation.trial_end
? new Date(billingInformation.trial_end).toLocaleDateString()
: "N/A"
}.`
);
}
if (noPaymentMethod) {
messages.push(
"You currently have no payment method on file. Please add one to avoid service interruption."
);
}
const variant = isExpired || noPaymentMethod ? "destructive" : "default";
if (messages.length === 0) return null;
return (
<Alert variant={variant}>
<AlertTitle className="flex items-center space-x-2">
{variant === "destructive" ? (
<CircleAlert className="h-4 w-4" />
) : (
<Info className="h-4 w-4" />
)}
<span>
{variant === "destructive"
? "Important Subscription Notice"
: "Subscription Notice"}
</span>
</AlertTitle>
<AlertDescription>
<ul className="list-disc list-inside space-y-1 mt-2">
{messages.map((msg, idx) => (
<li key={idx}>{msg}</li>
))}
</ul>
</AlertDescription>
</Alert>
);
}

View File

@ -1,18 +1,21 @@
"use client";
import { CreditCard, ArrowFatUp } from "@phosphor-icons/react";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { loadStripe } from "@stripe/stripe-js";
import { usePopup } from "@/components/admin/connectors/Popup";
import { SettingsIcon } from "@/components/icons/icons";
import {
updateSubscriptionQuantity,
fetchCustomerPortal,
statusToDisplay,
useBillingInformation,
} from "./utils";
import { useEffect } from "react";
import { usePopup } from "@/components/admin/connectors/Popup";
import { fetchCustomerPortal, useBillingInformation } from "./utils";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import { Button } from "@/components/ui/button";
import { CreditCard, ArrowFatUp } from "@phosphor-icons/react";
import { SubscriptionSummary } from "./SubscriptionSummary";
import { BillingAlerts } from "./BillingAlerts";
export default function BillingInformationPage() {
const router = useRouter();
@ -24,9 +27,6 @@ export default function BillingInformationPage() {
isLoading,
} = useBillingInformation();
if (error) {
console.error("Failed to fetch billing information:", error);
}
useEffect(() => {
const url = new URL(window.location.href);
if (url.searchParams.has("session_id")) {
@ -35,22 +35,33 @@ export default function BillingInformationPage() {
"Congratulations! Your subscription has been updated successfully.",
type: "success",
});
// Remove the session_id from the URL
url.searchParams.delete("session_id");
window.history.replaceState({}, "", url.toString());
// You might want to refresh the billing information here
// by calling an API endpoint to get the latest data
}
}, [setPopup]);
if (isLoading) {
return <div>Loading...</div>;
return <div className="text-center py-8">Loading...</div>;
}
if (error) {
console.error("Failed to fetch billing information:", error);
return (
<div className="text-center py-8 text-red-500">
Error loading billing information. Please try again later.
</div>
);
}
if (!billingInformation) {
return (
<div className="text-center py-8">No billing information available.</div>
);
}
const handleManageSubscription = async () => {
try {
const response = await fetchCustomerPortal();
if (!response.ok) {
const errorData = await response.json();
throw new Error(
@ -61,11 +72,9 @@ export default function BillingInformationPage() {
}
const { url } = await response.json();
if (!url) {
throw new Error("No portal URL returned from the server");
}
router.push(url);
} catch (error) {
console.error("Error creating customer portal session:", error);
@ -75,138 +84,39 @@ export default function BillingInformationPage() {
});
}
};
if (!billingInformation) {
return <div>Loading...</div>;
}
return (
<div className="space-y-8">
<div className="bg-background-50 rounded-lg p-8 border border-background-200">
{popup}
{popup}
<Card className="shadow-md">
<CardHeader>
<CardTitle className="text-2xl font-bold flex items-center">
<CreditCard className="mr-4 text-muted-foreground" size={24} />
Subscription Details
</CardTitle>
</CardHeader>
<CardContent className="space-y-6">
<SubscriptionSummary billingInformation={billingInformation} />
<BillingAlerts billingInformation={billingInformation} />
</CardContent>
</Card>
<h2 className="text-2xl font-bold mb-6 text-text-800 flex items-center">
{/* <CreditCard className="mr-4 text-text-600" size={24} /> */}
Subscription Details
</h2>
<div className="space-y-4">
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md">
<div className="flex justify-between items-center">
<div>
<p className="text-lg font-medium text-text-700">Seats</p>
<p className="text-sm text-text-500">
Number of licensed users
</p>
</div>
<p className="text-xl font-semibold text-text-900">
{billingInformation.seats}
</p>
</div>
</div>
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md">
<div className="flex justify-between items-center">
<div>
<p className="text-lg font-medium text-text-700">
Subscription Status
</p>
<p className="text-sm text-text-500">
Current state of your subscription
</p>
</div>
<p className="text-xl font-semibold text-text-900">
{statusToDisplay(billingInformation.subscription_status)}
</p>
</div>
</div>
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md">
<div className="flex justify-between items-center">
<div>
<p className="text-lg font-medium text-text-700">
Billing Start
</p>
<p className="text-sm text-text-500">
Start date of current billing cycle
</p>
</div>
<p className="text-xl font-semibold text-text-900">
{new Date(
billingInformation.billing_start
).toLocaleDateString()}
</p>
</div>
</div>
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md">
<div className="flex justify-between items-center">
<div>
<p className="text-lg font-medium text-text-700">Billing End</p>
<p className="text-sm text-text-500">
End date of current billing cycle
</p>
</div>
<p className="text-xl font-semibold text-text-900">
{new Date(billingInformation.billing_end).toLocaleDateString()}
</p>
</div>
</div>
</div>
{!billingInformation.payment_method_enabled && (
<div className="mt-4 p-4 bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700">
<p className="font-bold">Notice:</p>
<p>
You&apos;ll need to add a payment method before your trial ends to
continue using the service.
</p>
</div>
)}
{billingInformation.subscription_status === "trialing" ? (
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md mt-8">
<p className="text-lg font-medium text-text-700">
No cap on users during trial
</p>
</div>
) : (
<div className="flex items-center space-x-4 mt-8">
<div className="flex items-center space-x-4">
<p className="text-lg font-medium text-text-700">
Current Seats:
</p>
<p className="text-xl font-semibold text-text-900">
{billingInformation.seats}
</p>
</div>
<p className="text-sm text-text-500">
Seats automatically update based on adding, removing, or inviting
users.
</p>
</div>
)}
</div>
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md">
<div className="flex justify-between items-center mb-4">
<div>
<p className="text-lg font-medium text-text-700">
Manage Subscription
</p>
<p className="text-sm text-text-500">
View your plan, update payment, or change subscription
</p>
</div>
<SettingsIcon className="text-text-600" size={20} />
</div>
<button
onClick={handleManageSubscription}
className="bg-background-600 text-white px-4 py-2 rounded-md hover:bg-background-700 transition duration-300 ease-in-out focus:outline-none focus:ring-2 focus:ring-text-500 focus:ring-opacity-50 font-medium shadow-sm text-sm flex items-center justify-center"
>
<ArrowFatUp className="mr-2" size={16} />
Manage Subscription
</button>
</div>
<Card className="shadow-md">
<CardHeader>
<CardTitle className="text-xl font-semibold">
Manage Subscription
</CardTitle>
<CardDescription>
View your plan, update payment, or change subscription
</CardDescription>
</CardHeader>
<CardContent>
<Button onClick={handleManageSubscription} className="w-full">
<ArrowFatUp className="mr-2" size={16} />
Manage Subscription
</Button>
</CardContent>
</Card>
</div>
);
}

View File

@ -0,0 +1,17 @@
import React from "react";
interface InfoItemProps {
title: string;
value: string;
}
export function InfoItem({ title, value }: InfoItemProps) {
return (
<div className="bg-muted p-4 rounded-lg">
<p className="text-sm font-medium text-muted-foreground mb-1">{title}</p>
<p className="text-lg font-semibold text-foreground dark:text-white">
{value}
</p>
</div>
);
}

View File

@ -0,0 +1,33 @@
import React from "react";
import { InfoItem } from "./InfoItem";
import { statusToDisplay } from "./utils";
interface SubscriptionSummaryProps {
billingInformation: any;
}
export function SubscriptionSummary({
billingInformation,
}: SubscriptionSummaryProps) {
return (
<div className="grid grid-cols-2 gap-4">
<InfoItem
title="Subscription Status"
value={statusToDisplay(billingInformation.status)}
/>
<InfoItem title="Seats" value={billingInformation.seats.toString()} />
<InfoItem
title="Billing Start"
value={new Date(
billingInformation.current_period_start
).toLocaleDateString()}
/>
<InfoItem
title="Billing End"
value={new Date(
billingInformation.current_period_end
).toLocaleDateString()}
/>
</div>
);
}

View File

@ -0,0 +1,19 @@
export interface BillingInformation {
status: string;
trial_end: Date | null;
current_period_end: Date;
payment_method_enabled: boolean;
cancel_at_period_end: boolean;
current_period_start: Date;
number_of_seats: number;
canceled_at: Date | null;
trial_start: Date | null;
seats: number;
}
export enum BillingStatus {
TRIALING = "trialing",
ACTIVE = "active",
CANCELLED = "cancelled",
EXPIRED = "expired",
}

View File

@ -3,10 +3,16 @@ import BillingInformationPage from "./BillingInformationPage";
import { MdOutlineCreditCard } from "react-icons/md";
export interface BillingInformation {
stripe_subscription_id: string;
status: string;
current_period_start: Date;
current_period_end: Date;
number_of_seats: number;
cancel_at_period_end: boolean;
canceled_at: Date | null;
trial_start: Date | null;
trial_end: Date | null;
seats: number;
subscription_status: string;
billing_start: Date;
billing_end: Date;
payment_method_enabled: boolean;
}

View File

@ -35,9 +35,16 @@ export const statusToDisplay = (status: string) => {
export const useBillingInformation = () => {
const url = "/api/tenants/billing-information";
const swrResponse = useSWR<BillingInformation>(url, (url: string) =>
fetch(url).then((res) => res.json())
);
const swrResponse = useSWR<BillingInformation>(url, async (url: string) => {
const res = await fetch(url);
if (!res.ok) {
const errorData = await res.json();
throw new Error(
errorData.message || "Failed to fetch billing information"
);
}
return res.json();
});
return {
...swrResponse,

View File

@ -13,7 +13,10 @@ import {
import { Metadata } from "next";
import { buildClientUrl } from "@/lib/utilsSS";
import { Inter } from "next/font/google";
import { EnterpriseSettings, GatingType } from "./admin/settings/interfaces";
import {
EnterpriseSettings,
ApplicationStatus,
} from "./admin/settings/interfaces";
import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata";
import { AppProvider } from "@/components/context/AppProvider";
import { PHProvider } from "./providers";
@ -28,6 +31,7 @@ import { WebVitals } from "./web-vitals";
import { ThemeProvider } from "next-themes";
import CloudError from "@/components/errorPages/CloudErrorPage";
import Error from "@/components/errorPages/ErrorPage";
import AccessRestrictedPage from "@/components/errorPages/AccessRestrictedPage";
const inter = Inter({
subsets: ["latin"],
@ -75,7 +79,7 @@ export default async function RootLayout({
]);
const productGating =
combinedSettings?.settings.product_gating ?? GatingType.NONE;
combinedSettings?.settings.application_status ?? ApplicationStatus.ACTIVE;
const getPageContent = async (content: React.ReactNode) => (
<html
@ -130,40 +134,16 @@ export default async function RootLayout({
</html>
);
if (productGating === ApplicationStatus.GATED_ACCESS) {
return getPageContent(<AccessRestrictedPage />);
}
if (!combinedSettings) {
return getPageContent(
NEXT_PUBLIC_CLOUD_ENABLED ? <CloudError /> : <Error />
);
}
if (productGating === GatingType.FULL) {
return getPageContent(
<div className="flex flex-col items-center justify-center min-h-screen">
<div className="mb-2 flex items-center max-w-[175px]">
<LogoType />
</div>
<CardSection className="w-full max-w-md">
<h1 className="text-2xl font-bold mb-4 text-error">
Access Restricted
</h1>
<p className="text-text-500 mb-4">
We regret to inform you that your access to Onyx has been
temporarily suspended due to a lapse in your subscription.
</p>
<p className="text-text-500 mb-4">
To reinstate your access and continue benefiting from Onyx&apos;s
powerful features, please update your payment information.
</p>
<p className="text-text-500">
If you&apos;re an admin, you can resolve this by visiting the
billing section. For other users, please reach out to your
administrator to address this matter.
</p>
</CardSection>
</div>
);
}
const { assistants, hasAnyConnectors, hasImageCompatibleModel } =
assistantsData;

View File

@ -33,6 +33,9 @@ import { MdOutlineCreditCard } from "react-icons/md";
import { UserSettingsModal } from "@/app/chat/modal/UserSettingsModal";
import { usePopup } from "./connectors/Popup";
import { useChatContext } from "../context/ChatContext";
import { ApplicationStatus } from "@/app/admin/settings/interfaces";
import Link from "next/link";
import { Button } from "../ui/button";
export function ClientLayout({
user,
@ -74,6 +77,23 @@ export function ClientLayout({
defaultModel={user?.preferences?.default_model!}
/>
)}
{settings?.settings.application_status ===
ApplicationStatus.PAYMENT_REMINDER && (
<div className="fixed top-2 left-1/2 transform -translate-x-1/2 bg-amber-400 dark:bg-amber-500 text-gray-900 dark:text-gray-100 p-4 rounded-lg shadow-lg z-50 max-w-md text-center">
<strong className="font-bold">Warning:</strong> Your trial ends in
less than 2 days and no payment method has been added.
<div className="mt-2">
<Link href="/admin/billing">
<Button
variant="default"
className="bg-amber-600 hover:bg-amber-700 text-white"
>
Update Billing Information
</Button>
</Link>
</div>
</div>
)}
<div className="default-scrollbar flex-none text-text-settings-sidebar bg-background-sidebar dark:bg-[#000] w-[250px] overflow-x-hidden z-20 pt-2 pb-8 h-full border-r border-border dark:border-none miniscroll overflow-auto">
<AdminSidebar

View File

@ -0,0 +1,148 @@
"use client";
import { FiLock } from "react-icons/fi";
import ErrorPageLayout from "./ErrorPageLayout";
import { fetchCustomerPortal } from "@/app/ee/admin/billing/utils";
import { useState } from "react";
import { useRouter } from "next/navigation";
import { Button } from "@/components/ui/button";
import { logout } from "@/lib/user";
import { loadStripe } from "@stripe/stripe-js";
import { NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY } from "@/lib/constants";
const fetchResubscriptionSession = async () => {
const response = await fetch("/api/tenants/create-subscription-session", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
throw new Error("Failed to create resubscription session");
}
return response.json();
};
export default function AccessRestricted() {
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const router = useRouter();
const handleManageSubscription = async () => {
setIsLoading(true);
setError(null);
try {
const response = await fetchCustomerPortal();
if (!response.ok) {
const errorData = await response.json();
throw new Error(
`Failed to create customer portal session: ${
errorData.message || response.statusText
}`
);
}
const { url } = await response.json();
if (!url) {
throw new Error("No portal URL returned from the server");
}
router.push(url);
} catch (error) {
console.error("Error creating customer portal session:", error);
setError("Error opening customer portal. Please try again later.");
} finally {
setIsLoading(false);
}
};
const handleResubscribe = async () => {
setIsLoading(true);
setError(null);
if (!NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY) {
setError("Stripe public key not found");
setIsLoading(false);
return;
}
try {
const { sessionId } = await fetchResubscriptionSession();
const stripe = await loadStripe(NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY);
if (stripe) {
await stripe.redirectToCheckout({ sessionId });
} else {
throw new Error("Stripe failed to load");
}
} catch (error) {
console.error("Error creating resubscription session:", error);
setError("Error opening resubscription page. Please try again later.");
} finally {
setIsLoading(false);
}
};
return (
<ErrorPageLayout>
<h1 className="text-2xl font-semibold flex items-center gap-2 mb-4 text-gray-800 dark:text-gray-200">
<p>Access Restricted</p>
<FiLock className="text-error inline-block" />
</h1>
<div className="space-y-4 text-gray-600 dark:text-gray-300">
<p>
We regret to inform you that your access to Onyx has been temporarily
suspended due to a lapse in your subscription.
</p>
<p>
To reinstate your access and continue benefiting from Onyx&apos;s
powerful features, please update your payment information.
</p>
<p>
If you&apos;re an admin, you can manage your subscription by clicking
the button below. For other users, please reach out to your
administrator to address this matter.
</p>
<div className="flex flex-col space-y-4 sm:flex-row sm:space-y-0 sm:space-x-4">
<Button
onClick={handleResubscribe}
disabled={isLoading}
className="w-full sm:w-auto"
>
{isLoading ? "Loading..." : "Resubscribe"}
</Button>
<Button
variant="outline"
onClick={handleManageSubscription}
disabled={isLoading}
className="w-full sm:w-auto"
>
Manage Existing Subscription
</Button>
<Button
variant="outline"
onClick={async () => {
await logout();
window.location.reload();
}}
className="w-full sm:w-auto"
>
Log out
</Button>
</div>
{error && <p className="text-error">{error}</p>}
<p>
Need help? Join our{" "}
<a
className="text-blue-500 hover:text-blue-700 dark:text-blue-400 dark:hover:text-blue-300"
href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ"
target="_blank"
rel="noopener noreferrer"
>
Slack community
</a>{" "}
for support.
</p>
</div>
</ErrorPageLayout>
);
}

View File

@ -1,7 +1,7 @@
import {
CombinedSettings,
EnterpriseSettings,
GatingType,
ApplicationStatus,
Settings,
} from "@/app/admin/settings/interfaces";
import {
@ -45,7 +45,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
if (results[0].status === 403 || results[0].status === 401) {
settings = {
auto_scroll: true,
product_gating: GatingType.NONE,
application_status: ApplicationStatus.ACTIVE,
gpu_enabled: false,
maximum_chat_retention_days: null,
notifications: [],

View File

@ -91,3 +91,6 @@ export const NEXT_PUBLIC_ENABLE_CHROME_EXTENSION =
export const NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK =
process.env.NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK?.toLowerCase() ===
"true";
export const NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY =
process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY;