From 09dd7b424c59dfe33b8c7d247b320d23e34b051e Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sat, 19 Oct 2024 16:45:57 -0700 Subject: [PATCH] validated workaround for flush + reset --- backend/danswer/db/credentials.py | 8 --- backend/danswer/db/engine.py | 92 +++++++++++++++++++++---------- 2 files changed, 63 insertions(+), 37 deletions(-) diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index de040d5f4..72617dd33 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -232,8 +232,6 @@ def create_credential( user: User | None, db_session: Session, ) -> Credential: - all_credentials = db_session.query(Credential).all() - print(f"Total number of credentials: {len(all_credentials)}") credential = Credential( credential_json=credential_data.credential_json, user_id=user.id if user else None, @@ -243,13 +241,7 @@ def create_credential( curator_public=credential_data.curator_public, ) db_session.add(credential) - # Query and print length of all credentials - all_credentials = db_session.query(Credential).all() - print(f"Total number of credentials: {len(all_credentials)}") db_session.flush() # This ensures the credential gets an IDcredentials - all_credentials = db_session.query(Credential).all() - print(f"Total number of credentials: {len(all_credentials)}") - _relate_credential_to_user_groups__no_commit( db_session=db_session, credential_id=credential.id, diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index f0fdd8a24..4c0752da4 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -291,57 +291,91 @@ async def get_async_session_with_tenant( yield session -class TenantSession(Session): - def __init__(self, *args, **kwargs): - self.tenant_id = kwargs.pop("tenant_id", None) - super().__init__(*args, **kwargs) +# @contextmanager +# def get_session_with_tenant( +# tenant_id: str | None = None, +# ) -> Generator[Session, None, None]: +# """Generate a database session with the appropriate tenant schema set.""" +# if tenant_id is None: +# tenant_id = current_tenant_id.get() - def __enter__(self): - super().__enter__() - if self.tenant_id: - self.execute(text(f'SET search_path TO "{self.tenant_id}"')) - return self +# if not is_valid_schema_name(tenant_id): +# raise HTTPException(status_code=400, detail="Invalid tenant ID") + +# engine = get_sqlalchemy_engine() +# event.listen(engine, "checkout", set_search_path_on_checkout) +# SessionLocal = sessionmaker(bind=engine, expire_on_commit=False, class_=Session) + +# # Create a session +# with SessionLocal() as session: +# # Attach the event listener to set the search_path +# @event.listens_for(session, "after_begin") +# def _set_search_path(session, transaction, connection, **kw): +# connection.execute(text(f'SET search_path TO "{tenant_id}"')) + +# try: +# yield session +# finally: +# if MULTI_TENANT: +# # Reset search_path to default after the session is used +# session.execute(text('SET search_path TO "$user", public')) +# # Optionally, attach engine-level event listener +# def set_search_path_on_checkout(dbapi_connection, connection_record, connection_proxy): +# tenant_id = current_tenant_id.get() +# if tenant_id and is_valid_schema_name(tenant_id): +# with dbapi_connection.cursor() as cursor: +# cursor.execute(f'SET search_path TO "{tenant_id}"') @contextmanager def get_session_with_tenant( tenant_id: str | None = None, ) -> Generator[Session, None, None]: - """Generate a database session with the appropriate tenant schema set.""" + """Generate a database session bound to a connection with the appropriate tenant schema set.""" + engine = get_sqlalchemy_engine() + event.listen(engine, "checkout", set_search_path_on_checkout) + if tenant_id is None: tenant_id = current_tenant_id.get() if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") - engine = get_sqlalchemy_engine() - SessionLocal = sessionmaker(bind=engine, expire_on_commit=False, class_=Session) - - # Create a session - with SessionLocal() as session: - # Attach the event listener to set the search_path - @event.listens_for(session, "after_begin") - def _set_search_path(session, transaction, connection, **kw): - connection.execute(text(f'SET search_path TO "{tenant_id}"')) + # Establish a raw connection + with engine.connect() as connection: + # Access the raw DBAPI connection and set the search_path + dbapi_connection = connection.connection + # Set the search_path outside of any transaction + cursor = dbapi_connection.cursor() try: - yield session + cursor.execute(f'SET search_path = "{tenant_id}"') finally: - if MULTI_TENANT: + cursor.close() + + # Bind the session to the connection + with Session(bind=connection, expire_on_commit=False) as session: + try: + yield session + finally: # Reset search_path to default after the session is used - session.execute(text('SET search_path TO "$user", public')) + if MULTI_TENANT: + print("zzzzzz resetting search path") + cursor = dbapi_connection.cursor() + try: + cursor.execute('SET search_path TO "$user", public') + finally: + cursor.close() -# Optionally, attach engine-level event listener -def set_search_path_on_checkout(dbapi_connection, connection_record, connection_proxy): +def set_search_path_on_checkout(dbapi_conn, connection_record, connection_proxy): tenant_id = current_tenant_id.get() if tenant_id and is_valid_schema_name(tenant_id): - with dbapi_connection.cursor() as cursor: + with dbapi_conn.cursor() as cursor: cursor.execute(f'SET search_path TO "{tenant_id}"') - - -engine = get_sqlalchemy_engine() -event.listen(engine, "checkout", set_search_path_on_checkout) + logger.debug( + f"Set search_path to {tenant_id} for connection {connection_record}" + ) def get_session_generator_with_tenant(