From 493c3d73143686894d3e6228482b7cf0ddc031dd Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sat, 5 Oct 2024 14:08:41 -0700 Subject: [PATCH] Add only multi tenant dependency injection (#2588) * add only dependency injection * quick typing fix * additional non-dependency context * update nits --- backend/danswer/configs/app_configs.py | 4 + backend/danswer/configs/constants.py | 1 + backend/danswer/db/engine.py | 153 ++++++++++++++++++++----- 3 files changed, 130 insertions(+), 28 deletions(-) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 6096561b5..4857e2aa9 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -406,3 +406,7 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads( ENTERPRISE_EDITION_ENABLED = ( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true" ) + + +MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true" +SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "") diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index d8470d171..c26c2fbd6 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -41,6 +41,7 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light" POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy" POSTGRES_PERMISSIONS_APP_NAME = "permissions" POSTGRES_UNKNOWN_APP_NAME = "unknown" +POSTGRES_DEFAULT_SCHEMA = "public" # API Keys DANSWER_API_KEY_PREFIX = "API_KEY__" diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index af44498be..559c2dec0 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -1,4 +1,6 @@ import contextlib +import contextvars +import re import threading import time from collections.abc import AsyncGenerator @@ -7,6 +9,10 @@ from datetime import datetime from typing import Any from typing import ContextManager +import jwt +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Request from sqlalchemy import event from sqlalchemy import text from sqlalchemy.engine import create_engine @@ -19,6 +25,7 @@ from sqlalchemy.orm import sessionmaker from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS from danswer.configs.app_configs import LOG_POSTGRES_LATENCY +from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.app_configs import POSTGRES_DB from danswer.configs.app_configs import POSTGRES_HOST from danswer.configs.app_configs import POSTGRES_PASSWORD @@ -26,9 +33,12 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE from danswer.configs.app_configs import POSTGRES_PORT from danswer.configs.app_configs import POSTGRES_USER +from danswer.configs.app_configs import SECRET_JWT_KEY +from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME from danswer.utils.logger import setup_logger + logger = setup_logger() SYNC_DB_API = "psycopg2" @@ -37,11 +47,10 @@ ASYNC_DB_API = "asyncpg" # global so we don't create more than one engine per process # outside of being best practice, this is needed so we can properly pool # connections and not create a new pool on every request + _ASYNC_ENGINE: AsyncEngine | None = None - SessionFactory: sessionmaker[Session] | None = None - if LOG_POSTGRES_LATENCY: # Function to log before query execution @event.listens_for(Engine, "before_cursor_execute") @@ -105,10 +114,19 @@ def get_db_current_time(db_session: Session) -> datetime: return result +# Regular expression to validate schema names to prevent SQL injection +SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + + +def is_valid_schema_name(name: str) -> bool: + return SCHEMA_NAME_REGEX.match(name) is not None + + class SqlEngine: - """Class to manage a global sql alchemy engine (needed for proper resource control) + """Class to manage a global SQLAlchemy engine (needed for proper resource control). Will eventually subsume most of the standalone functions in this file. - Sync only for now""" + Sync only for now. + """ _engine: Engine | None = None _lock: threading.Lock = threading.Lock() @@ -137,16 +155,18 @@ class SqlEngine: @classmethod def init_engine(cls, **engine_kwargs: Any) -> None: """Allow the caller to init the engine with extra params. Different clients - such as the API server and different celery workers and tasks - need different settings.""" + such as the API server and different Celery workers and tasks + need different settings. + """ with cls._lock: if not cls._engine: cls._engine = cls._init_engine(**engine_kwargs) @classmethod def get_engine(cls) -> Engine: - """Gets the sql alchemy engine. Will init a default engine if init hasn't - already been called. You probably want to init first!""" + """Gets the SQLAlchemy engine. Will init a default engine if init hasn't + already been called. You probably want to init first! + """ if not cls._engine: with cls._lock: if not cls._engine: @@ -178,7 +198,6 @@ def build_connection_string( ) -> str: if app_name: return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}" - return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}" @@ -193,7 +212,7 @@ def get_sqlalchemy_engine() -> Engine: def get_sqlalchemy_async_engine() -> AsyncEngine: global _ASYNC_ENGINE if _ASYNC_ENGINE is None: - # underlying asyncpg cannot accept application_name directly in the connection string + # Underlying asyncpg cannot accept application_name directly in the connection string # https://github.com/MagicStack/asyncpg/issues/798 connection_string = build_connection_string() _ASYNC_ENGINE = create_async_engine( @@ -211,25 +230,110 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: return _ASYNC_ENGINE -def get_session_context_manager() -> ContextManager[Session]: - return contextlib.contextmanager(get_session)() +# Context variable to store the current tenant ID +# This allows us to maintain tenant-specific context throughout the request lifecycle +# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups +# This context variable works in both synchronous and asynchronous contexts +# In async code, it's automatically carried across coroutines +# In sync code, it's managed per thread +current_tenant_id = contextvars.ContextVar( + "current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA +) -def get_session() -> Generator[Session, None, None]: - # The line below was added to monitor the latency caused by Postgres connections - # during API calls. - # with tracer.trace("db.get_session"): - with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: +# Dependency to get the current tenant ID and set the context variable +def get_current_tenant_id(request: Request) -> str: + """Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable.""" + if not MULTI_TENANT: + tenant_id = POSTGRES_DEFAULT_SCHEMA + current_tenant_id.set(tenant_id) + return tenant_id + + token = request.cookies.get("tenant_details") + if not token: + # If no token is present, use the default schema or handle accordingly + tenant_id = POSTGRES_DEFAULT_SCHEMA + current_tenant_id.set(tenant_id) + return tenant_id + + try: + payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"]) + tenant_id = payload.get("tenant_id") + if not tenant_id: + raise HTTPException( + status_code=400, detail="Invalid token: tenant_id missing" + ) + if not is_valid_schema_name(tenant_id): + raise ValueError("Invalid tenant ID format") + current_tenant_id.set(tenant_id) + return tenant_id + except jwt.InvalidTokenError: + raise HTTPException(status_code=401, detail="Invalid token format") + except ValueError as e: + # Let the 400 error bubble up + raise HTTPException(status_code=400, detail=str(e)) + except Exception: + raise HTTPException(status_code=500, detail="Internal server error") + + +def get_session_with_tenant(tenant_id: str | None = None) -> Session: + if tenant_id is None: + tenant_id = current_tenant_id.get() + + if not is_valid_schema_name(tenant_id): + raise Exception("Invalid tenant ID") + + engine = SqlEngine.get_engine() + session = Session(engine, expire_on_commit=False) + + @event.listens_for(session, "after_begin") + def set_search_path(session: Session, transaction: Any, connection: Any) -> None: + connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id}) + + return session + + +def get_session( + tenant_id: str = Depends(get_current_tenant_id), +) -> Generator[Session, None, None]: + """Generate a database session with the appropriate tenant schema set.""" + engine = get_sqlalchemy_engine() + with Session(engine, expire_on_commit=False) as session: + if MULTI_TENANT: + if not is_valid_schema_name(tenant_id): + raise HTTPException(status_code=400, detail="Invalid tenant ID") + # Set the search_path to the tenant's schema + session.execute(text(f'SET search_path = "{tenant_id}"')) yield session -async def get_async_session() -> AsyncGenerator[AsyncSession, None]: - async with AsyncSession( - get_sqlalchemy_async_engine(), expire_on_commit=False - ) as async_session: +async def get_async_session( + tenant_id: str = Depends(get_current_tenant_id), +) -> AsyncGenerator[AsyncSession, None]: + """Generate an async database session with the appropriate tenant schema set.""" + engine = get_sqlalchemy_async_engine() + async with AsyncSession(engine, expire_on_commit=False) as async_session: + if MULTI_TENANT: + if not is_valid_schema_name(tenant_id): + raise HTTPException(status_code=400, detail="Invalid tenant ID") + # Set the search_path to the tenant's schema + await async_session.execute(text(f'SET search_path = "{tenant_id}"')) yield async_session +def get_session_context_manager() -> ContextManager[Session]: + """Context manager for database sessions.""" + return contextlib.contextmanager(get_session)() + + +def get_session_factory() -> sessionmaker[Session]: + """Get a session factory.""" + global SessionFactory + if SessionFactory is None: + SessionFactory = sessionmaker(bind=get_sqlalchemy_engine()) + return SessionFactory + + async def warm_up_connections( sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20 ) -> None: @@ -251,10 +355,3 @@ async def warm_up_connections( await async_conn.execute(text("SELECT 1")) for async_conn in async_connections: await async_conn.close() - - -def get_session_factory() -> sessionmaker[Session]: - global SessionFactory - if SessionFactory is None: - SessionFactory = sessionmaker(bind=get_sqlalchemy_engine()) - return SessionFactory