mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 12:30:49 +02:00
Admin usage for seeding (#3683)
* admin usage for seeding * functional * proper fix * k * typing
This commit is contained in:
parent
eb70699c0b
commit
76ca650972
@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@ -388,14 +389,23 @@ def add_credential_to_connector(
|
||||
auto_sync_options: dict | None = None,
|
||||
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
seeding_flow: bool = False,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id_for_user(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
# If we are in the seeding flow, we shouldn't need to check if the credential belongs to the user
|
||||
if seeding_flow:
|
||||
credential = fetch_credential_by_id(
|
||||
db_session=db_session,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
else:
|
||||
credential = fetch_credential_by_id_for_user(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
@ -354,6 +354,26 @@ async def get_current_tenant_id(request: Request) -> str:
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
# Listen for events on the synchronous Session class
|
||||
@event.listens_for(Session, "after_begin")
|
||||
def _set_search_path(
|
||||
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Every time a new transaction is started,
|
||||
set the search_path from the session's info."""
|
||||
tenant_id = session.info.get("tenant_id")
|
||||
if tenant_id:
|
||||
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
|
||||
|
||||
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
AsyncSessionLocal = sessionmaker( # type: ignore
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
@ -363,41 +383,22 @@ async def get_async_session_with_tenant(
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||
raise Exception("Invalid tenant ID")
|
||||
raise ValueError("Invalid tenant ID")
|
||||
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async_session_factory = sessionmaker(
|
||||
bind=engine, expire_on_commit=False, class_=AsyncSession
|
||||
) # type: ignore
|
||||
async with AsyncSessionLocal() as session:
|
||||
session.sync_session.info["tenant_id"] = tenant_id
|
||||
|
||||
async def _set_search_path(session: AsyncSession, tenant_id: str) -> None:
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
async with async_session_factory() as session:
|
||||
# Register an event listener that is called whenever a new transaction starts
|
||||
@event.listens_for(session.sync_session, "after_begin")
|
||||
def after_begin(session_: Any, transaction: Any, connection: Any) -> None:
|
||||
# Because the event is sync, we can't directly await here.
|
||||
# Instead we queue up an asyncio task to ensures
|
||||
# the next statement sets the search_path
|
||||
session_.do_orm_execute = lambda state: connection.exec_driver_sql(
|
||||
f'SET search_path = "{tenant_id}"'
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
text(
|
||||
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await _set_search_path(session, tenant_id)
|
||||
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error setting search_path.")
|
||||
raise
|
||||
else:
|
||||
yield session
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -189,6 +189,7 @@ def seed_initial_documents(
|
||||
groups=None,
|
||||
initial_status=ConnectorCredentialPairStatus.PAUSED,
|
||||
last_successful_index_time=last_index_time,
|
||||
seeding_flow=True,
|
||||
)
|
||||
cc_pair_id = cast(int, result.data)
|
||||
processed_docs = fetch_versioned_implementation(
|
||||
|
Loading…
x
Reference in New Issue
Block a user