add default schema config (#2888)

* add default schema config

* resolve circular import

* k
This commit is contained in:
pablodanswer
2024-10-23 16:12:17 -07:00
committed by GitHub
parent 3eb67baf5b
commit 14e75bbd24
13 changed files with 45 additions and 22 deletions

View File

@ -14,6 +14,7 @@ from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from danswer.db.engine import get_all_tenant_ids
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# Alembic Config object
config = context.config
@ -57,11 +58,15 @@ def get_schema_options() -> tuple[str, bool, bool]:
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", "public")
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
if MULTI_TENANT and schema_name == "public" and not upgrade_all_tenants:
if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
and not upgrade_all_tenants
):
raise ValueError(
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."

View File

@ -94,6 +94,7 @@ from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import current_tenant_id
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@ -187,7 +188,7 @@ def verify_email_domain(email: str) -> None:
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return "public"
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
@ -235,7 +236,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
) -> User:
try:
tenant_id = (
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
get_tenant_id_for_email(user_create.email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
@ -327,7 +330,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
get_tenant_id_for_email(account_email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")

View File

@ -46,7 +46,6 @@ POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"

View File

@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
@ -29,6 +28,7 @@ from danswer.file_processing.extract_file_text import read_text_file
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()

View File

@ -60,6 +60,7 @@ from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import current_tenant_id
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
logger = setup_logger()
@ -510,7 +511,9 @@ if __name__ == "__main__":
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
token = current_tenant_id.set(tenant_id or "public")
token = current_tenant_id.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
latest_slack_bot_tokens = fetch_tokens()
current_tenant_id.reset(token)

View File

@ -36,11 +36,11 @@ from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.configs.constants import TENANT_ID_PREFIX
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@ -192,13 +192,13 @@ class SqlEngine:
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id="public") as session:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
result = session.execute(
text(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
f"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', '{POSTGRES_DEFAULT_SCHEMA}')"""
)
)
tenant_ids = [row[0] for row in result]
@ -365,7 +365,7 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
def get_session() -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
tenant_id = current_tenant_id.get()
if tenant_id == "public" and MULTI_TENANT:
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise HTTPException(status_code=401, detail="User must authenticate")
engine = get_sqlalchemy_engine()

View File

@ -17,6 +17,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@ -35,7 +36,7 @@ class PgRedisKVStore(KeyValueStore):
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
tenant_id = current_tenant_id.get()
if tenant_id == "public":
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
raise HTTPException(
status_code=401, detail="User must authenticate"
)

View File

@ -10,9 +10,9 @@ from fastapi import Response
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.db.engine import is_valid_schema_name
from shared_configs.configs import current_tenant_id
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:

View File

@ -12,6 +12,7 @@ from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import UserTenantMapping
from danswer.utils.logger import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@ -71,7 +72,7 @@ def ensure_schema_exists(tenant_id: str) -> bool:
# For now, we're implementing a primitive mapping between users and tenants.
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant("public") as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
@ -81,7 +82,7 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant("public") as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
@ -91,7 +92,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant("public") as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)

View File

@ -21,6 +21,7 @@ from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.indexing.models import IndexChunk
from danswer.utils.timing import log_function_time
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.model_server_models import Embedding
@ -94,7 +95,7 @@ def generate_dummy_chunk(
),
document_sets={document_set for document_set in document_set_names},
boost=random.randint(-1, 1),
tenant_id="public",
tenant_id=POSTGRES_DEFAULT_SCHEMA,
)

View File

@ -128,7 +128,12 @@ else:
# If the environment variable is empty, allow all origins
CORS_ALLOWED_ORIGIN = ["*"]
current_tenant_id = contextvars.ContextVar("current_tenant_id", default="public")
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
current_tenant_id = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)
SUPPORTED_EMBEDDING_MODELS = [

View File

@ -63,6 +63,7 @@ services:
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
# Other services
- POSTGRES_HOST=relational_db
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- VESPA_HOST=index
- REDIS_HOST=cache
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
@ -147,6 +148,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
- POSTGRES_DB=${POSTGRES_DB:-}
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- VESPA_HOST=index
- REDIS_HOST=cache
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors

View File

@ -32,6 +32,7 @@ data:
QA_PROMPT_OVERRIDE: ""
# Other Services
POSTGRES_HOST: "relational-db-service"
POSTGRES_DEFAULT_SCHEMA: ""
VESPA_HOST: "document-index-service"
REDIS_HOST: "redis-service"
# Internet Search Tool