This commit is contained in:
pablonyx 2025-03-05 12:28:23 -08:00
parent b6e9e65bb8
commit 751af6824a
16 changed files with 903 additions and 171 deletions

View File

@ -26,7 +26,7 @@ from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.saml import router as saml_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.tenants.router import router as tenants_router
from ee.onyx.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)

View File

@ -0,0 +1,41 @@
# Tenant Provisioning Process
This directory contains the code for provisioning tenants in a multi-tenant Onyx deployment.
## Optimized Tenant Provisioning
The tenant provisioning process has been optimized to allow for faster authentication flow completion. The process is now split into two phases:
1. **Essential Setup (Synchronous)**
- Create the tenant schema
- Run essential Alembic migrations up to revision `465f78d9b7f9` (which includes OAuth access token changes)
- Add the user to the tenant mapping
- This allows the user to log in immediately without waiting for the full setup to complete
2. **Complete Setup (Asynchronous)**
- Run the remaining Alembic migrations
- Configure default API keys
- Set up Onyx (embedding models, search settings, etc.)
- Create milestone records
- This happens in the background after the user has already been able to log in
## Key Files
- `provisioning.py`: Contains the main tenant provisioning logic
- `schema_management.py`: Handles schema creation and Alembic migrations
- `async_setup.py`: Handles the asynchronous part of the tenant setup
- `user_mapping.py`: Manages user-tenant mappings
## Flow
1. User initiates login/signup
2. `provision_tenant()` is called
3. Essential migrations are run with `run_essential_alembic_migrations()`
4. User is added to tenant mapping
5. Asynchronous task is started with `complete_tenant_setup()`
6. User can log in while the rest of the setup continues in the background
## Performance Impact
This optimization significantly reduces the time required for a user to log in after tenant creation. The most time-consuming operations (full migrations, Onyx setup) are deferred to run asynchronously, allowing the auth flow to complete quickly.

View File

