From 09e4e73ba6896b2a0288347071bc431a767e784b Mon Sep 17 00:00:00 2001 From: pablonyx Date: Wed, 5 Mar 2025 13:13:24 -0800 Subject: [PATCH] k --- .../ee/onyx/server/tenants/initial_models.py | 0 .../ee/onyx/server/tenants/provisioning.py | 103 ++++++++++++------ backend/ee/onyx/server/tenants/router.py | 48 ++++++++ .../onyx/server/tenants/schema_management.py | 11 +- backend/onyx/auth/users.py | 35 ++++-- backend/onyx/db/engine.py | 1 + backend/onyx/db/models.py | 13 +++ backend/onyx/utils/telemetry.py | 1 + web/src/app/auth/signup/page.tsx | 2 - web/src/app/auth/waiting-on-setup/page.tsx | 72 ++++++++++++ web/src/lib/tenant.ts | 19 ++++ 11 files changed, 257 insertions(+), 48 deletions(-) create mode 100644 backend/ee/onyx/server/tenants/initial_models.py create mode 100644 backend/ee/onyx/server/tenants/router.py create mode 100644 web/src/app/auth/waiting-on-setup/page.tsx create mode 100644 web/src/lib/tenant.ts diff --git a/backend/ee/onyx/server/tenants/initial_models.py b/backend/ee/onyx/server/tenants/initial_models.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/ee/onyx/server/tenants/provisioning.py b/backend/ee/onyx/server/tenants/provisioning.py index 5328277ff6..5b1fc99a5b 100644 --- a/backend/ee/onyx/server/tenants/provisioning.py +++ b/backend/ee/onyx/server/tenants/provisioning.py @@ -54,20 +54,26 @@ logger = logging.getLogger(__name__) async def get_or_provision_tenant( email: str, referral_source: str | None = None, request: Request | None = None -) -> str: - """Get existing tenant ID for an email or create a new tenant if none exists.""" +) -> tuple[str, bool]: + """Get existing tenant ID for an email or create a new tenant if none exists. + + Returns: + tuple: (tenant_id, is_newly_created) - The tenant ID and a boolean indicating if it was newly created + """ if not MULTI_TENANT: - return POSTGRES_DEFAULT_SCHEMA + return POSTGRES_DEFAULT_SCHEMA, False if referral_source and request: await submit_to_hubspot(email, referral_source, request) + is_newly_created = False try: tenant_id = get_tenant_id_for_email(email) except exceptions.UserNotExists: # If tenant does not exist and in Multi tenant mode, provision a new tenant try: tenant_id = await create_tenant(email, referral_source) + is_newly_created = True except Exception as e: logger.error(f"Tenant provisioning failed: {e}") raise HTTPException(status_code=500, detail="Failed to provision tenant.") @@ -77,7 +83,7 @@ async def get_or_provision_tenant( status_code=401, detail="User does not belong to an organization" ) - return tenant_id + return tenant_id, is_newly_created async def create_tenant(email: str, referral_source: str | None = None) -> str: @@ -115,36 +121,12 @@ async def provision_tenant(tenant_id: str, email: str) -> None: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - # Await the Alembic migrations - await asyncio.to_thread(run_alembic_migrations, tenant_id) - - with get_session_with_tenant(tenant_id=tenant_id) as db_session: - configure_default_api_keys(db_session) - - current_search_settings = ( - db_session.query(SearchSettings) - .filter_by(status=IndexModelStatus.FUTURE) - .first() - ) - cohere_enabled = ( - current_search_settings is not None - and current_search_settings.provider_type == EmbeddingProvider.COHERE - ) - setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled) + # Await the Alembic migrations up to the specified revision + await asyncio.to_thread(run_alembic_migrations, tenant_id, "465f78d9b7f9") + # Add users to tenant - this is needed for authentication add_users_to_tenant([email], tenant_id) - with get_session_with_tenant(tenant_id=tenant_id) as db_session: - create_milestone_and_report( - user=None, - distinct_id=tenant_id, - event_type=MilestoneRecordType.TENANT_CREATED, - properties={ - "email": email, - }, - db_session=db_session, - ) - except Exception as e: logger.exception(f"Failed to create tenant {tenant_id}") raise HTTPException( @@ -349,3 +331,62 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None: raise Exception( f"Failed to delete tenant on control plane: {error_text}" ) + + +async def complete_tenant_setup(tenant_id: str, email: str) -> None: + """Complete the tenant setup process after user creation. + + This function handles the remaining steps of tenant provisioning after the initial + schema creation and user authentication: + 1. Completes the remaining Alembic migrations + 2. Configures default API keys + 3. Sets up Onyx + 4. Creates milestone record + """ + if not MULTI_TENANT: + raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") + + logger.debug(f"Completing setup for tenant {tenant_id}") + token = None + + try: + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + + # Complete the remaining Alembic migrations + await asyncio.to_thread(run_alembic_migrations, tenant_id) + + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + configure_default_api_keys(db_session) + + current_search_settings = ( + db_session.query(SearchSettings) + .filter_by(status=IndexModelStatus.FUTURE) + .first() + ) + cohere_enabled = ( + current_search_settings is not None + and current_search_settings.provider_type == EmbeddingProvider.COHERE + ) + setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled) + + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + create_milestone_and_report( + user=None, + distinct_id=tenant_id, + event_type=MilestoneRecordType.TENANT_CREATED, + properties={ + "email": email, + }, + db_session=db_session, + ) + + logger.info(f"Tenant setup completed for {tenant_id}") + + except Exception as e: + logger.exception(f"Failed to complete tenant setup for {tenant_id}") + raise HTTPException( + status_code=500, detail=f"Failed to complete tenant setup: {str(e)}" + ) + finally: + if token is not None: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) diff --git a/backend/ee/onyx/server/tenants/router.py b/backend/ee/onyx/server/tenants/router.py new file mode 100644 index 0000000000..81ebb7cc41 --- /dev/null +++ b/backend/ee/onyx/server/tenants/router.py @@ -0,0 +1,48 @@ +import logging + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from pydantic import BaseModel + +from ee.onyx.server.tenants.provisioning import complete_tenant_setup +from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email +from onyx.auth.users import current_user +from onyx.auth.users import exceptions +from onyx.db.models import User + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/tenants", tags=["tenants"]) + + +class CompleteTenantSetupRequest(BaseModel): + email: str + + +@router.post("/complete-setup") +async def api_complete_tenant_setup( + request: CompleteTenantSetupRequest, + user: User = Depends(current_user), +) -> dict: + """Complete the tenant setup process for a user. + + This endpoint is called from the frontend after user creation to complete + the tenant setup process (migrations, seeding, etc.). + """ + if not user.is_admin and user.email != request.email: + raise HTTPException( + status_code=403, detail="You can only complete setup for your own tenant" + ) + + try: + tenant_id = get_tenant_id_for_email(request.email) + except exceptions.UserNotExists: + raise HTTPException(status_code=404, detail="User or tenant not found") + + try: + await complete_tenant_setup(tenant_id, request.email) + return {"status": "success"} + except Exception as e: + logger.error(f"Failed to complete tenant setup: {e}") + raise HTTPException(status_code=500, detail="Failed to complete tenant setup") diff --git a/backend/ee/onyx/server/tenants/schema_management.py b/backend/ee/onyx/server/tenants/schema_management.py index 80712556ea..6a58b9d873 100644 --- a/backend/ee/onyx/server/tenants/schema_management.py +++ b/backend/ee/onyx/server/tenants/schema_management.py @@ -14,8 +14,10 @@ from onyx.db.engine import get_sqlalchemy_engine logger = logging.getLogger(__name__) -def run_alembic_migrations(schema_name: str) -> None: - logger.info(f"Starting Alembic migrations for schema: {schema_name}") +def run_alembic_migrations(schema_name: str, target_revision: str = "head") -> None: + logger.info( + f"Starting Alembic migrations for schema: {schema_name} to target: {target_revision}" + ) try: current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -37,11 +39,10 @@ def run_alembic_migrations(schema_name: str) -> None: alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore # Run migrations programmatically - command.upgrade(alembic_cfg, "head") + command.upgrade(alembic_cfg, target_revision) - # Run migrations programmatically logger.info( - f"Alembic migrations completed successfully for schema: {schema_name}" + f"Alembic migrations completed successfully for schema: {schema_name} to target: {target_revision}" ) except Exception as e: diff --git a/backend/onyx/auth/users.py b/backend/onyx/auth/users.py index df72c442df..b02f0905b6 100644 --- a/backend/onyx/auth/users.py +++ b/backend/onyx/auth/users.py @@ -90,6 +90,7 @@ from onyx.db.engine import get_async_session from onyx.db.engine import get_async_session_with_tenant from onyx.db.engine import get_session_with_tenant from onyx.db.models import AccessToken +from onyx.db.models import MinimalUser from onyx.db.models import OAuthAccount from onyx.db.models import User from onyx.db.users import get_user_by_email @@ -186,6 +187,7 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool: def verify_email_is_invited(email: str) -> None: + return None whitelist = get_invited_users() if not whitelist: return @@ -215,6 +217,7 @@ def verify_email_is_invited(email: str) -> None: def verify_email_in_whitelist(email: str, tenant_id: str) -> None: + return None with get_session_with_tenant(tenant_id=tenant_id) as db_session: if not get_user_by_email(email, db_session): verify_email_is_invited(email) @@ -235,6 +238,13 @@ def verify_email_domain(email: str) -> None: ) +class SimpleUserManager(UUIDIDMixin, BaseUserManager[MinimalUser, uuid.UUID]): + reset_password_token_secret = USER_AUTH_SECRET + verification_token_secret = USER_AUTH_SECRET + verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS + user_db: SQLAlchemyUserDatabase[MinimalUser, uuid.UUID] + + class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = USER_AUTH_SECRET verification_token_secret = USER_AUTH_SECRET @@ -247,8 +257,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): )(user_email) async with get_async_session_with_tenant(tenant_id) as db_session: if MULTI_TENANT: - tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID]( - db_session, User, OAuthAccount + tenant_user_db = SQLAlchemyUserAdminDB[MinimalUser, uuid.UUID]( + db_session, MinimalUser, OAuthAccount ) user = await tenant_user_db.get_by_email(user_email) else: @@ -277,7 +287,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): else None ) - tenant_id = await fetch_ee_implementation_or_noop( + tenant_id, is_newly_created = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, @@ -309,7 +319,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): else: user_create.role = UserRole.BASIC try: - user = await super().create(user_create, safe=safe, request=request) # type: ignore + simple_tennat_user_db = SQLAlchemyUserAdminDB[MinimalUser, uuid.UUID]( + db_session, MinimalUser, OAuthAccount + ) + user = await SimpleUserManager(simple_tennat_user_db).create( + user_create, safe=safe, request=request + ) # type: ignore except exceptions.UserAlreadyExists: user = await self.get_by_email(user_create.email) # Handle case where user has used product outside of web and is now creating an account through web @@ -374,7 +389,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): getattr(request.state, "referral_source", None) if request else None ) - tenant_id = await fetch_ee_implementation_or_noop( + tenant_id, is_newly_created = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, @@ -511,7 +526,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): async def on_after_register( self, user: User, request: Optional[Request] = None ) -> None: - tenant_id = await fetch_ee_implementation_or_noop( + tenant_id, is_newly_created = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, @@ -563,7 +578,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): status.HTTP_500_INTERNAL_SERVER_ERROR, "Your admin has not enabled this feature.", ) - tenant_id = await fetch_ee_implementation_or_noop( + tenant_id, is_newly_created = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, @@ -587,8 +602,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): ) -> Optional[User]: email = credentials.username - # Get tenant_id from mapping table - tenant_id = await fetch_ee_implementation_or_noop( + # Get tenant_id from mapping + tenant_id, is_newly_created = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, @@ -709,7 +724,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]): async def write_token(self, user: User) -> str: redis = await get_async_redis_connection() - tenant_id = await fetch_ee_implementation_or_noop( + tenant_id, is_newly_created = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, diff --git a/backend/onyx/db/engine.py b/backend/onyx/db/engine.py index 840b5d93d5..1b32b4b48a 100644 --- a/backend/onyx/db/engine.py +++ b/backend/onyx/db/engine.py @@ -180,6 +180,7 @@ SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$") def is_valid_schema_name(name: str) -> bool: + print(f"Checking if {name} is valid") return SCHEMA_NAME_REGEX.match(name) is not None diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 484d246209..c370667b4a 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -2318,3 +2318,16 @@ class TenantAnonymousUserPath(Base): anonymous_user_path: Mapped[str] = mapped_column( String, nullable=False, unique=True ) + + +class AdditionalBase(DeclarativeBase): + __abstract__ = True + + +class MinimalUser(SQLAlchemyBaseUserTableUUID, AdditionalBase): + # oauth_accounts: Mapped[list[OAuthAccount]] = relationship( + # "OAuthAccount", lazy="joined", cascade="all, delete-orphan" + # ) + role: Mapped[UserRole] = mapped_column( + Enum(UserRole, native_enum=False, default=UserRole.BASIC) + ) diff --git a/backend/onyx/utils/telemetry.py b/backend/onyx/utils/telemetry.py index af1eb6c0c8..66fc84a78a 100644 --- a/backend/onyx/utils/telemetry.py +++ b/backend/onyx/utils/telemetry.py @@ -43,6 +43,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str: def get_or_generate_uuid() -> str: + return "hi" # TODO: split out the whole "instance UUID" generation logic into a separate # utility function. Telemetry should not be aware at all of how the UUID is # generated/stored. diff --git a/web/src/app/auth/signup/page.tsx b/web/src/app/auth/signup/page.tsx index 3d6dd7227b..4235f5c52d 100644 --- a/web/src/app/auth/signup/page.tsx +++ b/web/src/app/auth/signup/page.tsx @@ -8,8 +8,6 @@ import { } from "@/lib/userSS"; import { redirect } from "next/navigation"; import { EmailPasswordForm } from "../login/EmailPasswordForm"; -import Text from "@/components/ui/text"; -import Link from "next/link"; import { SignInButton } from "../login/SignInButton"; import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; import ReferralSourceSelector from "./ReferralSourceSelector"; diff --git a/web/src/app/auth/waiting-on-setup/page.tsx b/web/src/app/auth/waiting-on-setup/page.tsx new file mode 100644 index 0000000000..bae959196e --- /dev/null +++ b/web/src/app/auth/waiting-on-setup/page.tsx @@ -0,0 +1,72 @@ +import { + AuthTypeMetadata, + getAuthTypeMetadataSS, + getCurrentUserSS, +} from "@/lib/userSS"; +import { redirect } from "next/navigation"; +import { HealthCheckBanner } from "@/components/health/healthcheck"; +import { User } from "@/lib/types"; +import Text from "@/components/ui/text"; +import { Logo } from "@/components/logo/Logo"; +import { completeTenantSetup } from "@/lib/tenant"; + +export default async function Page() { + // catch cases where the backend is completely unreachable here + // without try / catch, will just raise an exception and the page + // will not render + let authTypeMetadata: AuthTypeMetadata | null = null; + let currentUser: User | null = null; + try { + [authTypeMetadata, currentUser] = await Promise.all([ + getAuthTypeMetadataSS(), + getCurrentUserSS(), + ]); + } catch (e) { + console.log(`Some fetch failed for the waiting-on-setup page - ${e}`); + } + + if (!currentUser) { + if (authTypeMetadata?.authType === "disabled") { + return redirect("/chat"); + } + return redirect("/auth/login"); + } + + // If the user is already verified, redirect to chat + if (!authTypeMetadata?.requiresVerification || currentUser.is_verified) { + // Trigger the tenant setup completion in the background + if (currentUser.email) { + try { + await completeTenantSetup(currentUser.email); + } catch (e) { + console.error("Failed to complete tenant setup:", e); + } + } + return redirect("/chat"); + } + + return ( +
+
+ +
+
+
+ + +
+ + Hey {currentUser.email} - we're setting up your account. +
+ This may take a few moments. You'll be redirected automatically + when it's ready. +
+
+ If you're not redirected within a minute, please refresh the page. +
+
+
+
+
+ ); +} diff --git a/web/src/lib/tenant.ts b/web/src/lib/tenant.ts new file mode 100644 index 0000000000..5553c4da80 --- /dev/null +++ b/web/src/lib/tenant.ts @@ -0,0 +1,19 @@ +/** + * Completes the tenant setup process for a user + * @param email The email of the user + * @returns A promise that resolves when the setup is complete + */ +export async function completeTenantSetup(email: string): Promise { + const response = await fetch(`/api/tenants/complete-setup`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ email }), + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`Failed to complete tenant setup: ${errorText}`); + } +}