mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-04 11:41:04 +02:00
add default schema config (#2888)
* add default schema config * resolve circular import * k
This commit is contained in:
@ -14,6 +14,7 @@ from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
@ -57,11 +58,15 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
if "=" in pair:
|
||||
key, value = pair.split("=", 1)
|
||||
x_args[key.strip()] = value.strip()
|
||||
schema_name = x_args.get("schema", "public")
|
||||
schema_name = x_args.get("schema", POSTGRES_DEFAULT_SCHEMA)
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
if MULTI_TENANT and schema_name == "public" and not upgrade_all_tenants:
|
||||
if (
|
||||
MULTI_TENANT
|
||||
and schema_name == POSTGRES_DEFAULT_SCHEMA
|
||||
and not upgrade_all_tenants
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
|
||||
"Please specify a tenant-specific schema."
|
||||
|
@ -94,6 +94,7 @@ from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -187,7 +188,7 @@ def verify_email_domain(email: str) -> None:
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return "public"
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
@ -235,7 +236,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
) -> User:
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
|
||||
get_tenant_id_for_email(user_create.email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
@ -327,7 +330,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Get tenant_id from mapping table
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(account_email) if MULTI_TENANT else "public"
|
||||
get_tenant_id_for_email(account_email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
@ -46,7 +46,6 @@ POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
POSTGRES_DEFAULT_SCHEMA = "public"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
|
@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@ -29,6 +28,7 @@ from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
@ -60,6 +60,7 @@ from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
|
||||
logger = setup_logger()
|
||||
@ -510,7 +511,9 @@ if __name__ == "__main__":
|
||||
for tenant_id in tenant_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
token = current_tenant_id.set(tenant_id or "public")
|
||||
token = current_tenant_id.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
latest_slack_bot_tokens = fetch_tokens()
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
@ -36,11 +36,11 @@ from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -192,13 +192,13 @@ class SqlEngine:
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
with get_session_with_tenant(tenant_id="public") as session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
|
||||
result = session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||
f"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', '{POSTGRES_DEFAULT_SCHEMA}')"""
|
||||
)
|
||||
)
|
||||
tenant_ids = [row[0] for row in result]
|
||||
@ -365,7 +365,7 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = current_tenant_id.get()
|
||||
if tenant_id == "public" and MULTI_TENANT:
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
@ -17,6 +17,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -35,7 +36,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
tenant_id = current_tenant_id.get()
|
||||
if tenant_id == "public":
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User must authenticate"
|
||||
)
|
||||
|
@ -10,9 +10,9 @@ from fastapi import Response
|
||||
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.db.engine import is_valid_schema_name
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
|
||||
|
@ -12,6 +12,7 @@ from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -71,7 +72,7 @@ def ensure_schema_exists(tenant_id: str) -> bool:
|
||||
# For now, we're implementing a primitive mapping between users and tenants.
|
||||
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant("public") as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@ -81,7 +82,7 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant("public") as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
@ -91,7 +92,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant("public") as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
|
@ -21,6 +21,7 @@ from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
@ -94,7 +95,7 @@ def generate_dummy_chunk(
|
||||
),
|
||||
document_sets={document_set for document_set in document_set_names},
|
||||
boost=random.randint(-1, 1),
|
||||
tenant_id="public",
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
)
|
||||
|
||||
|
||||
|
@ -128,7 +128,12 @@ else:
|
||||
# If the environment variable is empty, allow all origins
|
||||
CORS_ALLOWED_ORIGIN = ["*"]
|
||||
|
||||
current_tenant_id = contextvars.ContextVar("current_tenant_id", default="public")
|
||||
|
||||
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
|
||||
|
||||
current_tenant_id = contextvars.ContextVar(
|
||||
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
|
||||
SUPPORTED_EMBEDDING_MODELS = [
|
||||
|
@ -63,6 +63,7 @@ services:
|
||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||
# Other services
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
|
||||
@ -147,6 +148,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors
|
||||
|
@ -32,6 +32,7 @@ data:
|
||||
QA_PROMPT_OVERRIDE: ""
|
||||
# Other Services
|
||||
POSTGRES_HOST: "relational-db-service"
|
||||
POSTGRES_DEFAULT_SCHEMA: ""
|
||||
VESPA_HOST: "document-index-service"
|
||||
REDIS_HOST: "redis-service"
|
||||
# Internet Search Tool
|
||||
|
Reference in New Issue
Block a user