mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 03:58:30 +02:00
temp
This commit is contained in:
@@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
|||||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||||
)
|
)
|
||||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||||
|
|
||||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||||
|
@@ -268,27 +268,34 @@ async def get_async_session_with_tenant(
|
|||||||
) -> AsyncGenerator[AsyncSession, None]:
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
if tenant_id is None:
|
if tenant_id is None:
|
||||||
tenant_id = current_tenant_id.get()
|
tenant_id = current_tenant_id.get()
|
||||||
|
else:
|
||||||
|
current_tenant_id.set(tenant_id)
|
||||||
|
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||||
raise Exception("Invalid tenant ID")
|
raise Exception("Invalid tenant ID")
|
||||||
|
|
||||||
engine = get_sqlalchemy_async_engine()
|
engine = get_sqlalchemy_async_engine()
|
||||||
|
|
||||||
async_session_factory = sessionmaker(
|
async_session_factory = sessionmaker(
|
||||||
bind=engine, expire_on_commit=False, class_=AsyncSession
|
bind=engine,
|
||||||
) # type: ignore
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
)
|
||||||
|
|
||||||
async with async_session_factory() as session:
|
async with async_session_factory() as session:
|
||||||
try:
|
# Start a SAVEPOINT to ensure the SET command is effective
|
||||||
# Set the search_path to the tenant's schema
|
async with session.begin():
|
||||||
|
# Set the search_path at the session level
|
||||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||||
except Exception as e:
|
try:
|
||||||
logger.error(f"Error setting search_path: {str(e)}")
|
|
||||||
# You can choose to re-raise the exception or handle it
|
|
||||||
# Here, we'll re-raise to prevent proceeding with an incorrect session
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
yield session
|
yield session
|
||||||
|
finally:
|
||||||
|
# Optionally reset the search_path after the session ends
|
||||||
|
if MULTI_TENANT:
|
||||||
|
async with session.begin():
|
||||||
|
await session.execute(text('SET search_path TO "$user", public'))
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@@ -113,7 +113,7 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
|||||||
|
|
||||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||||
"OAuthAccount", lazy="joined", cascade="all, delete-orphan"
|
"OAuthAccount", lazy="selectin", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
role: Mapped[UserRole] = mapped_column(
|
role: Mapped[UserRole] = mapped_column(
|
||||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||||
|
Reference in New Issue
Block a user