Adjust pg engine intialization (#4408)

* Adjust pg engine intialization

* Fix mypy

* Rename var

* fix typo

* Fix tests
This commit is contained in:
Chris Weaver
2025-04-06 12:44:49 -07:00
committed by GitHub
parent 8b05f98d54
commit aadd4f212a
5 changed files with 69 additions and 9 deletions

View File

@@ -1,4 +1,5 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -59,7 +60,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -1,4 +1,5 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -65,7 +66,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -90,7 +90,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -227,17 +227,62 @@ class SqlEngine:
return engine
@classmethod
def init_engine(cls, **engine_kwargs: Any) -> None:
def init_engine(
cls,
pool_size: int,
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
max_overflow: int,
**extra_engine_kwargs: Any,
) -> None:
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
important args, and if incorrectly specified, we have run into hitting the pool
limit / using too many connections and overwhelming the database."""
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
if cls._engine:
return
connection_string = build_connection_string(
db_api=SYNC_DB_API,
app_name=cls._app_name + "_sync",
use_iam=USE_IAM_AUTH,
)
# Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {}
if POSTGRES_USE_NULL_POOL:
# if null pool is specified, then we need to make sure that
# we remove any passed in kwargs related to pool size that would
# cause the initialization to fail
final_engine_kwargs.update(extra_engine_kwargs)
final_engine_kwargs["poolclass"] = pool.NullPool
if "pool_size" in final_engine_kwargs:
del final_engine_kwargs["pool_size"]
if "max_overflow" in final_engine_kwargs:
del final_engine_kwargs["max_overflow"]
else:
final_engine_kwargs["pool_size"] = pool_size
final_engine_kwargs["max_overflow"] = max_overflow
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
# any passed in kwargs override the defaults
final_engine_kwargs.update(extra_engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
cls._engine = engine
@classmethod
def get_engine(cls) -> Engine:
if not cls._engine:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine()
raise RuntimeError("Engine not initialized. Must call init_engine first.")
return cls._engine
@classmethod

View File

@@ -4,6 +4,7 @@ import pytest
from onyx.auth.schemas import UserRole
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import SqlEngine
from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
@@ -48,6 +49,15 @@ instantiate the session directly within the test.
# yield session
@pytest.fixture(scope="session", autouse=True)
def initialize_db() -> None:
# Make sure that the db engine is initialized before any tests are run
SqlEngine.init_engine(
pool_size=10,
max_overflow=5,
)
@pytest.fixture
def vespa_client() -> vespa_fixture:
with get_session_context_manager() as db_session: