mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Add only multi tenant dependency injection (#2588)
* add only dependency injection * quick typing fix * additional non-dependency context * update nits
This commit is contained in:
parent
b04e9e9b67
commit
493c3d7314
@ -406,3 +406,7 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
||||
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
|
||||
|
@ -41,6 +41,7 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
POSTGRES_DEFAULT_SCHEMA = "public"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
|
@ -1,4 +1,6 @@
|
||||
import contextlib
|
||||
import contextvars
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
@ -7,6 +9,10 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import ContextManager
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import create_engine
|
||||
@ -19,6 +25,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import POSTGRES_DB
|
||||
from danswer.configs.app_configs import POSTGRES_HOST
|
||||
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||
@ -26,9 +33,12 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
|
||||
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
SYNC_DB_API = "psycopg2"
|
||||
@ -37,11 +47,10 @@ ASYNC_DB_API = "asyncpg"
|
||||
# global so we don't create more than one engine per process
|
||||
# outside of being best practice, this is needed so we can properly pool
|
||||
# connections and not create a new pool on every request
|
||||
|
||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||
|
||||
SessionFactory: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
if LOG_POSTGRES_LATENCY:
|
||||
# Function to log before query execution
|
||||
@event.listens_for(Engine, "before_cursor_execute")
|
||||
@ -105,10 +114,19 @@ def get_db_current_time(db_session: Session) -> datetime:
|
||||
return result
|
||||
|
||||
|
||||
# Regular expression to validate schema names to prevent SQL injection
|
||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||
|
||||
|
||||
def is_valid_schema_name(name: str) -> bool:
|
||||
return SCHEMA_NAME_REGEX.match(name) is not None
|
||||
|
||||
|
||||
class SqlEngine:
|
||||
"""Class to manage a global sql alchemy engine (needed for proper resource control)
|
||||
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
|
||||
Will eventually subsume most of the standalone functions in this file.
|
||||
Sync only for now"""
|
||||
Sync only for now.
|
||||
"""
|
||||
|
||||
_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
@ -137,16 +155,18 @@ class SqlEngine:
|
||||
@classmethod
|
||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||
"""Allow the caller to init the engine with extra params. Different clients
|
||||
such as the API server and different celery workers and tasks
|
||||
need different settings."""
|
||||
such as the API server and different Celery workers and tasks
|
||||
need different settings.
|
||||
"""
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine(**engine_kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
"""Gets the sql alchemy engine. Will init a default engine if init hasn't
|
||||
already been called. You probably want to init first!"""
|
||||
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
|
||||
already been called. You probably want to init first!
|
||||
"""
|
||||
if not cls._engine:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
@ -178,7 +198,6 @@ def build_connection_string(
|
||||
) -> str:
|
||||
if app_name:
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
|
||||
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
@ -193,7 +212,7 @@ def get_sqlalchemy_engine() -> Engine:
|
||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
global _ASYNC_ENGINE
|
||||
if _ASYNC_ENGINE is None:
|
||||
# underlying asyncpg cannot accept application_name directly in the connection string
|
||||
# Underlying asyncpg cannot accept application_name directly in the connection string
|
||||
# https://github.com/MagicStack/asyncpg/issues/798
|
||||
connection_string = build_connection_string()
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
@ -211,25 +230,110 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
return contextlib.contextmanager(get_session)()
|
||||
# Context variable to store the current tenant ID
|
||||
# This allows us to maintain tenant-specific context throughout the request lifecycle
|
||||
# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups
|
||||
# This context variable works in both synchronous and asynchronous contexts
|
||||
# In async code, it's automatically carried across coroutines
|
||||
# In sync code, it's managed per thread
|
||||
current_tenant_id = contextvars.ContextVar(
|
||||
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
# The line below was added to monitor the latency caused by Postgres connections
|
||||
# during API calls.
|
||||
# with tracer.trace("db.get_session"):
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
|
||||
# Dependency to get the current tenant ID and set the context variable
|
||||
def get_current_tenant_id(request: Request) -> str:
|
||||
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
current_tenant_id.set(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
token = request.cookies.get("tenant_details")
|
||||
if not token:
|
||||
# If no token is present, use the default schema or handle accordingly
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
current_tenant_id.set(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
||||
tenant_id = payload.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid token: tenant_id missing"
|
||||
)
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise ValueError("Invalid tenant ID format")
|
||||
current_tenant_id.set(tenant_id)
|
||||
return tenant_id
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(status_code=401, detail="Invalid token format")
|
||||
except ValueError as e:
|
||||
# Let the 400 error bubble up
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
def get_session_with_tenant(tenant_id: str | None = None) -> Session:
|
||||
if tenant_id is None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise Exception("Invalid tenant ID")
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
session = Session(engine, expire_on_commit=False)
|
||||
|
||||
@event.listens_for(session, "after_begin")
|
||||
def set_search_path(session: Session, transaction: Any, connection: Any) -> None:
|
||||
connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id})
|
||||
|
||||
return session
|
||||
|
||||
|
||||
def get_session(
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield session
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with AsyncSession(
|
||||
get_sqlalchemy_async_engine(), expire_on_commit=False
|
||||
) as async_session:
|
||||
async def get_async_session(
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Generate an async database session with the appropriate tenant schema set."""
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield async_session
|
||||
|
||||
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
"""Context manager for database sessions."""
|
||||
return contextlib.contextmanager(get_session)()
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
"""Get a session factory."""
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
return SessionFactory
|
||||
|
||||
|
||||
async def warm_up_connections(
|
||||
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
|
||||
) -> None:
|
||||
@ -251,10 +355,3 @@ async def warm_up_connections(
|
||||
await async_conn.execute(text("SELECT 1"))
|
||||
for async_conn in async_connections:
|
||||
await async_conn.close()
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
return SessionFactory
|
||||
|
Loading…
x
Reference in New Issue
Block a user