Migrate tenant upgrades to data plane (#3051)

* add provisioning on data plane

* functional but scrappy

* minor cleanup

* minor clean up

* k

* simplify

* update provisioning

* improve import logic

* ensure proper conditional

* minor pydantic update

* minor config update

* nit
This commit is contained in:
pablodanswer 2024-11-08 09:13:29 -08:00 committed by GitHub
parent 1fb4cdfcc3
commit f6d8f5ca89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 356 additions and 226 deletions

View File

@ -48,7 +48,6 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
@ -83,21 +82,19 @@ from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@ -190,20 +187,6 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@ -238,19 +221,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
try:
tenant_id = (
get_tenant_id_for_email(user_create.email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
tenant_id = await get_or_create_tenant_id(user_create.email)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@ -271,7 +242,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
@ -292,7 +263,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
raise exceptions.UserAlreadyExists()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def oauth_callback(
@ -308,19 +281,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
tenant_id = await get_or_create_tenant_id(account_email)
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
# Proceed with the tenant context
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@ -371,9 +337,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:

View File

@ -119,10 +119,10 @@ class DynamicTenantScheduler(PersistentScheduler):
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tenant tasks: {str(e)}")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict

View File

@ -277,12 +277,14 @@ def get_application() -> FastAPI:
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_reset_password_router(),

View File

@ -30,7 +30,6 @@ from danswer.auth.schemas import UserStatus
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import get_tenant_id_for_email
from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
@ -66,7 +65,8 @@ from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_comm
from ee.danswer.db.user_group import remove_curator_status__no_commit
from ee.danswer.server.tenants.billing import register_tenant_users
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from ee.danswer.server.tenants.user_mapping import remove_users_from_tenant
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()

View File

@ -359,7 +359,7 @@ def handle_new_chat_message(
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception(f"Error in chat message streaming: {e}")
logger.exception("Error in chat message streaming")
yield json.dumps({"error": str(e)})
finally:

View File

@ -279,7 +279,7 @@ def get_answer_with_quote(
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception(f"Error in search answer streaming: {e}")
logger.exception("Error in search answer streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(stream_generator(), media_type="application/json")

View File

@ -7,7 +7,6 @@ from fastapi import Response
from danswer.auth.users import auth_backend
from danswer.auth.users import current_admin_user
from danswer.auth.users import get_jwt_strategy
from danswer.auth.users import get_tenant_id_for_email
from danswer.auth.users import User
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.engine import get_session_with_tenant
@ -15,7 +14,6 @@ from danswer.db.notification import create_notification
from danswer.db.users import get_user_by_email
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
@ -23,15 +21,9 @@ from ee.danswer.server.tenants.access import control_plane_dep
from ee.danswer.server.tenants.billing import fetch_billing_information
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
from ee.danswer.server.tenants.models import BillingInformation
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.models import ImpersonateRequest
from ee.danswer.server.tenants.models import ProductGatingRequest
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import configure_default_api_keys
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import MULTI_TENANT
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
stripe.api_key = STRIPE_SECRET_KEY
@ -40,52 +32,6 @@ logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/create")
def create_tenant(
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
) -> dict[str, str]:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
tenant_id = create_tenant_request.tenant_id
email = create_tenant_request.initial_admin_email
token = None
if user_owns_a_tenant(email):
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
)
try:
if not ensure_schema_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
run_alembic_migrations(tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session, tenant_id)
configure_default_api_keys(db_session)
add_users_to_tenant([email], tenant_id)
return {
"status": "success",
"message": f"Tenant {tenant_id} created successfully",
}
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)

View File

@ -33,3 +33,8 @@ class CheckoutSessionCreationResponse(BaseModel):
class ImpersonateRequest(BaseModel):
email: str
class TenantCreationPayload(BaseModel):
tenant_id: str
email: str

View File

@ -1,145 +1,210 @@
import os
from types import SimpleNamespace
import asyncio
import logging
import uuid
from sqlalchemy import text
import aiohttp # Async HTTP client
from fastapi import HTTPException
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from danswer.db.engine import build_connection_string
from danswer.auth.users import exceptions
from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import UserTenantMapping
from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from danswer.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from danswer.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.utils.logger import setup_logger
from danswer.setup import setup_danswer
from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.danswer.server.tenants.access import generate_data_plane_token
from ee.danswer.server.tenants.models import TenantCreationPayload
from ee.danswer.server.tenants.schema_management import create_schema_if_not_exists
from ee.danswer.server.tenants.schema_management import drop_schema
from ee.danswer.server.tenants.schema_management import run_alembic_migrations
from ee.danswer.server.tenants.user_mapping import add_users_to_tenant
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from ee.danswer.server.tenants.user_mapping import user_owns_a_tenant
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
logger = logging.getLogger(__name__)
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
async def get_or_create_tenant_id(email: str) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
tenant_id = get_tenant_id_for_email(email)
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
alembic_cfg.set_main_option(
"script_location", os.path.join(root_dir, "alembic")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
# Ensure that logging isn't broken
alembic_cfg.attributes["configure_logger"] = False
return tenant_id
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")
async def create_tenant(email: str) -> str:
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
await notify_control_plane(tenant_id, email)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
# Run migrations programmatically
logger.info(
f"Alembic migrations completed successfully for schema: {schema_name}"
async def provision_tenant(tenant_id: str, email: str) -> None:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
if user_owns_a_tenant(email):
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
)
logger.info(f"Provisioning tenant: {tenant_id}")
token = None
try:
if not create_schema_if_not_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session, tenant_id)
configure_default_api_keys(db_session)
add_users_to_tenant([email], tenant_id)
except Exception as e:
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
raise
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(tenant_id: str, email: str) -> None:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
async with aiohttp.ClientSession() as session:
async with session.post(
f"{CONTROL_PLANE_API_BASE_URL}/tenants/create",
headers=headers,
json=payload.model_dump(),
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")
raise Exception(
f"Failed to create tenant on control plane: {error_text}"
)
async def rollback_tenant_provisioning(tenant_id: str) -> None:
# Logic to rollback tenant provisioning on data plane
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
except Exception as e:
logger.error(f"Failed to rollback tenant provisioning: {e}")
def configure_default_api_keys(db_session: Session) -> None:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider="OpenAI",
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
)
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider="Anthropic",
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20240620",
)
upsert_llm_provider(open_provider, db_session)
upsert_llm_provider(anthropic_provider, db_session)
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
def ensure_schema_exists(tenant_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():
result = db_session.execute(
text(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
),
{"schema_name": tenant_id},
)
schema_exists = result.scalar() is not None
if not schema_exists:
stmt = CreateSchema(tenant_id)
db_session.execute(stmt)
return True
return False
# For now, we're implementing a primitive mapping between users and tenants.
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
.first()
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
)
return result is not None
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}")
db_session.commit()
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20241022",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
)
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(emails),
UserTenantMapping.tenant_id == tenant_id,
)
.all()
)
for mapping in mappings_to_delete:
db_session.delete(mapping)
db_session.commit()
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.exception(
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
try:
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
except Exception as e:
logger.error(f"Failed to configure Cohere embedding provider: {e}")
else:
logger.error(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
)

View File

@ -0,0 +1,76 @@
import logging
import os
from types import SimpleNamespace
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
logger = logging.getLogger(__name__)
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
alembic_cfg.set_main_option(
"script_location", os.path.join(root_dir, "alembic")
)
# Ensure that logging isn't broken
alembic_cfg.attributes["configure_logger"] = False
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")
# Run migrations programmatically
logger.info(
f"Alembic migrations completed successfully for schema: {schema_name}"
)
except Exception as e:
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
raise
def create_schema_if_not_exists(tenant_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():
result = db_session.execute(
text(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
),
{"schema_name": tenant_id},
)
schema_exists = result.scalar() is not None
if not schema_exists:
stmt = CreateSchema(tenant_id)
db_session.execute(stmt)
return True
return False
def drop_schema(tenant_id: str) -> None:
if not tenant_id.isidentifier():
raise ValueError("Invalid tenant_id.")
with get_sqlalchemy_engine().connect() as connection:
connection.execute(
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
{"schema_name": tenant_id},
)

View File

@ -0,0 +1,70 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import UserTenantMapping
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = logging.getLogger(__name__)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
.first()
)
return result is not None
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(emails),
UserTenantMapping.tenant_id == tenant_id,
)
.all()
)
for mapping in mappings_to_delete:
db_session.delete(mapping)
db_session.commit()
except Exception as e:
logger.exception(
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()

View File

@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-beat
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
image: danswer/danswer-backend:v0.11.0-cloud.beta.8
imagePullPolicy: Always
command:
[
@ -31,7 +31,7 @@ spec:
name: danswer-secrets
key: redis_password
- name: DANSWER_VERSION
value: "v0.11.0-cloud.beta.4"
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap

View File

@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-heavy
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
image: danswer/danswer-backend:v0.11.0-cloud.beta.8
imagePullPolicy: Always
command:
[
@ -34,7 +34,7 @@ spec:
name: danswer-secrets
key: redis_password
- name: DANSWER_VERSION
value: "v0.11.0-cloud.beta.4"
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap

View File

@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-indexing
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
image: danswer/danswer-backend:v0.11.0-cloud.beta.8
imagePullPolicy: Always
command:
[
@ -34,7 +34,7 @@ spec:
name: danswer-secrets
key: redis_password
- name: DANSWER_VERSION
value: "v0.11.0-cloud.beta.4"
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap

View File

@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-light
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
image: danswer/danswer-backend:v0.11.0-cloud.beta.8
imagePullPolicy: Always
command:
[
@ -34,7 +34,7 @@ spec:
name: danswer-secrets
key: redis_password
- name: DANSWER_VERSION
value: "v0.11.0-cloud.beta.4"
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap

View File

@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-primary
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
image: danswer/danswer-backend:v0.11.0-cloud.beta.8
imagePullPolicy: Always
command:
[
@ -34,7 +34,7 @@ spec:
name: danswer-secrets
key: redis_password
- name: DANSWER_VERSION
value: "v0.11.0-cloud.beta.4"
value: "v0.11.0-cloud.beta.8"
envFrom:
- configMapRef:
name: env-configmap