IAM Auth for RDS (#3479)

* k

* functional iam auth

* k

* k

* improve typing

* add deployment options

* cleanup

* quick clean up

* minor cleanup

* additional clarity for db session operations

* nit

* k

* k

* update configs

* docker compose spacing
This commit is contained in:
pablonyx
2024-12-17 14:02:37 -08:00
committed by GitHub
parent 28598694b1
commit 8db6d49fe5
15 changed files with 282 additions and 139 deletions

View File

@@ -1,39 +1,49 @@
from typing import Any, Literal
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
from sqlalchemy.engine.base import Connection from sqlalchemy.engine.base import Connection
from typing import Literal import os
import ssl
import asyncio import asyncio
from logging.config import fileConfig
import logging import logging
from logging.config import fileConfig
from alembic import context from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.engine import build_connection_string
from onyx.db.models import Base from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine import get_all_tenant_ids
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# Alembic Config object # Alembic Config object
config = context.config config = context.config
# Interpret the config file for Python logging.
if config.config_file_name is not None and config.attributes.get( if config.config_file_name is not None and config.attributes.get(
"configure_logger", True "configure_logger", True
): ):
fileConfig(config.config_file_name) fileConfig(config.config_file_name)
# Add your model's MetaData object here for 'autogenerate' support
target_metadata = [Base.metadata, ResultModelBase.metadata] target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"} EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
# Set up logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
if USE_IAM_AUTH:
if not os.path.exists(SSL_CERT_FILE):
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
def include_object( def include_object(
object: SchemaItem, object: SchemaItem,
@@ -49,20 +59,12 @@ def include_object(
reflected: bool, reflected: bool,
compare_to: SchemaItem | None, compare_to: SchemaItem | None,
) -> bool: ) -> bool:
"""
Determines whether a database object should be included in migrations.
Excludes specified tables from migrations.
"""
if type_ == "table" and name in EXCLUDE_TABLES: if type_ == "table" and name in EXCLUDE_TABLES:
return False return False
return True return True
def get_schema_options() -> tuple[str, bool, bool]: def get_schema_options() -> tuple[str, bool, bool]:
"""
Parses command-line options passed via '-x' in Alembic commands.
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
"""
x_args_raw = context.get_x_argument() x_args_raw = context.get_x_argument()
x_args = {} x_args = {}
for arg in x_args_raw: for arg in x_args_raw:
@@ -90,16 +92,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
def do_run_migrations( def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool connection: Connection, schema_name: str, create_schema: bool
) -> None: ) -> None:
"""
Executes migrations in the specified schema.
"""
logger.info(f"About to migrate schema: {schema_name}") logger.info(f"About to migrate schema: {schema_name}")
if create_schema: if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"')) connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT")) connection.execute(text("COMMIT"))
# Set search_path to the target schema
connection.execute(text(f'SET search_path TO "{schema_name}"')) connection.execute(text(f'SET search_path TO "{schema_name}"'))
context.configure( context.configure(
@@ -117,11 +115,25 @@ def do_run_migrations(
context.run_migrations() context.run_migrations()
def provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
if USE_IAM_AUTH:
# Database connection settings
region = AWS_REGION
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
# Get IAM authentication token
token = get_iam_auth_token(host, port, user, region)
# For Alembic / SQLAlchemy in this context, set SSL and password
cparams["password"] = token
cparams["ssl"] = ssl_context
async def run_async_migrations() -> None: async def run_async_migrations() -> None:
"""
Determines whether to run migrations for a single schema or all schemas,
and executes migrations accordingly.
"""
schema_name, create_schema, upgrade_all_tenants = get_schema_options() schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine( engine = create_async_engine(
@@ -129,10 +141,16 @@ async def run_async_migrations() -> None:
poolclass=pool.NullPool, poolclass=pool.NullPool,
) )
if upgrade_all_tenants: if USE_IAM_AUTH:
# Run migrations for all tenant schemas sequentially
tenant_schemas = get_all_tenant_ids()
@event.listens_for(engine.sync_engine, "do_connect")
def event_provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
for schema in tenant_schemas: for schema in tenant_schemas:
try: try:
logger.info(f"Migrating schema: {schema}") logger.info(f"Migrating schema: {schema}")
@@ -162,15 +180,20 @@ async def run_async_migrations() -> None:
def run_migrations_offline() -> None: def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode.
"""
schema_name, _, upgrade_all_tenants = get_schema_options() schema_name, _, upgrade_all_tenants = get_schema_options()
url = build_connection_string() url = build_connection_string()
if upgrade_all_tenants: if upgrade_all_tenants:
# Run offline migrations for all tenant schemas
engine = create_async_engine(url) engine = create_async_engine(url)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
def event_provide_iam_token_for_alembic_offline(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
tenant_schemas = get_all_tenant_ids() tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose() engine.sync_engine.dispose()
@@ -207,9 +230,6 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None: def run_migrations_online() -> None:
"""
Runs migrations in 'online' mode using an asynchronous engine.
"""
asyncio.run(run_async_migrations()) asyncio.run(run_async_migrations())

View File

@@ -144,6 +144,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost" POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432" POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres" POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
POSTGRES_API_SERVER_POOL_SIZE = int( POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40 os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
@@ -174,6 +175,9 @@ try:
except ValueError: except ValueError:
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true" REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost" REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))

View File

@@ -49,6 +49,7 @@ POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions" POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown" POSTGRES_UNKNOWN_APP_NAME = "unknown"
SSL_CERT_FILE = "bundle.pem"
# API Keys # API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__" DANSWER_API_KEY_PREFIX = "API_KEY__"
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai" DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"

View File

@@ -1,5 +1,7 @@
import contextlib import contextlib
import os
import re import re
import ssl
import threading import threading
import time import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
@@ -10,6 +12,8 @@ from datetime import datetime
from typing import Any from typing import Any
from typing import ContextManager from typing import ContextManager
import asyncpg # type: ignore
import boto3
import jwt import jwt
from fastapi import HTTPException from fastapi import HTTPException
from fastapi import Request from fastapi import Request
@@ -23,6 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from onyx.configs.app_configs import AWS_REGION
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
@@ -37,6 +42,7 @@ from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import USER_AUTH_SECRET from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
from onyx.server.utils import BasicAuthenticationError from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT from shared_configs.configs import MULTI_TENANT
@@ -49,28 +55,87 @@ logger = setup_logger()
SYNC_DB_API = "psycopg2" SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg" ASYNC_DB_API = "asyncpg"
# global so we don't create more than one engine per process USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
# Global so we don't create more than one engine per process
_ASYNC_ENGINE: AsyncEngine | None = None _ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None SessionFactory: sessionmaker[Session] | None = None
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
"""Create an SSL context if IAM authentication is enabled, else return None."""
if USE_IAM_AUTH:
return ssl.create_default_context(cafile=SSL_CERT_FILE)
return None
ssl_context = create_ssl_context_if_iam()
def get_iam_auth_token(
host: str, port: str, user: str, region: str = "us-east-2"
) -> str:
"""
Generate an IAM authentication token using boto3.
"""
client = boto3.client("rds", region_name=region)
token = client.generate_db_auth_token(
DBHostname=host, Port=int(port), DBUsername=user
)
return token
def configure_psycopg2_iam_auth(
cparams: dict[str, Any], host: str, port: str, user: str, region: str
) -> None:
"""
Configure cparams for psycopg2 with IAM token and SSL.
"""
token = get_iam_auth_token(host, port, user, region)
cparams["password"] = token
cparams["sslmode"] = "require"
cparams["sslrootcert"] = SSL_CERT_FILE
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
use_iam: bool = USE_IAM_AUTH,
region: str = "us-west-2",
) -> str:
if use_iam:
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
else:
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
# For asyncpg, do not include application_name in the connection string
if app_name and db_api != "asyncpg":
if "?" in base_conn_str:
return f"{base_conn_str}&application_name={app_name}"
else:
return f"{base_conn_str}?application_name={app_name}"
return base_conn_str
if LOG_POSTGRES_LATENCY: if LOG_POSTGRES_LATENCY:
# Function to log before query execution
@event.listens_for(Engine, "before_cursor_execute") @event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute( # type: ignore def before_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany conn, cursor, statement, parameters, context, executemany
): ):
conn.info["query_start_time"] = time.time() conn.info["query_start_time"] = time.time()
# Function to log after query execution
@event.listens_for(Engine, "after_cursor_execute") @event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute( # type: ignore def after_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany conn, cursor, statement, parameters, context, executemany
): ):
total_time = time.time() - conn.info["query_start_time"] total_time = time.time() - conn.info["query_start_time"]
# don't spam TOO hard
if total_time > 0.1: if total_time > 0.1:
logger.debug( logger.debug(
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds" f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
@@ -78,7 +143,6 @@ if LOG_POSTGRES_LATENCY:
if LOG_POSTGRES_CONN_COUNTS: if LOG_POSTGRES_CONN_COUNTS:
# Global counter for connection checkouts and checkins
checkout_count = 0 checkout_count = 0
checkin_count = 0 checkin_count = 0
@@ -105,21 +169,13 @@ if LOG_POSTGRES_CONN_COUNTS:
logger.debug(f"Total connection checkins: {checkin_count}") logger.debug(f"Total connection checkins: {checkin_count}")
"""END DEBUGGING LOGGING"""
def get_db_current_time(db_session: Session) -> datetime: def get_db_current_time(db_session: Session) -> datetime:
"""Get the current time from Postgres representing the start of the transaction
Within the same transaction this value will not update
This datetime object returned should be timezone aware, default Postgres timezone is UTC
"""
result = db_session.execute(text("SELECT NOW()")).scalar() result = db_session.execute(text("SELECT NOW()")).scalar()
if result is None: if result is None:
raise ValueError("Database did not return a time") raise ValueError("Database did not return a time")
return result return result
# Regular expression to validate schema names to prevent SQL injection
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$") SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
@@ -128,16 +184,9 @@ def is_valid_schema_name(name: str) -> bool:
class SqlEngine: class SqlEngine:
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
Will eventually subsume most of the standalone functions in this file.
Sync only for now.
"""
_engine: Engine | None = None _engine: Engine | None = None
_lock: threading.Lock = threading.Lock() _lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME _app_name: str = POSTGRES_UNKNOWN_APP_NAME
# Default parameters for engine creation
DEFAULT_ENGINE_KWARGS = { DEFAULT_ENGINE_KWARGS = {
"pool_size": 20, "pool_size": 20,
"max_overflow": 5, "max_overflow": 5,
@@ -145,33 +194,27 @@ class SqlEngine:
"pool_recycle": POSTGRES_POOL_RECYCLE, "pool_recycle": POSTGRES_POOL_RECYCLE,
} }
def __init__(self) -> None:
pass
@classmethod @classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine: def _init_engine(cls, **engine_kwargs: Any) -> Engine:
"""Private helper method to create and return an Engine."""
connection_string = build_connection_string( connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync" db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
) )
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs} merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
return create_engine(connection_string, **merged_kwargs) engine = create_engine(connection_string, **merged_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
return engine
@classmethod @classmethod
def init_engine(cls, **engine_kwargs: Any) -> None: def init_engine(cls, **engine_kwargs: Any) -> None:
"""Allow the caller to init the engine with extra params. Different clients
such as the API server and different Celery workers and tasks
need different settings.
"""
with cls._lock: with cls._lock:
if not cls._engine: if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs) cls._engine = cls._init_engine(**engine_kwargs)
@classmethod @classmethod
def get_engine(cls) -> Engine: def get_engine(cls) -> Engine:
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
already been called. You probably want to init first!
"""
if not cls._engine: if not cls._engine:
with cls._lock: with cls._lock:
if not cls._engine: if not cls._engine:
@@ -180,12 +223,10 @@ class SqlEngine:
@classmethod @classmethod
def set_app_name(cls, app_name: str) -> None: def set_app_name(cls, app_name: str) -> None:
"""Class method to set the app name."""
cls._app_name = app_name cls._app_name = app_name
@classmethod @classmethod
def get_app_name(cls) -> str: def get_app_name(cls) -> str:
"""Class method to get current app name."""
if not cls._app_name: if not cls._app_name:
return "" return ""
return cls._app_name return cls._app_name
@@ -217,56 +258,71 @@ def get_all_tenant_ids() -> list[str] | list[None]:
for tenant in tenant_ids for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX) if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
] ]
return valid_tenants return valid_tenants
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def get_sqlalchemy_engine() -> Engine: def get_sqlalchemy_engine() -> Engine:
return SqlEngine.get_engine() return SqlEngine.get_engine()
async def get_async_connection() -> Any:
"""
Custom connection function for async engine when using IAM auth.
"""
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
db = POSTGRES_DB
token = get_iam_auth_token(host, port, user, AWS_REGION)
# asyncpg requires 'ssl="require"' if SSL needed
return await asyncpg.connect(
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
)
def get_sqlalchemy_async_engine() -> AsyncEngine: def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE global _ASYNC_ENGINE
if _ASYNC_ENGINE is None: if _ASYNC_ENGINE is None:
# Underlying asyncpg cannot accept application_name directly in the connection string app_name = SqlEngine.get_app_name() + "_async"
# https://github.com/MagicStack/asyncpg/issues/798 connection_string = build_connection_string(
connection_string = build_connection_string() db_api=ASYNC_DB_API,
use_iam=USE_IAM_AUTH,
)
connect_args: dict[str, Any] = {}
if app_name:
connect_args["server_settings"] = {"application_name": app_name}
connect_args["ssl"] = ssl_context
_ASYNC_ENGINE = create_async_engine( _ASYNC_ENGINE = create_async_engine(
connection_string, connection_string,
connect_args={ connect_args=connect_args,
"server_settings": {
"application_name": SqlEngine.get_app_name() + "_async"
}
},
# async engine is only used by API server, so we can use those values
# here as well
pool_size=POSTGRES_API_SERVER_POOL_SIZE, pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW, max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
pool_pre_ping=POSTGRES_POOL_PRE_PING, pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE, pool_recycle=POSTGRES_POOL_RECYCLE,
) )
if USE_IAM_AUTH:
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
def provide_iam_token_async(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
# For async engine using asyncpg, we still need to set the IAM token here.
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
token = get_iam_auth_token(host, port, user, AWS_REGION)
cparams["password"] = token
cparams["ssl"] = ssl_context
return _ASYNC_ENGINE return _ASYNC_ENGINE
# Dependency to get the current tenant ID
# If no token is present, uses the default schema for this use case
def get_current_tenant_id(request: Request) -> str: def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT: if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -275,7 +331,6 @@ def get_current_tenant_id(request: Request) -> str:
token = request.cookies.get("fastapiusersauth") token = request.cookies.get("fastapiusersauth")
if not token: if not token:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get() current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
# If no token is present, use the default schema or handle accordingly
return current_value return current_value
try: try:
@@ -289,7 +344,6 @@ def get_current_tenant_id(request: Request) -> str:
if not is_valid_schema_name(tenant_id): if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format") raise HTTPException(status_code=400, detail="Invalid tenant ID format")
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id return tenant_id
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
return CURRENT_TENANT_ID_CONTEXTVAR.get() return CURRENT_TENANT_ID_CONTEXTVAR.get()
@@ -316,7 +370,6 @@ async def get_async_session_with_tenant(
async with async_session_factory() as session: async with async_session_factory() as session:
try: try:
# Set the search_path to the tenant's schema
await session.execute(text(f'SET search_path = "{tenant_id}"')) await session.execute(text(f'SET search_path = "{tenant_id}"'))
if POSTGRES_IDLE_SESSIONS_TIMEOUT: if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute( await session.execute(
@@ -326,8 +379,6 @@ async def get_async_session_with_tenant(
) )
except Exception: except Exception:
logger.exception("Error setting search_path.") logger.exception("Error setting search_path.")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise raise
else: else:
yield session yield session
@@ -335,9 +386,6 @@ async def get_async_session_with_tenant(
@contextmanager @contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]: def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session: with get_session_with_tenant(tenant_id) as session:
yield session yield session
@@ -349,7 +397,6 @@ def get_session_with_tenant(
) -> Generator[Session, None, None]: ) -> Generator[Session, None, None]:
""" """
Generate a database session for a specific tenant. Generate a database session for a specific tenant.
This function: This function:
1. Sets the database schema to the specified tenant's schema. 1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session. 2. Preserves the tenant ID across the session.
@@ -357,27 +404,20 @@ def get_session_with_tenant(
4. Uses the default schema if no tenant ID is provided. 4. Uses the default schema if no tenant ID is provided.
""" """
engine = get_sqlalchemy_engine() engine = get_sqlalchemy_engine()
# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None: if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout) event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id): if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID") raise HTTPException(status_code=400, detail="Invalid tenant ID")
try: try:
# Establish a raw connection
with engine.connect() as connection: with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor() cursor = dbapi_connection.cursor()
try: try:
cursor.execute(f'SET search_path = "{tenant_id}"') cursor.execute(f'SET search_path = "{tenant_id}"')
@@ -390,21 +430,17 @@ def get_session_with_tenant(
finally: finally:
cursor.close() cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session: with Session(bind=connection, expire_on_commit=False) as session:
try: try:
yield session yield session
finally: finally:
# Reset search_path to default after the session is used
if MULTI_TENANT: if MULTI_TENANT:
cursor = dbapi_connection.cursor() cursor = dbapi_connection.cursor()
try: try:
cursor.execute('SET search_path TO "$user", public') cursor.execute('SET search_path TO "$user", public')
finally: finally:
cursor.close() cursor.close()
finally: finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id) CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
@@ -424,12 +460,9 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
def get_session() -> Generator[Session, None, None]: def get_session() -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT: if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise BasicAuthenticationError( raise BasicAuthenticationError(detail="User must authenticate")
detail="User must authenticate",
)
engine = get_sqlalchemy_engine() engine = get_sqlalchemy_engine()
@@ -437,20 +470,17 @@ def get_session() -> Generator[Session, None, None]:
if MULTI_TENANT: if MULTI_TENANT:
if not is_valid_schema_name(tenant_id): if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID") raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"')) session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""Generate an async database session with the appropriate tenant schema set."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
engine = get_sqlalchemy_async_engine() engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session: async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT: if MULTI_TENANT:
if not is_valid_schema_name(tenant_id): if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID") raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
await async_session.execute(text(f'SET search_path = "{tenant_id}"')) await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
yield async_session yield async_session
@@ -461,7 +491,6 @@ def get_session_context_manager() -> ContextManager[Session]:
def get_session_factory() -> sessionmaker[Session]: def get_session_factory() -> sessionmaker[Session]:
"""Get a session factory."""
global SessionFactory global SessionFactory
if SessionFactory is None: if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine()) SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
@@ -489,3 +518,13 @@ async def warm_up_connections(
await async_conn.execute(text("SELECT 1")) await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections: for async_conn in async_connections:
await async_conn.close() await async_conn.close()
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
if USE_IAM_AUTH:
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
region = os.getenv("AWS_REGION", "us-east-2")
# Configure for psycopg2 with IAM token
configure_psycopg2_iam_auth(cparams, host, port, user, region)

View File

@@ -14,7 +14,7 @@ spec:
spec: spec:
containers: containers:
- name: celery-beat - name: celery-beat
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20 image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent imagePullPolicy: IfNotPresent
command: command:
[ [

View File

@@ -14,7 +14,7 @@ spec:
spec: spec:
containers: containers:
- name: celery-worker-heavy - name: celery-worker-heavy
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20 image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent imagePullPolicy: IfNotPresent
command: command:
[ [

View File

@@ -14,7 +14,7 @@ spec:
spec: spec:
containers: containers:
- name: celery-worker-indexing - name: celery-worker-indexing
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20 image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent imagePullPolicy: IfNotPresent
command: command:
[ [

View File

@@ -14,7 +14,7 @@ spec:
spec: spec:
containers: containers:
- name: celery-worker-light - name: celery-worker-light
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20 image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent imagePullPolicy: IfNotPresent
command: command:
[ [

View File

@@ -14,7 +14,7 @@ spec:
spec: spec:
containers: containers:
- name: celery-worker-primary - name: celery-worker-primary
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20 image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent imagePullPolicy: IfNotPresent
command: command:
[ [

View File

@@ -103,6 +103,13 @@ services:
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-} - API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
# Seeding configuration # Seeding configuration
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@@ -223,6 +230,13 @@ services:
# Enterprise Edition stuff # Enterprise Edition stuff
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:

View File

@@ -91,6 +91,13 @@ services:
# Enterprise Edition only # Enterprise Edition only
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-} - API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@@ -192,6 +199,13 @@ services:
# Enterprise Edition only # Enterprise Edition only
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-} - API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:

View File

@@ -22,6 +22,13 @@ services:
- VESPA_HOST=index - VESPA_HOST=index
- REDIS_HOST=cache - REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@@ -52,6 +59,13 @@ services:
- REDIS_HOST=cache - REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:

View File

@@ -23,6 +23,13 @@ services:
- VESPA_HOST=index - VESPA_HOST=index
- REDIS_HOST=cache - REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@@ -57,6 +64,13 @@ services:
- REDIS_HOST=cache - REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server} - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@@ -223,7 +237,7 @@ services:
volumes: volumes:
- ../data/certbot/conf:/etc/letsencrypt - ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot - ../data/certbot/www:/var/www/certbot
logging: logging::wq
driver: json-file driver: json-file
options: options:
max-size: "50m" max-size: "50m"
@@ -245,3 +259,6 @@ volumes:
# Created by the container itself # Created by the container itself
model_cache_huggingface: model_cache_huggingface:
indexing_huggingface_model_cache: indexing_huggingface_model_cache:

View File

@@ -60,3 +60,12 @@ spec:
envFrom: envFrom:
- configMapRef: - configMapRef:
name: env-configmap name: env-configmap
# Uncomment if you are using IAM auth for Postgres
# volumeMounts:
# - name: bundle-pem
# mountPath: "/app/certs"
# readOnly: true
# volumes:
# - name: bundle-pem
# secret:
# secretName: bundle-pem-secret

View File

@@ -43,6 +43,7 @@ spec:
# - name: my-ca-cert-volume # - name: my-ca-cert-volume
# mountPath: /etc/ssl/certs/custom-ca.crt # mountPath: /etc/ssl/certs/custom-ca.crt
# subPath: my-ca.crt # subPath: my-ca.crt
# Optional volume for CA certificate # Optional volume for CA certificate
# volumes: # volumes:
# - name: my-cas-cert-volume # - name: my-cas-cert-volume
@@ -51,3 +52,13 @@ spec:
# items: # items:
# - key: my-ca.crt # - key: my-ca.crt
# path: my-ca.crt # path: my-ca.crt
# Uncomment if you are using IAM auth for Postgres
# volumeMounts:
# - name: bundle-pem
# mountPath: "/app/certs"
# readOnly: true
# volumes:
# - name: bundle-pem
# secret:
# secretName: bundle-pem-secret