Admin usage for seeding (#3683)

* admin usage for seeding

* functional

* proper fix

* k

* typing
This commit is contained in:
pablonyx 2025-01-15 11:04:25 -08:00 committed by GitHub
parent eb70699c0b
commit 76ca650972
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 35 deletions

View File

@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@ -388,14 +389,23 @@ def add_credential_to_connector(
auto_sync_options: dict | None = None,
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
last_successful_index_time: datetime | None = None,
seeding_flow: bool = False,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id_for_user(
credential_id,
user,
db_session,
get_editable=False,
)
# If we are in the seeding flow, we shouldn't need to check if the credential belongs to the user
if seeding_flow:
credential = fetch_credential_by_id(
db_session=db_session,
credential_id=credential_id,
)
else:
credential = fetch_credential_by_id_for_user(
credential_id,
user,
db_session,
get_editable=False,
)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")

View File

@ -354,6 +354,26 @@ async def get_current_tenant_id(request: Request) -> str:
raise HTTPException(status_code=500, detail="Internal server error")
# Listen for events on the synchronous Session class
@event.listens_for(Session, "after_begin")
def _set_search_path(
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
) -> None:
"""Every time a new transaction is started,
set the search_path from the session's info."""
tenant_id = session.info.get("tenant_id")
if tenant_id:
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
engine = get_sqlalchemy_async_engine()
AsyncSessionLocal = sessionmaker( # type: ignore
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_async_session_with_tenant(
tenant_id: str | None = None,
@ -363,41 +383,22 @@ async def get_async_session_with_tenant(
if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
raise Exception("Invalid tenant ID")
raise ValueError("Invalid tenant ID")
engine = get_sqlalchemy_async_engine()
async_session_factory = sessionmaker(
bind=engine, expire_on_commit=False, class_=AsyncSession
) # type: ignore
async with AsyncSessionLocal() as session:
session.sync_session.info["tenant_id"] = tenant_id
async def _set_search_path(session: AsyncSession, tenant_id: str) -> None:
await session.execute(text(f'SET search_path = "{tenant_id}"'))
async with async_session_factory() as session:
# Register an event listener that is called whenever a new transaction starts
@event.listens_for(session.sync_session, "after_begin")
def after_begin(session_: Any, transaction: Any, connection: Any) -> None:
# Because the event is sync, we can't directly await here.
# Instead we queue up an asyncio task to ensures
# the next statement sets the search_path
session_.do_orm_execute = lambda state: connection.exec_driver_sql(
f'SET search_path = "{tenant_id}"'
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
text(
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
try:
await _set_search_path(session, tenant_id)
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
except Exception:
logger.exception("Error setting search_path.")
raise
else:
yield session
finally:
pass
@contextmanager

View File

@ -189,6 +189,7 @@ def seed_initial_documents(
groups=None,
initial_status=ConnectorCredentialPairStatus.PAUSED,
last_successful_index_time=last_index_time,
seeding_flow=True,
)
cc_pair_id = cast(int, result.data)
processed_docs = fetch_versioned_implementation(