@ -0,0 +1,45 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@ -0,0 +1,98 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@ -0,0 +1,143 @@
import asyncio
import logging
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider
logger = logging.getLogger(__name__)
async def complete_tenant_setup(tenant_id: str, email: str) -> None:
"""
Complete the tenant setup process asynchronously after the essential migrations
have been applied. This includes:
1. Running the remaining Alembic migrations
2. Setting up Onyx
3. Creating milestone records
"""
logger.info(f"Starting asynchronous tenant setup for tenant {tenant_id}")
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run the remaining Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Configure default API keys
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
# Setup Onyx
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
# Create milestone record
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
logger.info(f"Asynchronous tenant setup completed for tenant {tenant_id}")
except Exception as e:
logger.exception(
f"Failed to complete asynchronous tenant setup for tenant {tenant_id}: {e}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def configure_default_api_keys(db_session: Session) -> None:
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
try:
logger.info("Attempting to upsert Cohere cloud embedding provider")
upsert_cloud_embedding_provider(cloud_embedding_provider, db_session)
except Exception as e:
logger.error(f"Failed to configure Cohere provider: {e}")
else:
logger.error(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere provider configuration"
)

View File

@ -0,0 +1,96 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@ -67,3 +67,19 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
status: str
is_complete: bool
class ApproveUserRequest(BaseModel):
email: str
tenant_id: str
class RequestInviteRequest(BaseModel):
email: str
tenant_id: str

View File

@ -6,47 +6,28 @@ import aiohttp # Async HTTP client
import httpx
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.async_setup import complete_tenant_setup
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import drop_schema
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from ee.onyx.server.tenants.schema_management import run_essential_alembic_migrations
from ee.onyx.server.tenants.user_mapping import add_users_to_tenant
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider
logger = logging.getLogger(__name__)
@ -115,35 +96,19 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
# Run only the essential Alembic migrations needed for auth
await asyncio.to_thread(run_essential_alembic_migrations, tenant_id)
# Add user to tenant immediately so they can log in
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
# Start the rest of the setup process asynchronously
asyncio.create_task(complete_tenant_setup(tenant_id, email))
logger.info(f"Essential tenant provisioning completed for tenant {tenant_id}")
logger.info(
f"Remaining setup will continue asynchronously for tenant {tenant_id}"
)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
@ -199,136 +164,43 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
logger.error(f"Failed to rollback tenant provisioning: {e}")
def configure_default_api_keys(db_session: Session) -> None:
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
try:
logger.info("Attempting to upsert Cohere cloud embedding provider")
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
logger.info("Successfully upserted Cohere cloud embedding provider")
logger.info("Updating search settings with Cohere embedding model details")
query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.FUTURE)
.order_by(SearchSettings.id.desc())
)
result = db_session.execute(query)
current_search_settings = result.scalars().first()
if current_search_settings:
current_search_settings.model_name = (
"embed-english-v3.0" # Cohere's latest model as of now
)
current_search_settings.model_dim = (
1024 # Cohere's embed-english-v3.0 dimension
)
current_search_settings.provider_type = EmbeddingProvider.COHERE
current_search_settings.index_name = (
"danswer_chunk_cohere_embed_english_v3_0"
)
current_search_settings.query_prefix = ""
current_search_settings.passage_prefix = ""
db_session.commit()
else:
raise RuntimeError(
"No search settings specified, DB is not in a valid state"
)
logger.info("Fetching updated search settings to verify changes")
updated_query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.PRESENT)
.order_by(SearchSettings.id.desc())
)
updated_result = db_session.execute(updated_query)
updated_result.scalars().first()
except Exception:
logger.exception("Failed to configure Cohere embedding provider")
else:
logger.info(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
)
async def submit_to_hubspot(
email: str, referral_source: str | None, request: Request
) -> None:
if not HUBSPOT_TRACKING_URL:
logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission")
return
# HubSpot tracking cookie
hubspot_cookie = request.cookies.get("hubspotutk")
try:
user_agent = request.headers.get("user-agent", "")
referer = request.headers.get("referer", "")
ip_address = request.client.host if request.client else ""
# IP address
ip_address = request.client.host if request.client else None
payload = {
"email": email,
"referral_source": referral_source or "",
"user_agent": user_agent,
"referer": referer,
"ip_address": ip_address,
}
data = {
"fields": [
{"name": "email", "value": email},
{"name": "referral_source", "value": referral_source or ""},
],
"context": {
"hutk": hubspot_cookie,
"ipAddress": ip_address,
"pageUri": str(request.url),
"pageName": "User Registration",
},
}
async with httpx.AsyncClient() as client:
response = await client.post(HUBSPOT_TRACKING_URL, json=data)
if response.status_code != 200:
logger.error(f"Failed to submit to HubSpot: {response.text}")
async with httpx.AsyncClient() as client:
response = await client.post(
HUBSPOT_TRACKING_URL,
json=payload,
timeout=5.0,
)
if response.status_code != 200:
logger.error(
f"Failed to submit to HubSpot: {response.status_code} {response.text}"
)
except Exception as e:
logger.error(f"Error submitting to HubSpot: {e}")
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
if DEV_MODE:
return
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
@ -337,15 +209,14 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
async with aiohttp.ClientSession() as session:
async with session.delete(
async with session.post(
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
headers=headers,
json=payload.model_dump(),
) as response:
print(response)
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")
logger.error(f"Control plane tenant deletion failed: {error_text}")
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)

View File

@ -0,0 +1,62 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
)
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
# from ee.onyx.server.tenants.provisioning import get_tenant_setup_status
logger = setup_logger()
# Create a main router to include all sub-routers
router = APIRouter()
# Include all the sub-routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)
class TenantSetupStatusResponse(BaseModel):
"""Response model for tenant setup status."""
tenant_id: str
status: str
is_complete: bool
# Add the setup status endpoint directly to the main router
@router.get("/tenants/setup-status", response_model=TenantSetupStatusResponse)
async def get_setup_status(
current_user: User = Depends(current_user),
) -> TenantSetupStatusResponse:
"""
Get the current setup status for the tenant.
This is used by the frontend to determine if the tenant setup is complete.
"""
tenant_id = get_current_tenant_id()
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
# status = get_tenant_setup_status(tenant_id)
return TenantSetupStatusResponse(
tenant_id=tenant_id, status="completed", is_complete=True
)

View File

