From a2f144d80bebd520a4be4adc331e4dd369fb12f0 Mon Sep 17 00:00:00 2001 From: "Richard Kuo (Danswer)" Date: Thu, 20 Feb 2025 19:19:29 -0800 Subject: [PATCH] try schema_translation_map --- backend/onyx/db/engine.py | 88 ++++++++++++++++----------- backend/onyx/key_value_store/store.py | 29 +++++---- 2 files changed, 68 insertions(+), 49 deletions(-) diff --git a/backend/onyx/db/engine.py b/backend/onyx/db/engine.py index 86ae21778..a76559bf7 100644 --- a/backend/onyx/db/engine.py +++ b/backend/onyx/db/engine.py @@ -423,46 +423,39 @@ def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None if tenant_id is None: tenant_id = POSTGRES_DEFAULT_SCHEMA - engine = get_sqlalchemy_engine() + schema_translate_map = {None: tenant_id} - event.listen(engine, "checkout", set_search_path_on_checkout) + engine = get_sqlalchemy_engine() if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") - with engine.connect() as connection: + with engine.connect().execution_options( + schema_translate_map=schema_translate_map + ) as connection: dbapi_connection = connection.connection - cursor = dbapi_connection.cursor() - try: - cursor.execute(f'SET search_path = "{tenant_id}"') - if POSTGRES_IDLE_SESSIONS_TIMEOUT: + if POSTGRES_IDLE_SESSIONS_TIMEOUT: + try: + cursor = dbapi_connection.cursor() cursor.execute( text( f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}" ) ) - finally: - cursor.close() + finally: + cursor.close() with Session(bind=connection, expire_on_commit=False) as session: - try: - yield session - finally: - if MULTI_TENANT: - cursor = dbapi_connection.cursor() - try: - cursor.execute('SET search_path TO "$user", public') - finally: - cursor.close() + yield session -def set_search_path_on_checkout( - dbapi_conn: Any, connection_record: Any, connection_proxy: Any -) -> None: - tenant_id = get_current_tenant_id() - if tenant_id and is_valid_schema_name(tenant_id): - with dbapi_conn.cursor() as cursor: - cursor.execute(f'SET search_path TO "{tenant_id}"') +# def set_search_path_on_checkout( +# dbapi_conn: Any, connection_record: Any, connection_proxy: Any +# ) -> None: +# tenant_id = get_current_tenant_id() +# if tenant_id and is_valid_schema_name(tenant_id): +# with dbapi_conn.cursor() as cursor: +# cursor.execute(f'SET search_path TO "{tenant_id}"') def get_session_generator_with_tenant() -> Generator[Session, None, None]: @@ -478,23 +471,44 @@ def get_session() -> Generator[Session, None, None]: engine = get_sqlalchemy_engine() - with Session(engine, expire_on_commit=False) as session: - if MULTI_TENANT: - if not is_valid_schema_name(tenant_id): - raise HTTPException(status_code=400, detail="Invalid tenant ID") - session.execute(text(f'SET search_path = "{tenant_id}"')) - yield session + if MULTI_TENANT: + if not is_valid_schema_name(tenant_id): + raise HTTPException(status_code=400, detail="Invalid tenant ID") + schema_translate_map = {None: tenant_id} + with engine.connect().execution_options( + schema_translate_map=schema_translate_map + ) as connection: + with Session(bind=connection, expire_on_commit=False) as session: + yield session + else: + # single tenant + with engine.connect() as connection: + with Session(bind=connection, expire_on_commit=False) as session: + yield session async def get_async_session() -> AsyncGenerator[AsyncSession, None]: tenant_id = get_current_tenant_id() engine = get_sqlalchemy_async_engine() - async with AsyncSession(engine, expire_on_commit=False) as async_session: - if MULTI_TENANT: - if not is_valid_schema_name(tenant_id): - raise HTTPException(status_code=400, detail="Invalid tenant ID") - await async_session.execute(text(f'SET search_path = "{tenant_id}"')) - yield async_session + + if MULTI_TENANT: + if not is_valid_schema_name(tenant_id): + raise HTTPException(status_code=400, detail="Invalid tenant ID") + + # Create connection with schema translation + schema_translate_map = {None: tenant_id} + async with engine.connect() as connection: + connection = await connection.execution_options( + schema_translate_map=schema_translate_map + ) + async with AsyncSession( + bind=connection, expire_on_commit=False + ) as async_session: + yield async_session + else: + # single tenant + async with AsyncSession(engine, expire_on_commit=False) as async_session: + yield async_session def get_session_context_manager() -> ContextManager[Session]: diff --git a/backend/onyx/key_value_store/store.py b/backend/onyx/key_value_store/store.py index f0811b7e5..2036673d6 100644 --- a/backend/onyx/key_value_store/store.py +++ b/backend/onyx/key_value_store/store.py @@ -5,7 +5,6 @@ from typing import cast from fastapi import HTTPException from redis.client import Redis -from sqlalchemy import text from sqlalchemy.orm import Session from onyx.db.engine import get_sqlalchemy_engine @@ -40,17 +39,23 @@ class PgRedisKVStore(KeyValueStore): @contextmanager def _get_session(self) -> Iterator[Session]: engine = get_sqlalchemy_engine() - with Session(engine, expire_on_commit=False) as session: - if MULTI_TENANT: - if self.tenant_id == POSTGRES_DEFAULT_SCHEMA: - raise HTTPException( - status_code=401, detail="User must authenticate" - ) - if not is_valid_schema_name(self.tenant_id): - raise HTTPException(status_code=400, detail="Invalid tenant ID") - # Set the search_path to the tenant's schema - session.execute(text(f'SET search_path = "{self.tenant_id}"')) - yield session + if MULTI_TENANT: + if self.tenant_id == POSTGRES_DEFAULT_SCHEMA: + raise HTTPException(status_code=401, detail="User must authenticate") + if not is_valid_schema_name(self.tenant_id): + raise HTTPException(status_code=400, detail="Invalid tenant ID") + + schema_translate_map = {None: self.tenant_id} + with engine.connect().execution_options( + schema_translate_map=schema_translate_map + ) as connection: + with Session(bind=connection, expire_on_commit=False) as session: + yield session + else: + # single tenant + with engine.connect() as connection: + with Session(bind=connection, expire_on_commit=False) as session: + yield session def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None: # Not encrypted in Redis, but encrypted in Postgres