This commit is contained in:
pablodanswer
2024-10-22 09:33:41 -07:00
parent 8f67f1715c
commit f47d6798e1
3 changed files with 19 additions and 12 deletions

View File

@@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password" os.environ.get("POSTGRES_PASSWORD") or "password"
) )
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost" POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432" POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres" POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int( POSTGRES_API_SERVER_POOL_SIZE = int(

View File

@@ -268,27 +268,34 @@ async def get_async_session_with_tenant(
) -> AsyncGenerator[AsyncSession, None]: ) -> AsyncGenerator[AsyncSession, None]:
if tenant_id is None: if tenant_id is None:
tenant_id = current_tenant_id.get() tenant_id = current_tenant_id.get()
else:
current_tenant_id.set(tenant_id)
if not is_valid_schema_name(tenant_id): if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}") logger.error(f"Invalid tenant ID: {tenant_id}")
raise Exception("Invalid tenant ID") raise Exception("Invalid tenant ID")
engine = get_sqlalchemy_async_engine() engine = get_sqlalchemy_async_engine()
async_session_factory = sessionmaker( async_session_factory = sessionmaker(
bind=engine, expire_on_commit=False, class_=AsyncSession bind=engine,
) # type: ignore class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
async with async_session_factory() as session: async with async_session_factory() as session:
try: # Start a SAVEPOINT to ensure the SET command is effective
# Set the search_path to the tenant's schema async with session.begin():
# Set the search_path at the session level
await session.execute(text(f'SET search_path = "{tenant_id}"')) await session.execute(text(f'SET search_path = "{tenant_id}"'))
except Exception as e: try:
logger.error(f"Error setting search_path: {str(e)}")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
else:
yield session yield session
finally:
# Optionally reset the search_path after the session ends
if MULTI_TENANT:
async with session.begin():
await session.execute(text('SET search_path TO "$user", public'))
@contextmanager @contextmanager

View File

@@ -113,7 +113,7 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
class User(SQLAlchemyBaseUserTableUUID, Base): class User(SQLAlchemyBaseUserTableUUID, Base):
oauth_accounts: Mapped[list[OAuthAccount]] = relationship( oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
"OAuthAccount", lazy="joined", cascade="all, delete-orphan" "OAuthAccount", lazy="selectin", cascade="all, delete-orphan"
) )
role: Mapped[UserRole] = mapped_column( role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC) Enum(UserRole, native_enum=False, default=UserRole.BASIC)