@ -49,6 +49,47 @@ def run_alembic_migrations(schema_name: str) -> None:
raise
def run_essential_alembic_migrations(schema_name: str) -> None:
"""
Run only the essential Alembic migrations up to the 465f78d9b7f9 revision.
This is used for the auth flow to complete quickly, with the rest of the migrations
and setup being deferred to run asynchronously.
"""
logger.info(f"Starting essential Alembic migrations for schema: {schema_name}")
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
alembic_cfg.set_main_option(
"script_location", os.path.join(root_dir, "alembic")
)
# Ensure that logging isn't broken
alembic_cfg.attributes["configure_logger"] = False
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically up to the specified revision
command.upgrade(alembic_cfg, "465f78d9b7f9")
logger.info(
f"Essential Alembic migrations completed successfully for schema: {schema_name}"
)
except Exception as e:
logger.exception(
f"Essential Alembic migration failed for schema {schema_name}: {str(e)}"
)
raise
def create_schema_if_not_exists(tenant_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():

View File

@ -0,0 +1,67 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@ -0,0 +1,62 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
# from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_DOMAINS = [
"gmail.com",
"yahoo.com",
"hotmail.com",
"outlook.com",
"icloud.com",
"msn.com",
"live.com",
"msn.com",
"hotmail.com",
"hotmail.co.uk",
"hotmail.fr",
"hotmail.de",
"hotmail.it",
"hotmail.es",
"hotmail.nl",
"hotmail.pl",
"hotmail.pt",
"hotmail.ro",
"hotmail.ru",
"hotmail.sa",
"hotmail.se",
"hotmail.tr",
"hotmail.tw",
"hotmail.ua",
"hotmail.us",
"hotmail.vn",
"hotmail.za",
"hotmail.zw",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_admin_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if domain in FORBIDDEN_COMMON_EMAIL_DOMAINS:
return None
tenant_id = get_current_tenant_id()
return TenantByDomainResponse(
tenant_id=tenant_id, status="completed", is_complete=True
)
# return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@ -0,0 +1,91 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/request-invite")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/approve-invite")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/accept-invite")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/deny-invite")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@ -0,0 +1,52 @@
from typing import Optional
from fastapi import Depends
from fastapi import Request
from fastapi_users import BaseUserManager
from fastapi_users import UUIDIDMixin
from fastapi_users.db import SQLAlchemyUserDatabase
from onyx.auth.essential_user import EssentialUser
from onyx.auth.essential_user import get_essential_user_db
from onyx.configs.app_configs import USER_MANAGER_SECRET
class EssentialUserManager(UUIDIDMixin, BaseUserManager[EssentialUser, str]):
"""
A simplified user manager that only handles essential authentication operations.
This is used during the initial tenant setup phase to avoid errors with missing columns.
"""
reset_password_token_secret = USER_MANAGER_SECRET
verification_token_secret = USER_MANAGER_SECRET
async def on_after_register(
self, user: EssentialUser, request: Optional[Request] = None
) -> None:
"""
Simplified post-registration hook.
"""
async def on_after_forgot_password(
self, user: EssentialUser, token: str, request: Optional[Request] = None
) -> None:
"""
Simplified post-forgot-password hook.
"""
async def on_after_request_verify(
self, user: EssentialUser, token: str, request: Optional[Request] = None
) -> None:
"""
Simplified post-verification-request hook.
"""
async def get_essential_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_essential_user_db),
) -> EssentialUserManager:
"""
Get a user manager that uses the essential user model.
This avoids errors with missing columns during the initial tenant setup.
"""
yield EssentialUserManager(user_db)

View File

@ -0,0 +1,47 @@
from collections.abc import AsyncGenerator
from typing import Optional
from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import String
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import relationship
from onyx.db.engine import get_async_session
Base: DeclarativeMeta = declarative_base()
class EssentialUser(SQLAlchemyBaseUserTableUUID, Base):
"""
A simplified user model that only includes essential columns needed for authentication.
This is used during the initial tenant setup phase to avoid errors with missing columns
that would be added in later migrations.
"""
__tablename__ = "user"
email: str = Column(String(length=320), unique=True, index=True, nullable=False)
hashed_password: Optional[str] = Column(String(length=1024), nullable=True)
is_active: bool = Column(Boolean, default=True, nullable=False)
is_superuser: bool = Column(Boolean, default=False, nullable=False)
is_verified: bool = Column(Boolean, default=False, nullable=False)
# Relationships are defined but not used in the essential auth flow
oauth_accounts = relationship("OAuthAccount", lazy="joined")
credentials = relationship("Credential", lazy="joined")
async def get_essential_user_db(
session: AsyncSession = Depends(get_async_session),
) -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
"""
Get a user database that uses the essential user model.
This avoids errors with missing columns during the initial tenant setup.
"""
yield SQLAlchemyUserDatabase(session, EssentialUser)

View File

@ -95,7 +95,7 @@ const Page = async (props: {
</div>
)}
<EmailPasswordForm
<EmailPasswordFormau
isSignup
shouldVerify={authTypeMetadata?.requiresVerification}
nextUrl={nextUrl}