mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
temp
This commit is contained in:
@@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
|
@@ -268,27 +268,34 @@ async def get_async_session_with_tenant(
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
if tenant_id is None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
else:
|
||||
current_tenant_id.set(tenant_id)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||
raise Exception("Invalid tenant ID")
|
||||
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
|
||||
async_session_factory = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, class_=AsyncSession
|
||||
) # type: ignore
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
# Set the search_path to the tenant's schema
|
||||
# Start a SAVEPOINT to ensure the SET command is effective
|
||||
async with session.begin():
|
||||
# Set the search_path at the session level
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting search_path: {str(e)}")
|
||||
# You can choose to re-raise the exception or handle it
|
||||
# Here, we'll re-raise to prevent proceeding with an incorrect session
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Optionally reset the search_path after the session ends
|
||||
if MULTI_TENANT:
|
||||
async with session.begin():
|
||||
await session.execute(text('SET search_path TO "$user", public'))
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@@ -113,7 +113,7 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
"OAuthAccount", lazy="joined", cascade="all, delete-orphan"
|
||||
"OAuthAccount", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
|
Reference in New Issue
Block a user