From e10cc8ccdb3d3eb94e6cbcfccb924e696d20795a Mon Sep 17 00:00:00 2001 From: pablonyx Date: Thu, 27 Feb 2025 10:35:38 -0800 Subject: [PATCH] Multi tenant user google auth fix (#4145) --- backend/ee/onyx/server/tenants/billing.py | 20 +++++-- backend/onyx/auth/users.py | 63 ++++++++++++++--------- 2 files changed, 56 insertions(+), 27 deletions(-) diff --git a/backend/ee/onyx/server/tenants/billing.py b/backend/ee/onyx/server/tenants/billing.py index 98de75a9a..7c5ae8534 100644 --- a/backend/ee/onyx/server/tenants/billing.py +++ b/backend/ee/onyx/server/tenants/billing.py @@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY from ee.onyx.server.tenants.access import generate_data_plane_token from ee.onyx.server.tenants.models import BillingInformation +from ee.onyx.server.tenants.models import SubscriptionStatusResponse from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL from onyx.utils.logger import setup_logger @@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict: return response.json() -def fetch_billing_information(tenant_id: str) -> BillingInformation: +def fetch_billing_information( + tenant_id: str, +) -> BillingInformation | SubscriptionStatusResponse: logger.info("Fetching billing information") token = generate_data_plane_token() headers = { @@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation: params = {"tenant_id": tenant_id} response = requests.get(url, headers=headers, params=params) response.raise_for_status() - billing_info = BillingInformation(**response.json()) - return billing_info + + response_data = response.json() + + # Check if the response indicates no subscription + if ( + isinstance(response_data, dict) + and "subscribed" in response_data + and not response_data["subscribed"] + ): + return SubscriptionStatusResponse(**response_data) + + # Otherwise, parse as BillingInformation + return BillingInformation(**response_data) def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription: diff --git a/backend/onyx/auth/users.py b/backend/onyx/auth/users.py index ec163f167..54dec426e 100644 --- a/backend/onyx/auth/users.py +++ b/backend/onyx/auth/users.py @@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): "refresh_token": refresh_token, } - user: User + user: User | None = None try: # Attempt to get user by OAuth account @@ -420,15 +420,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): except exceptions.UserNotExists: try: # Attempt to get user by email - user = cast(User, await self.user_db.get_by_email(account_email)) + user = await self.user_db.get_by_email(account_email) if not associate_by_email: raise exceptions.UserAlreadyExists() - user = await self.user_db.add_oauth_account( - user, oauth_account_dict - ) + # Make sure user is not None before adding OAuth account + if user is not None: + user = await self.user_db.add_oauth_account( + user, oauth_account_dict + ) + else: + # This shouldn't happen since get_by_email would raise UserNotExists + # but adding as a safeguard + raise exceptions.UserNotExists() - # If user not found by OAuth account or email, create a new user except exceptions.UserNotExists: password = self.password_helper.generate() user_dict = { @@ -439,26 +444,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): user = await self.user_db.create(user_dict) - # Explicitly set the Postgres schema for this session to ensure - # OAuth account creation happens in the correct tenant schema - - # Add OAuth account - await self.user_db.add_oauth_account(user, oauth_account_dict) - await self.on_after_register(user, request) + # Add OAuth account only if user creation was successful + if user is not None: + await self.user_db.add_oauth_account(user, oauth_account_dict) + await self.on_after_register(user, request) + else: + raise HTTPException( + status_code=500, detail="Failed to create user account" + ) else: - for existing_oauth_account in user.oauth_accounts: - if ( - existing_oauth_account.account_id == account_id - and existing_oauth_account.oauth_name == oauth_name - ): - user = await self.user_db.update_oauth_account( - user, - # NOTE: OAuthAccount DOES implement the OAuthAccountProtocol - # but the type checker doesn't know that :( - existing_oauth_account, # type: ignore - oauth_account_dict, - ) + # User exists, update OAuth account if needed + if user is not None: # Add explicit check + for existing_oauth_account in user.oauth_accounts: + if ( + existing_oauth_account.account_id == account_id + and existing_oauth_account.oauth_name == oauth_name + ): + user = await self.user_db.update_oauth_account( + user, + # NOTE: OAuthAccount DOES implement the OAuthAccountProtocol + # but the type checker doesn't know that :( + existing_oauth_account, # type: ignore + oauth_account_dict, + ) + + # Ensure user is not None before proceeding + if user is None: + raise HTTPException( + status_code=500, detail="Failed to authenticate or create user" + ) # NOTE: Most IdPs have very short expiry times, and we don't want to force the user to # re-authenticate that frequently, so by default this is disabled