Add only multi tenant dependency injection (#2588)

* add only dependency injection

* quick typing fix

* additional non-dependency context

* update nits
This commit is contained in:
pablodanswer 2024-10-05 14:08:41 -07:00 committed by GitHub
parent b04e9e9b67
commit 493c3d7314
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 130 additions and 28 deletions

View File

@ -406,3 +406,7 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")

View File

@ -41,6 +41,7 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"

View File

@ -1,4 +1,6 @@
import contextlib
import contextvars
import re
import threading
import time
from collections.abc import AsyncGenerator
@ -7,6 +9,10 @@ from datetime import datetime
from typing import Any
from typing import ContextManager
import jwt
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
from sqlalchemy import text
from sqlalchemy.engine import create_engine
@ -19,6 +25,7 @@ from sqlalchemy.orm import sessionmaker
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
@ -26,9 +33,12 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
logger = setup_logger()
SYNC_DB_API = "psycopg2"
@ -37,11 +47,10 @@ ASYNC_DB_API = "asyncpg"
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
if LOG_POSTGRES_LATENCY:
# Function to log before query execution
@event.listens_for(Engine, "before_cursor_execute")
@ -105,10 +114,19 @@ def get_db_current_time(db_session: Session) -> datetime:
return result
# Regular expression to validate schema names to prevent SQL injection
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
def is_valid_schema_name(name: str) -> bool:
return SCHEMA_NAME_REGEX.match(name) is not None
class SqlEngine:
"""Class to manage a global sql alchemy engine (needed for proper resource control)
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
Will eventually subsume most of the standalone functions in this file.
Sync only for now"""
Sync only for now.
"""
_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
@ -137,16 +155,18 @@ class SqlEngine:
@classmethod
def init_engine(cls, **engine_kwargs: Any) -> None:
"""Allow the caller to init the engine with extra params. Different clients
such as the API server and different celery workers and tasks
need different settings."""
such as the API server and different Celery workers and tasks
need different settings.
"""
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
@classmethod
def get_engine(cls) -> Engine:
"""Gets the sql alchemy engine. Will init a default engine if init hasn't
already been called. You probably want to init first!"""
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
already been called. You probably want to init first!
"""
if not cls._engine:
with cls._lock:
if not cls._engine:
@ -178,7 +198,6 @@ def build_connection_string(
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
@ -193,7 +212,7 @@ def get_sqlalchemy_engine() -> Engine:
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# underlying asyncpg cannot accept application_name directly in the connection string
# Underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
_ASYNC_ENGINE = create_async_engine(
@ -211,25 +230,110 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()
# Context variable to store the current tenant ID
# This allows us to maintain tenant-specific context throughout the request lifecycle
# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups
# This context variable works in both synchronous and asynchronous contexts
# In async code, it's automatically carried across coroutines
# In sync code, it's managed per thread
current_tenant_id = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)
def get_session() -> Generator[Session, None, None]:
# The line below was added to monitor the latency caused by Postgres connections
# during API calls.
# with tracer.trace("db.get_session"):
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
# Dependency to get the current tenant ID and set the context variable
def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
return tenant_id
token = request.cookies.get("tenant_details")
if not token:
# If no token is present, use the default schema or handle accordingly
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
return tenant_id
try:
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
tenant_id = payload.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=400, detail="Invalid token: tenant_id missing"
)
if not is_valid_schema_name(tenant_id):
raise ValueError("Invalid tenant ID format")
current_tenant_id.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token format")
except ValueError as e:
# Let the 400 error bubble up
raise HTTPException(status_code=400, detail=str(e))
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")
def get_session_with_tenant(tenant_id: str | None = None) -> Session:
if tenant_id is None:
tenant_id = current_tenant_id.get()
if not is_valid_schema_name(tenant_id):
raise Exception("Invalid tenant ID")
engine = SqlEngine.get_engine()
session = Session(engine, expire_on_commit=False)
@event.listens_for(session, "after_begin")
def set_search_path(session: Session, transaction: Any, connection: Any) -> None:
connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id})
return session
def get_session(
tenant_id: str = Depends(get_current_tenant_id),
) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(
get_sqlalchemy_async_engine(), expire_on_commit=False
) as async_session:
async def get_async_session(
tenant_id: str = Depends(get_current_tenant_id),
) -> AsyncGenerator[AsyncSession, None]:
"""Generate an async database session with the appropriate tenant schema set."""
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
yield async_session
def get_session_context_manager() -> ContextManager[Session]:
"""Context manager for database sessions."""
return contextlib.contextmanager(get_session)()
def get_session_factory() -> sessionmaker[Session]:
"""Get a session factory."""
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory
async def warm_up_connections(
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
) -> None:
@ -251,10 +355,3 @@ async def warm_up_connections(
await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections:
await async_conn.close()
def get_session_factory() -> sessionmaker[Session]:
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory