mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-08 06:00:05 +02:00
IAM Auth for RDS (#3479)
* k * functional iam auth * k * k * improve typing * add deployment options * cleanup * quick clean up * minor cleanup * additional clarity for db session operations * nit * k * k * update configs * docker compose spacing
This commit is contained in:
@@ -1,39 +1,49 @@
|
|||||||
|
from typing import Any, Literal
|
||||||
|
from onyx.db.engine import get_iam_auth_token
|
||||||
|
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||||
|
from onyx.configs.app_configs import POSTGRES_HOST
|
||||||
|
from onyx.configs.app_configs import POSTGRES_PORT
|
||||||
|
from onyx.configs.app_configs import POSTGRES_USER
|
||||||
|
from onyx.configs.app_configs import AWS_REGION
|
||||||
|
from onyx.db.engine import build_connection_string
|
||||||
|
from onyx.db.engine import get_all_tenant_ids
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlalchemy import text
|
||||||
from sqlalchemy.engine.base import Connection
|
from sqlalchemy.engine.base import Connection
|
||||||
from typing import Literal
|
import os
|
||||||
|
import ssl
|
||||||
import asyncio
|
import asyncio
|
||||||
from logging.config import fileConfig
|
|
||||||
import logging
|
import logging
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sqlalchemy import pool
|
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
from sqlalchemy.sql import text
|
|
||||||
from sqlalchemy.sql.schema import SchemaItem
|
from sqlalchemy.sql.schema import SchemaItem
|
||||||
|
from onyx.configs.constants import SSL_CERT_FILE
|
||||||
from shared_configs.configs import MULTI_TENANT
|
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||||
from onyx.db.engine import build_connection_string
|
|
||||||
from onyx.db.models import Base
|
from onyx.db.models import Base
|
||||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||||
from onyx.db.engine import get_all_tenant_ids
|
|
||||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
|
||||||
|
|
||||||
# Alembic Config object
|
# Alembic Config object
|
||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
|
||||||
if config.config_file_name is not None and config.attributes.get(
|
if config.config_file_name is not None and config.attributes.get(
|
||||||
"configure_logger", True
|
"configure_logger", True
|
||||||
):
|
):
|
||||||
fileConfig(config.config_file_name)
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
# Add your model's MetaData object here for 'autogenerate' support
|
|
||||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||||
|
|
||||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||||
|
|
||||||
# Set up logging
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ssl_context: ssl.SSLContext | None = None
|
||||||
|
if USE_IAM_AUTH:
|
||||||
|
if not os.path.exists(SSL_CERT_FILE):
|
||||||
|
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
|
||||||
|
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||||
|
|
||||||
|
|
||||||
def include_object(
|
def include_object(
|
||||||
object: SchemaItem,
|
object: SchemaItem,
|
||||||
@@ -49,20 +59,12 @@ def include_object(
|
|||||||
reflected: bool,
|
reflected: bool,
|
||||||
compare_to: SchemaItem | None,
|
compare_to: SchemaItem | None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
|
||||||
Determines whether a database object should be included in migrations.
|
|
||||||
Excludes specified tables from migrations.
|
|
||||||
"""
|
|
||||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_schema_options() -> tuple[str, bool, bool]:
|
def get_schema_options() -> tuple[str, bool, bool]:
|
||||||
"""
|
|
||||||
Parses command-line options passed via '-x' in Alembic commands.
|
|
||||||
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
|
|
||||||
"""
|
|
||||||
x_args_raw = context.get_x_argument()
|
x_args_raw = context.get_x_argument()
|
||||||
x_args = {}
|
x_args = {}
|
||||||
for arg in x_args_raw:
|
for arg in x_args_raw:
|
||||||
@@ -90,16 +92,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
|||||||
def do_run_migrations(
|
def do_run_migrations(
|
||||||
connection: Connection, schema_name: str, create_schema: bool
|
connection: Connection, schema_name: str, create_schema: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
Executes migrations in the specified schema.
|
|
||||||
"""
|
|
||||||
logger.info(f"About to migrate schema: {schema_name}")
|
logger.info(f"About to migrate schema: {schema_name}")
|
||||||
|
|
||||||
if create_schema:
|
if create_schema:
|
||||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||||
connection.execute(text("COMMIT"))
|
connection.execute(text("COMMIT"))
|
||||||
|
|
||||||
# Set search_path to the target schema
|
|
||||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||||
|
|
||||||
context.configure(
|
context.configure(
|
||||||
@@ -117,11 +115,25 @@ def do_run_migrations(
|
|||||||
context.run_migrations()
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def provide_iam_token_for_alembic(
|
||||||
|
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||||
|
) -> None:
|
||||||
|
if USE_IAM_AUTH:
|
||||||
|
# Database connection settings
|
||||||
|
region = AWS_REGION
|
||||||
|
host = POSTGRES_HOST
|
||||||
|
port = POSTGRES_PORT
|
||||||
|
user = POSTGRES_USER
|
||||||
|
|
||||||
|
# Get IAM authentication token
|
||||||
|
token = get_iam_auth_token(host, port, user, region)
|
||||||
|
|
||||||
|
# For Alembic / SQLAlchemy in this context, set SSL and password
|
||||||
|
cparams["password"] = token
|
||||||
|
cparams["ssl"] = ssl_context
|
||||||
|
|
||||||
|
|
||||||
async def run_async_migrations() -> None:
|
async def run_async_migrations() -> None:
|
||||||
"""
|
|
||||||
Determines whether to run migrations for a single schema or all schemas,
|
|
||||||
and executes migrations accordingly.
|
|
||||||
"""
|
|
||||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||||
|
|
||||||
engine = create_async_engine(
|
engine = create_async_engine(
|
||||||
@@ -129,10 +141,16 @@ async def run_async_migrations() -> None:
|
|||||||
poolclass=pool.NullPool,
|
poolclass=pool.NullPool,
|
||||||
)
|
)
|
||||||
|
|
||||||
if upgrade_all_tenants:
|
if USE_IAM_AUTH:
|
||||||
# Run migrations for all tenant schemas sequentially
|
|
||||||
tenant_schemas = get_all_tenant_ids()
|
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "do_connect")
|
||||||
|
def event_provide_iam_token_for_alembic(
|
||||||
|
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||||
|
) -> None:
|
||||||
|
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||||
|
|
||||||
|
if upgrade_all_tenants:
|
||||||
|
tenant_schemas = get_all_tenant_ids()
|
||||||
for schema in tenant_schemas:
|
for schema in tenant_schemas:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Migrating schema: {schema}")
|
logger.info(f"Migrating schema: {schema}")
|
||||||
@@ -162,15 +180,20 @@ async def run_async_migrations() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def run_migrations_offline() -> None:
|
def run_migrations_offline() -> None:
|
||||||
"""
|
|
||||||
Run migrations in 'offline' mode.
|
|
||||||
"""
|
|
||||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||||
url = build_connection_string()
|
url = build_connection_string()
|
||||||
|
|
||||||
if upgrade_all_tenants:
|
if upgrade_all_tenants:
|
||||||
# Run offline migrations for all tenant schemas
|
|
||||||
engine = create_async_engine(url)
|
engine = create_async_engine(url)
|
||||||
|
|
||||||
|
if USE_IAM_AUTH:
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "do_connect")
|
||||||
|
def event_provide_iam_token_for_alembic_offline(
|
||||||
|
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||||
|
) -> None:
|
||||||
|
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||||
|
|
||||||
tenant_schemas = get_all_tenant_ids()
|
tenant_schemas = get_all_tenant_ids()
|
||||||
engine.sync_engine.dispose()
|
engine.sync_engine.dispose()
|
||||||
|
|
||||||
@@ -207,9 +230,6 @@ def run_migrations_offline() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def run_migrations_online() -> None:
|
def run_migrations_online() -> None:
|
||||||
"""
|
|
||||||
Runs migrations in 'online' mode using an asynchronous engine.
|
|
||||||
"""
|
|
||||||
asyncio.run(run_async_migrations())
|
asyncio.run(run_async_migrations())
|
||||||
|
|
||||||
|
|
||||||
|
@@ -144,6 +144,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
|||||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||||
|
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
|
||||||
|
|
||||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||||
@@ -174,6 +175,9 @@ try:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
||||||
|
|
||||||
|
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
||||||
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||||
|
@@ -49,6 +49,7 @@ POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
|||||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||||
|
|
||||||
|
SSL_CERT_FILE = "bundle.pem"
|
||||||
# API Keys
|
# API Keys
|
||||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
|
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
import ssl
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@@ -10,6 +12,8 @@ from datetime import datetime
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import ContextManager
|
from typing import ContextManager
|
||||||
|
|
||||||
|
import asyncpg # type: ignore
|
||||||
|
import boto3
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -23,6 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from onyx.configs.app_configs import AWS_REGION
|
||||||
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||||
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||||
@@ -37,6 +42,7 @@ from onyx.configs.app_configs import POSTGRES_PORT
|
|||||||
from onyx.configs.app_configs import POSTGRES_USER
|
from onyx.configs.app_configs import POSTGRES_USER
|
||||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||||
|
from onyx.configs.constants import SSL_CERT_FILE
|
||||||
from onyx.server.utils import BasicAuthenticationError
|
from onyx.server.utils import BasicAuthenticationError
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from shared_configs.configs import MULTI_TENANT
|
from shared_configs.configs import MULTI_TENANT
|
||||||
@@ -49,28 +55,87 @@ logger = setup_logger()
|
|||||||
SYNC_DB_API = "psycopg2"
|
SYNC_DB_API = "psycopg2"
|
||||||
ASYNC_DB_API = "asyncpg"
|
ASYNC_DB_API = "asyncpg"
|
||||||
|
|
||||||
# global so we don't create more than one engine per process
|
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
||||||
# outside of being best practice, this is needed so we can properly pool
|
|
||||||
# connections and not create a new pool on every request
|
|
||||||
|
|
||||||
|
# Global so we don't create more than one engine per process
|
||||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||||
SessionFactory: sessionmaker[Session] | None = None
|
SessionFactory: sessionmaker[Session] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
|
||||||
|
"""Create an SSL context if IAM authentication is enabled, else return None."""
|
||||||
|
if USE_IAM_AUTH:
|
||||||
|
return ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
ssl_context = create_ssl_context_if_iam()
|
||||||
|
|
||||||
|
|
||||||
|
def get_iam_auth_token(
|
||||||
|
host: str, port: str, user: str, region: str = "us-east-2"
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate an IAM authentication token using boto3.
|
||||||
|
"""
|
||||||
|
client = boto3.client("rds", region_name=region)
|
||||||
|
token = client.generate_db_auth_token(
|
||||||
|
DBHostname=host, Port=int(port), DBUsername=user
|
||||||
|
)
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def configure_psycopg2_iam_auth(
|
||||||
|
cparams: dict[str, Any], host: str, port: str, user: str, region: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Configure cparams for psycopg2 with IAM token and SSL.
|
||||||
|
"""
|
||||||
|
token = get_iam_auth_token(host, port, user, region)
|
||||||
|
cparams["password"] = token
|
||||||
|
cparams["sslmode"] = "require"
|
||||||
|
cparams["sslrootcert"] = SSL_CERT_FILE
|
||||||
|
|
||||||
|
|
||||||
|
def build_connection_string(
|
||||||
|
*,
|
||||||
|
db_api: str = ASYNC_DB_API,
|
||||||
|
user: str = POSTGRES_USER,
|
||||||
|
password: str = POSTGRES_PASSWORD,
|
||||||
|
host: str = POSTGRES_HOST,
|
||||||
|
port: str = POSTGRES_PORT,
|
||||||
|
db: str = POSTGRES_DB,
|
||||||
|
app_name: str | None = None,
|
||||||
|
use_iam: bool = USE_IAM_AUTH,
|
||||||
|
region: str = "us-west-2",
|
||||||
|
) -> str:
|
||||||
|
if use_iam:
|
||||||
|
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
|
||||||
|
else:
|
||||||
|
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||||
|
|
||||||
|
# For asyncpg, do not include application_name in the connection string
|
||||||
|
if app_name and db_api != "asyncpg":
|
||||||
|
if "?" in base_conn_str:
|
||||||
|
return f"{base_conn_str}&application_name={app_name}"
|
||||||
|
else:
|
||||||
|
return f"{base_conn_str}?application_name={app_name}"
|
||||||
|
return base_conn_str
|
||||||
|
|
||||||
|
|
||||||
if LOG_POSTGRES_LATENCY:
|
if LOG_POSTGRES_LATENCY:
|
||||||
# Function to log before query execution
|
|
||||||
@event.listens_for(Engine, "before_cursor_execute")
|
@event.listens_for(Engine, "before_cursor_execute")
|
||||||
def before_cursor_execute( # type: ignore
|
def before_cursor_execute( # type: ignore
|
||||||
conn, cursor, statement, parameters, context, executemany
|
conn, cursor, statement, parameters, context, executemany
|
||||||
):
|
):
|
||||||
conn.info["query_start_time"] = time.time()
|
conn.info["query_start_time"] = time.time()
|
||||||
|
|
||||||
# Function to log after query execution
|
|
||||||
@event.listens_for(Engine, "after_cursor_execute")
|
@event.listens_for(Engine, "after_cursor_execute")
|
||||||
def after_cursor_execute( # type: ignore
|
def after_cursor_execute( # type: ignore
|
||||||
conn, cursor, statement, parameters, context, executemany
|
conn, cursor, statement, parameters, context, executemany
|
||||||
):
|
):
|
||||||
total_time = time.time() - conn.info["query_start_time"]
|
total_time = time.time() - conn.info["query_start_time"]
|
||||||
# don't spam TOO hard
|
|
||||||
if total_time > 0.1:
|
if total_time > 0.1:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
|
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
|
||||||
@@ -78,7 +143,6 @@ if LOG_POSTGRES_LATENCY:
|
|||||||
|
|
||||||
|
|
||||||
if LOG_POSTGRES_CONN_COUNTS:
|
if LOG_POSTGRES_CONN_COUNTS:
|
||||||
# Global counter for connection checkouts and checkins
|
|
||||||
checkout_count = 0
|
checkout_count = 0
|
||||||
checkin_count = 0
|
checkin_count = 0
|
||||||
|
|
||||||
@@ -105,21 +169,13 @@ if LOG_POSTGRES_CONN_COUNTS:
|
|||||||
logger.debug(f"Total connection checkins: {checkin_count}")
|
logger.debug(f"Total connection checkins: {checkin_count}")
|
||||||
|
|
||||||
|
|
||||||
"""END DEBUGGING LOGGING"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_current_time(db_session: Session) -> datetime:
|
def get_db_current_time(db_session: Session) -> datetime:
|
||||||
"""Get the current time from Postgres representing the start of the transaction
|
|
||||||
Within the same transaction this value will not update
|
|
||||||
This datetime object returned should be timezone aware, default Postgres timezone is UTC
|
|
||||||
"""
|
|
||||||
result = db_session.execute(text("SELECT NOW()")).scalar()
|
result = db_session.execute(text("SELECT NOW()")).scalar()
|
||||||
if result is None:
|
if result is None:
|
||||||
raise ValueError("Database did not return a time")
|
raise ValueError("Database did not return a time")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# Regular expression to validate schema names to prevent SQL injection
|
|
||||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||||
|
|
||||||
|
|
||||||
@@ -128,16 +184,9 @@ def is_valid_schema_name(name: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
class SqlEngine:
|
class SqlEngine:
|
||||||
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
|
|
||||||
Will eventually subsume most of the standalone functions in this file.
|
|
||||||
Sync only for now.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_engine: Engine | None = None
|
_engine: Engine | None = None
|
||||||
_lock: threading.Lock = threading.Lock()
|
_lock: threading.Lock = threading.Lock()
|
||||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||||
|
|
||||||
# Default parameters for engine creation
|
|
||||||
DEFAULT_ENGINE_KWARGS = {
|
DEFAULT_ENGINE_KWARGS = {
|
||||||
"pool_size": 20,
|
"pool_size": 20,
|
||||||
"max_overflow": 5,
|
"max_overflow": 5,
|
||||||
@@ -145,33 +194,27 @@ class SqlEngine:
|
|||||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||||
"""Private helper method to create and return an Engine."""
|
|
||||||
connection_string = build_connection_string(
|
connection_string = build_connection_string(
|
||||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
|
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||||
)
|
)
|
||||||
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
||||||
return create_engine(connection_string, **merged_kwargs)
|
engine = create_engine(connection_string, **merged_kwargs)
|
||||||
|
|
||||||
|
if USE_IAM_AUTH:
|
||||||
|
event.listen(engine, "do_connect", provide_iam_token)
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||||
"""Allow the caller to init the engine with extra params. Different clients
|
|
||||||
such as the API server and different Celery workers and tasks
|
|
||||||
need different settings.
|
|
||||||
"""
|
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if not cls._engine:
|
if not cls._engine:
|
||||||
cls._engine = cls._init_engine(**engine_kwargs)
|
cls._engine = cls._init_engine(**engine_kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_engine(cls) -> Engine:
|
def get_engine(cls) -> Engine:
|
||||||
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
|
|
||||||
already been called. You probably want to init first!
|
|
||||||
"""
|
|
||||||
if not cls._engine:
|
if not cls._engine:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if not cls._engine:
|
if not cls._engine:
|
||||||
@@ -180,12 +223,10 @@ class SqlEngine:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_app_name(cls, app_name: str) -> None:
|
def set_app_name(cls, app_name: str) -> None:
|
||||||
"""Class method to set the app name."""
|
|
||||||
cls._app_name = app_name
|
cls._app_name = app_name
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_app_name(cls) -> str:
|
def get_app_name(cls) -> str:
|
||||||
"""Class method to get current app name."""
|
|
||||||
if not cls._app_name:
|
if not cls._app_name:
|
||||||
return ""
|
return ""
|
||||||
return cls._app_name
|
return cls._app_name
|
||||||
@@ -217,56 +258,71 @@ def get_all_tenant_ids() -> list[str] | list[None]:
|
|||||||
for tenant in tenant_ids
|
for tenant in tenant_ids
|
||||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||||
]
|
]
|
||||||
|
|
||||||
return valid_tenants
|
return valid_tenants
|
||||||
|
|
||||||
|
|
||||||
def build_connection_string(
|
|
||||||
*,
|
|
||||||
db_api: str = ASYNC_DB_API,
|
|
||||||
user: str = POSTGRES_USER,
|
|
||||||
password: str = POSTGRES_PASSWORD,
|
|
||||||
host: str = POSTGRES_HOST,
|
|
||||||
port: str = POSTGRES_PORT,
|
|
||||||
db: str = POSTGRES_DB,
|
|
||||||
app_name: str | None = None,
|
|
||||||
) -> str:
|
|
||||||
if app_name:
|
|
||||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
|
|
||||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_engine() -> Engine:
|
def get_sqlalchemy_engine() -> Engine:
|
||||||
return SqlEngine.get_engine()
|
return SqlEngine.get_engine()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_async_connection() -> Any:
|
||||||
|
"""
|
||||||
|
Custom connection function for async engine when using IAM auth.
|
||||||
|
"""
|
||||||
|
host = POSTGRES_HOST
|
||||||
|
port = POSTGRES_PORT
|
||||||
|
user = POSTGRES_USER
|
||||||
|
db = POSTGRES_DB
|
||||||
|
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||||
|
|
||||||
|
# asyncpg requires 'ssl="require"' if SSL needed
|
||||||
|
return await asyncpg.connect(
|
||||||
|
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||||
global _ASYNC_ENGINE
|
global _ASYNC_ENGINE
|
||||||
if _ASYNC_ENGINE is None:
|
if _ASYNC_ENGINE is None:
|
||||||
# Underlying asyncpg cannot accept application_name directly in the connection string
|
app_name = SqlEngine.get_app_name() + "_async"
|
||||||
# https://github.com/MagicStack/asyncpg/issues/798
|
connection_string = build_connection_string(
|
||||||
connection_string = build_connection_string()
|
db_api=ASYNC_DB_API,
|
||||||
|
use_iam=USE_IAM_AUTH,
|
||||||
|
)
|
||||||
|
|
||||||
|
connect_args: dict[str, Any] = {}
|
||||||
|
if app_name:
|
||||||
|
connect_args["server_settings"] = {"application_name": app_name}
|
||||||
|
|
||||||
|
connect_args["ssl"] = ssl_context
|
||||||
|
|
||||||
_ASYNC_ENGINE = create_async_engine(
|
_ASYNC_ENGINE = create_async_engine(
|
||||||
connection_string,
|
connection_string,
|
||||||
connect_args={
|
connect_args=connect_args,
|
||||||
"server_settings": {
|
|
||||||
"application_name": SqlEngine.get_app_name() + "_async"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
# async engine is only used by API server, so we can use those values
|
|
||||||
# here as well
|
|
||||||
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
||||||
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
||||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if USE_IAM_AUTH:
|
||||||
|
|
||||||
|
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
|
||||||
|
def provide_iam_token_async(
|
||||||
|
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||||
|
) -> None:
|
||||||
|
# For async engine using asyncpg, we still need to set the IAM token here.
|
||||||
|
host = POSTGRES_HOST
|
||||||
|
port = POSTGRES_PORT
|
||||||
|
user = POSTGRES_USER
|
||||||
|
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||||
|
cparams["password"] = token
|
||||||
|
cparams["ssl"] = ssl_context
|
||||||
|
|
||||||
return _ASYNC_ENGINE
|
return _ASYNC_ENGINE
|
||||||
|
|
||||||
|
|
||||||
# Dependency to get the current tenant ID
|
|
||||||
# If no token is present, uses the default schema for this use case
|
|
||||||
def get_current_tenant_id(request: Request) -> str:
|
def get_current_tenant_id(request: Request) -> str:
|
||||||
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
|
||||||
if not MULTI_TENANT:
|
if not MULTI_TENANT:
|
||||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
@@ -275,7 +331,6 @@ def get_current_tenant_id(request: Request) -> str:
|
|||||||
token = request.cookies.get("fastapiusersauth")
|
token = request.cookies.get("fastapiusersauth")
|
||||||
if not token:
|
if not token:
|
||||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||||
# If no token is present, use the default schema or handle accordingly
|
|
||||||
return current_value
|
return current_value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -289,7 +344,6 @@ def get_current_tenant_id(request: Request) -> str:
|
|||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
|
|
||||||
return tenant_id
|
return tenant_id
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||||
@@ -316,7 +370,6 @@ async def get_async_session_with_tenant(
|
|||||||
|
|
||||||
async with async_session_factory() as session:
|
async with async_session_factory() as session:
|
||||||
try:
|
try:
|
||||||
# Set the search_path to the tenant's schema
|
|
||||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||||
await session.execute(
|
await session.execute(
|
||||||
@@ -326,8 +379,6 @@ async def get_async_session_with_tenant(
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error setting search_path.")
|
logger.exception("Error setting search_path.")
|
||||||
# You can choose to re-raise the exception or handle it
|
|
||||||
# Here, we'll re-raise to prevent proceeding with an incorrect session
|
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
yield session
|
yield session
|
||||||
@@ -335,9 +386,6 @@ async def get_async_session_with_tenant(
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_session_with_default_tenant() -> Generator[Session, None, None]:
|
def get_session_with_default_tenant() -> Generator[Session, None, None]:
|
||||||
"""
|
|
||||||
Get a database session using the current tenant ID from the context variable.
|
|
||||||
"""
|
|
||||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||||
with get_session_with_tenant(tenant_id) as session:
|
with get_session_with_tenant(tenant_id) as session:
|
||||||
yield session
|
yield session
|
||||||
@@ -349,7 +397,6 @@ def get_session_with_tenant(
|
|||||||
) -> Generator[Session, None, None]:
|
) -> Generator[Session, None, None]:
|
||||||
"""
|
"""
|
||||||
Generate a database session for a specific tenant.
|
Generate a database session for a specific tenant.
|
||||||
|
|
||||||
This function:
|
This function:
|
||||||
1. Sets the database schema to the specified tenant's schema.
|
1. Sets the database schema to the specified tenant's schema.
|
||||||
2. Preserves the tenant ID across the session.
|
2. Preserves the tenant ID across the session.
|
||||||
@@ -357,27 +404,20 @@ def get_session_with_tenant(
|
|||||||
4. Uses the default schema if no tenant ID is provided.
|
4. Uses the default schema if no tenant ID is provided.
|
||||||
"""
|
"""
|
||||||
engine = get_sqlalchemy_engine()
|
engine = get_sqlalchemy_engine()
|
||||||
|
|
||||||
# Store the previous tenant ID
|
|
||||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
||||||
|
|
||||||
if tenant_id is None:
|
if tenant_id is None:
|
||||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
|
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
|
|
||||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||||
|
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Establish a raw connection
|
|
||||||
with engine.connect() as connection:
|
with engine.connect() as connection:
|
||||||
# Access the raw DBAPI connection and set the search_path
|
|
||||||
dbapi_connection = connection.connection
|
dbapi_connection = connection.connection
|
||||||
|
|
||||||
# Set the search_path outside of any transaction
|
|
||||||
cursor = dbapi_connection.cursor()
|
cursor = dbapi_connection.cursor()
|
||||||
try:
|
try:
|
||||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||||
@@ -390,21 +430,17 @@ def get_session_with_tenant(
|
|||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
# Bind the session to the connection
|
|
||||||
with Session(bind=connection, expire_on_commit=False) as session:
|
with Session(bind=connection, expire_on_commit=False) as session:
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
finally:
|
finally:
|
||||||
# Reset search_path to default after the session is used
|
|
||||||
if MULTI_TENANT:
|
if MULTI_TENANT:
|
||||||
cursor = dbapi_connection.cursor()
|
cursor = dbapi_connection.cursor()
|
||||||
try:
|
try:
|
||||||
cursor.execute('SET search_path TO "$user", public')
|
cursor.execute('SET search_path TO "$user", public')
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Restore the previous tenant ID
|
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
@@ -424,12 +460,9 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
def get_session() -> Generator[Session, None, None]:
|
def get_session() -> Generator[Session, None, None]:
|
||||||
"""Generate a database session with the appropriate tenant schema set."""
|
|
||||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||||
raise BasicAuthenticationError(
|
raise BasicAuthenticationError(detail="User must authenticate")
|
||||||
detail="User must authenticate",
|
|
||||||
)
|
|
||||||
|
|
||||||
engine = get_sqlalchemy_engine()
|
engine = get_sqlalchemy_engine()
|
||||||
|
|
||||||
@@ -437,20 +470,17 @@ def get_session() -> Generator[Session, None, None]:
|
|||||||
if MULTI_TENANT:
|
if MULTI_TENANT:
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||||
# Set the search_path to the tenant's schema
|
|
||||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""Generate an async database session with the appropriate tenant schema set."""
|
|
||||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||||
engine = get_sqlalchemy_async_engine()
|
engine = get_sqlalchemy_async_engine()
|
||||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||||
if MULTI_TENANT:
|
if MULTI_TENANT:
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||||
# Set the search_path to the tenant's schema
|
|
||||||
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||||
yield async_session
|
yield async_session
|
||||||
|
|
||||||
@@ -461,7 +491,6 @@ def get_session_context_manager() -> ContextManager[Session]:
|
|||||||
|
|
||||||
|
|
||||||
def get_session_factory() -> sessionmaker[Session]:
|
def get_session_factory() -> sessionmaker[Session]:
|
||||||
"""Get a session factory."""
|
|
||||||
global SessionFactory
|
global SessionFactory
|
||||||
if SessionFactory is None:
|
if SessionFactory is None:
|
||||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||||
@@ -489,3 +518,13 @@ async def warm_up_connections(
|
|||||||
await async_conn.execute(text("SELECT 1"))
|
await async_conn.execute(text("SELECT 1"))
|
||||||
for async_conn in async_connections:
|
for async_conn in async_connections:
|
||||||
await async_conn.close()
|
await async_conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
|
||||||
|
if USE_IAM_AUTH:
|
||||||
|
host = POSTGRES_HOST
|
||||||
|
port = POSTGRES_PORT
|
||||||
|
user = POSTGRES_USER
|
||||||
|
region = os.getenv("AWS_REGION", "us-east-2")
|
||||||
|
# Configure for psycopg2 with IAM token
|
||||||
|
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||||
|
@@ -14,7 +14,7 @@ spec:
|
|||||||
spec:
|
spec:
|
||||||
containers:
|
containers:
|
||||||
- name: celery-beat
|
- name: celery-beat
|
||||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||||
imagePullPolicy: IfNotPresent
|
imagePullPolicy: IfNotPresent
|
||||||
command:
|
command:
|
||||||
[
|
[
|
||||||
|
@@ -14,7 +14,7 @@ spec:
|
|||||||
spec:
|
spec:
|
||||||
containers:
|
containers:
|
||||||
- name: celery-worker-heavy
|
- name: celery-worker-heavy
|
||||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||||
imagePullPolicy: IfNotPresent
|
imagePullPolicy: IfNotPresent
|
||||||
command:
|
command:
|
||||||
[
|
[
|
||||||
|
@@ -14,7 +14,7 @@ spec:
|
|||||||
spec:
|
spec:
|
||||||
containers:
|
containers:
|
||||||
- name: celery-worker-indexing
|
- name: celery-worker-indexing
|
||||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||||
imagePullPolicy: IfNotPresent
|
imagePullPolicy: IfNotPresent
|
||||||
command:
|
command:
|
||||||
[
|
[
|
||||||
|
@@ -14,7 +14,7 @@ spec:
|
|||||||
spec:
|
spec:
|
||||||
containers:
|
containers:
|
||||||
- name: celery-worker-light
|
- name: celery-worker-light
|
||||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||||
imagePullPolicy: IfNotPresent
|
imagePullPolicy: IfNotPresent
|
||||||
command:
|
command:
|
||||||
[
|
[
|
||||||
|
@@ -14,7 +14,7 @@ spec:
|
|||||||
spec:
|
spec:
|
||||||
containers:
|
containers:
|
||||||
- name: celery-worker-primary
|
- name: celery-worker-primary
|
||||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||||
imagePullPolicy: IfNotPresent
|
imagePullPolicy: IfNotPresent
|
||||||
command:
|
command:
|
||||||
[
|
[
|
||||||
|
@@ -103,6 +103,13 @@ services:
|
|||||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||||
# Seeding configuration
|
# Seeding configuration
|
||||||
|
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||||
|
- AWS_REGION=${AWS_REGION-}
|
||||||
|
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||||
|
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||||
|
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||||
|
# volumes:
|
||||||
|
# - ./bundle.pem:/app/bundle.pem:ro
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@@ -223,6 +230,13 @@ services:
|
|||||||
|
|
||||||
# Enterprise Edition stuff
|
# Enterprise Edition stuff
|
||||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||||
|
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||||
|
- AWS_REGION=${AWS_REGION-}
|
||||||
|
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||||
|
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||||
|
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||||
|
# volumes:
|
||||||
|
# - ./bundle.pem:/app/bundle.pem:ro
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
|
@@ -91,6 +91,13 @@ services:
|
|||||||
# Enterprise Edition only
|
# Enterprise Edition only
|
||||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||||
|
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||||
|
- AWS_REGION=${AWS_REGION-}
|
||||||
|
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||||
|
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||||
|
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||||
|
# volumes:
|
||||||
|
# - ./bundle.pem:/app/bundle.pem:ro
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@@ -192,6 +199,13 @@ services:
|
|||||||
# Enterprise Edition only
|
# Enterprise Edition only
|
||||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||||
|
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||||
|
- AWS_REGION=${AWS_REGION-}
|
||||||
|
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||||
|
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||||
|
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||||
|
# volumes:
|
||||||
|
# - ./bundle.pem:/app/bundle.pem:ro
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
|
@@ -22,6 +22,13 @@ services:
|
|||||||
- VESPA_HOST=index
|
- VESPA_HOST=index
|
||||||
- REDIS_HOST=cache
|
- REDIS_HOST=cache
|
||||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||||
|
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||||
|
- AWS_REGION=${AWS_REGION-}
|
||||||
|
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||||
|
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||||
|
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||||
|
# volumes:
|
||||||
|
# - ./bundle.pem:/app/bundle.pem:ro
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@@ -52,6 +59,13 @@ services:
|
|||||||
- REDIS_HOST=cache
|
- REDIS_HOST=cache
|
||||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||||
|
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||||
|
- AWS_REGION=${AWS_REGION-}
|
||||||
|
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||||
|
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||||
|
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||||
|
# volumes:
|
||||||
|
# - ./bundle.pem:/app/bundle.pem:ro
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
|
@@ -23,6 +23,13 @@ services:
|
|||||||
- VESPA_HOST=index
|
- VESPA_HOST=index
|
||||||
- REDIS_HOST=cache
|
- REDIS_HOST=cache
|
||||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||||
|
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||||
|
- AWS_REGION=${AWS_REGION-}
|
||||||
|
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||||
|
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||||
|
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||||
|
# volumes:
|
||||||
|
# - ./bundle.pem:/app/bundle.pem:ro
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@@ -57,6 +64,13 @@ services:
|
|||||||
- REDIS_HOST=cache
|
- REDIS_HOST=cache
|
||||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||||
|
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||||
|
- AWS_REGION=${AWS_REGION-}
|
||||||
|
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||||
|
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||||
|
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||||
|
# volumes:
|
||||||
|
# - ./bundle.pem:/app/bundle.pem:ro
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@@ -223,7 +237,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ../data/certbot/conf:/etc/letsencrypt
|
- ../data/certbot/conf:/etc/letsencrypt
|
||||||
- ../data/certbot/www:/var/www/certbot
|
- ../data/certbot/www:/var/www/certbot
|
||||||
logging:
|
logging::wq
|
||||||
driver: json-file
|
driver: json-file
|
||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
@@ -245,3 +259,6 @@ volumes:
|
|||||||
# Created by the container itself
|
# Created by the container itself
|
||||||
model_cache_huggingface:
|
model_cache_huggingface:
|
||||||
indexing_huggingface_model_cache:
|
indexing_huggingface_model_cache:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@@ -60,3 +60,12 @@ spec:
|
|||||||
envFrom:
|
envFrom:
|
||||||
- configMapRef:
|
- configMapRef:
|
||||||
name: env-configmap
|
name: env-configmap
|
||||||
|
# Uncomment if you are using IAM auth for Postgres
|
||||||
|
# volumeMounts:
|
||||||
|
# - name: bundle-pem
|
||||||
|
# mountPath: "/app/certs"
|
||||||
|
# readOnly: true
|
||||||
|
# volumes:
|
||||||
|
# - name: bundle-pem
|
||||||
|
# secret:
|
||||||
|
# secretName: bundle-pem-secret
|
||||||
|
@@ -43,6 +43,7 @@ spec:
|
|||||||
# - name: my-ca-cert-volume
|
# - name: my-ca-cert-volume
|
||||||
# mountPath: /etc/ssl/certs/custom-ca.crt
|
# mountPath: /etc/ssl/certs/custom-ca.crt
|
||||||
# subPath: my-ca.crt
|
# subPath: my-ca.crt
|
||||||
|
|
||||||
# Optional volume for CA certificate
|
# Optional volume for CA certificate
|
||||||
# volumes:
|
# volumes:
|
||||||
# - name: my-cas-cert-volume
|
# - name: my-cas-cert-volume
|
||||||
@@ -51,3 +52,13 @@ spec:
|
|||||||
# items:
|
# items:
|
||||||
# - key: my-ca.crt
|
# - key: my-ca.crt
|
||||||
# path: my-ca.crt
|
# path: my-ca.crt
|
||||||
|
|
||||||
|
# Uncomment if you are using IAM auth for Postgres
|
||||||
|
# volumeMounts:
|
||||||
|
# - name: bundle-pem
|
||||||
|
# mountPath: "/app/certs"
|
||||||
|
# readOnly: true
|
||||||
|
# volumes:
|
||||||
|
# - name: bundle-pem
|
||||||
|
# secret:
|
||||||
|
# secretName: bundle-pem-secret
|
||||||
|
Reference in New Issue
Block a user