try schema_translation_map

This commit is contained in:
Richard Kuo (Danswer) 2025-02-20 19:19:29 -08:00
parent 599b7705c2
commit a2f144d80b
2 changed files with 68 additions and 49 deletions

View File

@ -423,46 +423,39 @@ def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None
if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA
engine = get_sqlalchemy_engine()
schema_translate_map = {None: tenant_id}
event.listen(engine, "checkout", set_search_path_on_checkout)
engine = get_sqlalchemy_engine()
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
with engine.connect() as connection:
with engine.connect().execution_options(
schema_translate_map=schema_translate_map
) as connection:
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
try:
cursor = dbapi_connection.cursor()
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
finally:
cursor.close()
finally:
cursor.close()
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
yield session
def set_search_path_on_checkout(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
tenant_id = get_current_tenant_id()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
# def set_search_path_on_checkout(
# dbapi_conn: Any, connection_record: Any, connection_proxy: Any
# ) -> None:
# tenant_id = get_current_tenant_id()
# if tenant_id and is_valid_schema_name(tenant_id):
# with dbapi_conn.cursor() as cursor:
# cursor.execute(f'SET search_path TO "{tenant_id}"')
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
@ -478,23 +471,44 @@ def get_session() -> Generator[Session, None, None]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
schema_translate_map = {None: tenant_id}
with engine.connect().execution_options(
schema_translate_map=schema_translate_map
) as connection:
with Session(bind=connection, expire_on_commit=False) as session:
yield session
else:
# single tenant
with engine.connect() as connection:
with Session(bind=connection, expire_on_commit=False) as session:
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
tenant_id = get_current_tenant_id()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
yield async_session
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Create connection with schema translation
schema_translate_map = {None: tenant_id}
async with engine.connect() as connection:
connection = await connection.execution_options(
schema_translate_map=schema_translate_map
)
async with AsyncSession(
bind=connection, expire_on_commit=False
) as async_session:
yield async_session
else:
# single tenant
async with AsyncSession(engine, expire_on_commit=False) as async_session:
yield async_session
def get_session_context_manager() -> ContextManager[Session]:

View File

@ -5,7 +5,6 @@ from typing import cast
from fastapi import HTTPException
from redis.client import Redis
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.db.engine import get_sqlalchemy_engine
@ -40,17 +39,23 @@ class PgRedisKVStore(KeyValueStore):
@contextmanager
def _get_session(self) -> Iterator[Session]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
raise HTTPException(
status_code=401, detail="User must authenticate"
)
if not is_valid_schema_name(self.tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
yield session
if MULTI_TENANT:
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
raise HTTPException(status_code=401, detail="User must authenticate")
if not is_valid_schema_name(self.tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
schema_translate_map = {None: self.tenant_id}
with engine.connect().execution_options(
schema_translate_map=schema_translate_map
) as connection:
with Session(bind=connection, expire_on_commit=False) as session:
yield session
else:
# single tenant
with engine.connect() as connection:
with Session(bind=connection, expire_on_commit=False) as session:
yield session
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# Not encrypted in Redis, but encrypted in Postgres