mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-04 03:31:23 +02:00
* db setup * transfer 1 - incomplete * more adjustments * relationship table + query update * temp view creation * restructuring * nits * updates * separate read_only engine * extraction revamp * focus on metadata relatonships 1 * dev * migration downgrade fix * rebase migration change * a3+ * progress * base * new extraction * progress * fixed KG extraction * nits * updates * simplifications & cleanup * fixes * updates * more feature flag checks * fixes * extraction process fix * read-only user creation as part of setup * fix for missing entity attributes * kg read-only user creation as part of migration * typo * EL initial comments * initial Account/SF Connector chnges * SF Connector update - include account information * base w/ salesforce * evan updates + quite a bit more * kg-filtered search * EL changes pt 2 * migrations and env vars * quick migration fix * migration update * post_rebase fixes * mypy fixes * test fixes * test fix * test fix * read_only pool + misc * nf * env vars * test improvements * salesforce fix * test update * small changes * small adjustments * SF Connector fix & kg_stage removal for one table * mypy fix * small fixes * EL + RK (pt 1) comments * nit * setting updated * Salesforce test update * EL comments * read-only user replacement & cleanup * SQL View fix * converting entity type-name separators * sql view group ownership * view fix * SQL tweak * dealing with docs that were skipped by indexing * increased error handling * more error handling * Output formatting fix * kg-incremental-reindexing * 0-doc found improvement * celery * migration correction * timeout adjustments * nit * Updated migration * Entity Normalization for KG Dev 1 (#4746) * feat: trigrams column * fix: reranking and db * feat: v1 * fix: convert to orm * feat: parallel * fix: default to id_name * fix: renamed semantic_id and semantic_id_trigrams * fix: scalar subquery * fix: tuning + redundancy * fix: threshold * fix: typo * fix: shorten names * wip * fix: reverted * feat: config * feat: works but it was dumb * feat: clustering works * fix: mypy * normalization <-> language awareness for SQL generation * small type fixes --------- Co-authored-by: joachim-danswer <joachim@danswer.ai> * mypy * typo and dead code * kg_time_fencing * feat: remove temp views on migration downgrade * remove functions and triggers for now * rebase adjustments * EL code review results * quick fix + trigger/funcs for single tenant * fix: typo, mypy, dead code * fix: autoflake * small updatesd * nit * fix: typo * early + faster view creation * Extension creation in MT migration * nit changes to default ETs * Incremental Clustering and KG Refactor V1 (#4784) Optimized/restructured incremental clustering. New pipeline actually that moves vespa updates to clustering. Also, celery configuration has been updated. --------- Co-authored-by: joachim-danswer <joachim@danswer.ai> * prompt tweak & ET extraction reset * more general hierarchical structure * feat: better vespa reset logic * prompt optimization and entity replacemants * small prompt changes * KG Refactor V2 (#4814) Clustering & Extraction improvements & various nits Co-authored-by: joachim-danswer <joachim@danswer.ai> * add connector-level coverage days * fix: nit * initial EL responses * refactor: helper functions for formatting * fix: more helper fns & comments * fix: comment code that's been implemented elsewhere * fix: tenant_id missing arg * fix: removed debugging stuff * fix: moved kg_interactions db query to helper fn * fix: tenant_id * fix: tenant_id & removed outdated helper fn * fix always set entity class * fix: typo * fix alembic heads * fix: celery logging * fix: migrations fix * fix: multi tenant permissions * fix: temp connector fix * fix: downgrade * Fix upgrade migration * fix: tenant for normalization * added additional acl * stray EL comments * fix: connector test * fix mypy * fix: temporary connector test fix * fix: jira connector test * nit * small nits * fix: black * fix: mypy * fix: mypy --------- Co-authored-by: Rei Meguro <36625832+Orbital-Web@users.noreply.github.com>
689 lines
24 KiB
Python
689 lines
24 KiB
Python
import contextlib
|
|
import os
|
|
import re
|
|
import ssl
|
|
import threading
|
|
import time
|
|
from collections.abc import AsyncGenerator
|
|
from collections.abc import Generator
|
|
from contextlib import asynccontextmanager
|
|
from contextlib import contextmanager
|
|
from datetime import datetime
|
|
from typing import Any
|
|
from typing import AsyncContextManager
|
|
|
|
import asyncpg # type: ignore
|
|
import boto3
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import event
|
|
from sqlalchemy import pool
|
|
from sqlalchemy import text
|
|
from sqlalchemy.engine import create_engine
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from onyx.configs.app_configs import AWS_REGION_NAME
|
|
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
|
from onyx.configs.app_configs import DB_READONLY_USER
|
|
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
|
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
|
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
|
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
|
from onyx.configs.app_configs import POSTGRES_DB
|
|
from onyx.configs.app_configs import POSTGRES_HOST
|
|
from onyx.configs.app_configs import POSTGRES_IDLE_SESSIONS_TIMEOUT
|
|
from onyx.configs.app_configs import POSTGRES_PASSWORD
|
|
from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
|
|
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
|
from onyx.configs.app_configs import POSTGRES_PORT
|
|
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
|
from onyx.configs.app_configs import POSTGRES_USER
|
|
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
|
from onyx.configs.constants import SSL_CERT_FILE
|
|
from onyx.server.utils import BasicAuthenticationError
|
|
from onyx.utils.logger import setup_logger
|
|
from shared_configs.configs import MULTI_TENANT
|
|
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
|
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
|
from shared_configs.configs import TENANT_ID_PREFIX
|
|
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
|
from shared_configs.contextvars import get_current_tenant_id
|
|
|
|
logger = setup_logger()
|
|
|
|
SYNC_DB_API = "psycopg2"
|
|
ASYNC_DB_API = "asyncpg"
|
|
|
|
# why isn't this in configs?
|
|
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
|
|
|
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
|
|
|
|
|
# Global so we don't create more than one engine per process
|
|
_ASYNC_ENGINE: AsyncEngine | None = None
|
|
SessionFactory: sessionmaker[Session] | None = None
|
|
|
|
|
|
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
|
|
"""Create an SSL context if IAM authentication is enabled, else return None."""
|
|
if USE_IAM_AUTH:
|
|
return ssl.create_default_context(cafile=SSL_CERT_FILE)
|
|
return None
|
|
|
|
|
|
ssl_context = create_ssl_context_if_iam()
|
|
|
|
|
|
def get_iam_auth_token(
|
|
host: str, port: str, user: str, region: str = "us-east-2"
|
|
) -> str:
|
|
"""
|
|
Generate an IAM authentication token using boto3.
|
|
"""
|
|
client = boto3.client("rds", region_name=region)
|
|
token = client.generate_db_auth_token(
|
|
DBHostname=host, Port=int(port), DBUsername=user
|
|
)
|
|
return token
|
|
|
|
|
|
def configure_psycopg2_iam_auth(
|
|
cparams: dict[str, Any], host: str, port: str, user: str, region: str
|
|
) -> None:
|
|
"""
|
|
Configure cparams for psycopg2 with IAM token and SSL.
|
|
"""
|
|
token = get_iam_auth_token(host, port, user, region)
|
|
cparams["password"] = token
|
|
cparams["sslmode"] = "require"
|
|
cparams["sslrootcert"] = SSL_CERT_FILE
|
|
|
|
|
|
def build_connection_string(
|
|
*,
|
|
db_api: str = ASYNC_DB_API,
|
|
user: str = POSTGRES_USER,
|
|
password: str = POSTGRES_PASSWORD,
|
|
host: str = POSTGRES_HOST,
|
|
port: str = POSTGRES_PORT,
|
|
db: str = POSTGRES_DB,
|
|
app_name: str | None = None,
|
|
use_iam_auth: bool = USE_IAM_AUTH,
|
|
region: str = "us-west-2",
|
|
) -> str:
|
|
if use_iam_auth:
|
|
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
|
|
else:
|
|
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
|
|
|
# For asyncpg, do not include application_name in the connection string
|
|
if app_name and db_api != "asyncpg":
|
|
if "?" in base_conn_str:
|
|
return f"{base_conn_str}&application_name={app_name}"
|
|
else:
|
|
return f"{base_conn_str}?application_name={app_name}"
|
|
return base_conn_str
|
|
|
|
|
|
if LOG_POSTGRES_LATENCY:
|
|
|
|
@event.listens_for(Engine, "before_cursor_execute")
|
|
def before_cursor_execute( # type: ignore
|
|
conn, cursor, statement, parameters, context, executemany
|
|
):
|
|
conn.info["query_start_time"] = time.time()
|
|
|
|
@event.listens_for(Engine, "after_cursor_execute")
|
|
def after_cursor_execute( # type: ignore
|
|
conn, cursor, statement, parameters, context, executemany
|
|
):
|
|
total_time = time.time() - conn.info["query_start_time"]
|
|
if total_time > 0.1:
|
|
logger.debug(
|
|
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
|
|
)
|
|
|
|
|
|
if LOG_POSTGRES_CONN_COUNTS:
|
|
checkout_count = 0
|
|
checkin_count = 0
|
|
|
|
@event.listens_for(Engine, "checkout")
|
|
def log_checkout(dbapi_connection, connection_record, connection_proxy): # type: ignore
|
|
global checkout_count
|
|
checkout_count += 1
|
|
|
|
active_connections = connection_proxy._pool.checkedout()
|
|
idle_connections = connection_proxy._pool.checkedin()
|
|
pool_size = connection_proxy._pool.size()
|
|
logger.debug(
|
|
"Connection Checkout\n"
|
|
f"Active Connections: {active_connections};\n"
|
|
f"Idle: {idle_connections};\n"
|
|
f"Pool Size: {pool_size};\n"
|
|
f"Total connection checkouts: {checkout_count}"
|
|
)
|
|
|
|
@event.listens_for(Engine, "checkin")
|
|
def log_checkin(dbapi_connection, connection_record): # type: ignore
|
|
global checkin_count
|
|
checkin_count += 1
|
|
logger.debug(f"Total connection checkins: {checkin_count}")
|
|
|
|
|
|
def get_db_current_time(db_session: Session) -> datetime:
|
|
result = db_session.execute(text("SELECT NOW()")).scalar()
|
|
if result is None:
|
|
raise ValueError("Database did not return a time")
|
|
return result
|
|
|
|
|
|
def is_valid_schema_name(name: str) -> bool:
|
|
return SCHEMA_NAME_REGEX.match(name) is not None
|
|
|
|
|
|
class SqlEngine:
|
|
_engine: Engine | None = None
|
|
_readonly_engine: Engine | None = None
|
|
_lock: threading.Lock = threading.Lock()
|
|
_readonly_lock: threading.Lock = threading.Lock()
|
|
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
|
|
|
@classmethod
|
|
def init_engine(
|
|
cls,
|
|
pool_size: int,
|
|
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
|
|
max_overflow: int,
|
|
app_name: str | None = None,
|
|
db_api: str = SYNC_DB_API,
|
|
use_iam: bool = USE_IAM_AUTH,
|
|
connection_string: str | None = None,
|
|
**extra_engine_kwargs: Any,
|
|
) -> None:
|
|
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
|
|
important args, and if incorrectly specified, we have run into hitting the pool
|
|
limit / using too many connections and overwhelming the database.
|
|
|
|
Specifying connection_string directly will cause some of the other parameters
|
|
to be ignored.
|
|
"""
|
|
with cls._lock:
|
|
if cls._engine:
|
|
return
|
|
|
|
if not connection_string:
|
|
connection_string = build_connection_string(
|
|
db_api=db_api,
|
|
app_name=cls._app_name + "_sync",
|
|
use_iam_auth=use_iam,
|
|
)
|
|
|
|
# Start with base kwargs that are valid for all pool types
|
|
final_engine_kwargs: dict[str, Any] = {}
|
|
|
|
if POSTGRES_USE_NULL_POOL:
|
|
# if null pool is specified, then we need to make sure that
|
|
# we remove any passed in kwargs related to pool size that would
|
|
# cause the initialization to fail
|
|
final_engine_kwargs.update(extra_engine_kwargs)
|
|
|
|
final_engine_kwargs["poolclass"] = pool.NullPool
|
|
if "pool_size" in final_engine_kwargs:
|
|
del final_engine_kwargs["pool_size"]
|
|
if "max_overflow" in final_engine_kwargs:
|
|
del final_engine_kwargs["max_overflow"]
|
|
else:
|
|
final_engine_kwargs["pool_size"] = pool_size
|
|
final_engine_kwargs["max_overflow"] = max_overflow
|
|
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
|
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
|
|
|
# any passed in kwargs override the defaults
|
|
final_engine_kwargs.update(extra_engine_kwargs)
|
|
|
|
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
|
# echo=True here for inspecting all emitted db queries
|
|
engine = create_engine(connection_string, **final_engine_kwargs)
|
|
|
|
if use_iam:
|
|
event.listen(engine, "do_connect", provide_iam_token)
|
|
|
|
cls._engine = engine
|
|
|
|
@classmethod
|
|
def init_readonly_engine(
|
|
cls,
|
|
pool_size: int,
|
|
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
|
|
max_overflow: int,
|
|
**extra_engine_kwargs: Any,
|
|
) -> None:
|
|
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
|
|
important args, and if incorrectly specified, we have run into hitting the pool
|
|
limit / using too many connections and overwhelming the database."""
|
|
with cls._readonly_lock:
|
|
if cls._readonly_engine:
|
|
return
|
|
|
|
if not DB_READONLY_USER or not DB_READONLY_PASSWORD:
|
|
raise ValueError(
|
|
"Custom database user credentials not configured in environment variables"
|
|
)
|
|
|
|
# Build connection string with custom user
|
|
connection_string = build_connection_string(
|
|
user=DB_READONLY_USER,
|
|
password=DB_READONLY_PASSWORD,
|
|
use_iam_auth=False, # Custom users typically don't use IAM auth
|
|
db_api=SYNC_DB_API, # Explicitly use sync DB API
|
|
)
|
|
|
|
# Start with base kwargs that are valid for all pool types
|
|
final_engine_kwargs: dict[str, Any] = {}
|
|
|
|
if POSTGRES_USE_NULL_POOL:
|
|
# if null pool is specified, then we need to make sure that
|
|
# we remove any passed in kwargs related to pool size that would
|
|
# cause the initialization to fail
|
|
final_engine_kwargs.update(extra_engine_kwargs)
|
|
|
|
final_engine_kwargs["poolclass"] = pool.NullPool
|
|
if "pool_size" in final_engine_kwargs:
|
|
del final_engine_kwargs["pool_size"]
|
|
if "max_overflow" in final_engine_kwargs:
|
|
del final_engine_kwargs["max_overflow"]
|
|
else:
|
|
final_engine_kwargs["pool_size"] = pool_size
|
|
final_engine_kwargs["max_overflow"] = max_overflow
|
|
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
|
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
|
|
|
# any passed in kwargs override the defaults
|
|
final_engine_kwargs.update(extra_engine_kwargs)
|
|
|
|
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
|
# echo=True here for inspecting all emitted db queries
|
|
engine = create_engine(connection_string, **final_engine_kwargs)
|
|
|
|
if USE_IAM_AUTH:
|
|
event.listen(engine, "do_connect", provide_iam_token)
|
|
|
|
cls._readonly_engine = engine
|
|
|
|
@classmethod
|
|
def get_engine(cls) -> Engine:
|
|
if not cls._engine:
|
|
raise RuntimeError("Engine not initialized. Must call init_engine first.")
|
|
return cls._engine
|
|
|
|
@classmethod
|
|
def get_readonly_engine(cls) -> Engine:
|
|
if not cls._readonly_engine:
|
|
raise RuntimeError(
|
|
"Readonly engine not initialized. Must call init_readonly_engine first."
|
|
)
|
|
return cls._readonly_engine
|
|
|
|
@classmethod
|
|
def set_app_name(cls, app_name: str) -> None:
|
|
cls._app_name = app_name
|
|
|
|
@classmethod
|
|
def get_app_name(cls) -> str:
|
|
if not cls._app_name:
|
|
return ""
|
|
return cls._app_name
|
|
|
|
@classmethod
|
|
def reset_engine(cls) -> None:
|
|
with cls._lock:
|
|
if cls._engine:
|
|
cls._engine.dispose()
|
|
cls._engine = None
|
|
|
|
|
|
def get_all_tenant_ids() -> list[str]:
|
|
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
|
|
|
tenant_ids: list[str]
|
|
|
|
if not MULTI_TENANT:
|
|
return [POSTGRES_DEFAULT_SCHEMA]
|
|
|
|
with get_session_with_shared_schema() as session:
|
|
result = session.execute(
|
|
text(
|
|
f"""
|
|
SELECT schema_name
|
|
FROM information_schema.schemata
|
|
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', '{POSTGRES_DEFAULT_SCHEMA}')"""
|
|
)
|
|
)
|
|
tenant_ids = [row[0] for row in result]
|
|
|
|
valid_tenants = [
|
|
tenant
|
|
for tenant in tenant_ids
|
|
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
|
]
|
|
return valid_tenants
|
|
|
|
|
|
def get_sqlalchemy_engine() -> Engine:
|
|
return SqlEngine.get_engine()
|
|
|
|
|
|
def get_readonly_sqlalchemy_engine() -> Engine:
|
|
return SqlEngine.get_readonly_engine()
|
|
|
|
|
|
async def get_async_connection() -> Any:
|
|
"""
|
|
Custom connection function for async engine when using IAM auth.
|
|
"""
|
|
host = POSTGRES_HOST
|
|
port = POSTGRES_PORT
|
|
user = POSTGRES_USER
|
|
db = POSTGRES_DB
|
|
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
|
|
|
# asyncpg requires 'ssl="require"' if SSL needed
|
|
return await asyncpg.connect(
|
|
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
|
|
)
|
|
|
|
|
|
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
|
global _ASYNC_ENGINE
|
|
if _ASYNC_ENGINE is None:
|
|
app_name = SqlEngine.get_app_name() + "_async"
|
|
connection_string = build_connection_string(
|
|
db_api=ASYNC_DB_API,
|
|
use_iam_auth=USE_IAM_AUTH,
|
|
)
|
|
|
|
connect_args: dict[str, Any] = {}
|
|
if app_name:
|
|
connect_args["server_settings"] = {"application_name": app_name}
|
|
|
|
connect_args["ssl"] = ssl_context
|
|
|
|
engine_kwargs = {
|
|
"connect_args": connect_args,
|
|
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
|
|
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
|
}
|
|
|
|
if POSTGRES_USE_NULL_POOL:
|
|
engine_kwargs["poolclass"] = pool.NullPool
|
|
else:
|
|
engine_kwargs["pool_size"] = POSTGRES_API_SERVER_POOL_SIZE
|
|
engine_kwargs["max_overflow"] = POSTGRES_API_SERVER_POOL_OVERFLOW
|
|
|
|
_ASYNC_ENGINE = create_async_engine(
|
|
connection_string,
|
|
**engine_kwargs,
|
|
)
|
|
|
|
if USE_IAM_AUTH:
|
|
|
|
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
|
|
def provide_iam_token_async(
|
|
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
|
) -> None:
|
|
# For async engine using asyncpg, we still need to set the IAM token here.
|
|
host = POSTGRES_HOST
|
|
port = POSTGRES_PORT
|
|
user = POSTGRES_USER
|
|
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
|
cparams["password"] = token
|
|
cparams["ssl"] = ssl_context
|
|
|
|
return _ASYNC_ENGINE
|
|
|
|
|
|
engine = get_sqlalchemy_async_engine()
|
|
AsyncSessionLocal = sessionmaker( # type: ignore
|
|
bind=engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def get_session_with_current_tenant() -> Generator[Session, None, None]:
|
|
tenant_id = get_current_tenant_id()
|
|
|
|
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
|
yield session
|
|
|
|
|
|
# Used in multi tenant mode when need to refer to the shared `public` schema
|
|
@contextmanager
|
|
def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
|
token = CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
|
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
|
|
yield session
|
|
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
|
|
|
|
|
def _set_search_path_on_checkout__listener(
|
|
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
|
|
) -> None:
|
|
"""Listener to make sure we ALWAYS set the search path on checkout."""
|
|
tenant_id = get_current_tenant_id()
|
|
if tenant_id and is_valid_schema_name(tenant_id):
|
|
with dbapi_conn.cursor() as cursor:
|
|
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
|
|
|
|
|
@contextmanager
|
|
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
|
"""
|
|
Generate a database session for a specific tenant.
|
|
"""
|
|
engine = get_sqlalchemy_engine()
|
|
|
|
event.listen(engine, "checkout", _set_search_path_on_checkout__listener)
|
|
|
|
if not is_valid_schema_name(tenant_id):
|
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
|
|
|
with engine.connect() as connection:
|
|
dbapi_connection = connection.connection
|
|
cursor = dbapi_connection.cursor()
|
|
try:
|
|
# NOTE: don't use `text()` here since we're using the cursor directly
|
|
cursor.execute(f'SET search_path = "{tenant_id}"')
|
|
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
|
cursor.execute(
|
|
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
|
)
|
|
except Exception:
|
|
raise RuntimeError(f"search_path not set for {tenant_id}")
|
|
finally:
|
|
cursor.close()
|
|
|
|
try:
|
|
# automatically rollback or close
|
|
with Session(bind=connection, expire_on_commit=False) as session:
|
|
yield session
|
|
finally:
|
|
# always reset the search path on exit
|
|
if MULTI_TENANT:
|
|
cursor = dbapi_connection.cursor()
|
|
try:
|
|
cursor.execute('SET search_path TO "$user", public')
|
|
except Exception as e:
|
|
logger.warning(f"Failed to reset search path: {e}")
|
|
connection.rollback()
|
|
finally:
|
|
cursor.close()
|
|
|
|
|
|
def get_session() -> Generator[Session, None, None]:
|
|
"""For use w/ Depends for FastAPI endpoints.
|
|
|
|
Has some additional validation, and likely should be merged
|
|
with get_session_context_manager in the future."""
|
|
tenant_id = get_current_tenant_id()
|
|
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
|
raise BasicAuthenticationError(detail="User must authenticate")
|
|
|
|
if not is_valid_schema_name(tenant_id):
|
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
|
|
|
with get_session_context_manager() as db_session:
|
|
yield db_session
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def get_session_context_manager() -> Generator[Session, None, None]:
|
|
"""Context manager for database sessions."""
|
|
tenant_id = get_current_tenant_id()
|
|
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
|
yield session
|
|
|
|
|
|
def _set_search_path_on_transaction__listener(
|
|
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
|
|
) -> None:
|
|
"""Every time a new transaction is started,
|
|
set the search_path from the session's info."""
|
|
tenant_id = session.info.get("tenant_id")
|
|
if tenant_id:
|
|
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
|
|
|
|
|
|
async def get_async_session(
|
|
tenant_id: str | None = None,
|
|
) -> AsyncGenerator[AsyncSession, None]:
|
|
"""For use w/ Depends for *async* FastAPI endpoints.
|
|
|
|
For standard `async with ... as ...` use, use get_async_session_context_manager.
|
|
"""
|
|
|
|
if tenant_id is None:
|
|
tenant_id = get_current_tenant_id()
|
|
|
|
engine = get_sqlalchemy_async_engine()
|
|
|
|
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
|
# IMPORTANT: do NOT remove. The search_path seems to get reset on every `.commit()`
|
|
# without this. Do not fully understand why atm
|
|
async_session.info["tenant_id"] = tenant_id
|
|
event.listen(
|
|
async_session.sync_session,
|
|
"after_begin",
|
|
_set_search_path_on_transaction__listener,
|
|
)
|
|
|
|
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
|
await async_session.execute(
|
|
text(
|
|
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
|
)
|
|
)
|
|
|
|
if not is_valid_schema_name(tenant_id):
|
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
|
|
|
# don't need to set the search path for self-hosted + default schema
|
|
# this is also true for sync sessions, but just not adding it there for
|
|
# now to simplify / not change too much
|
|
if MULTI_TENANT or tenant_id != POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE:
|
|
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
|
|
|
yield async_session
|
|
|
|
|
|
def get_async_session_context_manager(
|
|
tenant_id: str | None = None,
|
|
) -> AsyncContextManager[AsyncSession]:
|
|
return asynccontextmanager(get_async_session)(tenant_id)
|
|
|
|
|
|
async def warm_up_connections(
|
|
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
|
|
) -> None:
|
|
sync_postgres_engine = get_sqlalchemy_engine()
|
|
connections = [
|
|
sync_postgres_engine.connect() for _ in range(sync_connections_to_warm_up)
|
|
]
|
|
for conn in connections:
|
|
conn.execute(text("SELECT 1"))
|
|
for conn in connections:
|
|
conn.close()
|
|
|
|
async_postgres_engine = get_sqlalchemy_async_engine()
|
|
async_connections = [
|
|
await async_postgres_engine.connect()
|
|
for _ in range(async_connections_to_warm_up)
|
|
]
|
|
for async_conn in async_connections:
|
|
await async_conn.execute(text("SELECT 1"))
|
|
for async_conn in async_connections:
|
|
await async_conn.close()
|
|
|
|
|
|
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
|
|
if USE_IAM_AUTH:
|
|
host = POSTGRES_HOST
|
|
port = POSTGRES_PORT
|
|
user = POSTGRES_USER
|
|
region = os.getenv("AWS_REGION_NAME", "us-east-2")
|
|
# Configure for psycopg2 with IAM token
|
|
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
|
|
|
|
|
@contextmanager
|
|
def get_db_readonly_user_session_with_current_tenant() -> (
|
|
Generator[Session, None, None]
|
|
):
|
|
"""
|
|
Generate a database session using a custom database user for the current tenant.
|
|
The custom user credentials are obtained from environment variables.
|
|
"""
|
|
tenant_id = get_current_tenant_id()
|
|
|
|
readonly_engine = get_readonly_sqlalchemy_engine()
|
|
|
|
event.listen(readonly_engine, "checkout", _set_search_path_on_checkout__listener)
|
|
|
|
if not is_valid_schema_name(tenant_id):
|
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
|
|
|
with readonly_engine.connect() as connection:
|
|
dbapi_connection = connection.connection
|
|
cursor = dbapi_connection.cursor()
|
|
try:
|
|
cursor.execute(f'SET search_path = "{tenant_id}"')
|
|
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
|
cursor.execute(
|
|
text(
|
|
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
|
)
|
|
)
|
|
finally:
|
|
cursor.close()
|
|
|
|
with Session(bind=connection, expire_on_commit=False) as session:
|
|
try:
|
|
yield session
|
|
finally:
|
|
if MULTI_TENANT:
|
|
cursor = dbapi_connection.cursor()
|
|
try:
|
|
cursor.execute('SET search_path TO "$user", public')
|
|
except Exception as e:
|
|
logger.warning(f"Failed to reset search path: {e}")
|
|
connection.rollback()
|
|
finally:
|
|
cursor.close()
|