mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-09 06:02:00 +02:00
temporary stopgap for uperts
This commit is contained in:
@ -377,4 +377,5 @@ STRIPE_PRICE = os.environ.get("STRIPE_PRICE", "price_1PsYoPHlhTYqRZib2t5ydpq5")
|
|||||||
STRIPE_WEBHOOK_SECRET = (
|
STRIPE_WEBHOOK_SECRET = (
|
||||||
"whsec_1cd766cd6bd08590aa8c46ab5c21ac32cad77c29de2e09a152a01971d6f405d3"
|
"whsec_1cd766cd6bd08590aa8c46ab5c21ac32cad77c29de2e09a152a01971d6f405d3"
|
||||||
)
|
)
|
||||||
# STRIPE_SECRET_KEY="sk_test_51NwZq2HlhTYqRZibT2cssHV8E5QcLAUmaRLQPMjGb5aOxOWomVxOmzRgxf82ziDBuGdPP2GIDod8xe6DyqeGgUDi00KbsHPoT4"
|
|
||||||
|
DEFAULT_SCHEMA = os.environ.get("DEFAULT_SCHEMA", "public")
|
@ -15,6 +15,8 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import DEFAULT_SCHEMA
|
||||||
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||||
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
|
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||||
from danswer.configs.app_configs import POSTGRES_DB
|
from danswer.configs.app_configs import POSTGRES_DB
|
||||||
@ -129,7 +131,7 @@ def init_sqlalchemy_engine(app_name: str) -> None:
|
|||||||
POSTGRES_APP_NAME = app_name
|
POSTGRES_APP_NAME = app_name
|
||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_engine() -> Engine:
|
def get_sqlalchemy_engine(schema: str = DEFAULT_SCHEMA):
|
||||||
global _SYNC_ENGINE
|
global _SYNC_ENGINE
|
||||||
if _SYNC_ENGINE is None:
|
if _SYNC_ENGINE is None:
|
||||||
connection_string = build_connection_string(
|
connection_string = build_connection_string(
|
||||||
@ -142,14 +144,20 @@ def get_sqlalchemy_engine() -> Engine:
|
|||||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||||
)
|
)
|
||||||
|
@event.listens_for(_SYNC_ENGINE, "connect")
|
||||||
|
def set_search_path(dbapi_connection, connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute(f"SET search_path TO {schema}")
|
||||||
|
cursor.close()
|
||||||
|
dbapi_connection.commit()
|
||||||
|
|
||||||
|
|
||||||
return _SYNC_ENGINE
|
return _SYNC_ENGINE
|
||||||
|
|
||||||
|
|
||||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||||
global _ASYNC_ENGINE
|
global _ASYNC_ENGINE
|
||||||
if _ASYNC_ENGINE is None:
|
if _ASYNC_ENGINE is None:
|
||||||
# underlying asyncpg cannot accept application_name directly in the connection string
|
|
||||||
# https://github.com/MagicStack/asyncpg/issues/798
|
|
||||||
connection_string = build_connection_string()
|
connection_string = build_connection_string()
|
||||||
_ASYNC_ENGINE = create_async_engine(
|
_ASYNC_ENGINE = create_async_engine(
|
||||||
connection_string,
|
connection_string,
|
||||||
@ -161,6 +169,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
|||||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _ASYNC_ENGINE
|
return _ASYNC_ENGINE
|
||||||
|
|
||||||
|
|
||||||
@ -168,12 +177,15 @@ def get_session_context_manager() -> ContextManager[Session]:
|
|||||||
return contextlib.contextmanager(get_session)()
|
return contextlib.contextmanager(get_session)()
|
||||||
|
|
||||||
|
|
||||||
def get_session() -> Generator[Session, None, None]:
|
def get_session(schema: str = DEFAULT_SCHEMA) -> Generator[Session, None, None]:
|
||||||
# The line below was added to monitor the latency caused by Postgres connections
|
# The line below was added to monitor the latency caused by Postgres connections
|
||||||
# during API calls.
|
# during API calls.
|
||||||
# with tracer.trace("db.get_session"):
|
# with tracer.trace("db.get_session"):
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
|
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
|
||||||
|
session.execute(text(f"SET search_path TO {schema}"))
|
||||||
yield session
|
yield session
|
||||||
|
session.execute(text("SET search_path TO public"))
|
||||||
|
|
||||||
|
|
||||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
Reference in New Issue
Block a user