mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 03:58:30 +02:00
Fix connection pooling
This commit is contained in:
@@ -8,8 +8,8 @@ from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.engine import build_engine
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_successful_attempt
|
||||
@@ -185,7 +185,7 @@ def run_indexing_jobs(db_session: Session) -> None:
|
||||
|
||||
|
||||
def update_loop(delay: int = 10) -> None:
|
||||
engine = build_engine()
|
||||
engine = get_sqlalchemy_engine()
|
||||
while True:
|
||||
start = time.time()
|
||||
logger.info(f"Running update, current time: {time.ctime(start)}")
|
||||
|
@@ -3,8 +3,8 @@ from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.engine import build_async_engine
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_sqlalchemy_async_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
@@ -18,7 +18,7 @@ from sqlalchemy.future import select
|
||||
|
||||
|
||||
async def get_user_count() -> int:
|
||||
async with AsyncSession(build_async_engine()) as asession:
|
||||
async with AsyncSession(get_sqlalchemy_async_engine()) as asession:
|
||||
stmt = select(func.count(User.id))
|
||||
result = await asession.execute(stmt)
|
||||
user_count = result.scalar()
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.db.engine import build_engine
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import User
|
||||
from danswer.server.models import CredentialBase
|
||||
@@ -136,7 +136,7 @@ def create_initial_public_credential() -> None:
|
||||
"DB is not in a valid initial state."
|
||||
"There must exist an empty public credential for data connectors that do not require additional Auth."
|
||||
)
|
||||
with Session(build_engine(), expire_on_commit=False) as db_session:
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
|
||||
|
||||
if first_credential is not None:
|
||||
|
@@ -20,6 +20,12 @@ from sqlalchemy.orm import Session
|
||||
SYNC_DB_API = "psycopg2"
|
||||
ASYNC_DB_API = "asyncpg"
|
||||
|
||||
# global so we don't create more than one engine per process
|
||||
# outside of being best practice, this is needed so we can properly pool
|
||||
# connections and not create a new pool on every request
|
||||
_SYNC_ENGINE: Engine | None = None
|
||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||
|
||||
|
||||
def get_db_current_time(db_session: Session) -> datetime:
|
||||
result = db_session.execute(text("SELECT NOW()")).scalar()
|
||||
@@ -49,23 +55,29 @@ def build_connection_string(
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
def build_engine() -> Engine:
|
||||
connection_string = build_connection_string(db_api=SYNC_DB_API)
|
||||
return create_engine(connection_string)
|
||||
def get_sqlalchemy_engine() -> Engine:
|
||||
global _SYNC_ENGINE
|
||||
if _SYNC_ENGINE is None:
|
||||
connection_string = build_connection_string(db_api=SYNC_DB_API)
|
||||
_SYNC_ENGINE = create_engine(connection_string)
|
||||
return _SYNC_ENGINE
|
||||
|
||||
|
||||
def build_async_engine() -> AsyncEngine:
|
||||
connection_string = build_connection_string()
|
||||
return create_async_engine(connection_string)
|
||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
global _ASYNC_ENGINE
|
||||
if _ASYNC_ENGINE is None:
|
||||
connection_string = build_connection_string()
|
||||
_ASYNC_ENGINE = create_async_engine(connection_string)
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
with Session(build_engine(), expire_on_commit=False) as session:
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with AsyncSession(
|
||||
build_async_engine(), expire_on_commit=False
|
||||
get_sqlalchemy_async_engine(), expire_on_commit=False
|
||||
) as async_session:
|
||||
yield async_session
|
||||
|
@@ -34,8 +34,8 @@ from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.credentials import fetch_credentials
|
||||
from danswer.db.credentials import mask_credential_dict
|
||||
from danswer.db.credentials import update_credential
|
||||
from danswer.db.engine import build_async_engine
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_async_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -89,7 +89,7 @@ async def promote_admin(
|
||||
) -> None:
|
||||
if user.role != UserRole.ADMIN:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
async with AsyncSession(build_async_engine()) as asession:
|
||||
async with AsyncSession(get_sqlalchemy_async_engine()) as asession:
|
||||
user_db = SQLAlchemyUserDatabase(asession, User) # type: ignore
|
||||
user_to_promote = await user_db.get_by_email(user_email.user_email)
|
||||
if not user_to_promote:
|
||||
|
Reference in New Issue
Block a user