diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 9ee4994c7ad4..909871bb5d9b 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -8,8 +8,8 @@ from danswer.db.connector import disable_connector from danswer.db.connector import fetch_connectors from danswer.db.connector_credential_pair import update_connector_credential_pair from danswer.db.credentials import backend_update_credential_json -from danswer.db.engine import build_engine from danswer.db.engine import get_db_current_time +from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_inprogress_index_attempts from danswer.db.index_attempt import get_last_successful_attempt @@ -185,7 +185,7 @@ def run_indexing_jobs(db_session: Session) -> None: def update_loop(delay: int = 10) -> None: - engine = build_engine() + engine = get_sqlalchemy_engine() while True: start = time.time() logger.info(f"Running update, current time: {time.ctime(start)}") diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py index 20809d7e7e7c..28438a396fb3 100644 --- a/backend/danswer/db/auth.py +++ b/backend/danswer/db/auth.py @@ -3,8 +3,8 @@ from typing import Any from typing import Dict from danswer.auth.schemas import UserRole -from danswer.db.engine import build_async_engine from danswer.db.engine import get_async_session +from danswer.db.engine import get_sqlalchemy_async_engine from danswer.db.models import AccessToken from danswer.db.models import OAuthAccount from danswer.db.models import User @@ -18,7 +18,7 @@ from sqlalchemy.future import select async def get_user_count() -> int: - async with AsyncSession(build_async_engine()) as asession: + async with AsyncSession(get_sqlalchemy_async_engine()) as asession: stmt = select(func.count(User.id)) result = await asession.execute(stmt) user_count = result.scalar() diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index d748f793fcef..ddb4f2190d9f 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -1,6 +1,6 @@ from typing import Any -from danswer.db.engine import build_engine +from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import Credential from danswer.db.models import User from danswer.server.models import CredentialBase @@ -136,7 +136,7 @@ def create_initial_public_credential() -> None: "DB is not in a valid initial state." "There must exist an empty public credential for data connectors that do not require additional Auth." ) - with Session(build_engine(), expire_on_commit=False) as db_session: + with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session: first_credential = fetch_credential_by_id(public_cred_id, None, db_session) if first_credential is not None: diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 02eb1047944d..97c4e21c51e8 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -20,6 +20,12 @@ from sqlalchemy.orm import Session SYNC_DB_API = "psycopg2" ASYNC_DB_API = "asyncpg" +# global so we don't create more than one engine per process +# outside of being best practice, this is needed so we can properly pool +# connections and not create a new pool on every request +_SYNC_ENGINE: Engine | None = None +_ASYNC_ENGINE: AsyncEngine | None = None + def get_db_current_time(db_session: Session) -> datetime: result = db_session.execute(text("SELECT NOW()")).scalar() @@ -49,23 +55,29 @@ def build_connection_string( return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}" -def build_engine() -> Engine: - connection_string = build_connection_string(db_api=SYNC_DB_API) - return create_engine(connection_string) +def get_sqlalchemy_engine() -> Engine: + global _SYNC_ENGINE + if _SYNC_ENGINE is None: + connection_string = build_connection_string(db_api=SYNC_DB_API) + _SYNC_ENGINE = create_engine(connection_string) + return _SYNC_ENGINE -def build_async_engine() -> AsyncEngine: - connection_string = build_connection_string() - return create_async_engine(connection_string) +def get_sqlalchemy_async_engine() -> AsyncEngine: + global _ASYNC_ENGINE + if _ASYNC_ENGINE is None: + connection_string = build_connection_string() + _ASYNC_ENGINE = create_async_engine(connection_string) + return _ASYNC_ENGINE def get_session() -> Generator[Session, None, None]: - with Session(build_engine(), expire_on_commit=False) as session: + with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session: yield session async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async with AsyncSession( - build_async_engine(), expire_on_commit=False + get_sqlalchemy_async_engine(), expire_on_commit=False ) as async_session: yield async_session diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index 9319f5262009..c0bc4214a76a 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -34,8 +34,8 @@ from danswer.db.credentials import fetch_credential_by_id from danswer.db.credentials import fetch_credentials from danswer.db.credentials import mask_credential_dict from danswer.db.credentials import update_credential -from danswer.db.engine import build_async_engine from danswer.db.engine import get_session +from danswer.db.engine import get_sqlalchemy_async_engine from danswer.db.index_attempt import create_index_attempt from danswer.db.models import Connector from danswer.db.models import IndexAttempt @@ -89,7 +89,7 @@ async def promote_admin( ) -> None: if user.role != UserRole.ADMIN: raise HTTPException(status_code=401, detail="Unauthorized") - async with AsyncSession(build_async_engine()) as asession: + async with AsyncSession(get_sqlalchemy_async_engine()) as asession: user_db = SQLAlchemyUserDatabase(asession, User) # type: ignore user_to_promote = await user_db.get_by_email(user_email.user_email) if not user_to_promote: