From 14e75bbd24b067ff2b3138af4f6c7937201b34e2 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 23 Oct 2024 16:12:17 -0700 Subject: [PATCH] add default schema config (#2888) * add default schema config * resolve circular import * k --- backend/alembic/env.py | 9 +++++++-- backend/danswer/auth/users.py | 11 ++++++++--- backend/danswer/configs/constants.py | 1 - backend/danswer/connectors/file/connector.py | 2 +- backend/danswer/danswerbot/slack/listener.py | 5 ++++- backend/danswer/db/engine.py | 14 +++++++------- backend/danswer/key_value_store/store.py | 3 ++- .../danswer/server/middleware/tenant_tracking.py | 2 +- backend/ee/danswer/server/tenants/provisioning.py | 7 ++++--- .../scripts/query_time_check/seed_dummy_docs.py | 3 ++- backend/shared_configs/configs.py | 7 ++++++- deployment/docker_compose/docker-compose.dev.yml | 2 ++ deployment/kubernetes/env-configmap.yaml | 1 + 13 files changed, 45 insertions(+), 22 deletions(-) diff --git a/backend/alembic/env.py b/backend/alembic/env.py index b4b0ecb466..7ccd04cf16 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -14,6 +14,7 @@ from danswer.db.engine import build_connection_string from danswer.db.models import Base from celery.backends.database.session import ResultModelBase # type: ignore from danswer.db.engine import get_all_tenant_ids +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # Alembic Config object config = context.config @@ -57,11 +58,15 @@ def get_schema_options() -> tuple[str, bool, bool]: if "=" in pair: key, value = pair.split("=", 1) x_args[key.strip()] = value.strip() - schema_name = x_args.get("schema", "public") + schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA) create_schema = x_args.get("create_schema", "true").lower() == "true" upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true" - if MULTI_TENANT and schema_name == "public" and not upgrade_all_tenants: + if ( + MULTI_TENANT + and schema_name == POSTGRES_DEFAULT_SCHEMA + and not upgrade_all_tenants + ): raise ValueError( "Cannot run default migrations in public schema when multi-tenancy is enabled. " "Please specify a tenant-specific schema." diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 4565073b6a..51cac314fd 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -94,6 +94,7 @@ from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation from shared_configs.configs import current_tenant_id +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() @@ -187,7 +188,7 @@ def verify_email_domain(email: str) -> None: def get_tenant_id_for_email(email: str) -> str: if not MULTI_TENANT: - return "public" + return POSTGRES_DEFAULT_SCHEMA # Implement logic to get tenant_id from the mapping table with Session(get_sqlalchemy_engine()) as db_session: result = db_session.execute( @@ -235,7 +236,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): ) -> User: try: tenant_id = ( - get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public" + get_tenant_id_for_email(user_create.email) + if MULTI_TENANT + else POSTGRES_DEFAULT_SCHEMA ) except exceptions.UserNotExists: raise HTTPException(status_code=401, detail="User not found") @@ -327,7 +330,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): # Get tenant_id from mapping table try: tenant_id = ( - get_tenant_id_for_email(account_email) if MULTI_TENANT else "public" + get_tenant_id_for_email(account_email) + if MULTI_TENANT + else POSTGRES_DEFAULT_SCHEMA ) except exceptions.UserNotExists: raise HTTPException(status_code=401, detail="User not found") diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 2c86c7f054..6a3385b9fa 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -46,7 +46,6 @@ POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing" POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child" POSTGRES_PERMISSIONS_APP_NAME = "permissions" POSTGRES_UNKNOWN_APP_NAME = "unknown" -POSTGRES_DEFAULT_SCHEMA = "public" # API Keys DANSWER_API_KEY_PREFIX = "API_KEY__" diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index 9992159eb3..eb79cce579 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -10,7 +10,6 @@ from sqlalchemy.orm import Session from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource -from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import LoadConnector @@ -29,6 +28,7 @@ from danswer.file_processing.extract_file_text import read_text_file from danswer.file_store.file_store import get_default_file_store from danswer.utils.logger import setup_logger from shared_configs.configs import current_tenant_id +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index e3b2d213e8..b05c3a5ce5 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -60,6 +60,7 @@ from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable from shared_configs.configs import current_tenant_id from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import SLACK_CHANNEL_ID logger = setup_logger() @@ -510,7 +511,9 @@ if __name__ == "__main__": for tenant_id in tenant_ids: with get_session_with_tenant(tenant_id) as db_session: try: - token = current_tenant_id.set(tenant_id or "public") + token = current_tenant_id.set( + tenant_id or POSTGRES_DEFAULT_SCHEMA + ) latest_slack_bot_tokens = fetch_tokens() current_tenant_id.reset(token) diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 625c36435c..7bf813b44f 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -36,11 +36,11 @@ from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE from danswer.configs.app_configs import POSTGRES_PORT from danswer.configs.app_configs import POSTGRES_USER from danswer.configs.app_configs import SECRET_JWT_KEY -from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME from danswer.configs.constants import TENANT_ID_PREFIX from danswer.utils.logger import setup_logger from shared_configs.configs import current_tenant_id +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() @@ -192,13 +192,13 @@ class SqlEngine: def get_all_tenant_ids() -> list[str] | list[None]: if not MULTI_TENANT: return [None] - with get_session_with_tenant(tenant_id="public") as session: + with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session: result = session.execute( text( - """ - SELECT schema_name - FROM information_schema.schemata - WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')""" + f""" + SELECT schema_name + FROM information_schema.schemata + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', '{POSTGRES_DEFAULT_SCHEMA}')""" ) ) tenant_ids = [row[0] for row in result] @@ -365,7 +365,7 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]: def get_session() -> Generator[Session, None, None]: """Generate a database session with the appropriate tenant schema set.""" tenant_id = current_tenant_id.get() - if tenant_id == "public" and MULTI_TENANT: + if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT: raise HTTPException(status_code=401, detail="User must authenticate") engine = get_sqlalchemy_engine() diff --git a/backend/danswer/key_value_store/store.py b/backend/danswer/key_value_store/store.py index 98f3d7ec1c..b461ca22fe 100644 --- a/backend/danswer/key_value_store/store.py +++ b/backend/danswer/key_value_store/store.py @@ -17,6 +17,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger from shared_configs.configs import current_tenant_id +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() @@ -35,7 +36,7 @@ class PgRedisKVStore(KeyValueStore): with Session(engine, expire_on_commit=False) as session: if MULTI_TENANT: tenant_id = current_tenant_id.get() - if tenant_id == "public": + if tenant_id == POSTGRES_DEFAULT_SCHEMA: raise HTTPException( status_code=401, detail="User must authenticate" ) diff --git a/backend/ee/danswer/server/middleware/tenant_tracking.py b/backend/ee/danswer/server/middleware/tenant_tracking.py index f564a4fc68..63b0f82be8 100644 --- a/backend/ee/danswer/server/middleware/tenant_tracking.py +++ b/backend/ee/danswer/server/middleware/tenant_tracking.py @@ -10,9 +10,9 @@ from fastapi import Response from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.app_configs import SECRET_JWT_KEY -from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA from danswer.db.engine import is_valid_schema_name from shared_configs.configs import current_tenant_id +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None: diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 9ec7b8061a..311698e6a3 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -12,6 +12,7 @@ from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import UserTenantMapping from danswer.utils.logger import setup_logger +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() @@ -71,7 +72,7 @@ def ensure_schema_exists(tenant_id: str) -> bool: # For now, we're implementing a primitive mapping between users and tenants. # This function is only used to determine a user's relationship to a tenant upon creation (implying ownership). def user_owns_a_tenant(email: str) -> bool: - with get_session_with_tenant("public") as db_session: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: result = ( db_session.query(UserTenantMapping) .filter(UserTenantMapping.email == email) @@ -81,7 +82,7 @@ def user_owns_a_tenant(email: str) -> bool: def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant("public") as db_session: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: try: for email in emails: db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id)) @@ -91,7 +92,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: - with get_session_with_tenant("public") as db_session: + with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session: try: mappings_to_delete = ( db_session.query(UserTenantMapping) diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py index 70cb2a4a6a..e7aa65fba7 100644 --- a/backend/scripts/query_time_check/seed_dummy_docs.py +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -21,6 +21,7 @@ from danswer.indexing.models import ChunkEmbedding from danswer.indexing.models import DocMetadataAwareIndexChunk from danswer.indexing.models import IndexChunk from danswer.utils.timing import log_function_time +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.model_server_models import Embedding @@ -94,7 +95,7 @@ def generate_dummy_chunk( ), document_sets={document_set for document_set in document_set_names}, boost=random.randint(-1, 1), - tenant_id="public", + tenant_id=POSTGRES_DEFAULT_SCHEMA, ) diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index f10855f103..77139125f6 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -128,7 +128,12 @@ else: # If the environment variable is empty, allow all origins CORS_ALLOWED_ORIGIN = ["*"] -current_tenant_id = contextvars.ContextVar("current_tenant_id", default="public") + +POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public" + +current_tenant_id = contextvars.ContextVar( + "current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA +) SUPPORTED_EMBEDDING_MODELS = [ diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index d22bde5b46..7b31689c8f 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -63,6 +63,7 @@ services: - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} # Other services - POSTGRES_HOST=relational_db + - POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-} - VESPA_HOST=index - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose @@ -147,6 +148,7 @@ services: - POSTGRES_USER=${POSTGRES_USER:-} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-} - POSTGRES_DB=${POSTGRES_DB:-} + - POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-} - VESPA_HOST=index - REDIS_HOST=cache - WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors diff --git a/deployment/kubernetes/env-configmap.yaml b/deployment/kubernetes/env-configmap.yaml index 1d4bf1cffd..e1eefaeca9 100644 --- a/deployment/kubernetes/env-configmap.yaml +++ b/deployment/kubernetes/env-configmap.yaml @@ -32,6 +32,7 @@ data: QA_PROMPT_OVERRIDE: "" # Other Services POSTGRES_HOST: "relational-db-service" + POSTGRES_DEFAULT_SCHEMA: "" VESPA_HOST: "document-index-service" REDIS_HOST: "redis-service" # Internet Search Tool