diff --git a/backend/onyx/db/connector_credential_pair.py b/backend/onyx/db/connector_credential_pair.py index c8651e306..712e81894 100644 --- a/backend/onyx/db/connector_credential_pair.py +++ b/backend/onyx/db/connector_credential_pair.py @@ -12,6 +12,7 @@ from sqlalchemy.orm import Session from onyx.configs.app_configs import DISABLE_AUTH from onyx.db.connector import fetch_connector_by_id +from onyx.db.credentials import fetch_credential_by_id from onyx.db.credentials import fetch_credential_by_id_for_user from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus @@ -388,14 +389,23 @@ def add_credential_to_connector( auto_sync_options: dict | None = None, initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE, last_successful_index_time: datetime | None = None, + seeding_flow: bool = False, ) -> StatusResponse: connector = fetch_connector_by_id(connector_id, db_session) - credential = fetch_credential_by_id_for_user( - credential_id, - user, - db_session, - get_editable=False, - ) + + # If we are in the seeding flow, we shouldn't need to check if the credential belongs to the user + if seeding_flow: + credential = fetch_credential_by_id( + db_session=db_session, + credential_id=credential_id, + ) + else: + credential = fetch_credential_by_id_for_user( + credential_id, + user, + db_session, + get_editable=False, + ) if connector is None: raise HTTPException(status_code=404, detail="Connector does not exist") diff --git a/backend/onyx/db/engine.py b/backend/onyx/db/engine.py index 0d1ebceea..30d8d3ab4 100644 --- a/backend/onyx/db/engine.py +++ b/backend/onyx/db/engine.py @@ -354,6 +354,26 @@ async def get_current_tenant_id(request: Request) -> str: raise HTTPException(status_code=500, detail="Internal server error") +# Listen for events on the synchronous Session class +@event.listens_for(Session, "after_begin") +def _set_search_path( + session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any +) -> None: + """Every time a new transaction is started, + set the search_path from the session's info.""" + tenant_id = session.info.get("tenant_id") + if tenant_id: + connection.exec_driver_sql(f'SET search_path = "{tenant_id}"') + + +engine = get_sqlalchemy_async_engine() +AsyncSessionLocal = sessionmaker( # type: ignore + bind=engine, + class_=AsyncSession, + expire_on_commit=False, +) + + @asynccontextmanager async def get_async_session_with_tenant( tenant_id: str | None = None, @@ -363,41 +383,22 @@ async def get_async_session_with_tenant( if not is_valid_schema_name(tenant_id): logger.error(f"Invalid tenant ID: {tenant_id}") - raise Exception("Invalid tenant ID") + raise ValueError("Invalid tenant ID") - engine = get_sqlalchemy_async_engine() - async_session_factory = sessionmaker( - bind=engine, expire_on_commit=False, class_=AsyncSession - ) # type: ignore + async with AsyncSessionLocal() as session: + session.sync_session.info["tenant_id"] = tenant_id - async def _set_search_path(session: AsyncSession, tenant_id: str) -> None: - await session.execute(text(f'SET search_path = "{tenant_id}"')) - - async with async_session_factory() as session: - # Register an event listener that is called whenever a new transaction starts - @event.listens_for(session.sync_session, "after_begin") - def after_begin(session_: Any, transaction: Any, connection: Any) -> None: - # Because the event is sync, we can't directly await here. - # Instead we queue up an asyncio task to ensures - # the next statement sets the search_path - session_.do_orm_execute = lambda state: connection.exec_driver_sql( - f'SET search_path = "{tenant_id}"' + if POSTGRES_IDLE_SESSIONS_TIMEOUT: + await session.execute( + text( + f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}" + ) ) try: - await _set_search_path(session, tenant_id) - - if POSTGRES_IDLE_SESSIONS_TIMEOUT: - await session.execute( - text( - f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}" - ) - ) - except Exception: - logger.exception("Error setting search_path.") - raise - else: yield session + finally: + pass @contextmanager diff --git a/backend/onyx/seeding/load_docs.py b/backend/onyx/seeding/load_docs.py index 44d0750e0..b895a0f7a 100644 --- a/backend/onyx/seeding/load_docs.py +++ b/backend/onyx/seeding/load_docs.py @@ -189,6 +189,7 @@ def seed_initial_documents( groups=None, initial_status=ConnectorCredentialPairStatus.PAUSED, last_successful_index_time=last_index_time, + seeding_flow=True, ) cc_pair_id = cast(int, result.data) processed_docs = fetch_versioned_implementation(