temporary stopgap for uperts

This commit is contained in:
pablodanswer 2024-09-21 19:50:56 -07:00
parent d3d63ee8f7
commit 127526d080
2 changed files with 18 additions and 5 deletions

View File

@ -377,4 +377,5 @@ STRIPE_PRICE = os.environ.get("STRIPE_PRICE", "price_1PsYoPHlhTYqRZib2t5ydpq5")
STRIPE_WEBHOOK_SECRET = (
"whsec_1cd766cd6bd08590aa8c46ab5c21ac32cad77c29de2e09a152a01971d6f405d3"
)
# STRIPE_SECRET_KEY="sk_test_51NwZq2HlhTYqRZibT2cssHV8E5QcLAUmaRLQPMjGb5aOxOWomVxOmzRgxf82ziDBuGdPP2GIDod8xe6DyqeGgUDi00KbsHPoT4"
DEFAULT_SCHEMA = os.environ.get("DEFAULT_SCHEMA", "public")

View File

@ -15,6 +15,8 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from danswer.configs.app_configs import DEFAULT_SCHEMA
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
from danswer.configs.app_configs import POSTGRES_DB
@ -129,7 +131,7 @@ def init_sqlalchemy_engine(app_name: str) -> None:
POSTGRES_APP_NAME = app_name
def get_sqlalchemy_engine() -> Engine:
def get_sqlalchemy_engine(schema: str = DEFAULT_SCHEMA):
global _SYNC_ENGINE
if _SYNC_ENGINE is None:
connection_string = build_connection_string(
@ -142,14 +144,20 @@ def get_sqlalchemy_engine() -> Engine:
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
@event.listens_for(_SYNC_ENGINE, "connect")
def set_search_path(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute(f"SET search_path TO {schema}")
cursor.close()
dbapi_connection.commit()
return _SYNC_ENGINE
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
_ASYNC_ENGINE = create_async_engine(
connection_string,
@ -161,6 +169,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _ASYNC_ENGINE
@ -168,12 +177,15 @@ def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()
def get_session() -> Generator[Session, None, None]:
def get_session(schema: str = DEFAULT_SCHEMA) -> Generator[Session, None, None]:
# The line below was added to monitor the latency caused by Postgres connections
# during API calls.
# with tracer.trace("db.get_session"):
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
session.execute(text(f"SET search_path TO {schema}"))
yield session
session.execute(text("SET search_path TO public"))
async def get_async_session() -> AsyncGenerator[AsyncSession, None]: