diff --git a/backend/ee/onyx/main.py b/backend/ee/onyx/main.py index 7d7278bb2..32d09b463 100644 --- a/backend/ee/onyx/main.py +++ b/backend/ee/onyx/main.py @@ -26,7 +26,7 @@ from ee.onyx.server.query_history.api import router as query_history_router from ee.onyx.server.reporting.usage_export_api import router as usage_export_router from ee.onyx.server.saml import router as saml_router from ee.onyx.server.seeding import seed_db -from ee.onyx.server.tenants.api import router as tenants_router +from ee.onyx.server.tenants.router import router as tenants_router from ee.onyx.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) diff --git a/backend/ee/onyx/server/tenants/README.md b/backend/ee/onyx/server/tenants/README.md new file mode 100644 index 000000000..701272d00 --- /dev/null +++ b/backend/ee/onyx/server/tenants/README.md @@ -0,0 +1,41 @@ +# Tenant Provisioning Process + +This directory contains the code for provisioning tenants in a multi-tenant Onyx deployment. + +## Optimized Tenant Provisioning + +The tenant provisioning process has been optimized to allow for faster authentication flow completion. The process is now split into two phases: + +1. **Essential Setup (Synchronous)** + + - Create the tenant schema + - Run essential Alembic migrations up to revision `465f78d9b7f9` (which includes OAuth access token changes) + - Add the user to the tenant mapping + - This allows the user to log in immediately without waiting for the full setup to complete + +2. **Complete Setup (Asynchronous)** + - Run the remaining Alembic migrations + - Configure default API keys + - Set up Onyx (embedding models, search settings, etc.) + - Create milestone records + - This happens in the background after the user has already been able to log in + +## Key Files + +- `provisioning.py`: Contains the main tenant provisioning logic +- `schema_management.py`: Handles schema creation and Alembic migrations +- `async_setup.py`: Handles the asynchronous part of the tenant setup +- `user_mapping.py`: Manages user-tenant mappings + +## Flow + +1. User initiates login/signup +2. `provision_tenant()` is called +3. Essential migrations are run with `run_essential_alembic_migrations()` +4. User is added to tenant mapping +5. Asynchronous task is started with `complete_tenant_setup()` +6. User can log in while the rest of the setup continues in the background + +## Performance Impact + +This optimization significantly reduces the time required for a user to log in after tenant creation. The most time-consuming operations (full migrations, Onyx setup) are deferred to run asynchronously, allowing the auth flow to complete quickly. diff --git a/backend/ee/onyx/server/tenants/admin_api.py b/backend/ee/onyx/server/tenants/admin_api.py new file mode 100644 index 000000000..d1dbc9274 --- /dev/null +++ b/backend/ee/onyx/server/tenants/admin_api.py @@ -0,0 +1,45 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Response + +from ee.onyx.auth.users import current_cloud_superuser +from ee.onyx.server.tenants.models import ImpersonateRequest +from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email +from onyx.auth.users import auth_backend +from onyx.auth.users import get_redis_strategy +from onyx.auth.users import User +from onyx.db.engine import get_session_with_tenant +from onyx.db.users import get_user_by_email +from onyx.utils.logger import setup_logger + +logger = setup_logger() + +router = APIRouter(prefix="/tenants") + + +@router.post("/impersonate") +async def impersonate_user( + impersonate_request: ImpersonateRequest, + _: User = Depends(current_cloud_superuser), +) -> Response: + """Allows a cloud superuser to impersonate another user by generating an impersonation JWT token""" + tenant_id = get_tenant_id_for_email(impersonate_request.email) + + with get_session_with_tenant(tenant_id=tenant_id) as tenant_session: + user_to_impersonate = get_user_by_email( + impersonate_request.email, tenant_session + ) + if user_to_impersonate is None: + raise HTTPException(status_code=404, detail="User not found") + token = await get_redis_strategy().write_token(user_to_impersonate) + + response = await auth_backend.transport.get_login_response(token) + response.set_cookie( + key="fastapiusersauth", + value=token, + httponly=True, + secure=True, + samesite="lax", + ) + return response diff --git a/backend/ee/onyx/server/tenants/anonymous_users_api.py b/backend/ee/onyx/server/tenants/anonymous_users_api.py new file mode 100644 index 000000000..0dccc0916 --- /dev/null +++ b/backend/ee/onyx/server/tenants/anonymous_users_api.py @@ -0,0 +1,98 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Response +from sqlalchemy.exc import IntegrityError + +from ee.onyx.auth.users import generate_anonymous_user_jwt_token +from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME +from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path +from ee.onyx.server.tenants.anonymous_user_path import ( + get_tenant_id_for_anonymous_user_path, +) +from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path +from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path +from ee.onyx.server.tenants.models import AnonymousUserPath +from onyx.auth.users import anonymous_user_enabled +from onyx.auth.users import current_admin_user +from onyx.auth.users import optional_user +from onyx.auth.users import User +from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME +from onyx.db.engine import get_session_with_shared_schema +from onyx.utils.logger import setup_logger +from shared_configs.contextvars import get_current_tenant_id + +logger = setup_logger() + +router = APIRouter(prefix="/tenants") + + +@router.get("/anonymous-user-path") +async def get_anonymous_user_path_api( + _: User | None = Depends(current_admin_user), +) -> AnonymousUserPath: + tenant_id = get_current_tenant_id() + + if tenant_id is None: + raise HTTPException(status_code=404, detail="Tenant not found") + + with get_session_with_shared_schema() as db_session: + current_path = get_anonymous_user_path(tenant_id, db_session) + + return AnonymousUserPath(anonymous_user_path=current_path) + + +@router.post("/anonymous-user-path") +async def set_anonymous_user_path_api( + anonymous_user_path: str, + _: User | None = Depends(current_admin_user), +) -> None: + tenant_id = get_current_tenant_id() + try: + validate_anonymous_user_path(anonymous_user_path) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + with get_session_with_shared_schema() as db_session: + try: + modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session) + except IntegrityError: + raise HTTPException( + status_code=409, + detail="The anonymous user path is already in use. Please choose a different path.", + ) + except Exception as e: + logger.exception(f"Failed to modify anonymous user path: {str(e)}") + raise HTTPException( + status_code=500, + detail="An unexpected error occurred while modifying the anonymous user path", + ) + + +@router.post("/anonymous-user") +async def login_as_anonymous_user( + anonymous_user_path: str, + _: User | None = Depends(optional_user), +) -> Response: + with get_session_with_shared_schema() as db_session: + tenant_id = get_tenant_id_for_anonymous_user_path( + anonymous_user_path, db_session + ) + if not tenant_id: + raise HTTPException(status_code=404, detail="Tenant not found") + + if not anonymous_user_enabled(tenant_id=tenant_id): + raise HTTPException(status_code=403, detail="Anonymous user is not enabled") + + token = generate_anonymous_user_jwt_token(tenant_id) + + response = Response() + response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME) + response.set_cookie( + key=ANONYMOUS_USER_COOKIE_NAME, + value=token, + httponly=True, + secure=True, + samesite="strict", + ) + return response diff --git a/backend/ee/onyx/server/tenants/async_setup.py b/backend/ee/onyx/server/tenants/async_setup.py new file mode 100644 index 000000000..d59197123 --- /dev/null +++ b/backend/ee/onyx/server/tenants/async_setup.py @@ -0,0 +1,143 @@ +import asyncio +import logging + +from sqlalchemy.orm import Session + +from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY +from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY +from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY +from ee.onyx.server.tenants.schema_management import run_alembic_migrations +from onyx.configs.constants import MilestoneRecordType +from onyx.db.engine import get_session_with_tenant +from onyx.db.llm import update_default_provider +from onyx.db.llm import upsert_cloud_embedding_provider +from onyx.db.llm import upsert_llm_provider +from onyx.db.models import IndexModelStatus +from onyx.db.models import SearchSettings +from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES +from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME +from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES +from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME +from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest +from onyx.server.manage.llm.models import LLMProviderUpsertRequest +from onyx.setup import setup_onyx +from onyx.utils.telemetry import create_milestone_and_report +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR +from shared_configs.enums import EmbeddingProvider + +logger = logging.getLogger(__name__) + + +async def complete_tenant_setup(tenant_id: str, email: str) -> None: + """ + Complete the tenant setup process asynchronously after the essential migrations + have been applied. This includes: + 1. Running the remaining Alembic migrations + 2. Setting up Onyx + 3. Creating milestone records + """ + logger.info(f"Starting asynchronous tenant setup for tenant {tenant_id}") + token = None + + try: + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + + # Run the remaining Alembic migrations + await asyncio.to_thread(run_alembic_migrations, tenant_id) + + # Configure default API keys + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + configure_default_api_keys(db_session) + + # Setup Onyx + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + current_search_settings = ( + db_session.query(SearchSettings) + .filter_by(status=IndexModelStatus.FUTURE) + .first() + ) + cohere_enabled = ( + current_search_settings is not None + and current_search_settings.provider_type == EmbeddingProvider.COHERE + ) + setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled) + + # Create milestone record + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + create_milestone_and_report( + user=None, + distinct_id=tenant_id, + event_type=MilestoneRecordType.TENANT_CREATED, + properties={ + "email": email, + }, + db_session=db_session, + ) + + logger.info(f"Asynchronous tenant setup completed for tenant {tenant_id}") + + except Exception as e: + logger.exception( + f"Failed to complete asynchronous tenant setup for tenant {tenant_id}: {e}" + ) + finally: + if token is not None: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + +def configure_default_api_keys(db_session: Session) -> None: + if ANTHROPIC_DEFAULT_API_KEY: + anthropic_provider = LLMProviderUpsertRequest( + name="Anthropic", + provider=ANTHROPIC_PROVIDER_NAME, + api_key=ANTHROPIC_DEFAULT_API_KEY, + default_model_name="claude-3-7-sonnet-20250219", + fast_default_model_name="claude-3-5-sonnet-20241022", + model_names=ANTHROPIC_MODEL_NAMES, + display_model_names=["claude-3-5-sonnet-20241022"], + ) + try: + full_provider = upsert_llm_provider(anthropic_provider, db_session) + update_default_provider(full_provider.id, db_session) + except Exception as e: + logger.error(f"Failed to configure Anthropic provider: {e}") + else: + logger.error( + "ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration" + ) + + if OPENAI_DEFAULT_API_KEY: + open_provider = LLMProviderUpsertRequest( + name="OpenAI", + provider=OPENAI_PROVIDER_NAME, + api_key=OPENAI_DEFAULT_API_KEY, + default_model_name="gpt-4o", + fast_default_model_name="gpt-4o-mini", + model_names=OPEN_AI_MODEL_NAMES, + display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"], + ) + try: + full_provider = upsert_llm_provider(open_provider, db_session) + update_default_provider(full_provider.id, db_session) + except Exception as e: + logger.error(f"Failed to configure OpenAI provider: {e}") + else: + logger.error( + "OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration" + ) + + if COHERE_DEFAULT_API_KEY: + cloud_embedding_provider = CloudEmbeddingProviderCreationRequest( + provider_type=EmbeddingProvider.COHERE, + api_key=COHERE_DEFAULT_API_KEY, + ) + + try: + logger.info("Attempting to upsert Cohere cloud embedding provider") + upsert_cloud_embedding_provider(cloud_embedding_provider, db_session) + except Exception as e: + logger.error(f"Failed to configure Cohere provider: {e}") + else: + logger.error( + "COHERE_DEFAULT_API_KEY not set, skipping Cohere provider configuration" + ) diff --git a/backend/ee/onyx/server/tenants/billing_api.py b/backend/ee/onyx/server/tenants/billing_api.py new file mode 100644 index 000000000..18da0b95d --- /dev/null +++ b/backend/ee/onyx/server/tenants/billing_api.py @@ -0,0 +1,96 @@ +import stripe +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException + +from ee.onyx.auth.users import current_admin_user +from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY +from ee.onyx.server.tenants.access import control_plane_dep +from ee.onyx.server.tenants.billing import fetch_billing_information +from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session +from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information +from ee.onyx.server.tenants.models import BillingInformation +from ee.onyx.server.tenants.models import ProductGatingRequest +from ee.onyx.server.tenants.models import ProductGatingResponse +from ee.onyx.server.tenants.models import SubscriptionSessionResponse +from ee.onyx.server.tenants.models import SubscriptionStatusResponse +from ee.onyx.server.tenants.product_gating import store_product_gating +from onyx.auth.users import User +from onyx.configs.app_configs import WEB_DOMAIN +from onyx.utils.logger import setup_logger +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR +from shared_configs.contextvars import get_current_tenant_id + +stripe.api_key = STRIPE_SECRET_KEY +logger = setup_logger() + +router = APIRouter(prefix="/tenants") + + +@router.post("/product-gating") +def gate_product( + product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep) +) -> ProductGatingResponse: + """ + Gating the product means that the product is not available to the tenant. + They will be directed to the billing page. + We gate the product when their subscription has ended. + """ + try: + store_product_gating( + product_gating_request.tenant_id, product_gating_request.application_status + ) + return ProductGatingResponse(updated=True, error=None) + + except Exception as e: + logger.exception("Failed to gate product") + return ProductGatingResponse(updated=False, error=str(e)) + + +@router.get("/billing-information") +async def billing_information( + _: User = Depends(current_admin_user), +) -> BillingInformation | SubscriptionStatusResponse: + logger.info("Fetching billing information") + tenant_id = get_current_tenant_id() + return fetch_billing_information(tenant_id) + + +@router.post("/create-customer-portal-session") +async def create_customer_portal_session( + _: User = Depends(current_admin_user), +) -> dict: + tenant_id = get_current_tenant_id() + + try: + stripe_info = fetch_tenant_stripe_information(tenant_id) + stripe_customer_id = stripe_info.get("stripe_customer_id") + if not stripe_customer_id: + raise HTTPException(status_code=400, detail="Stripe customer ID not found") + logger.info(stripe_customer_id) + + portal_session = stripe.billing_portal.Session.create( + customer=stripe_customer_id, + return_url=f"{WEB_DOMAIN}/admin/billing", + ) + logger.info(portal_session) + return {"url": portal_session.url} + except Exception as e: + logger.exception("Failed to create customer portal session") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/create-subscription-session") +async def create_subscription_session( + _: User = Depends(current_admin_user), +) -> SubscriptionSessionResponse: + try: + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() + if not tenant_id: + raise HTTPException(status_code=400, detail="Tenant ID not found") + session_id = fetch_stripe_checkout_session(tenant_id) + return SubscriptionSessionResponse(sessionId=session_id) + + except Exception as e: + logger.exception("Failed to create resubscription session") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/ee/onyx/server/tenants/models.py b/backend/ee/onyx/server/tenants/models.py index 7931a06a7..8604c422f 100644 --- a/backend/ee/onyx/server/tenants/models.py +++ b/backend/ee/onyx/server/tenants/models.py @@ -67,3 +67,19 @@ class ProductGatingResponse(BaseModel): class SubscriptionSessionResponse(BaseModel): sessionId: str + + +class TenantByDomainResponse(BaseModel): + tenant_id: str + status: str + is_complete: bool + + +class ApproveUserRequest(BaseModel): + email: str + tenant_id: str + + +class RequestInviteRequest(BaseModel): + email: str + tenant_id: str diff --git a/backend/ee/onyx/server/tenants/provisioning.py b/backend/ee/onyx/server/tenants/provisioning.py index 5328277ff..97ff4b2f6 100644 --- a/backend/ee/onyx/server/tenants/provisioning.py +++ b/backend/ee/onyx/server/tenants/provisioning.py @@ -6,47 +6,28 @@ import aiohttp # Async HTTP client import httpx from fastapi import HTTPException from fastapi import Request -from sqlalchemy import select from sqlalchemy.orm import Session -from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY -from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL -from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY from ee.onyx.server.tenants.access import generate_data_plane_token +from ee.onyx.server.tenants.async_setup import complete_tenant_setup from ee.onyx.server.tenants.models import TenantCreationPayload from ee.onyx.server.tenants.models import TenantDeletionPayload from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists from ee.onyx.server.tenants.schema_management import drop_schema -from ee.onyx.server.tenants.schema_management import run_alembic_migrations +from ee.onyx.server.tenants.schema_management import run_essential_alembic_migrations from ee.onyx.server.tenants.user_mapping import add_users_to_tenant from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant from onyx.auth.users import exceptions from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL from onyx.configs.app_configs import DEV_MODE -from onyx.configs.constants import MilestoneRecordType -from onyx.db.engine import get_session_with_tenant from onyx.db.engine import get_sqlalchemy_engine -from onyx.db.llm import update_default_provider -from onyx.db.llm import upsert_cloud_embedding_provider -from onyx.db.llm import upsert_llm_provider -from onyx.db.models import IndexModelStatus -from onyx.db.models import SearchSettings from onyx.db.models import UserTenantMapping -from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES -from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME -from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES -from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME -from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest -from onyx.server.manage.llm.models import LLMProviderUpsertRequest -from onyx.setup import setup_onyx -from onyx.utils.telemetry import create_milestone_and_report from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import TENANT_ID_PREFIX from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR -from shared_configs.enums import EmbeddingProvider logger = logging.getLogger(__name__) @@ -115,35 +96,19 @@ async def provision_tenant(tenant_id: str, email: str) -> None: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - # Await the Alembic migrations - await asyncio.to_thread(run_alembic_migrations, tenant_id) - - with get_session_with_tenant(tenant_id=tenant_id) as db_session: - configure_default_api_keys(db_session) - - current_search_settings = ( - db_session.query(SearchSettings) - .filter_by(status=IndexModelStatus.FUTURE) - .first() - ) - cohere_enabled = ( - current_search_settings is not None - and current_search_settings.provider_type == EmbeddingProvider.COHERE - ) - setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled) + # Run only the essential Alembic migrations needed for auth + await asyncio.to_thread(run_essential_alembic_migrations, tenant_id) + # Add user to tenant immediately so they can log in add_users_to_tenant([email], tenant_id) - with get_session_with_tenant(tenant_id=tenant_id) as db_session: - create_milestone_and_report( - user=None, - distinct_id=tenant_id, - event_type=MilestoneRecordType.TENANT_CREATED, - properties={ - "email": email, - }, - db_session=db_session, - ) + # Start the rest of the setup process asynchronously + asyncio.create_task(complete_tenant_setup(tenant_id, email)) + + logger.info(f"Essential tenant provisioning completed for tenant {tenant_id}") + logger.info( + f"Remaining setup will continue asynchronously for tenant {tenant_id}" + ) except Exception as e: logger.exception(f"Failed to create tenant {tenant_id}") @@ -199,136 +164,43 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None: logger.error(f"Failed to rollback tenant provisioning: {e}") -def configure_default_api_keys(db_session: Session) -> None: - if ANTHROPIC_DEFAULT_API_KEY: - anthropic_provider = LLMProviderUpsertRequest( - name="Anthropic", - provider=ANTHROPIC_PROVIDER_NAME, - api_key=ANTHROPIC_DEFAULT_API_KEY, - default_model_name="claude-3-7-sonnet-20250219", - fast_default_model_name="claude-3-5-sonnet-20241022", - model_names=ANTHROPIC_MODEL_NAMES, - display_model_names=["claude-3-5-sonnet-20241022"], - ) - try: - full_provider = upsert_llm_provider(anthropic_provider, db_session) - update_default_provider(full_provider.id, db_session) - except Exception as e: - logger.error(f"Failed to configure Anthropic provider: {e}") - else: - logger.error( - "ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration" - ) - - if OPENAI_DEFAULT_API_KEY: - open_provider = LLMProviderUpsertRequest( - name="OpenAI", - provider=OPENAI_PROVIDER_NAME, - api_key=OPENAI_DEFAULT_API_KEY, - default_model_name="gpt-4o", - fast_default_model_name="gpt-4o-mini", - model_names=OPEN_AI_MODEL_NAMES, - display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"], - ) - try: - full_provider = upsert_llm_provider(open_provider, db_session) - update_default_provider(full_provider.id, db_session) - except Exception as e: - logger.error(f"Failed to configure OpenAI provider: {e}") - else: - logger.error( - "OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration" - ) - - if COHERE_DEFAULT_API_KEY: - cloud_embedding_provider = CloudEmbeddingProviderCreationRequest( - provider_type=EmbeddingProvider.COHERE, - api_key=COHERE_DEFAULT_API_KEY, - ) - - try: - logger.info("Attempting to upsert Cohere cloud embedding provider") - upsert_cloud_embedding_provider(db_session, cloud_embedding_provider) - logger.info("Successfully upserted Cohere cloud embedding provider") - - logger.info("Updating search settings with Cohere embedding model details") - query = ( - select(SearchSettings) - .where(SearchSettings.status == IndexModelStatus.FUTURE) - .order_by(SearchSettings.id.desc()) - ) - result = db_session.execute(query) - current_search_settings = result.scalars().first() - - if current_search_settings: - current_search_settings.model_name = ( - "embed-english-v3.0" # Cohere's latest model as of now - ) - current_search_settings.model_dim = ( - 1024 # Cohere's embed-english-v3.0 dimension - ) - current_search_settings.provider_type = EmbeddingProvider.COHERE - current_search_settings.index_name = ( - "danswer_chunk_cohere_embed_english_v3_0" - ) - current_search_settings.query_prefix = "" - current_search_settings.passage_prefix = "" - db_session.commit() - else: - raise RuntimeError( - "No search settings specified, DB is not in a valid state" - ) - logger.info("Fetching updated search settings to verify changes") - updated_query = ( - select(SearchSettings) - .where(SearchSettings.status == IndexModelStatus.PRESENT) - .order_by(SearchSettings.id.desc()) - ) - updated_result = db_session.execute(updated_query) - updated_result.scalars().first() - - except Exception: - logger.exception("Failed to configure Cohere embedding provider") - else: - logger.info( - "COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration" - ) - - async def submit_to_hubspot( email: str, referral_source: str | None, request: Request ) -> None: if not HUBSPOT_TRACKING_URL: - logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission") return - # HubSpot tracking cookie - hubspot_cookie = request.cookies.get("hubspotutk") + try: + user_agent = request.headers.get("user-agent", "") + referer = request.headers.get("referer", "") + ip_address = request.client.host if request.client else "" - # IP address - ip_address = request.client.host if request.client else None + payload = { + "email": email, + "referral_source": referral_source or "", + "user_agent": user_agent, + "referer": referer, + "ip_address": ip_address, + } - data = { - "fields": [ - {"name": "email", "value": email}, - {"name": "referral_source", "value": referral_source or ""}, - ], - "context": { - "hutk": hubspot_cookie, - "ipAddress": ip_address, - "pageUri": str(request.url), - "pageName": "User Registration", - }, - } - - async with httpx.AsyncClient() as client: - response = await client.post(HUBSPOT_TRACKING_URL, json=data) - - if response.status_code != 200: - logger.error(f"Failed to submit to HubSpot: {response.text}") + async with httpx.AsyncClient() as client: + response = await client.post( + HUBSPOT_TRACKING_URL, + json=payload, + timeout=5.0, + ) + if response.status_code != 200: + logger.error( + f"Failed to submit to HubSpot: {response.status_code} {response.text}" + ) + except Exception as e: + logger.error(f"Error submitting to HubSpot: {e}") async def delete_user_from_control_plane(tenant_id: str, email: str) -> None: + if DEV_MODE: + return + token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", @@ -337,15 +209,14 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None: payload = TenantDeletionPayload(tenant_id=tenant_id, email=email) async with aiohttp.ClientSession() as session: - async with session.delete( + async with session.post( f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete", headers=headers, json=payload.model_dump(), ) as response: - print(response) if response.status != 200: error_text = await response.text() - logger.error(f"Control plane tenant creation failed: {error_text}") + logger.error(f"Control plane tenant deletion failed: {error_text}") raise Exception( f"Failed to delete tenant on control plane: {error_text}" ) diff --git a/backend/ee/onyx/server/tenants/router.py b/backend/ee/onyx/server/tenants/router.py new file mode 100644 index 000000000..ca1a3036a --- /dev/null +++ b/backend/ee/onyx/server/tenants/router.py @@ -0,0 +1,62 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from pydantic import BaseModel + +from ee.onyx.server.tenants.admin_api import router as admin_router +from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router +from ee.onyx.server.tenants.billing_api import router as billing_router +from ee.onyx.server.tenants.team_membership_api import router as team_membership_router +from ee.onyx.server.tenants.tenant_management_api import ( + router as tenant_management_router, +) +from ee.onyx.server.tenants.user_invitations_api import ( + router as user_invitations_router, +) +from onyx.auth.users import current_user +from onyx.auth.users import User +from onyx.utils.logger import setup_logger +from shared_configs.contextvars import get_current_tenant_id + +# from ee.onyx.server.tenants.provisioning import get_tenant_setup_status + +logger = setup_logger() + +# Create a main router to include all sub-routers +router = APIRouter() + +# Include all the sub-routers +router.include_router(admin_router) +router.include_router(anonymous_users_router) +router.include_router(billing_router) +router.include_router(team_membership_router) +router.include_router(tenant_management_router) +router.include_router(user_invitations_router) + + +class TenantSetupStatusResponse(BaseModel): + """Response model for tenant setup status.""" + + tenant_id: str + status: str + is_complete: bool + + +# Add the setup status endpoint directly to the main router +@router.get("/tenants/setup-status", response_model=TenantSetupStatusResponse) +async def get_setup_status( + current_user: User = Depends(current_user), +) -> TenantSetupStatusResponse: + """ + Get the current setup status for the tenant. + This is used by the frontend to determine if the tenant setup is complete. + """ + tenant_id = get_current_tenant_id() + if not tenant_id: + raise HTTPException(status_code=404, detail="Tenant not found") + + # status = get_tenant_setup_status(tenant_id) + + return TenantSetupStatusResponse( + tenant_id=tenant_id, status="completed", is_complete=True + ) diff --git a/backend/ee/onyx/server/tenants/schema_management.py b/backend/ee/onyx/server/tenants/schema_management.py index 80712556e..3442a4697 100644 --- a/backend/ee/onyx/server/tenants/schema_management.py +++ b/backend/ee/onyx/server/tenants/schema_management.py @@ -49,6 +49,47 @@ def run_alembic_migrations(schema_name: str) -> None: raise +def run_essential_alembic_migrations(schema_name: str) -> None: + """ + Run only the essential Alembic migrations up to the 465f78d9b7f9 revision. + This is used for the auth flow to complete quickly, with the rest of the migrations + and setup being deferred to run asynchronously. + """ + logger.info(f"Starting essential Alembic migrations for schema: {schema_name}") + + try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) + alembic_ini_path = os.path.join(root_dir, "alembic.ini") + + # Configure Alembic + alembic_cfg = Config(alembic_ini_path) + alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string()) + alembic_cfg.set_main_option( + "script_location", os.path.join(root_dir, "alembic") + ) + + # Ensure that logging isn't broken + alembic_cfg.attributes["configure_logger"] = False + + # Mimic command-line options by adding 'cmd_opts' to the config + alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore + alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore + + # Run migrations programmatically up to the specified revision + command.upgrade(alembic_cfg, "465f78d9b7f9") + + logger.info( + f"Essential Alembic migrations completed successfully for schema: {schema_name}" + ) + + except Exception as e: + logger.exception( + f"Essential Alembic migration failed for schema {schema_name}: {str(e)}" + ) + raise + + def create_schema_if_not_exists(tenant_id: str) -> bool: with Session(get_sqlalchemy_engine()) as db_session: with db_session.begin(): diff --git a/backend/ee/onyx/server/tenants/team_membership_api.py b/backend/ee/onyx/server/tenants/team_membership_api.py new file mode 100644 index 000000000..bdf899aaa --- /dev/null +++ b/backend/ee/onyx/server/tenants/team_membership_api.py @@ -0,0 +1,67 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane +from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant +from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant +from onyx.auth.users import current_admin_user +from onyx.auth.users import User +from onyx.db.auth import get_user_count +from onyx.db.engine import get_session +from onyx.db.users import delete_user_from_db +from onyx.db.users import get_user_by_email +from onyx.server.manage.models import UserByEmail +from onyx.utils.logger import setup_logger +from shared_configs.contextvars import get_current_tenant_id + +logger = setup_logger() + +router = APIRouter(prefix="/tenants") + + +@router.post("/leave-team") +async def leave_organization( + user_email: UserByEmail, + current_user: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + tenant_id = get_current_tenant_id() + + if current_user is None or current_user.email != user_email.user_email: + raise HTTPException( + status_code=403, detail="You can only leave the organization as yourself" + ) + + user_to_delete = get_user_by_email(user_email.user_email, db_session) + if user_to_delete is None: + raise HTTPException(status_code=404, detail="User not found") + + num_admin_users = await get_user_count(only_admin_users=True) + + should_delete_tenant = num_admin_users == 1 + + if should_delete_tenant: + logger.info( + "Last admin user is leaving the organization. Deleting tenant from control plane." + ) + try: + await delete_user_from_control_plane(tenant_id, user_to_delete.email) + logger.debug("User deleted from control plane") + except Exception as e: + logger.exception( + f"Failed to delete user from control plane for tenant {tenant_id}: {e}" + ) + raise HTTPException( + status_code=500, + detail=f"Failed to remove user from control plane: {str(e)}", + ) + + db_session.expunge(user_to_delete) + delete_user_from_db(user_to_delete, db_session) + + if should_delete_tenant: + remove_all_users_from_tenant(tenant_id) + else: + remove_users_from_tenant([user_to_delete.email], tenant_id) diff --git a/backend/ee/onyx/server/tenants/tenant_management_api.py b/backend/ee/onyx/server/tenants/tenant_management_api.py new file mode 100644 index 000000000..f71bf5817 --- /dev/null +++ b/backend/ee/onyx/server/tenants/tenant_management_api.py @@ -0,0 +1,62 @@ +from fastapi import APIRouter +from fastapi import Depends + +from ee.onyx.server.tenants.models import TenantByDomainResponse +from onyx.auth.users import current_admin_user +from onyx.auth.users import User +from onyx.utils.logger import setup_logger +from shared_configs.contextvars import get_current_tenant_id + +# from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane + +logger = setup_logger() + +router = APIRouter(prefix="/tenants") + +FORBIDDEN_COMMON_EMAIL_DOMAINS = [ + "gmail.com", + "yahoo.com", + "hotmail.com", + "outlook.com", + "icloud.com", + "msn.com", + "live.com", + "msn.com", + "hotmail.com", + "hotmail.co.uk", + "hotmail.fr", + "hotmail.de", + "hotmail.it", + "hotmail.es", + "hotmail.nl", + "hotmail.pl", + "hotmail.pt", + "hotmail.ro", + "hotmail.ru", + "hotmail.sa", + "hotmail.se", + "hotmail.tr", + "hotmail.tw", + "hotmail.ua", + "hotmail.us", + "hotmail.vn", + "hotmail.za", + "hotmail.zw", +] + + +@router.get("/existing-team-by-domain") +def get_existing_tenant_by_domain( + user: User | None = Depends(current_admin_user), +) -> TenantByDomainResponse | None: + if not user: + return None + domain = user.email.split("@")[1] + if domain in FORBIDDEN_COMMON_EMAIL_DOMAINS: + return None + tenant_id = get_current_tenant_id() + return TenantByDomainResponse( + tenant_id=tenant_id, status="completed", is_complete=True + ) + + # return get_tenant_by_domain_from_control_plane(domain, tenant_id) diff --git a/backend/ee/onyx/server/tenants/user_invitations_api.py b/backend/ee/onyx/server/tenants/user_invitations_api.py new file mode 100644 index 000000000..c5e9d4ebb --- /dev/null +++ b/backend/ee/onyx/server/tenants/user_invitations_api.py @@ -0,0 +1,91 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException + +from ee.onyx.server.tenants.models import ApproveUserRequest +from ee.onyx.server.tenants.models import PendingUserSnapshot +from ee.onyx.server.tenants.models import RequestInviteRequest +from ee.onyx.server.tenants.user_mapping import accept_user_invite +from ee.onyx.server.tenants.user_mapping import approve_user_invite +from ee.onyx.server.tenants.user_mapping import deny_user_invite +from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant +from onyx.auth.invited_users import get_pending_users +from onyx.auth.users import current_admin_user +from onyx.auth.users import current_user +from onyx.auth.users import User +from onyx.utils.logger import setup_logger +from shared_configs.contextvars import get_current_tenant_id + +logger = setup_logger() + +router = APIRouter(prefix="/tenants") + + +@router.post("/request-invite") +async def request_invite( + invite_request: RequestInviteRequest, + user: User | None = Depends(current_admin_user), +) -> None: + if user is None: + raise HTTPException(status_code=401, detail="User not authenticated") + try: + invite_self_to_tenant(user.email, invite_request.tenant_id) + except Exception as e: + logger.exception( + f"Failed to invite self to tenant {invite_request.tenant_id}: {e}" + ) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/users/pending") +def list_pending_users( + _: User | None = Depends(current_admin_user), +) -> list[PendingUserSnapshot]: + pending_emails = get_pending_users() + + return [PendingUserSnapshot(email=email) for email in pending_emails] + + +@router.post("/users/approve-invite") +async def approve_user( + approve_user_request: ApproveUserRequest, + _: User | None = Depends(current_admin_user), +) -> None: + tenant_id = get_current_tenant_id() + approve_user_invite(approve_user_request.email, tenant_id) + + +@router.post("/users/accept-invite") +async def accept_invite( + invite_request: RequestInviteRequest, + user: User | None = Depends(current_user), +) -> None: + """ + Accept an invitation to join a tenant. + """ + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + + try: + accept_user_invite(user.email, invite_request.tenant_id) + except Exception as e: + logger.exception(f"Failed to accept invite: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to accept invitation") + + +@router.post("/users/deny-invite") +async def deny_invite( + invite_request: RequestInviteRequest, + user: User | None = Depends(current_user), +) -> None: + """ + Deny an invitation to join a tenant. + """ + if not user: + raise HTTPException(status_code=401, detail="Not authenticated") + + try: + deny_user_invite(user.email, invite_request.tenant_id) + except Exception as e: + logger.exception(f"Failed to deny invite: {str(e)}") + raise HTTPException(status_code=500, detail="Failed to deny invitation") diff --git a/backend/onyx/auth/essential_manager.py b/backend/onyx/auth/essential_manager.py new file mode 100644 index 000000000..187db0a98 --- /dev/null +++ b/backend/onyx/auth/essential_manager.py @@ -0,0 +1,52 @@ +from typing import Optional + +from fastapi import Depends +from fastapi import Request +from fastapi_users import BaseUserManager +from fastapi_users import UUIDIDMixin +from fastapi_users.db import SQLAlchemyUserDatabase + +from onyx.auth.essential_user import EssentialUser +from onyx.auth.essential_user import get_essential_user_db +from onyx.configs.app_configs import USER_MANAGER_SECRET + + +class EssentialUserManager(UUIDIDMixin, BaseUserManager[EssentialUser, str]): + """ + A simplified user manager that only handles essential authentication operations. + This is used during the initial tenant setup phase to avoid errors with missing columns. + """ + + reset_password_token_secret = USER_MANAGER_SECRET + verification_token_secret = USER_MANAGER_SECRET + + async def on_after_register( + self, user: EssentialUser, request: Optional[Request] = None + ) -> None: + """ + Simplified post-registration hook. + """ + + async def on_after_forgot_password( + self, user: EssentialUser, token: str, request: Optional[Request] = None + ) -> None: + """ + Simplified post-forgot-password hook. + """ + + async def on_after_request_verify( + self, user: EssentialUser, token: str, request: Optional[Request] = None + ) -> None: + """ + Simplified post-verification-request hook. + """ + + +async def get_essential_user_manager( + user_db: SQLAlchemyUserDatabase = Depends(get_essential_user_db), +) -> EssentialUserManager: + """ + Get a user manager that uses the essential user model. + This avoids errors with missing columns during the initial tenant setup. + """ + yield EssentialUserManager(user_db) diff --git a/backend/onyx/auth/essential_user.py b/backend/onyx/auth/essential_user.py new file mode 100644 index 000000000..a0bb31c56 --- /dev/null +++ b/backend/onyx/auth/essential_user.py @@ -0,0 +1,47 @@ +from collections.abc import AsyncGenerator +from typing import Optional + +from fastapi import Depends +from fastapi_users.db import SQLAlchemyBaseUserTableUUID +from fastapi_users.db import SQLAlchemyUserDatabase +from sqlalchemy import Boolean +from sqlalchemy import Column +from sqlalchemy import String +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.declarative import DeclarativeMeta +from sqlalchemy.orm import relationship + +from onyx.db.engine import get_async_session + +Base: DeclarativeMeta = declarative_base() + + +class EssentialUser(SQLAlchemyBaseUserTableUUID, Base): + """ + A simplified user model that only includes essential columns needed for authentication. + This is used during the initial tenant setup phase to avoid errors with missing columns + that would be added in later migrations. + """ + + __tablename__ = "user" + + email: str = Column(String(length=320), unique=True, index=True, nullable=False) + hashed_password: Optional[str] = Column(String(length=1024), nullable=True) + is_active: bool = Column(Boolean, default=True, nullable=False) + is_superuser: bool = Column(Boolean, default=False, nullable=False) + is_verified: bool = Column(Boolean, default=False, nullable=False) + + # Relationships are defined but not used in the essential auth flow + oauth_accounts = relationship("OAuthAccount", lazy="joined") + credentials = relationship("Credential", lazy="joined") + + +async def get_essential_user_db( + session: AsyncSession = Depends(get_async_session), +) -> AsyncGenerator[SQLAlchemyUserDatabase, None]: + """ + Get a user database that uses the essential user model. + This avoids errors with missing columns during the initial tenant setup. + """ + yield SQLAlchemyUserDatabase(session, EssentialUser) diff --git a/web/src/app/auth/signup/page.tsx b/web/src/app/auth/signup/page.tsx index 3d6dd7227..13b2953e3 100644 --- a/web/src/app/auth/signup/page.tsx +++ b/web/src/app/auth/signup/page.tsx @@ -95,7 +95,7 @@ const Page = async (props: { )} -