mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-25 23:40:58 +02:00
Add only multi tenant dependency injection (#2588)
* add only dependency injection * quick typing fix * additional non-dependency context * update nits
This commit is contained in:
parent
b04e9e9b67
commit
493c3d7314
@ -406,3 +406,7 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
|||||||
ENTERPRISE_EDITION_ENABLED = (
|
ENTERPRISE_EDITION_ENABLED = (
|
||||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
||||||
|
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
|
||||||
|
@ -41,6 +41,7 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
|||||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||||
|
POSTGRES_DEFAULT_SCHEMA = "public"
|
||||||
|
|
||||||
# API Keys
|
# API Keys
|
||||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import contextvars
|
||||||
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@ -7,6 +9,10 @@ from datetime import datetime
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import ContextManager
|
from typing import ContextManager
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from fastapi import Request
|
||||||
from sqlalchemy import event
|
from sqlalchemy import event
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.engine import create_engine
|
from sqlalchemy.engine import create_engine
|
||||||
@ -19,6 +25,7 @@ from sqlalchemy.orm import sessionmaker
|
|||||||
|
|
||||||
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||||
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
|
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||||
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.app_configs import POSTGRES_DB
|
from danswer.configs.app_configs import POSTGRES_DB
|
||||||
from danswer.configs.app_configs import POSTGRES_HOST
|
from danswer.configs.app_configs import POSTGRES_HOST
|
||||||
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||||
@ -26,9 +33,12 @@ from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
|
|||||||
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||||
from danswer.configs.app_configs import POSTGRES_PORT
|
from danswer.configs.app_configs import POSTGRES_PORT
|
||||||
from danswer.configs.app_configs import POSTGRES_USER
|
from danswer.configs.app_configs import POSTGRES_USER
|
||||||
|
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||||
|
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
SYNC_DB_API = "psycopg2"
|
SYNC_DB_API = "psycopg2"
|
||||||
@ -37,11 +47,10 @@ ASYNC_DB_API = "asyncpg"
|
|||||||
# global so we don't create more than one engine per process
|
# global so we don't create more than one engine per process
|
||||||
# outside of being best practice, this is needed so we can properly pool
|
# outside of being best practice, this is needed so we can properly pool
|
||||||
# connections and not create a new pool on every request
|
# connections and not create a new pool on every request
|
||||||
|
|
||||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||||
|
|
||||||
SessionFactory: sessionmaker[Session] | None = None
|
SessionFactory: sessionmaker[Session] | None = None
|
||||||
|
|
||||||
|
|
||||||
if LOG_POSTGRES_LATENCY:
|
if LOG_POSTGRES_LATENCY:
|
||||||
# Function to log before query execution
|
# Function to log before query execution
|
||||||
@event.listens_for(Engine, "before_cursor_execute")
|
@event.listens_for(Engine, "before_cursor_execute")
|
||||||
@ -105,10 +114,19 @@ def get_db_current_time(db_session: Session) -> datetime:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Regular expression to validate schema names to prevent SQL injection
|
||||||
|
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_schema_name(name: str) -> bool:
|
||||||
|
return SCHEMA_NAME_REGEX.match(name) is not None
|
||||||
|
|
||||||
|
|
||||||
class SqlEngine:
|
class SqlEngine:
|
||||||
"""Class to manage a global sql alchemy engine (needed for proper resource control)
|
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
|
||||||
Will eventually subsume most of the standalone functions in this file.
|
Will eventually subsume most of the standalone functions in this file.
|
||||||
Sync only for now"""
|
Sync only for now.
|
||||||
|
"""
|
||||||
|
|
||||||
_engine: Engine | None = None
|
_engine: Engine | None = None
|
||||||
_lock: threading.Lock = threading.Lock()
|
_lock: threading.Lock = threading.Lock()
|
||||||
@ -137,16 +155,18 @@ class SqlEngine:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||||
"""Allow the caller to init the engine with extra params. Different clients
|
"""Allow the caller to init the engine with extra params. Different clients
|
||||||
such as the API server and different celery workers and tasks
|
such as the API server and different Celery workers and tasks
|
||||||
need different settings."""
|
need different settings.
|
||||||
|
"""
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if not cls._engine:
|
if not cls._engine:
|
||||||
cls._engine = cls._init_engine(**engine_kwargs)
|
cls._engine = cls._init_engine(**engine_kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_engine(cls) -> Engine:
|
def get_engine(cls) -> Engine:
|
||||||
"""Gets the sql alchemy engine. Will init a default engine if init hasn't
|
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
|
||||||
already been called. You probably want to init first!"""
|
already been called. You probably want to init first!
|
||||||
|
"""
|
||||||
if not cls._engine:
|
if not cls._engine:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if not cls._engine:
|
if not cls._engine:
|
||||||
@ -178,7 +198,6 @@ def build_connection_string(
|
|||||||
) -> str:
|
) -> str:
|
||||||
if app_name:
|
if app_name:
|
||||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
|
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
|
||||||
|
|
||||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||||
|
|
||||||
|
|
||||||
@ -193,7 +212,7 @@ def get_sqlalchemy_engine() -> Engine:
|
|||||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||||
global _ASYNC_ENGINE
|
global _ASYNC_ENGINE
|
||||||
if _ASYNC_ENGINE is None:
|
if _ASYNC_ENGINE is None:
|
||||||
# underlying asyncpg cannot accept application_name directly in the connection string
|
# Underlying asyncpg cannot accept application_name directly in the connection string
|
||||||
# https://github.com/MagicStack/asyncpg/issues/798
|
# https://github.com/MagicStack/asyncpg/issues/798
|
||||||
connection_string = build_connection_string()
|
connection_string = build_connection_string()
|
||||||
_ASYNC_ENGINE = create_async_engine(
|
_ASYNC_ENGINE = create_async_engine(
|
||||||
@ -211,25 +230,110 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
|||||||
return _ASYNC_ENGINE
|
return _ASYNC_ENGINE
|
||||||
|
|
||||||
|
|
||||||
def get_session_context_manager() -> ContextManager[Session]:
|
# Context variable to store the current tenant ID
|
||||||
return contextlib.contextmanager(get_session)()
|
# This allows us to maintain tenant-specific context throughout the request lifecycle
|
||||||
|
# The default value is set to POSTGRES_DEFAULT_SCHEMA for non-multi-tenant setups
|
||||||
|
# This context variable works in both synchronous and asynchronous contexts
|
||||||
|
# In async code, it's automatically carried across coroutines
|
||||||
|
# In sync code, it's managed per thread
|
||||||
|
current_tenant_id = contextvars.ContextVar(
|
||||||
|
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_session() -> Generator[Session, None, None]:
|
# Dependency to get the current tenant ID and set the context variable
|
||||||
# The line below was added to monitor the latency caused by Postgres connections
|
def get_current_tenant_id(request: Request) -> str:
|
||||||
# during API calls.
|
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
||||||
# with tracer.trace("db.get_session"):
|
if not MULTI_TENANT:
|
||||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
|
current_tenant_id.set(tenant_id)
|
||||||
|
return tenant_id
|
||||||
|
|
||||||
|
token = request.cookies.get("tenant_details")
|
||||||
|
if not token:
|
||||||
|
# If no token is present, use the default schema or handle accordingly
|
||||||
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
|
current_tenant_id.set(tenant_id)
|
||||||
|
return tenant_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
||||||
|
tenant_id = payload.get("tenant_id")
|
||||||
|
if not tenant_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Invalid token: tenant_id missing"
|
||||||
|
)
|
||||||
|
if not is_valid_schema_name(tenant_id):
|
||||||
|
raise ValueError("Invalid tenant ID format")
|
||||||
|
current_tenant_id.set(tenant_id)
|
||||||
|
return tenant_id
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token format")
|
||||||
|
except ValueError as e:
|
||||||
|
# Let the 400 error bubble up
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_with_tenant(tenant_id: str | None = None) -> Session:
|
||||||
|
if tenant_id is None:
|
||||||
|
tenant_id = current_tenant_id.get()
|
||||||
|
|
||||||
|
if not is_valid_schema_name(tenant_id):
|
||||||
|
raise Exception("Invalid tenant ID")
|
||||||
|
|
||||||
|
engine = SqlEngine.get_engine()
|
||||||
|
session = Session(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
@event.listens_for(session, "after_begin")
|
||||||
|
def set_search_path(session: Session, transaction: Any, connection: Any) -> None:
|
||||||
|
connection.execute(text("SET search_path TO :schema"), {"schema": tenant_id})
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def get_session(
|
||||||
|
tenant_id: str = Depends(get_current_tenant_id),
|
||||||
|
) -> Generator[Session, None, None]:
|
||||||
|
"""Generate a database session with the appropriate tenant schema set."""
|
||||||
|
engine = get_sqlalchemy_engine()
|
||||||
|
with Session(engine, expire_on_commit=False) as session:
|
||||||
|
if MULTI_TENANT:
|
||||||
|
if not is_valid_schema_name(tenant_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||||
|
# Set the search_path to the tenant's schema
|
||||||
|
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
async def get_async_session(
|
||||||
async with AsyncSession(
|
tenant_id: str = Depends(get_current_tenant_id),
|
||||||
get_sqlalchemy_async_engine(), expire_on_commit=False
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
) as async_session:
|
"""Generate an async database session with the appropriate tenant schema set."""
|
||||||
|
engine = get_sqlalchemy_async_engine()
|
||||||
|
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||||
|
if MULTI_TENANT:
|
||||||
|
if not is_valid_schema_name(tenant_id):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||||
|
# Set the search_path to the tenant's schema
|
||||||
|
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||||
yield async_session
|
yield async_session
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_context_manager() -> ContextManager[Session]:
|
||||||
|
"""Context manager for database sessions."""
|
||||||
|
return contextlib.contextmanager(get_session)()
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_factory() -> sessionmaker[Session]:
|
||||||
|
"""Get a session factory."""
|
||||||
|
global SessionFactory
|
||||||
|
if SessionFactory is None:
|
||||||
|
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||||
|
return SessionFactory
|
||||||
|
|
||||||
|
|
||||||
async def warm_up_connections(
|
async def warm_up_connections(
|
||||||
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
|
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -251,10 +355,3 @@ async def warm_up_connections(
|
|||||||
await async_conn.execute(text("SELECT 1"))
|
await async_conn.execute(text("SELECT 1"))
|
||||||
for async_conn in async_connections:
|
for async_conn in async_connections:
|
||||||
await async_conn.close()
|
await async_conn.close()
|
||||||
|
|
||||||
|
|
||||||
def get_session_factory() -> sessionmaker[Session]:
|
|
||||||
global SessionFactory
|
|
||||||
if SessionFactory is None:
|
|
||||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
|
||||||
return SessionFactory
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user