diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 0cb4ae232..6c162dfb8 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -48,7 +48,6 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback from httpx_oauth.oauth2 import BaseOAuth2 from httpx_oauth.oauth2 import OAuth2Token from pydantic import BaseModel -from sqlalchemy import select from sqlalchemy import text from sqlalchemy.orm import attributes from sqlalchemy.orm import Session @@ -83,21 +82,19 @@ from danswer.db.auth import SQLAlchemyUserAdminDB from danswer.db.engine import get_async_session_with_tenant from danswer.db.engine import get_session from danswer.db.engine import get_session_with_tenant -from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import AccessToken from danswer.db.models import OAuthAccount from danswer.db.models import User -from danswer.db.models import UserTenantMapping from danswer.db.users import get_user_by_email from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation +from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id +from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email from shared_configs.configs import MULTI_TENANT -from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR - logger = setup_logger() @@ -190,20 +187,6 @@ def verify_email_domain(email: str) -> None: ) -def get_tenant_id_for_email(email: str) -> str: - if not MULTI_TENANT: - return POSTGRES_DEFAULT_SCHEMA - # Implement logic to get tenant_id from the mapping table - with Session(get_sqlalchemy_engine()) as db_session: - result = db_session.execute( - select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email) - ) - tenant_id = result.scalar_one_or_none() - if tenant_id is None: - raise exceptions.UserNotExists() - return tenant_id - - def send_user_verification_email( user_email: str, token: str, @@ -238,19 +221,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): safe: bool = False, request: Optional[Request] = None, ) -> User: - try: - tenant_id = ( - get_tenant_id_for_email(user_create.email) - if MULTI_TENANT - else POSTGRES_DEFAULT_SCHEMA - ) - except exceptions.UserNotExists: - raise HTTPException(status_code=401, detail="User not found") - - if not tenant_id: - raise HTTPException( - status_code=401, detail="User does not belong to an organization" - ) + tenant_id = await get_or_create_tenant_id(user_create.email) async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -271,7 +242,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): user_create.role = UserRole.ADMIN else: user_create.role = UserRole.BASIC - user = None + try: user = await super().create(user_create, safe=safe, request=request) # type: ignore except exceptions.UserAlreadyExists: @@ -292,7 +263,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): else: raise exceptions.UserAlreadyExists() - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + finally: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + return user async def oauth_callback( @@ -308,19 +281,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: - # Get tenant_id from mapping table - try: - tenant_id = ( - get_tenant_id_for_email(account_email) - if MULTI_TENANT - else POSTGRES_DEFAULT_SCHEMA - ) - except exceptions.UserNotExists: - raise HTTPException(status_code=401, detail="User not found") + tenant_id = await get_or_create_tenant_id(account_email) if not tenant_id: raise HTTPException(status_code=401, detail="User not found") + # Proceed with the tenant context token = None async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -371,9 +337,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): # Explicitly set the Postgres schema for this session to ensure # OAuth account creation happens in the correct tenant schema await db_session.execute(text(f'SET search_path = "{tenant_id}"')) - user = await self.user_db.add_oauth_account( - user, oauth_account_dict - ) + + # Add OAuth account + await self.user_db.add_oauth_account(user, oauth_account_dict) await self.on_after_register(user, request) else: diff --git a/backend/danswer/background/celery/apps/beat.py b/backend/danswer/background/celery/apps/beat.py index 5ef887121..979cf07cb 100644 --- a/backend/danswer/background/celery/apps/beat.py +++ b/backend/danswer/background/celery/apps/beat.py @@ -119,10 +119,10 @@ class DynamicTenantScheduler(PersistentScheduler): else: logger.info("Schedule is up to date, no changes needed") - except (AttributeError, KeyError) as e: - logger.exception(f"Failed to process task configuration: {str(e)}") - except Exception as e: - logger.exception(f"Unexpected error updating tenant tasks: {str(e)}") + except (AttributeError, KeyError): + logger.exception("Failed to process task configuration") + except Exception: + logger.exception("Unexpected error updating tenant tasks") def _should_update_schedule( self, current_schedule: dict, new_schedule: dict diff --git a/backend/danswer/main.py b/backend/danswer/main.py index ae18ab3cc..06ce7bf40 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -277,12 +277,14 @@ def get_application() -> FastAPI: prefix="/auth", tags=["auth"], ) + include_router_with_global_prefix_prepended( application, fastapi_users.get_register_router(UserRead, UserCreate), prefix="/auth", tags=["auth"], ) + include_router_with_global_prefix_prepended( application, fastapi_users.get_reset_password_router(), diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 7802067b0..5bfbce209 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -30,7 +30,6 @@ from danswer.auth.schemas import UserStatus from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user -from danswer.auth.users import get_tenant_id_for_email from danswer.auth.users import optional_user from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import ENABLE_EMAIL_INVITES @@ -66,7 +65,8 @@ from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_comm from ee.danswer.db.user_group import remove_curator_status__no_commit from ee.danswer.server.tenants.billing import register_tenant_users from ee.danswer.server.tenants.provisioning import add_users_to_tenant -from ee.danswer.server.tenants.provisioning import remove_users_from_tenant +from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email +from ee.danswer.server.tenants.user_mapping import remove_users_from_tenant from shared_configs.configs import MULTI_TENANT logger = setup_logger() diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index c1f4a7b39..41176a045 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -359,7 +359,7 @@ def handle_new_chat_message( yield json.dumps(packet) if isinstance(packet, dict) else packet except Exception as e: - logger.exception(f"Error in chat message streaming: {e}") + logger.exception("Error in chat message streaming") yield json.dumps({"error": str(e)}) finally: diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 1b8d5dc4b..4d6767ac2 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -279,7 +279,7 @@ def get_answer_with_quote( ): yield json.dumps(packet) if isinstance(packet, dict) else packet except Exception as e: - logger.exception(f"Error in search answer streaming: {e}") + logger.exception("Error in search answer streaming") yield json.dumps({"error": str(e)}) return StreamingResponse(stream_generator(), media_type="application/json") diff --git a/backend/ee/danswer/server/tenants/api.py b/backend/ee/danswer/server/tenants/api.py index 8e79c0b37..8c1331c15 100644 --- a/backend/ee/danswer/server/tenants/api.py +++ b/backend/ee/danswer/server/tenants/api.py @@ -7,7 +7,6 @@ from fastapi import Response from danswer.auth.users import auth_backend from danswer.auth.users import current_admin_user from danswer.auth.users import get_jwt_strategy -from danswer.auth.users import get_tenant_id_for_email from danswer.auth.users import User from danswer.configs.app_configs import WEB_DOMAIN from danswer.db.engine import get_session_with_tenant @@ -15,7 +14,6 @@ from danswer.db.notification import create_notification from danswer.db.users import get_user_by_email from danswer.server.settings.store import load_settings from danswer.server.settings.store import store_settings -from danswer.setup import setup_danswer from danswer.utils.logger import setup_logger from ee.danswer.auth.users import current_cloud_superuser from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY @@ -23,15 +21,9 @@ from ee.danswer.server.tenants.access import control_plane_dep from ee.danswer.server.tenants.billing import fetch_billing_information from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information from ee.danswer.server.tenants.models import BillingInformation -from ee.danswer.server.tenants.models import CreateTenantRequest from ee.danswer.server.tenants.models import ImpersonateRequest from ee.danswer.server.tenants.models import ProductGatingRequest -from ee.danswer.server.tenants.provisioning import add_users_to_tenant -from ee.danswer.server.tenants.provisioning import configure_default_api_keys -from ee.danswer.server.tenants.provisioning import ensure_schema_exists -from ee.danswer.server.tenants.provisioning import run_alembic_migrations -from ee.danswer.server.tenants.provisioning import user_owns_a_tenant -from shared_configs.configs import MULTI_TENANT +from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR stripe.api_key = STRIPE_SECRET_KEY @@ -40,52 +32,6 @@ logger = setup_logger() router = APIRouter(prefix="/tenants") -@router.post("/create") -def create_tenant( - create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep) -) -> dict[str, str]: - if not MULTI_TENANT: - raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") - - tenant_id = create_tenant_request.tenant_id - email = create_tenant_request.initial_admin_email - token = None - - if user_owns_a_tenant(email): - raise HTTPException( - status_code=409, detail="User already belongs to an organization" - ) - - try: - if not ensure_schema_exists(tenant_id): - logger.info(f"Created schema for tenant {tenant_id}") - else: - logger.info(f"Schema already exists for tenant {tenant_id}") - - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) - run_alembic_migrations(tenant_id) - - with get_session_with_tenant(tenant_id) as db_session: - setup_danswer(db_session, tenant_id) - - configure_default_api_keys(db_session) - - add_users_to_tenant([email], tenant_id) - - return { - "status": "success", - "message": f"Tenant {tenant_id} created successfully", - } - except Exception as e: - logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}") - raise HTTPException( - status_code=500, detail=f"Failed to create tenant: {str(e)}" - ) - finally: - if token is not None: - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) - - @router.post("/product-gating") def gate_product( product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep) diff --git a/backend/ee/danswer/server/tenants/models.py b/backend/ee/danswer/server/tenants/models.py index 2c1fdbecd..df24ff6c3 100644 --- a/backend/ee/danswer/server/tenants/models.py +++ b/backend/ee/danswer/server/tenants/models.py @@ -33,3 +33,8 @@ class CheckoutSessionCreationResponse(BaseModel): class ImpersonateRequest(BaseModel): email: str + + +class TenantCreationPayload(BaseModel): + tenant_id: str + email: str diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 9106821b5..e956cf435 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -1,145 +1,210 @@ -import os -from types import SimpleNamespace +import asyncio +import logging +import uuid -from sqlalchemy import text +import aiohttp # Async HTTP client +from fastapi import HTTPException from sqlalchemy.orm import Session -from sqlalchemy.schema import CreateSchema -from alembic import command -from alembic.config import Config -from danswer.db.engine import build_connection_string +from danswer.auth.users import exceptions +from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.llm import update_default_provider from danswer.db.llm import upsert_cloud_embedding_provider from danswer.db.llm import upsert_llm_provider from danswer.db.models import UserTenantMapping +from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES +from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME +from danswer.llm.llm_provider_options import OPEN_AI_MODEL_NAMES +from danswer.llm.llm_provider_options import OPENAI_PROVIDER_NAME from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from danswer.server.manage.llm.models import LLMProviderUpsertRequest -from danswer.utils.logger import setup_logger +from danswer.setup import setup_danswer from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY +from ee.danswer.server.tenants.access import generate_data_plane_token +from ee.danswer.server.tenants.models import TenantCreationPayload +from ee.danswer.server.tenants.schema_management import create_schema_if_not_exists +from ee.danswer.server.tenants.schema_management import drop_schema +from ee.danswer.server.tenants.schema_management import run_alembic_migrations +from ee.danswer.server.tenants.user_mapping import add_users_to_tenant +from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email +from ee.danswer.server.tenants.user_mapping import user_owns_a_tenant +from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA +from shared_configs.configs import TENANT_ID_PREFIX +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.enums import EmbeddingProvider -logger = setup_logger() +logger = logging.getLogger(__name__) -def run_alembic_migrations(schema_name: str) -> None: - logger.info(f"Starting Alembic migrations for schema: {schema_name}") +async def get_or_create_tenant_id(email: str) -> str: + """Get existing tenant ID for an email or create a new tenant if none exists.""" + if not MULTI_TENANT: + return POSTGRES_DEFAULT_SCHEMA try: - current_dir = os.path.dirname(os.path.abspath(__file__)) - root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) - alembic_ini_path = os.path.join(root_dir, "alembic.ini") + tenant_id = get_tenant_id_for_email(email) + except exceptions.UserNotExists: + # If tenant does not exist and in Multi tenant mode, provision a new tenant + try: + tenant_id = await create_tenant(email) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + raise HTTPException(status_code=500, detail="Failed to provision tenant.") - # Configure Alembic - alembic_cfg = Config(alembic_ini_path) - alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string()) - alembic_cfg.set_main_option( - "script_location", os.path.join(root_dir, "alembic") + if not tenant_id: + raise HTTPException( + status_code=401, detail="User does not belong to an organization" ) - # Ensure that logging isn't broken - alembic_cfg.attributes["configure_logger"] = False + return tenant_id - # Mimic command-line options by adding 'cmd_opts' to the config - alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore - alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore - # Run migrations programmatically - command.upgrade(alembic_cfg, "head") +async def create_tenant(email: str) -> str: + tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) + try: + # Provision tenant on data plane + await provision_tenant(tenant_id, email) + # Notify control plane + await notify_control_plane(tenant_id, email) + except Exception as e: + logger.error(f"Tenant provisioning failed: {e}") + await rollback_tenant_provisioning(tenant_id) + raise HTTPException(status_code=500, detail="Failed to provision tenant.") + return tenant_id - # Run migrations programmatically - logger.info( - f"Alembic migrations completed successfully for schema: {schema_name}" + +async def provision_tenant(tenant_id: str, email: str) -> None: + if not MULTI_TENANT: + raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") + + if user_owns_a_tenant(email): + raise HTTPException( + status_code=409, detail="User already belongs to an organization" ) + logger.info(f"Provisioning tenant: {tenant_id}") + token = None + + try: + if not create_schema_if_not_exists(tenant_id): + logger.info(f"Created schema for tenant {tenant_id}") + else: + logger.info(f"Schema already exists for tenant {tenant_id}") + + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + + # Await the Alembic migrations + await asyncio.to_thread(run_alembic_migrations, tenant_id) + + with get_session_with_tenant(tenant_id) as db_session: + setup_danswer(db_session, tenant_id) + configure_default_api_keys(db_session) + + add_users_to_tenant([email], tenant_id) + except Exception as e: - logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") - raise + logger.exception(f"Failed to create tenant {tenant_id}") + raise HTTPException( + status_code=500, detail=f"Failed to create tenant: {str(e)}" + ) + finally: + if token is not None: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + +async def notify_control_plane(tenant_id: str, email: str) -> None: + logger.info("Fetching billing information") + token = generate_data_plane_token() + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + payload = TenantCreationPayload(tenant_id=tenant_id, email=email) + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{CONTROL_PLANE_API_BASE_URL}/tenants/create", + headers=headers, + json=payload.model_dump(), + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error(f"Control plane tenant creation failed: {error_text}") + raise Exception( + f"Failed to create tenant on control plane: {error_text}" + ) + + +async def rollback_tenant_provisioning(tenant_id: str) -> None: + # Logic to rollback tenant provisioning on data plane + logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}") + try: + # Drop the tenant's schema to rollback provisioning + drop_schema(tenant_id) + # Remove tenant mapping + with Session(get_sqlalchemy_engine()) as db_session: + db_session.query(UserTenantMapping).filter( + UserTenantMapping.tenant_id == tenant_id + ).delete() + db_session.commit() + except Exception as e: + logger.error(f"Failed to rollback tenant provisioning: {e}") def configure_default_api_keys(db_session: Session) -> None: - open_provider = LLMProviderUpsertRequest( - name="OpenAI", - provider="OpenAI", - api_key=OPENAI_DEFAULT_API_KEY, - default_model_name="gpt-4o", - ) - anthropic_provider = LLMProviderUpsertRequest( - name="Anthropic", - provider="Anthropic", - api_key=ANTHROPIC_DEFAULT_API_KEY, - default_model_name="claude-3-5-sonnet-20240620", - ) - upsert_llm_provider(open_provider, db_session) - upsert_llm_provider(anthropic_provider, db_session) - - cloud_embedding_provider = CloudEmbeddingProviderCreationRequest( - provider_type=EmbeddingProvider.COHERE, - api_key=COHERE_DEFAULT_API_KEY, - ) - upsert_cloud_embedding_provider(db_session, cloud_embedding_provider) - - -def ensure_schema_exists(tenant_id: str) -> bool: - with Session(get_sqlalchemy_engine()) as db_session: - with db_session.begin(): - result = db_session.execute( - text( - "SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name" - ), - {"schema_name": tenant_id}, - ) - schema_exists = result.scalar() is not None - if not schema_exists: - stmt = CreateSchema(tenant_id) - db_session.execute(stmt) - return True - return False - - -# For now, we're implementing a primitive mapping between users and tenants. -# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership). -def user_owns_a_tenant(email: str) -> bool: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: - result = ( - db_session.query(UserTenantMapping) - .filter(UserTenantMapping.email == email) - .first() + if OPENAI_DEFAULT_API_KEY: + open_provider = LLMProviderUpsertRequest( + name="OpenAI", + provider=OPENAI_PROVIDER_NAME, + api_key=OPENAI_DEFAULT_API_KEY, + default_model_name="gpt-4", + fast_default_model_name="gpt-4o-mini", + model_names=OPEN_AI_MODEL_NAMES, ) - return result is not None - - -def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: try: - for email in emails: - db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) + full_provider = upsert_llm_provider(open_provider, db_session) + update_default_provider(full_provider.id, db_session) except Exception as e: - logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}") - db_session.commit() + logger.error(f"Failed to configure OpenAI provider: {e}") + else: + logger.error( + "OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration" + ) - -def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + if ANTHROPIC_DEFAULT_API_KEY: + anthropic_provider = LLMProviderUpsertRequest( + name="Anthropic", + provider=ANTHROPIC_PROVIDER_NAME, + api_key=ANTHROPIC_DEFAULT_API_KEY, + default_model_name="claude-3-5-sonnet-20241022", + fast_default_model_name="claude-3-5-sonnet-20241022", + model_names=ANTHROPIC_MODEL_NAMES, + ) try: - mappings_to_delete = ( - db_session.query(UserTenantMapping) - .filter( - UserTenantMapping.email.in_(emails), - UserTenantMapping.tenant_id == tenant_id, - ) - .all() - ) - - for mapping in mappings_to_delete: - db_session.delete(mapping) - - db_session.commit() + full_provider = upsert_llm_provider(anthropic_provider, db_session) + update_default_provider(full_provider.id, db_session) except Exception as e: - logger.exception( - f"Failed to remove users from tenant {tenant_id}: {str(e)}" - ) - db_session.rollback() + logger.error(f"Failed to configure Anthropic provider: {e}") + else: + logger.error( + "ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration" + ) + + if COHERE_DEFAULT_API_KEY: + cloud_embedding_provider = CloudEmbeddingProviderCreationRequest( + provider_type=EmbeddingProvider.COHERE, + api_key=COHERE_DEFAULT_API_KEY, + ) + try: + upsert_cloud_embedding_provider(db_session, cloud_embedding_provider) + except Exception as e: + logger.error(f"Failed to configure Cohere embedding provider: {e}") + else: + logger.error( + "COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration" + ) diff --git a/backend/ee/danswer/server/tenants/schema_management.py b/backend/ee/danswer/server/tenants/schema_management.py new file mode 100644 index 000000000..9be4e79f9 --- /dev/null +++ b/backend/ee/danswer/server/tenants/schema_management.py @@ -0,0 +1,76 @@ +import logging +import os +from types import SimpleNamespace + +from sqlalchemy import text +from sqlalchemy.orm import Session +from sqlalchemy.schema import CreateSchema + +from alembic import command +from alembic.config import Config +from danswer.db.engine import build_connection_string +from danswer.db.engine import get_sqlalchemy_engine + +logger = logging.getLogger(__name__) + + +def run_alembic_migrations(schema_name: str) -> None: + logger.info(f"Starting Alembic migrations for schema: {schema_name}") + + try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) + alembic_ini_path = os.path.join(root_dir, "alembic.ini") + + # Configure Alembic + alembic_cfg = Config(alembic_ini_path) + alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string()) + alembic_cfg.set_main_option( + "script_location", os.path.join(root_dir, "alembic") + ) + + # Ensure that logging isn't broken + alembic_cfg.attributes["configure_logger"] = False + + # Mimic command-line options by adding 'cmd_opts' to the config + alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore + alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore + + # Run migrations programmatically + command.upgrade(alembic_cfg, "head") + + # Run migrations programmatically + logger.info( + f"Alembic migrations completed successfully for schema: {schema_name}" + ) + + except Exception as e: + logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") + raise + + +def create_schema_if_not_exists(tenant_id: str) -> bool: + with Session(get_sqlalchemy_engine()) as db_session: + with db_session.begin(): + result = db_session.execute( + text( + "SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name" + ), + {"schema_name": tenant_id}, + ) + schema_exists = result.scalar() is not None + if not schema_exists: + stmt = CreateSchema(tenant_id) + db_session.execute(stmt) + return True + return False + + +def drop_schema(tenant_id: str) -> None: + if not tenant_id.isidentifier(): + raise ValueError("Invalid tenant_id.") + with get_sqlalchemy_engine().connect() as connection: + connection.execute( + text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"), + {"schema_name": tenant_id}, + ) diff --git a/backend/ee/danswer/server/tenants/user_mapping.py b/backend/ee/danswer/server/tenants/user_mapping.py new file mode 100644 index 000000000..cf0e5ec5f --- /dev/null +++ b/backend/ee/danswer/server/tenants/user_mapping.py @@ -0,0 +1,70 @@ +import logging + +from fastapi_users import exceptions +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.engine import get_session_with_tenant +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import UserTenantMapping +from shared_configs.configs import MULTI_TENANT +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA + +logger = logging.getLogger(__name__) + + +def get_tenant_id_for_email(email: str) -> str: + if not MULTI_TENANT: + return POSTGRES_DEFAULT_SCHEMA + # Implement logic to get tenant_id from the mapping table + with Session(get_sqlalchemy_engine()) as db_session: + result = db_session.execute( + select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email) + ) + tenant_id = result.scalar_one_or_none() + if tenant_id is None: + raise exceptions.UserNotExists() + return tenant_id + + +def user_owns_a_tenant(email: str) -> bool: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + result = ( + db_session.query(UserTenantMapping) + .filter(UserTenantMapping.email == email) + .first() + ) + return result is not None + + +def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + try: + for email in emails: + db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) + except Exception: + logger.exception(f"Failed to add users to tenant {tenant_id}") + db_session.commit() + + +def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: + try: + mappings_to_delete = ( + db_session.query(UserTenantMapping) + .filter( + UserTenantMapping.email.in_(emails), + UserTenantMapping.tenant_id == tenant_id, + ) + .all() + ) + + for mapping in mappings_to_delete: + db_session.delete(mapping) + + db_session.commit() + except Exception as e: + logger.exception( + f"Failed to remove users from tenant {tenant_id}: {str(e)}" + ) + db_session.rollback() diff --git a/deployment/cloud_kubernetes/workers/beat.yaml b/deployment/cloud_kubernetes/workers/beat.yaml index a9d053f72..563dbf104 100644 --- a/deployment/cloud_kubernetes/workers/beat.yaml +++ b/deployment/cloud_kubernetes/workers/beat.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-beat - image: danswer/danswer-backend:v0.11.0-cloud.beta.4 + image: danswer/danswer-backend:v0.11.0-cloud.beta.8 imagePullPolicy: Always command: [ @@ -31,7 +31,7 @@ spec: name: danswer-secrets key: redis_password - name: DANSWER_VERSION - value: "v0.11.0-cloud.beta.4" + value: "v0.11.0-cloud.beta.8" envFrom: - configMapRef: name: env-configmap diff --git a/deployment/cloud_kubernetes/workers/heavy_worker.yaml b/deployment/cloud_kubernetes/workers/heavy_worker.yaml index 682cadee6..d8da6a3d3 100644 --- a/deployment/cloud_kubernetes/workers/heavy_worker.yaml +++ b/deployment/cloud_kubernetes/workers/heavy_worker.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-heavy - image: danswer/danswer-backend:v0.11.0-cloud.beta.4 + image: danswer/danswer-backend:v0.11.0-cloud.beta.8 imagePullPolicy: Always command: [ @@ -34,7 +34,7 @@ spec: name: danswer-secrets key: redis_password - name: DANSWER_VERSION - value: "v0.11.0-cloud.beta.4" + value: "v0.11.0-cloud.beta.8" envFrom: - configMapRef: name: env-configmap diff --git a/deployment/cloud_kubernetes/workers/indexing_worker.yaml b/deployment/cloud_kubernetes/workers/indexing_worker.yaml index 286cd3036..98158f62e 100644 --- a/deployment/cloud_kubernetes/workers/indexing_worker.yaml +++ b/deployment/cloud_kubernetes/workers/indexing_worker.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-indexing - image: danswer/danswer-backend:v0.11.0-cloud.beta.4 + image: danswer/danswer-backend:v0.11.0-cloud.beta.8 imagePullPolicy: Always command: [ @@ -34,7 +34,7 @@ spec: name: danswer-secrets key: redis_password - name: DANSWER_VERSION - value: "v0.11.0-cloud.beta.4" + value: "v0.11.0-cloud.beta.8" envFrom: - configMapRef: name: env-configmap diff --git a/deployment/cloud_kubernetes/workers/light_worker.yaml b/deployment/cloud_kubernetes/workers/light_worker.yaml index 055fac836..2df3b50ea 100644 --- a/deployment/cloud_kubernetes/workers/light_worker.yaml +++ b/deployment/cloud_kubernetes/workers/light_worker.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-light - image: danswer/danswer-backend:v0.11.0-cloud.beta.4 + image: danswer/danswer-backend:v0.11.0-cloud.beta.8 imagePullPolicy: Always command: [ @@ -34,7 +34,7 @@ spec: name: danswer-secrets key: redis_password - name: DANSWER_VERSION - value: "v0.11.0-cloud.beta.4" + value: "v0.11.0-cloud.beta.8" envFrom: - configMapRef: name: env-configmap diff --git a/deployment/cloud_kubernetes/workers/primary.yaml b/deployment/cloud_kubernetes/workers/primary.yaml index 7408e3bfb..32e34b5cd 100644 --- a/deployment/cloud_kubernetes/workers/primary.yaml +++ b/deployment/cloud_kubernetes/workers/primary.yaml @@ -14,7 +14,7 @@ spec: spec: containers: - name: celery-worker-primary - image: danswer/danswer-backend:v0.11.0-cloud.beta.4 + image: danswer/danswer-backend:v0.11.0-cloud.beta.8 imagePullPolicy: Always command: [ @@ -34,7 +34,7 @@ spec: name: danswer-secrets key: redis_password - name: DANSWER_VERSION - value: "v0.11.0-cloud.beta.4" + value: "v0.11.0-cloud.beta.8" envFrom: - configMapRef: name: env-configmap