mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 20:08:36 +02:00
try schema_translation_map
This commit is contained in:
parent
599b7705c2
commit
a2f144d80b
@ -423,46 +423,39 @@ def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None
|
||||
if tenant_id is None:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
schema_translate_map = {None: tenant_id}
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
with engine.connect() as connection:
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
dbapi_connection = connection.connection
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
try:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
cursor.close()
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
yield session
|
||||
|
||||
|
||||
def set_search_path_on_checkout(
|
||||
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
if tenant_id and is_valid_schema_name(tenant_id):
|
||||
with dbapi_conn.cursor() as cursor:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
# def set_search_path_on_checkout(
|
||||
# dbapi_conn: Any, connection_record: Any, connection_proxy: Any
|
||||
# ) -> None:
|
||||
# tenant_id = get_current_tenant_id()
|
||||
# if tenant_id and is_valid_schema_name(tenant_id):
|
||||
# with dbapi_conn.cursor() as cursor:
|
||||
# cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
|
||||
|
||||
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
@ -478,23 +471,44 @@ def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield session
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
schema_translate_map = {None: tenant_id}
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
yield session
|
||||
else:
|
||||
# single tenant
|
||||
with engine.connect() as connection:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
tenant_id = get_current_tenant_id()
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield async_session
|
||||
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
# Create connection with schema translation
|
||||
schema_translate_map = {None: tenant_id}
|
||||
async with engine.connect() as connection:
|
||||
connection = await connection.execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
)
|
||||
async with AsyncSession(
|
||||
bind=connection, expire_on_commit=False
|
||||
) as async_session:
|
||||
yield async_session
|
||||
else:
|
||||
# single tenant
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
yield async_session
|
||||
|
||||
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
|
@ -5,7 +5,6 @@ from typing import cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from redis.client import Redis
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
@ -40,17 +39,23 @@ class PgRedisKVStore(KeyValueStore):
|
||||
@contextmanager
|
||||
def _get_session(self) -> Iterator[Session]:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User must authenticate"
|
||||
)
|
||||
if not is_valid_schema_name(self.tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
|
||||
yield session
|
||||
if MULTI_TENANT:
|
||||
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||
if not is_valid_schema_name(self.tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
schema_translate_map = {None: self.tenant_id}
|
||||
with engine.connect().execution_options(
|
||||
schema_translate_map=schema_translate_map
|
||||
) as connection:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
yield session
|
||||
else:
|
||||
# single tenant
|
||||
with engine.connect() as connection:
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
yield session
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
# Not encrypted in Redis, but encrypted in Postgres
|
||||
|
Loading…
x
Reference in New Issue
Block a user