Fix connection pooling

This commit is contained in:
Weves
2023-06-19 14:38:43 -06:00
committed by Chris Weaver
parent 490d39f081
commit 620579cbec
5 changed files with 28 additions and 16 deletions

View File

@@ -8,8 +8,8 @@ from danswer.db.connector import disable_connector
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import build_engine
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_successful_attempt
@@ -185,7 +185,7 @@ def run_indexing_jobs(db_session: Session) -> None:
def update_loop(delay: int = 10) -> None:
engine = build_engine()
engine = get_sqlalchemy_engine()
while True:
start = time.time()
logger.info(f"Running update, current time: {time.ctime(start)}")

View File

@@ -3,8 +3,8 @@ from typing import Any
from typing import Dict
from danswer.auth.schemas import UserRole
from danswer.db.engine import build_async_engine
from danswer.db.engine import get_async_session
from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
@@ -18,7 +18,7 @@ from sqlalchemy.future import select
async def get_user_count() -> int:
async with AsyncSession(build_async_engine()) as asession:
async with AsyncSession(get_sqlalchemy_async_engine()) as asession:
stmt = select(func.count(User.id))
result = await asession.execute(stmt)
user_count = result.scalar()

View File

@@ -1,6 +1,6 @@
from typing import Any
from danswer.db.engine import build_engine
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Credential
from danswer.db.models import User
from danswer.server.models import CredentialBase
@@ -136,7 +136,7 @@ def create_initial_public_credential() -> None:
"DB is not in a valid initial state."
"There must exist an empty public credential for data connectors that do not require additional Auth."
)
with Session(build_engine(), expire_on_commit=False) as db_session:
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
if first_credential is not None:

View File

@@ -20,6 +20,12 @@ from sqlalchemy.orm import Session
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
_SYNC_ENGINE: Engine | None = None
_ASYNC_ENGINE: AsyncEngine | None = None
def get_db_current_time(db_session: Session) -> datetime:
result = db_session.execute(text("SELECT NOW()")).scalar()
@@ -49,23 +55,29 @@ def build_connection_string(
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def build_engine() -> Engine:
connection_string = build_connection_string(db_api=SYNC_DB_API)
return create_engine(connection_string)
def get_sqlalchemy_engine() -> Engine:
global _SYNC_ENGINE
if _SYNC_ENGINE is None:
connection_string = build_connection_string(db_api=SYNC_DB_API)
_SYNC_ENGINE = create_engine(connection_string)
return _SYNC_ENGINE
def build_async_engine() -> AsyncEngine:
connection_string = build_connection_string()
return create_async_engine(connection_string)
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
connection_string = build_connection_string()
_ASYNC_ENGINE = create_async_engine(connection_string)
return _ASYNC_ENGINE
def get_session() -> Generator[Session, None, None]:
with Session(build_engine(), expire_on_commit=False) as session:
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(
build_async_engine(), expire_on_commit=False
get_sqlalchemy_async_engine(), expire_on_commit=False
) as async_session:
yield async_session

View File

@@ -34,8 +34,8 @@ from danswer.db.credentials import fetch_credential_by_id
from danswer.db.credentials import fetch_credentials
from danswer.db.credentials import mask_credential_dict
from danswer.db.credentials import update_credential
from danswer.db.engine import build_async_engine
from danswer.db.engine import get_session
from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.models import Connector
from danswer.db.models import IndexAttempt
@@ -89,7 +89,7 @@ async def promote_admin(
) -> None:
if user.role != UserRole.ADMIN:
raise HTTPException(status_code=401, detail="Unauthorized")
async with AsyncSession(build_async_engine()) as asession:
async with AsyncSession(get_sqlalchemy_async_engine()) as asession:
user_db = SQLAlchemyUserDatabase(asession, User) # type: ignore
user_to_promote = await user_db.get_by_email(user_email.user_email)
if not user_to_promote: