mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-30 09:40:50 +02:00
k
This commit is contained in:
0
backend/ee/onyx/server/tenants/initial_models.py
Normal file
0
backend/ee/onyx/server/tenants/initial_models.py
Normal file
@ -54,20 +54,26 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_or_provision_tenant(
|
||||
email: str, referral_source: str | None = None, request: Request | None = None
|
||||
) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
) -> tuple[str, bool]:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists.
|
||||
|
||||
Returns:
|
||||
tuple: (tenant_id, is_newly_created) - The tenant ID and a boolean indicating if it was newly created
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
return POSTGRES_DEFAULT_SCHEMA, False
|
||||
|
||||
if referral_source and request:
|
||||
await submit_to_hubspot(email, referral_source, request)
|
||||
|
||||
is_newly_created = False
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
except exceptions.UserNotExists:
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
is_newly_created = True
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
@ -77,7 +83,7 @@ async def get_or_provision_tenant(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
return tenant_id, is_newly_created
|
||||
|
||||
|
||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
@ -115,36 +121,12 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
# Await the Alembic migrations up to the specified revision
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id, "465f78d9b7f9")
|
||||
|
||||
# Add users to tenant - this is needed for authentication
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
@ -349,3 +331,62 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
|
||||
async def complete_tenant_setup(tenant_id: str, email: str) -> None:
|
||||
"""Complete the tenant setup process after user creation.
|
||||
|
||||
This function handles the remaining steps of tenant provisioning after the initial
|
||||
schema creation and user authentication:
|
||||
1. Completes the remaining Alembic migrations
|
||||
2. Configures default API keys
|
||||
3. Sets up Onyx
|
||||
4. Creates milestone record
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
logger.debug(f"Completing setup for tenant {tenant_id}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Complete the remaining Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(f"Tenant setup completed for {tenant_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to complete tenant setup for {tenant_id}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to complete tenant setup: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
48
backend/ee/onyx/server/tenants/router.py
Normal file
48
backend/ee/onyx/server/tenants/router.py
Normal file
@ -0,0 +1,48 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.server.tenants.provisioning import complete_tenant_setup
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.db.models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/tenants", tags=["tenants"])
|
||||
|
||||
|
||||
class CompleteTenantSetupRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
@router.post("/complete-setup")
|
||||
async def api_complete_tenant_setup(
|
||||
request: CompleteTenantSetupRequest,
|
||||
user: User = Depends(current_user),
|
||||
) -> dict:
|
||||
"""Complete the tenant setup process for a user.
|
||||
|
||||
This endpoint is called from the frontend after user creation to complete
|
||||
the tenant setup process (migrations, seeding, etc.).
|
||||
"""
|
||||
if not user.is_admin and user.email != request.email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only complete setup for your own tenant"
|
||||
)
|
||||
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(request.email)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=404, detail="User or tenant not found")
|
||||
|
||||
try:
|
||||
await complete_tenant_setup(tenant_id, request.email)
|
||||
return {"status": "success"}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to complete tenant setup: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to complete tenant setup")
|
@ -14,8 +14,10 @@ from onyx.db.engine import get_sqlalchemy_engine
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_alembic_migrations(schema_name: str) -> None:
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
def run_alembic_migrations(schema_name: str, target_revision: str = "head") -> None:
|
||||
logger.info(
|
||||
f"Starting Alembic migrations for schema: {schema_name} to target: {target_revision}"
|
||||
)
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@ -37,11 +39,10 @@ def run_alembic_migrations(schema_name: str) -> None:
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
command.upgrade(alembic_cfg, target_revision)
|
||||
|
||||
# Run migrations programmatically
|
||||
logger.info(
|
||||
f"Alembic migrations completed successfully for schema: {schema_name}"
|
||||
f"Alembic migrations completed successfully for schema: {schema_name} to target: {target_revision}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -90,6 +90,7 @@ from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import MinimalUser
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
@ -186,6 +187,7 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
return None
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
return
|
||||
@ -215,6 +217,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
return None
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
@ -235,6 +238,13 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
class SimpleUserManager(UUIDIDMixin, BaseUserManager[MinimalUser, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
user_db: SQLAlchemyUserDatabase[MinimalUser, uuid.UUID]
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@ -247,8 +257,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)(user_email)
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[MinimalUser, uuid.UUID](
|
||||
db_session, MinimalUser, OAuthAccount
|
||||
)
|
||||
user = await tenant_user_db.get_by_email(user_email)
|
||||
else:
|
||||
@ -277,7 +287,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else None
|
||||
)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
@ -309,7 +319,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
simple_tennat_user_db = SQLAlchemyUserAdminDB[MinimalUser, uuid.UUID](
|
||||
db_session, MinimalUser, OAuthAccount
|
||||
)
|
||||
user = await SimpleUserManager(simple_tennat_user_db).create(
|
||||
user_create, safe=safe, request=request
|
||||
) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
@ -374,7 +389,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
getattr(request.state, "referral_source", None) if request else None
|
||||
)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
@ -511,7 +526,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
@ -563,7 +578,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
"Your admin has not enabled this feature.",
|
||||
)
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
@ -587,8 +602,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
) -> Optional[User]:
|
||||
email = credentials.username
|
||||
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
# Get tenant_id from mapping
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
@ -709,7 +724,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
async def write_token(self, user: User) -> str:
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
tenant_id, is_newly_created = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
|
@ -180,6 +180,7 @@ SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
|
||||
def is_valid_schema_name(name: str) -> bool:
|
||||
print(f"Checking if {name} is valid")
|
||||
return SCHEMA_NAME_REGEX.match(name) is not None
|
||||
|
||||
|
||||
|
@ -2318,3 +2318,16 @@ class TenantAnonymousUserPath(Base):
|
||||
anonymous_user_path: Mapped[str] = mapped_column(
|
||||
String, nullable=False, unique=True
|
||||
)
|
||||
|
||||
|
||||
class AdditionalBase(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class MinimalUser(SQLAlchemyBaseUserTableUUID, AdditionalBase):
|
||||
# oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
# "OAuthAccount", lazy="joined", cascade="all, delete-orphan"
|
||||
# )
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
|
@ -43,6 +43,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
|
||||
|
||||
|
||||
def get_or_generate_uuid() -> str:
|
||||
return "hi"
|
||||
# TODO: split out the whole "instance UUID" generation logic into a separate
|
||||
# utility function. Telemetry should not be aware at all of how the UUID is
|
||||
# generated/stored.
|
||||
|
@ -8,8 +8,6 @@ import {
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
||||
import Text from "@/components/ui/text";
|
||||
import Link from "next/link";
|
||||
import { SignInButton } from "../login/SignInButton";
|
||||
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||
import ReferralSourceSelector from "./ReferralSourceSelector";
|
||||
|
72
web/src/app/auth/waiting-on-setup/page.tsx
Normal file
72
web/src/app/auth/waiting-on-setup/page.tsx
Normal file
@ -0,0 +1,72 @@
|
||||
import {
|
||||
AuthTypeMetadata,
|
||||
getAuthTypeMetadataSS,
|
||||
getCurrentUserSS,
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
import { User } from "@/lib/types";
|
||||
import Text from "@/components/ui/text";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
import { completeTenantSetup } from "@/lib/tenant";
|
||||
|
||||
export default async function Page() {
|
||||
// catch cases where the backend is completely unreachable here
|
||||
// without try / catch, will just raise an exception and the page
|
||||
// will not render
|
||||
let authTypeMetadata: AuthTypeMetadata | null = null;
|
||||
let currentUser: User | null = null;
|
||||
try {
|
||||
[authTypeMetadata, currentUser] = await Promise.all([
|
||||
getAuthTypeMetadataSS(),
|
||||
getCurrentUserSS(),
|
||||
]);
|
||||
} catch (e) {
|
||||
console.log(`Some fetch failed for the waiting-on-setup page - ${e}`);
|
||||
}
|
||||
|
||||
if (!currentUser) {
|
||||
if (authTypeMetadata?.authType === "disabled") {
|
||||
return redirect("/chat");
|
||||
}
|
||||
return redirect("/auth/login");
|
||||
}
|
||||
|
||||
// If the user is already verified, redirect to chat
|
||||
if (!authTypeMetadata?.requiresVerification || currentUser.is_verified) {
|
||||
// Trigger the tenant setup completion in the background
|
||||
if (currentUser.email) {
|
||||
try {
|
||||
await completeTenantSetup(currentUser.email);
|
||||
} catch (e) {
|
||||
console.error("Failed to complete tenant setup:", e);
|
||||
}
|
||||
}
|
||||
return redirect("/chat");
|
||||
}
|
||||
|
||||
return (
|
||||
<main>
|
||||
<div className="absolute top-10x w-full">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
<div className="min-h-screen flex items-center justify-center py-12 px-4 sm:px-6 lg:px-8">
|
||||
<div>
|
||||
<Logo height={64} width={64} className="mx-auto w-fit" />
|
||||
|
||||
<div className="flex">
|
||||
<Text className="text-center font-medium text-lg mt-6 w-108">
|
||||
Hey <i>{currentUser.email}</i> - we're setting up your account.
|
||||
<br />
|
||||
This may take a few moments. You'll be redirected automatically
|
||||
when it's ready.
|
||||
<br />
|
||||
<br />
|
||||
If you're not redirected within a minute, please refresh the page.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
);
|
||||
}
|
19
web/src/lib/tenant.ts
Normal file
19
web/src/lib/tenant.ts
Normal file
@ -0,0 +1,19 @@
|
||||
/**
|
||||
* Completes the tenant setup process for a user
|
||||
* @param email The email of the user
|
||||
* @returns A promise that resolves when the setup is complete
|
||||
*/
|
||||
export async function completeTenantSetup(email: string): Promise<void> {
|
||||
const response = await fetch(`/api/tenants/complete-setup`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ email }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`Failed to complete tenant setup: ${errorText}`);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user