This commit is contained in:
pablodanswer
2024-10-22 09:33:41 -07:00
parent 8f67f1715c
commit f47d6798e1
3 changed files with 19 additions and 12 deletions

View File

@@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int(

View File

@@ -268,27 +268,34 @@ async def get_async_session_with_tenant(
) -> AsyncGenerator[AsyncSession, None]:
if tenant_id is None:
tenant_id = current_tenant_id.get()
else:
current_tenant_id.set(tenant_id)
if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
raise Exception("Invalid tenant ID")
engine = get_sqlalchemy_async_engine()
async_session_factory = sessionmaker(
bind=engine, expire_on_commit=False, class_=AsyncSession
) # type: ignore
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
async with async_session_factory() as session:
try:
# Set the search_path to the tenant's schema
# Start a SAVEPOINT to ensure the SET command is effective
async with session.begin():
# Set the search_path at the session level
await session.execute(text(f'SET search_path = "{tenant_id}"'))
except Exception as e:
logger.error(f"Error setting search_path: {str(e)}")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
else:
try:
yield session
finally:
# Optionally reset the search_path after the session ends
if MULTI_TENANT:
async with session.begin():
await session.execute(text('SET search_path TO "$user", public'))
@contextmanager

View File

@@ -113,7 +113,7 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
class User(SQLAlchemyBaseUserTableUUID, Base):
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
"OAuthAccount", lazy="joined", cascade="all, delete-orphan"
"OAuthAccount", lazy="selectin", cascade="all, delete-orphan"
)
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)