rkuo-danswer 7c283b090d
Feature/postgres connection names (#1998)
* avoid reindexing secondary indexes after they succeed

* use postgres application names to facilitate connection debugging

* centralize all postgres application_name constants in the constants file

* missed a couple of files

* mypy fixes

* update dev background script
2024-07-31 20:36:30 +00:00

148 lines
4.8 KiB
Python

import contextlib
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import datetime
from typing import ContextManager
from sqlalchemy import text
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
POSTGRES_APP_NAME = (
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
)
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
_SYNC_ENGINE: Engine | None = None
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory = None
def get_db_current_time(db_session: Session) -> datetime:
"""Get the current time from Postgres representing the start of the transaction
Within the same transaction this value will not update
This datetime object returned should be timezone aware, default Postgres timezone is UTC
"""
result = db_session.execute(text("SELECT NOW()")).scalar()
if result is None:
raise ValueError("Database did not return a time")
return result
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def init_sqlalchemy_engine(app_name: str) -> None:
global POSTGRES_APP_NAME
POSTGRES_APP_NAME = app_name
def get_sqlalchemy_engine() -> Engine:
global _SYNC_ENGINE
if _SYNC_ENGINE is None:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
)
_SYNC_ENGINE = create_engine(connection_string, pool_size=40, max_overflow=10)
return _SYNC_ENGINE
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
_ASYNC_ENGINE = create_async_engine(
connection_string,
connect_args={
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
},
pool_size=40,
max_overflow=10,
)
return _ASYNC_ENGINE
def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()
def get_session() -> Generator[Session, None, None]:
# The line below was added to monitor the latency caused by Postgres connections
# during API calls.
# with tracer.trace("db.get_session"):
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(
get_sqlalchemy_async_engine(), expire_on_commit=False
) as async_session:
yield async_session
async def warm_up_connections(
sync_connections_to_warm_up: int = 10, async_connections_to_warm_up: int = 10
) -> None:
sync_postgres_engine = get_sqlalchemy_engine()
connections = [
sync_postgres_engine.connect() for _ in range(sync_connections_to_warm_up)
]
for conn in connections:
conn.execute(text("SELECT 1"))
for conn in connections:
conn.close()
async_postgres_engine = get_sqlalchemy_async_engine()
async_connections = [
await async_postgres_engine.connect()
for _ in range(async_connections_to_warm_up)
]
for async_conn in async_connections:
await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections:
await async_conn.close()
def get_session_factory() -> sessionmaker[Session]:
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory