Added TTL to EE Celery tasks (#1713)

* Added TTL to EE Celery tasks

* fixed alembic files

* fixed frontend build issue and reworked file deletion

* FileD

* revert change

* reworked delete chatmessage

* added orphan cleanup

* ensured syntax

* default value to None

* made all deletions manual

* added fix

* Use tremor buttons now

* removed words

* Update 23957775e5f5_remove_feedback_foreignkey_constraint.py

* fixed alembic version
This commit is contained in:
hagen-danswer 2024-06-28 15:13:47 -07:00 committed by GitHub
parent de6d040349
commit bd0925611a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 412 additions and 75 deletions

View File

@ -0,0 +1,82 @@
"""remove-feedback-foreignkey-constraint
Revision ID: 23957775e5f5
Revises: bc9771dccadf
Create Date: 2024-06-27 16:04:51.480437
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "23957775e5f5"
down_revision = "bc9771dccadf"
branch_labels = None # type: ignore
depends_on = None # type: ignore
def upgrade() -> None:
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)
op.create_foreign_key(
"chat_feedback__chat_message_fk",
"chat_feedback",
"chat_message",
["chat_message_id"],
["id"],
ondelete="SET NULL",
)
op.alter_column(
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True
)
op.drop_constraint(
"document_retrieval__chat_message_fk", "document_retrieval", type_="foreignkey"
)
op.create_foreign_key(
"document_retrieval__chat_message_fk",
"document_retrieval",
"chat_message",
["chat_message_id"],
["id"],
ondelete="SET NULL",
)
op.alter_column(
"document_retrieval_feedback",
"chat_message_id",
existing_type=sa.Integer(),
nullable=True,
)
def downgrade() -> None:
op.alter_column(
"chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False
)
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)
op.create_foreign_key(
"chat_feedback__chat_message_fk",
"chat_feedback",
"chat_message",
["chat_message_id"],
["id"],
)
op.alter_column(
"document_retrieval_feedback",
"chat_message_id",
existing_type=sa.Integer(),
nullable=False,
)
op.drop_constraint(
"document_retrieval__chat_message_fk", "document_retrieval", type_="foreignkey"
)
op.create_foreign_key(
"document_retrieval__chat_message_fk",
"document_retrieval",
"chat_message",
["chat_message_id"],
["id"],
)

View File

@ -1,7 +1,7 @@
"""create usage reports table
Revision ID: bc9771dccadf
Revises: 48d14957fe80
Revises: 0568ccf46a6b
Create Date: 2024-06-18 10:04:26.800282
"""
@ -12,6 +12,7 @@ import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "bc9771dccadf"
down_revision = "0568ccf46a6b"
branch_labels: None = None
depends_on: None = None

View File

@ -1,3 +1,5 @@
from datetime import datetime
from datetime import timedelta
from uuid import UUID
from sqlalchemy import delete
@ -12,6 +14,7 @@ from danswer.auth.schemas import UserRole
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ChatMessage__SearchDoc
from danswer.db.models import ChatSession
from danswer.db.models import ChatSessionSharedStatus
from danswer.db.models import Prompt
@ -19,6 +22,7 @@ from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.pg_file_store import delete_lobj_by_name
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
@ -83,6 +87,54 @@ def get_chat_sessions_by_user(
return list(chat_sessions)
def delete_search_doc_message_relationship(
message_id: int, db_session: Session
) -> None:
db_session.query(ChatMessage__SearchDoc).filter(
ChatMessage__SearchDoc.chat_message_id == message_id
).delete(synchronize_session=False)
db_session.commit()
def delete_orphaned_search_docs(db_session: Session) -> None:
orphaned_docs = (
db_session.query(SearchDoc)
.outerjoin(ChatMessage__SearchDoc)
.filter(ChatMessage__SearchDoc.chat_message_id.is_(None))
.all()
)
for doc in orphaned_docs:
db_session.delete(doc)
db_session.commit()
def delete_messages_and_files_from_chat_session(
chat_session_id: int, db_session: Session
) -> None:
# Select messages older than cutoff_time with files
messages_with_files = db_session.execute(
select(ChatMessage.id, ChatMessage.files).where(
ChatMessage.chat_session_id == chat_session_id,
)
).fetchall()
for id, files in messages_with_files:
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
for file_info in files or {}:
lobj_name = file_info.get("id")
if lobj_name:
logger.info(f"Deleting file with name: {lobj_name}")
delete_lobj_by_name(lobj_name, db_session)
db_session.execute(
delete(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
)
db_session.commit()
delete_orphaned_search_docs(db_session)
def create_chat_session(
db_session: Session,
description: str,
@ -139,25 +191,30 @@ def delete_chat_session(
db_session: Session,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
if hard_delete:
stmt_messages = delete(ChatMessage).where(
ChatMessage.chat_session_id == chat_session_id
)
db_session.execute(stmt_messages)
stmt = delete(ChatSession).where(ChatSession.id == chat_session_id)
db_session.execute(stmt)
delete_messages_and_files_from_chat_session(chat_session_id, db_session)
db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id))
else:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
chat_session.deleted = True
db_session.commit()
def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None:
cutoff_time = datetime.utcnow() - timedelta(days=days_old)
old_sessions = db_session.execute(
select(ChatSession.user_id, ChatSession.id).where(
ChatSession.time_created < cutoff_time
)
).fetchall()
for user_id, session_id in old_sessions:
delete_chat_session(user_id, session_id, db_session, hard_delete=True)
def get_chat_message(
chat_message_id: int,
user_id: UUID | None,

View File

@ -690,7 +690,7 @@ class ChatSession(Base):
"ChatFolder", back_populates="chat_sessions"
)
messages: Mapped[list["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session", cascade="delete"
"ChatMessage", back_populates="chat_session"
)
persona: Mapped["Persona"] = relationship("Persona")
@ -737,10 +737,12 @@ class ChatMessage(Base):
chat_session: Mapped[ChatSession] = relationship("ChatSession")
prompt: Mapped[Optional["Prompt"]] = relationship("Prompt")
chat_message_feedbacks: Mapped[list["ChatMessageFeedback"]] = relationship(
"ChatMessageFeedback", back_populates="chat_message"
"ChatMessageFeedback",
back_populates="chat_message",
)
document_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="chat_message"
"DocumentRetrievalFeedback",
back_populates="chat_message",
)
search_docs: Mapped[list["SearchDoc"]] = relationship(
"SearchDoc",
@ -787,7 +789,9 @@ class DocumentRetrievalFeedback(Base):
__tablename__ = "document_retrieval_feedback"
id: Mapped[int] = mapped_column(primary_key=True)
chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
chat_message_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True
)
document_id: Mapped[str] = mapped_column(ForeignKey("document.id"))
# How high up this document is in the results, 1 for first
document_rank: Mapped[int] = mapped_column(Integer)
@ -797,7 +801,9 @@ class DocumentRetrievalFeedback(Base):
)
chat_message: Mapped[ChatMessage] = relationship(
"ChatMessage", back_populates="document_feedbacks"
"ChatMessage",
back_populates="document_feedbacks",
foreign_keys=[chat_message_id],
)
document: Mapped[Document] = relationship(
"Document", back_populates="retrieval_feedbacks"
@ -808,14 +814,18 @@ class ChatMessageFeedback(Base):
__tablename__ = "chat_feedback"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
chat_message_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True
)
is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
required_followup: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True)
predefined_feedback: Mapped[str | None] = mapped_column(String, nullable=True)
chat_message: Mapped[ChatMessage] = relationship(
"ChatMessage", back_populates="chat_message_feedbacks"
"ChatMessage",
back_populates="chat_message_feedbacks",
foreign_keys=[chat_message_id],
)

View File

@ -18,6 +18,25 @@ def get_pg_conn_from_session(db_session: Session) -> connection:
return db_session.connection().connection.connection # type: ignore
def get_pgfilestore_by_file_name(
file_name: str,
db_session: Session,
) -> PGFileStore:
pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first()
if not pgfilestore:
raise RuntimeError(f"File by name {file_name} does not exist or was deleted")
return pgfilestore
def delete_pgfilestore_by_file_name(
file_name: str,
db_session: Session,
) -> None:
db_session.query(PGFileStore).filter_by(file_name=file_name).delete()
def create_populate_lobj(
content: IO,
db_session: Session,
@ -73,6 +92,23 @@ def delete_lobj_by_id(
pg_conn.lobject(lobj_oid).unlink()
def delete_lobj_by_name(
lobj_name: str,
db_session: Session,
) -> None:
try:
pgfilestore = get_pgfilestore_by_file_name(lobj_name, db_session)
except RuntimeError:
logger.info(f"no file with name {lobj_name} found")
return
pg_conn = get_pg_conn_from_session(db_session)
pg_conn.lobject(pgfilestore.lobj_oid).unlink()
delete_pgfilestore_by_file_name(lobj_name, db_session)
db_session.commit()
def upsert_pgfilestore(
file_name: str,
display_name: str | None,
@ -112,22 +148,3 @@ def upsert_pgfilestore(
db_session.commit()
return pgfilestore
def get_pgfilestore_by_file_name(
file_name: str,
db_session: Session,
) -> PGFileStore:
pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first()
if not pgfilestore:
raise RuntimeError(f"File by name {file_name} does not exist or was deleted")
return pgfilestore
def delete_pgfilestore_by_file_name(
file_name: str,
db_session: Session,
) -> None:
db_session.query(PGFileStore).filter_by(file_name=file_name).delete()

View File

@ -14,6 +14,7 @@ class Settings(BaseModel):
chat_page_enabled: bool = True
search_page_enabled: bool = True
default_page: PageType = PageType.SEARCH
maximum_chat_retention_days: int | None = None
def check_validity(self) -> None:
chat_page_enabled = self.chat_page_enabled

View File

@ -3,16 +3,17 @@ from datetime import timedelta
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.db.chat import delete_chat_sessions_older_than
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import mark_task_finished
from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.background.user_group_sync import name_user_group_sync_task
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
from ee.danswer.background.celery_utils import should_sync_user_groups
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_user_group_sync_task
from ee.danswer.db.user_group import fetch_user_groups
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
from ee.danswer.user_groups.sync import sync_user_groups
@ -23,30 +24,45 @@ logger = setup_logger()
global_version.set_ee()
@build_celery_task_wrapper(name_user_group_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_user_group_task(user_group_id: int) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
task_name = name_user_group_sync_task(user_group_id)
mark_task_start(task_name, db_session)
with Session(get_sqlalchemy_engine()) as db_session:
# actual sync logic
error_msg = None
try:
sync_user_groups(user_group_id=user_group_id, db_session=db_session)
except Exception as e:
error_msg = str(e)
logger.exception(f"Failed to sync user group - {error_msg}")
logger.exception(f"Failed to sync user group - {e}")
# Need a new session so this can be committed (previous transaction may have
# been rolled back due to the exception)
with Session(engine) as db_session:
mark_task_finished(task_name, db_session, success=error_msg is None)
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_ttl_management_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task() -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with Session(get_sqlalchemy_engine()) as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(retention_limit_days=retention_limit_days),
)
@celery_app.task(
name="check_for_user_groups_sync_task",
soft_time_limit=JOB_TIMEOUT,
@ -58,23 +74,11 @@ def check_for_user_groups_sync_task() -> None:
# check if any document sets are not synced
user_groups = fetch_user_groups(db_session=db_session, only_current=False)
for user_group in user_groups:
if not user_group.is_up_to_date:
task_name = name_user_group_sync_task(user_group.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(
latest_sync, db_session
):
logger.info(
f"User Group '{user_group.id}' is already syncing. Skipping."
)
continue
if should_sync_user_groups(user_group, db_session):
logger.info(f"User Group {user_group.id} is not synced. Syncing now!")
task = sync_user_group_task.apply_async(
sync_user_group_task.apply_async(
kwargs=dict(user_group_id=user_group.id),
)
register_task(task.id, task_name, db_session)
@celery_app.task(
@ -103,5 +107,9 @@ celery_app.conf.beat_schedule = {
"task": "autogenerate_usage_report_task",
"schedule": timedelta(days=30), # TODO: change this to config flag
},
"check-ttl-management": {
"task": "check_ttl_management_task",
"schedule": timedelta(hours=1),
},
**(celery_app.conf.beat_schedule or {}),
}

View File

@ -0,0 +1,40 @@
from sqlalchemy.orm import Session
from danswer.db.models import UserGroup
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import name_user_group_sync_task
logger = setup_logger()
def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool:
if user_group.is_up_to_date:
return False
task_name = name_user_group_sync_task(user_group.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(latest_sync, db_session):
logger.info("TTL check is already being performed. Skipping.")
return False
return True
def should_perform_chat_ttl_check(
retention_limit_days: int | None, db_session: Session
) -> bool:
# TODO: make this a check for None and add behavior for 0 day TTL
if not retention_limit_days:
return False
task_name = name_chat_ttl_task(retention_limit_days)
latest_task = get_latest_task(task_name, db_session)
if not latest_task:
return True
if latest_task and check_live_task_not_timed_out(latest_task, db_session):
logger.info("TTL check is already being performed. Skipping.")
return False
return True

View File

@ -0,0 +1,6 @@
def name_user_group_sync_task(user_group_id: int) -> str:
return f"user_group_sync_task__{user_group_id}"
def name_chat_ttl_task(retention_limit_days: int) -> str:
return f"chat_ttl_{retention_limit_days}_days"

View File

@ -1,2 +0,0 @@
def name_user_group_sync_task(user_group_id: int) -> str:
return f"user_group_sync_task__{user_group_id}"

View File

@ -1,12 +1,16 @@
"use client";
import { Label, SubLabel } from "@/components/admin/connectors/Field";
import { usePopup } from "@/components/admin/connectors/Popup";
import { Title } from "@tremor/react";
import { Settings } from "./interfaces";
import { useRouter } from "next/navigation";
import { DefaultDropdown, Option } from "@/components/Dropdown";
import { useContext } from "react";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import React, { useState, useEffect } from "react";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { Button } from "@tremor/react";
function Checkbox({
label,
@ -49,7 +53,7 @@ function Selector({
onSelect: (value: string | number | null) => void;
}) {
return (
<div>
<div className="mb-8">
{label && <Label>{label}</Label>}
{subtext && <SubLabel>{subtext}</SubLabel>}
@ -64,10 +68,54 @@ function Selector({
);
}
function IntegerInput({
label,
sublabel,
value,
onChange,
id,
placeholder = "Enter a number", // Default placeholder if none is provided
}: {
label: string;
sublabel: string;
value: number | null;
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
id?: string;
placeholder?: string;
}) {
return (
<label className="flex flex-col text-sm mb-4">
<Label>{label}</Label>
<SubLabel>{sublabel}</SubLabel>
<input
type="number"
className="mt-1 p-2 border rounded w-full max-w-xs"
value={value ?? ""}
onChange={onChange}
min="1"
step="1"
id={id}
placeholder={placeholder}
/>
</label>
);
}
export function SettingsForm() {
const router = useRouter();
const combinedSettings = useContext(SettingsContext);
const [chatRetention, setChatRetention] = useState("");
const { popup, setPopup } = usePopup();
const isEnterpriseEnabled = usePaidEnterpriseFeaturesEnabled();
useEffect(() => {
if (combinedSettings?.settings.maximum_chat_retention_days !== undefined) {
setChatRetention(
combinedSettings.settings.maximum_chat_retention_days?.toString() || ""
);
}
}, [combinedSettings?.settings.maximum_chat_retention_days]);
if (!combinedSettings) {
return null;
}
@ -99,8 +147,45 @@ export function SettingsForm() {
}
}
function handleSetChatRetention() {
// Convert chatRetention to a number or null and update the global settings
const newValue =
chatRetention === "" ? null : parseInt(chatRetention.toString(), 10);
updateSettingField([
{ fieldName: "maximum_chat_retention_days", newValue: newValue },
])
.then(() => {
setPopup({
message: "Chat retention settings updated successfully!",
type: "success",
});
})
.catch((error) => {
console.error("Error updating settings:", error);
const errorMessage =
error.response?.data?.message || error.message || "Unknown error";
setPopup({
message: `Failed to update settings: ${errorMessage}`,
type: "error",
});
});
}
function handleClearChatRetention() {
setChatRetention(""); // Clear the chat retention input
updateSettingField([
{ fieldName: "maximum_chat_retention_days", newValue: null },
]).then(() => {
setPopup({
message: "Chat retention cleared successfully!",
type: "success",
});
});
}
return (
<div>
{popup}
<Title className="mb-4">Page Visibility</Title>
<Checkbox
@ -152,6 +237,37 @@ export function SettingsForm() {
]);
}}
/>
{isEnterpriseEnabled && (
<>
<Title className="mb-4">Chat Settings</Title>
<IntegerInput
label="Chat Retention"
sublabel="Enter the maximum number of days you would like Danswer to retain chat messages. Leaving this field empty will cause Danswer to never delete chat messages."
value={chatRetention === "" ? null : Number(chatRetention)}
onChange={(e) => {
const numValue = parseInt(e.target.value, 10);
if (numValue >= 1) {
setChatRetention(numValue.toString());
} else if (e.target.value === "") {
setChatRetention("");
}
}}
id="chatRetentionInput"
placeholder="Infinite Retention"
/>
<Button
onClick={handleSetChatRetention}
color="green"
size="xs"
className="mr-3"
>
Set Retention Limit
</Button>
<Button onClick={handleClearChatRetention} color="blue" size="xs">
Retain All
</Button>
</>
)}
</div>
);
}

View File

@ -2,6 +2,7 @@ export interface Settings {
chat_page_enabled: boolean;
search_page_enabled: boolean;
default_page: "search" | "chat";
maximum_chat_retention_days: number | null;
}
export interface EnterpriseSettings {