mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-28 21:05:17 +02:00
Admin usage for seeding (#3683)
* admin usage for seeding * functional * proper fix * k * typing
This commit is contained in:
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from onyx.configs.app_configs import DISABLE_AUTH
|
from onyx.configs.app_configs import DISABLE_AUTH
|
||||||
from onyx.db.connector import fetch_connector_by_id
|
from onyx.db.connector import fetch_connector_by_id
|
||||||
|
from onyx.db.credentials import fetch_credential_by_id
|
||||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||||
from onyx.db.enums import AccessType
|
from onyx.db.enums import AccessType
|
||||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||||
@@ -388,8 +389,17 @@ def add_credential_to_connector(
|
|||||||
auto_sync_options: dict | None = None,
|
auto_sync_options: dict | None = None,
|
||||||
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
|
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
|
||||||
last_successful_index_time: datetime | None = None,
|
last_successful_index_time: datetime | None = None,
|
||||||
|
seeding_flow: bool = False,
|
||||||
) -> StatusResponse:
|
) -> StatusResponse:
|
||||||
connector = fetch_connector_by_id(connector_id, db_session)
|
connector = fetch_connector_by_id(connector_id, db_session)
|
||||||
|
|
||||||
|
# If we are in the seeding flow, we shouldn't need to check if the credential belongs to the user
|
||||||
|
if seeding_flow:
|
||||||
|
credential = fetch_credential_by_id(
|
||||||
|
db_session=db_session,
|
||||||
|
credential_id=credential_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
credential = fetch_credential_by_id_for_user(
|
credential = fetch_credential_by_id_for_user(
|
||||||
credential_id,
|
credential_id,
|
||||||
user,
|
user,
|
||||||
|
@@ -354,6 +354,26 @@ async def get_current_tenant_id(request: Request) -> str:
|
|||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
|
|
||||||
|
# Listen for events on the synchronous Session class
|
||||||
|
@event.listens_for(Session, "after_begin")
|
||||||
|
def _set_search_path(
|
||||||
|
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Every time a new transaction is started,
|
||||||
|
set the search_path from the session's info."""
|
||||||
|
tenant_id = session.info.get("tenant_id")
|
||||||
|
if tenant_id:
|
||||||
|
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
|
||||||
|
|
||||||
|
|
||||||
|
engine = get_sqlalchemy_async_engine()
|
||||||
|
AsyncSessionLocal = sessionmaker( # type: ignore
|
||||||
|
bind=engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_async_session_with_tenant(
|
async def get_async_session_with_tenant(
|
||||||
tenant_id: str | None = None,
|
tenant_id: str | None = None,
|
||||||
@@ -363,41 +383,22 @@ async def get_async_session_with_tenant(
|
|||||||
|
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||||
raise Exception("Invalid tenant ID")
|
raise ValueError("Invalid tenant ID")
|
||||||
|
|
||||||
engine = get_sqlalchemy_async_engine()
|
async with AsyncSessionLocal() as session:
|
||||||
async_session_factory = sessionmaker(
|
session.sync_session.info["tenant_id"] = tenant_id
|
||||||
bind=engine, expire_on_commit=False, class_=AsyncSession
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
async def _set_search_path(session: AsyncSession, tenant_id: str) -> None:
|
|
||||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
|
||||||
|
|
||||||
async with async_session_factory() as session:
|
|
||||||
# Register an event listener that is called whenever a new transaction starts
|
|
||||||
@event.listens_for(session.sync_session, "after_begin")
|
|
||||||
def after_begin(session_: Any, transaction: Any, connection: Any) -> None:
|
|
||||||
# Because the event is sync, we can't directly await here.
|
|
||||||
# Instead we queue up an asyncio task to ensures
|
|
||||||
# the next statement sets the search_path
|
|
||||||
session_.do_orm_execute = lambda state: connection.exec_driver_sql(
|
|
||||||
f'SET search_path = "{tenant_id}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _set_search_path(session, tenant_id)
|
|
||||||
|
|
||||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||||
await session.execute(
|
await session.execute(
|
||||||
text(
|
text(
|
||||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
logger.exception("Error setting search_path.")
|
try:
|
||||||
raise
|
|
||||||
else:
|
|
||||||
yield session
|
yield session
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@@ -189,6 +189,7 @@ def seed_initial_documents(
|
|||||||
groups=None,
|
groups=None,
|
||||||
initial_status=ConnectorCredentialPairStatus.PAUSED,
|
initial_status=ConnectorCredentialPairStatus.PAUSED,
|
||||||
last_successful_index_time=last_index_time,
|
last_successful_index_time=last_index_time,
|
||||||
|
seeding_flow=True,
|
||||||
)
|
)
|
||||||
cc_pair_id = cast(int, result.data)
|
cc_pair_id = cast(int, result.data)
|
||||||
processed_docs = fetch_versioned_implementation(
|
processed_docs = fetch_versioned_implementation(
|
||||||
|
Reference in New Issue
Block a user