mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-06 17:39:28 +02:00
Adjust pg engine intialization (#4408)
* Adjust pg engine intialization * Fix mypy * Rename var * fix typo * Fix tests
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -59,7 +60,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -65,7 +66,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
# actually setting the spawn method in the cloud fixes 95% of these.
|
||||
# setting pre ping might help even more, but not worrying about that yet
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
@@ -90,7 +90,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
@@ -227,17 +227,62 @@ class SqlEngine:
|
||||
return engine
|
||||
|
||||
@classmethod
|
||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||
def init_engine(
|
||||
cls,
|
||||
pool_size: int,
|
||||
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
|
||||
max_overflow: int,
|
||||
**extra_engine_kwargs: Any,
|
||||
) -> None:
|
||||
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
|
||||
important args, and if incorrectly specified, we have run into hitting the pool
|
||||
limit / using too many connections and overwhelming the database."""
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine(**engine_kwargs)
|
||||
if cls._engine:
|
||||
return
|
||||
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API,
|
||||
app_name=cls._app_name + "_sync",
|
||||
use_iam=USE_IAM_AUTH,
|
||||
)
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
final_engine_kwargs: dict[str, Any] = {}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
# if null pool is specified, then we need to make sure that
|
||||
# we remove any passed in kwargs related to pool size that would
|
||||
# cause the initialization to fail
|
||||
final_engine_kwargs.update(extra_engine_kwargs)
|
||||
|
||||
final_engine_kwargs["poolclass"] = pool.NullPool
|
||||
if "pool_size" in final_engine_kwargs:
|
||||
del final_engine_kwargs["pool_size"]
|
||||
if "max_overflow" in final_engine_kwargs:
|
||||
del final_engine_kwargs["max_overflow"]
|
||||
else:
|
||||
final_engine_kwargs["pool_size"] = pool_size
|
||||
final_engine_kwargs["max_overflow"] = max_overflow
|
||||
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||
|
||||
# any passed in kwargs override the defaults
|
||||
final_engine_kwargs.update(extra_engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
# echo=True here for inspecting all emitted db queries
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
|
||||
cls._engine = engine
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
if not cls._engine:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine()
|
||||
raise RuntimeError("Engine not initialized. Must call init_engine first.")
|
||||
return cls._engine
|
||||
|
||||
@classmethod
|
||||
|
@@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
@@ -48,6 +49,15 @@ instantiate the session directly within the test.
|
||||
# yield session
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def initialize_db() -> None:
|
||||
# Make sure that the db engine is initialized before any tests are run
|
||||
SqlEngine.init_engine(
|
||||
pool_size=10,
|
||||
max_overflow=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vespa_client() -> vespa_fixture:
|
||||
with get_session_context_manager() as db_session:
|
||||
|
Reference in New Issue
Block a user