mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Migrate tenant upgrades to data plane (#3051)
* add provisioning on data plane * functional but scrappy * minor cleanup * minor clean up * k * simplify * update provisioning * improve import logic * ensure proper conditional * minor pydantic update * minor config update * nit
This commit is contained in:
parent
1fb4cdfcc3
commit
f6d8f5ca89
@ -48,7 +48,6 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import attributes
|
||||
from sqlalchemy.orm import Session
|
||||
@ -83,21 +82,19 @@ from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id
|
||||
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@ -190,20 +187,6 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
@ -238,19 +221,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
tenant_id = await get_or_create_tenant_id(user_create.email)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@ -271,7 +242,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
user = None
|
||||
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
@ -292,7 +263,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
async def oauth_callback(
|
||||
@ -308,19 +281,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> models.UOAP:
|
||||
# Get tenant_id from mapping table
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(account_email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
tenant_id = await get_or_create_tenant_id(account_email)
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
# Proceed with the tenant context
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@ -371,9 +337,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
|
@ -119,10 +119,10 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tenant tasks: {str(e)}")
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
|
@ -277,12 +277,14 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_reset_password_router(),
|
||||
|
@ -30,7 +30,6 @@ from danswer.auth.schemas import UserStatus
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import get_tenant_id_for_email
|
||||
from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
@ -66,7 +65,8 @@ from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_comm
|
||||
from ee.danswer.db.user_group import remove_curator_status__no_commit
|
||||
from ee.danswer.server.tenants.billing import register_tenant_users
|
||||
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
|
||||
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.danswer.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
@ -359,7 +359,7 @@ def handle_new_chat_message(
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in chat message streaming: {e}")
|
||||
logger.exception("Error in chat message streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
finally:
|
||||
|
@ -279,7 +279,7 @@ def get_answer_with_quote(
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in search answer streaming: {e}")
|
||||
logger.exception("Error in search answer streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="application/json")
|
||||
|
@ -7,7 +7,6 @@ from fastapi import Response
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import get_jwt_strategy
|
||||
from danswer.auth.users import get_tenant_id_for_email
|
||||
from danswer.auth.users import User
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
@ -15,7 +14,6 @@ from danswer.db.notification import create_notification
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.auth.users import current_cloud_superuser
|
||||
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
|
||||
@ -23,15 +21,9 @@ from ee.danswer.server.tenants.access import control_plane_dep
|
||||
from ee.danswer.server.tenants.billing import fetch_billing_information
|
||||
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.danswer.server.tenants.models import BillingInformation
|
||||
from ee.danswer.server.tenants.models import CreateTenantRequest
|
||||
from ee.danswer.server.tenants.models import ImpersonateRequest
|
||||
from ee.danswer.server.tenants.models import ProductGatingRequest
|
||||
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.provisioning import configure_default_api_keys
|
||||
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
|
||||
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
|
||||
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
@ -40,52 +32,6 @@ logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
def create_tenant(
|
||||
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
|
||||
) -> dict[str, str]:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
tenant_id = create_tenant_request.tenant_id
|
||||
email = create_tenant_request.initial_admin_email
|
||||
token = None
|
||||
|
||||
if user_owns_a_tenant(email):
|
||||
raise HTTPException(
|
||||
status_code=409, detail="User already belongs to an organization"
|
||||
)
|
||||
|
||||
try:
|
||||
if not ensure_schema_exists(tenant_id):
|
||||
logger.info(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
run_alembic_migrations(tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session, tenant_id)
|
||||
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Tenant {tenant_id} created successfully",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
|
@ -33,3 +33,8 @@ class CheckoutSessionCreationResponse(BaseModel):
|
||||
|
||||
class ImpersonateRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
|
@ -1,145 +1,210 @@
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import text
|
||||
import aiohttp # Async HTTP client
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.auth.users import exceptions
|
||||
from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
|
||||
from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from danswer.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
|
||||
from danswer.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.setup import setup_danswer
|
||||
from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.danswer.server.tenants.access import generate_data_plane_token
|
||||
from ee.danswer.server.tenants.models import TenantCreationPayload
|
||||
from ee.danswer.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.danswer.server.tenants.schema_management import drop_schema
|
||||
from ee.danswer.server.tenants.schema_management import run_alembic_migrations
|
||||
from ee.danswer.server.tenants.user_mapping import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.danswer.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_alembic_migrations(schema_name: str) -> None:
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
async def get_or_create_tenant_id(email: str) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
|
||||
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
except exceptions.UserNotExists:
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
# Configure Alembic
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
|
||||
alembic_cfg.set_main_option(
|
||||
"script_location", os.path.join(root_dir, "alembic")
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
# Ensure that logging isn't broken
|
||||
alembic_cfg.attributes["configure_logger"] = False
|
||||
return tenant_id
|
||||
|
||||
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
async def create_tenant(email: str) -> str:
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
await notify_control_plane(tenant_id, email)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
return tenant_id
|
||||
|
||||
# Run migrations programmatically
|
||||
logger.info(
|
||||
f"Alembic migrations completed successfully for schema: {schema_name}"
|
||||
|
||||
async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
if user_owns_a_tenant(email):
|
||||
raise HTTPException(
|
||||
status_code=409, detail="User already belongs to an organization"
|
||||
)
|
||||
|
||||
logger.info(f"Provisioning tenant: {tenant_id}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
if not create_schema_if_not_exists(tenant_id):
|
||||
logger.info(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session, tenant_id)
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
|
||||
raise
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(tenant_id: str, email: str) -> None:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenants/create",
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
raise Exception(
|
||||
f"Failed to create tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
|
||||
async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
# Logic to rollback tenant provisioning on data plane
|
||||
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rollback tenant provisioning: {e}")
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider="OpenAI",
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
)
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider="Anthropic",
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-5-sonnet-20240620",
|
||||
)
|
||||
upsert_llm_provider(open_provider, db_session)
|
||||
upsert_llm_provider(anthropic_provider, db_session)
|
||||
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
api_key=COHERE_DEFAULT_API_KEY,
|
||||
)
|
||||
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
|
||||
|
||||
|
||||
def ensure_schema_exists(tenant_id: str) -> bool:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with db_session.begin():
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
|
||||
),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
schema_exists = result.scalar() is not None
|
||||
if not schema_exists:
|
||||
stmt = CreateSchema(tenant_id)
|
||||
db_session.execute(stmt)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# For now, we're implementing a primitive mapping between users and tenants.
|
||||
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
.first()
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
)
|
||||
return result is not None
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}")
|
||||
db_session.commit()
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-5-sonnet-20241022",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
)
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for mapping in mappings_to_delete:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
api_key=COHERE_DEFAULT_API_KEY,
|
||||
)
|
||||
try:
|
||||
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Cohere embedding provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
|
||||
)
|
||||
|
76
backend/ee/danswer/server/tenants/schema_management.py
Normal file
76
backend/ee/danswer/server/tenants/schema_management.py
Normal file
@ -0,0 +1,76 @@
|
||||
import logging
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_alembic_migrations(schema_name: str) -> None:
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
|
||||
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
|
||||
|
||||
# Configure Alembic
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
|
||||
alembic_cfg.set_main_option(
|
||||
"script_location", os.path.join(root_dir, "alembic")
|
||||
)
|
||||
|
||||
# Ensure that logging isn't broken
|
||||
alembic_cfg.attributes["configure_logger"] = False
|
||||
|
||||
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
# Run migrations programmatically
|
||||
logger.info(
|
||||
f"Alembic migrations completed successfully for schema: {schema_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_schema_if_not_exists(tenant_id: str) -> bool:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with db_session.begin():
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
|
||||
),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
schema_exists = result.scalar() is not None
|
||||
if not schema_exists:
|
||||
stmt = CreateSchema(tenant_id)
|
||||
db_session.execute(stmt)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def drop_schema(tenant_id: str) -> None:
|
||||
if not tenant_id.isidentifier():
|
||||
raise ValueError("Invalid tenant_id.")
|
||||
with get_sqlalchemy_engine().connect() as connection:
|
||||
connection.execute(
|
||||
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
70
backend/ee/danswer/server/tenants/user_mapping.py
Normal file
70
backend/ee/danswer/server/tenants/user_mapping.py
Normal file
@ -0,0 +1,70 @@
|
||||
import logging
|
||||
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
.first()
|
||||
)
|
||||
return result is not None
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for mapping in mappings_to_delete:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-beat
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.8
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
@ -31,7 +31,7 @@ spec:
|
||||
name: danswer-secrets
|
||||
key: redis_password
|
||||
- name: DANSWER_VERSION
|
||||
value: "v0.11.0-cloud.beta.4"
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
|
@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-heavy
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.8
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
@ -34,7 +34,7 @@ spec:
|
||||
name: danswer-secrets
|
||||
key: redis_password
|
||||
- name: DANSWER_VERSION
|
||||
value: "v0.11.0-cloud.beta.4"
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
|
@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-indexing
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.8
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
@ -34,7 +34,7 @@ spec:
|
||||
name: danswer-secrets
|
||||
key: redis_password
|
||||
- name: DANSWER_VERSION
|
||||
value: "v0.11.0-cloud.beta.4"
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
|
@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-light
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.8
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
@ -34,7 +34,7 @@ spec:
|
||||
name: danswer-secrets
|
||||
key: redis_password
|
||||
- name: DANSWER_VERSION
|
||||
value: "v0.11.0-cloud.beta.4"
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
|
@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-primary
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.8
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
@ -34,7 +34,7 @@ spec:
|
||||
name: danswer-secrets
|
||||
key: redis_password
|
||||
- name: DANSWER_VERSION
|
||||
value: "v0.11.0-cloud.beta.4"
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
|
Loading…
x
Reference in New Issue
Block a user