validated workaround for flush + reset

This commit is contained in:
pablodanswer 2024-10-19 16:45:57 -07:00
parent a2fd8d5e0a
commit 09dd7b424c
2 changed files with 63 additions and 37 deletions

View File

@ -232,8 +232,6 @@ def create_credential(
user: User | None,
db_session: Session,
) -> Credential:
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
credential = Credential(
credential_json=credential_data.credential_json,
user_id=user.id if user else None,
@ -243,13 +241,7 @@ def create_credential(
curator_public=credential_data.curator_public,
)
db_session.add(credential)
# Query and print length of all credentials
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
db_session.flush() # This ensures the credential gets an IDcredentials
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
_relate_credential_to_user_groups__no_commit(
db_session=db_session,
credential_id=credential.id,

View File

@ -291,57 +291,91 @@ async def get_async_session_with_tenant(
yield session
class TenantSession(Session):
def __init__(self, *args, **kwargs):
self.tenant_id = kwargs.pop("tenant_id", None)
super().__init__(*args, **kwargs)
# @contextmanager
# def get_session_with_tenant(
# tenant_id: str | None = None,
# ) -> Generator[Session, None, None]:
# """Generate a database session with the appropriate tenant schema set."""
# if tenant_id is None:
# tenant_id = current_tenant_id.get()
def __enter__(self):
super().__enter__()
if self.tenant_id:
self.execute(text(f'SET search_path TO "{self.tenant_id}"'))
return self
# if not is_valid_schema_name(tenant_id):
# raise HTTPException(status_code=400, detail="Invalid tenant ID")
# engine = get_sqlalchemy_engine()
# event.listen(engine, "checkout", set_search_path_on_checkout)
# SessionLocal = sessionmaker(bind=engine, expire_on_commit=False, class_=Session)
# # Create a session
# with SessionLocal() as session:
# # Attach the event listener to set the search_path
# @event.listens_for(session, "after_begin")
# def _set_search_path(session, transaction, connection, **kw):
# connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# try:
# yield session
# finally:
# if MULTI_TENANT:
# # Reset search_path to default after the session is used
# session.execute(text('SET search_path TO "$user", public'))
# # Optionally, attach engine-level event listener
# def set_search_path_on_checkout(dbapi_connection, connection_record, connection_proxy):
# tenant_id = current_tenant_id.get()
# if tenant_id and is_valid_schema_name(tenant_id):
# with dbapi_connection.cursor() as cursor:
# cursor.execute(f'SET search_path TO "{tenant_id}"')
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
event.listen(engine, "checkout", set_search_path_on_checkout)
if tenant_id is None:
tenant_id = current_tenant_id.get()
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
engine = get_sqlalchemy_engine()
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False, class_=Session)
# Create a session
with SessionLocal() as session:
# Attach the event listener to set the search_path
@event.listens_for(session, "after_begin")
def _set_search_path(session, transaction, connection, **kw):
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
yield session
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
if MULTI_TENANT:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
session.execute(text('SET search_path TO "$user", public'))
if MULTI_TENANT:
print("zzzzzz resetting search path")
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
# Optionally, attach engine-level event listener
def set_search_path_on_checkout(dbapi_connection, connection_record, connection_proxy):
def set_search_path_on_checkout(dbapi_conn, connection_record, connection_proxy):
tenant_id = current_tenant_id.get()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_connection.cursor() as cursor:
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
engine = get_sqlalchemy_engine()
event.listen(engine, "checkout", set_search_path_on_checkout)
logger.debug(
f"Set search_path to {tenant_id} for connection {connection_record}"
)
def get_session_generator_with_tenant(