mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-05 20:49:48 +02:00
Multi tenant user google auth fix (#4145)
This commit is contained in:
parent
7018bc974b
commit
e10cc8ccdb
@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
|||||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||||
from ee.onyx.server.tenants.models import BillingInformation
|
from ee.onyx.server.tenants.models import BillingInformation
|
||||||
|
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
|||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
def fetch_billing_information(
|
||||||
|
tenant_id: str,
|
||||||
|
) -> BillingInformation | SubscriptionStatusResponse:
|
||||||
logger.info("Fetching billing information")
|
logger.info("Fetching billing information")
|
||||||
token = generate_data_plane_token()
|
token = generate_data_plane_token()
|
||||||
headers = {
|
headers = {
|
||||||
@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
|||||||
params = {"tenant_id": tenant_id}
|
params = {"tenant_id": tenant_id}
|
||||||
response = requests.get(url, headers=headers, params=params)
|
response = requests.get(url, headers=headers, params=params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
billing_info = BillingInformation(**response.json())
|
|
||||||
return billing_info
|
response_data = response.json()
|
||||||
|
|
||||||
|
# Check if the response indicates no subscription
|
||||||
|
if (
|
||||||
|
isinstance(response_data, dict)
|
||||||
|
and "subscribed" in response_data
|
||||||
|
and not response_data["subscribed"]
|
||||||
|
):
|
||||||
|
return SubscriptionStatusResponse(**response_data)
|
||||||
|
|
||||||
|
# Otherwise, parse as BillingInformation
|
||||||
|
return BillingInformation(**response_data)
|
||||||
|
|
||||||
|
|
||||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||||
|
@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
"refresh_token": refresh_token,
|
"refresh_token": refresh_token,
|
||||||
}
|
}
|
||||||
|
|
||||||
user: User
|
user: User | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Attempt to get user by OAuth account
|
# Attempt to get user by OAuth account
|
||||||
@ -420,15 +420,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
except exceptions.UserNotExists:
|
except exceptions.UserNotExists:
|
||||||
try:
|
try:
|
||||||
# Attempt to get user by email
|
# Attempt to get user by email
|
||||||
user = cast(User, await self.user_db.get_by_email(account_email))
|
user = await self.user_db.get_by_email(account_email)
|
||||||
if not associate_by_email:
|
if not associate_by_email:
|
||||||
raise exceptions.UserAlreadyExists()
|
raise exceptions.UserAlreadyExists()
|
||||||
|
|
||||||
user = await self.user_db.add_oauth_account(
|
# Make sure user is not None before adding OAuth account
|
||||||
user, oauth_account_dict
|
if user is not None:
|
||||||
)
|
user = await self.user_db.add_oauth_account(
|
||||||
|
user, oauth_account_dict
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# This shouldn't happen since get_by_email would raise UserNotExists
|
||||||
|
# but adding as a safeguard
|
||||||
|
raise exceptions.UserNotExists()
|
||||||
|
|
||||||
# If user not found by OAuth account or email, create a new user
|
|
||||||
except exceptions.UserNotExists:
|
except exceptions.UserNotExists:
|
||||||
password = self.password_helper.generate()
|
password = self.password_helper.generate()
|
||||||
user_dict = {
|
user_dict = {
|
||||||
@ -439,26 +444,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
|
|
||||||
user = await self.user_db.create(user_dict)
|
user = await self.user_db.create(user_dict)
|
||||||
|
|
||||||
# Explicitly set the Postgres schema for this session to ensure
|
# Add OAuth account only if user creation was successful
|
||||||
# OAuth account creation happens in the correct tenant schema
|
if user is not None:
|
||||||
|
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||||
# Add OAuth account
|
await self.on_after_register(user, request)
|
||||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
else:
|
||||||
await self.on_after_register(user, request)
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to create user account"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for existing_oauth_account in user.oauth_accounts:
|
# User exists, update OAuth account if needed
|
||||||
if (
|
if user is not None: # Add explicit check
|
||||||
existing_oauth_account.account_id == account_id
|
for existing_oauth_account in user.oauth_accounts:
|
||||||
and existing_oauth_account.oauth_name == oauth_name
|
if (
|
||||||
):
|
existing_oauth_account.account_id == account_id
|
||||||
user = await self.user_db.update_oauth_account(
|
and existing_oauth_account.oauth_name == oauth_name
|
||||||
user,
|
):
|
||||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
user = await self.user_db.update_oauth_account(
|
||||||
# but the type checker doesn't know that :(
|
user,
|
||||||
existing_oauth_account, # type: ignore
|
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||||
oauth_account_dict,
|
# but the type checker doesn't know that :(
|
||||||
)
|
existing_oauth_account, # type: ignore
|
||||||
|
oauth_account_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure user is not None before proceeding
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to authenticate or create user"
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||||
# re-authenticate that frequently, so by default this is disabled
|
# re-authenticate that frequently, so by default this is disabled
|
||||||
|
Loading…
x
Reference in New Issue
Block a user