fix blowing up the entire task on exception and trying to reuse an in… (#4179)

* fix blowing up the entire task on exception and trying to reuse an invalid db session

* list comprehension

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2025-03-03 16:57:27 -08:00 committed by GitHub
parent 33cc4be492
commit 61e8f371b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 14 deletions

View File

@ -4,7 +4,8 @@ from ee.onyx.server.reporting.usage_export_generation import create_new_usage_re
from onyx.background.celery.apps.primary import celery_app from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.db.chat import delete_chat_sessions_older_than from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant from onyx.db.engine import get_session_with_current_tenant
from onyx.server.settings.store import load_settings from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
@ -18,7 +19,26 @@ logger = setup_logger()
@celery_app.task(soft_time_limit=JOB_TIMEOUT) @celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None: def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session: with get_session_with_current_tenant() as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session) old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
try:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
##### #####

View File

@ -3,6 +3,7 @@ from datetime import datetime
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from typing import cast from typing import cast
from typing import Tuple
from uuid import UUID from uuid import UUID
from fastapi import HTTPException from fastapi import HTTPException
@ -11,6 +12,7 @@ from sqlalchemy import desc
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy import nullsfirst from sqlalchemy import nullsfirst
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy import Row
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy import update from sqlalchemy import update
from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.exc import MultipleResultsFound
@ -375,24 +377,33 @@ def delete_chat_session(
db_session.commit() db_session.commit()
def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None: def get_chat_sessions_older_than(
days_old: int, db_session: Session
) -> list[tuple[UUID | None, UUID]]:
"""
Retrieves chat sessions older than a specified number of days.
Args:
days_old: The number of days to consider as "old".
db_session: The database session.
Returns:
A list of tuples, where each tuple contains the user_id (can be None) and the chat_session_id of an old chat session.
"""
cutoff_time = datetime.utcnow() - timedelta(days=days_old) cutoff_time = datetime.utcnow() - timedelta(days=days_old)
old_sessions = db_session.execute( old_sessions: Sequence[Row[Tuple[UUID | None, UUID]]] = db_session.execute(
select(ChatSession.user_id, ChatSession.id).where( select(ChatSession.user_id, ChatSession.id).where(
ChatSession.time_created < cutoff_time ChatSession.time_created < cutoff_time
) )
).fetchall() ).fetchall()
for user_id, session_id in old_sessions: # convert old_sessions to a conventional list of tuples
try: returned_sessions: list[tuple[UUID | None, UUID]] = [
delete_chat_session( (user_id, session_id) for user_id, session_id in old_sessions
user_id, session_id, db_session, include_deleted=True, hard_delete=True ]
)
except Exception: return returned_sessions
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
def get_chat_message( def get_chat_message(