diff --git a/backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py b/backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py new file mode 100644 index 000000000..96cebab97 --- /dev/null +++ b/backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py @@ -0,0 +1,82 @@ +"""remove-feedback-foreignkey-constraint + +Revision ID: 23957775e5f5 +Revises: bc9771dccadf +Create Date: 2024-06-27 16:04:51.480437 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "23957775e5f5" +down_revision = "bc9771dccadf" +branch_labels = None # type: ignore +depends_on = None # type: ignore + + +def upgrade() -> None: + op.drop_constraint( + "chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey" + ) + op.create_foreign_key( + "chat_feedback__chat_message_fk", + "chat_feedback", + "chat_message", + ["chat_message_id"], + ["id"], + ondelete="SET NULL", + ) + op.alter_column( + "chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True + ) + op.drop_constraint( + "document_retrieval__chat_message_fk", "document_retrieval", type_="foreignkey" + ) + op.create_foreign_key( + "document_retrieval__chat_message_fk", + "document_retrieval", + "chat_message", + ["chat_message_id"], + ["id"], + ondelete="SET NULL", + ) + op.alter_column( + "document_retrieval_feedback", + "chat_message_id", + existing_type=sa.Integer(), + nullable=True, + ) + + +def downgrade() -> None: + op.alter_column( + "chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False + ) + op.drop_constraint( + "chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey" + ) + op.create_foreign_key( + "chat_feedback__chat_message_fk", + "chat_feedback", + "chat_message", + ["chat_message_id"], + ["id"], + ) + + op.alter_column( + "document_retrieval_feedback", + "chat_message_id", + existing_type=sa.Integer(), + nullable=False, + ) + op.drop_constraint( + "document_retrieval__chat_message_fk", "document_retrieval", type_="foreignkey" + ) + op.create_foreign_key( + "document_retrieval__chat_message_fk", + "document_retrieval", + "chat_message", + ["chat_message_id"], + ["id"], + ) diff --git a/backend/alembic/versions/bc9771dccadf_create_usage_reports_table.py b/backend/alembic/versions/bc9771dccadf_create_usage_reports_table.py index f50645f71..eab3253a0 100644 --- a/backend/alembic/versions/bc9771dccadf_create_usage_reports_table.py +++ b/backend/alembic/versions/bc9771dccadf_create_usage_reports_table.py @@ -1,7 +1,7 @@ """create usage reports table Revision ID: bc9771dccadf -Revises: 48d14957fe80 +Revises: 0568ccf46a6b Create Date: 2024-06-18 10:04:26.800282 """ @@ -12,6 +12,7 @@ import fastapi_users_db_sqlalchemy # revision identifiers, used by Alembic. revision = "bc9771dccadf" down_revision = "0568ccf46a6b" + branch_labels: None = None depends_on: None = None diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 60c61922e..61e42bde6 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -1,3 +1,5 @@ +from datetime import datetime +from datetime import timedelta from uuid import UUID from sqlalchemy import delete @@ -12,6 +14,7 @@ from danswer.auth.schemas import UserRole from danswer.configs.chat_configs import HARD_DELETE_CHATS from danswer.configs.constants import MessageType from danswer.db.models import ChatMessage +from danswer.db.models import ChatMessage__SearchDoc from danswer.db.models import ChatSession from danswer.db.models import ChatSessionSharedStatus from danswer.db.models import Prompt @@ -19,6 +22,7 @@ from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import ToolCall from danswer.db.models import User +from danswer.db.pg_file_store import delete_lobj_by_name from danswer.file_store.models import FileDescriptor from danswer.llm.override_models import LLMOverride from danswer.llm.override_models import PromptOverride @@ -83,6 +87,54 @@ def get_chat_sessions_by_user( return list(chat_sessions) +def delete_search_doc_message_relationship( + message_id: int, db_session: Session +) -> None: + db_session.query(ChatMessage__SearchDoc).filter( + ChatMessage__SearchDoc.chat_message_id == message_id + ).delete(synchronize_session=False) + + db_session.commit() + + +def delete_orphaned_search_docs(db_session: Session) -> None: + orphaned_docs = ( + db_session.query(SearchDoc) + .outerjoin(ChatMessage__SearchDoc) + .filter(ChatMessage__SearchDoc.chat_message_id.is_(None)) + .all() + ) + for doc in orphaned_docs: + db_session.delete(doc) + db_session.commit() + + +def delete_messages_and_files_from_chat_session( + chat_session_id: int, db_session: Session +) -> None: + # Select messages older than cutoff_time with files + messages_with_files = db_session.execute( + select(ChatMessage.id, ChatMessage.files).where( + ChatMessage.chat_session_id == chat_session_id, + ) + ).fetchall() + + for id, files in messages_with_files: + delete_search_doc_message_relationship(message_id=id, db_session=db_session) + for file_info in files or {}: + lobj_name = file_info.get("id") + if lobj_name: + logger.info(f"Deleting file with name: {lobj_name}") + delete_lobj_by_name(lobj_name, db_session) + + db_session.execute( + delete(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id) + ) + db_session.commit() + + delete_orphaned_search_docs(db_session) + + def create_chat_session( db_session: Session, description: str, @@ -139,25 +191,30 @@ def delete_chat_session( db_session: Session, hard_delete: bool = HARD_DELETE_CHATS, ) -> None: - chat_session = get_chat_session_by_id( - chat_session_id=chat_session_id, user_id=user_id, db_session=db_session - ) - if hard_delete: - stmt_messages = delete(ChatMessage).where( - ChatMessage.chat_session_id == chat_session_id - ) - db_session.execute(stmt_messages) - - stmt = delete(ChatSession).where(ChatSession.id == chat_session_id) - db_session.execute(stmt) - + delete_messages_and_files_from_chat_session(chat_session_id, db_session) + db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id)) else: + chat_session = get_chat_session_by_id( + chat_session_id=chat_session_id, user_id=user_id, db_session=db_session + ) chat_session.deleted = True db_session.commit() +def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None: + cutoff_time = datetime.utcnow() - timedelta(days=days_old) + old_sessions = db_session.execute( + select(ChatSession.user_id, ChatSession.id).where( + ChatSession.time_created < cutoff_time + ) + ).fetchall() + + for user_id, session_id in old_sessions: + delete_chat_session(user_id, session_id, db_session, hard_delete=True) + + def get_chat_message( chat_message_id: int, user_id: UUID | None, diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 7bed22a3b..bfed0a03e 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -690,7 +690,7 @@ class ChatSession(Base): "ChatFolder", back_populates="chat_sessions" ) messages: Mapped[list["ChatMessage"]] = relationship( - "ChatMessage", back_populates="chat_session", cascade="delete" + "ChatMessage", back_populates="chat_session" ) persona: Mapped["Persona"] = relationship("Persona") @@ -737,10 +737,12 @@ class ChatMessage(Base): chat_session: Mapped[ChatSession] = relationship("ChatSession") prompt: Mapped[Optional["Prompt"]] = relationship("Prompt") chat_message_feedbacks: Mapped[list["ChatMessageFeedback"]] = relationship( - "ChatMessageFeedback", back_populates="chat_message" + "ChatMessageFeedback", + back_populates="chat_message", ) document_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship( - "DocumentRetrievalFeedback", back_populates="chat_message" + "DocumentRetrievalFeedback", + back_populates="chat_message", ) search_docs: Mapped[list["SearchDoc"]] = relationship( "SearchDoc", @@ -787,7 +789,9 @@ class DocumentRetrievalFeedback(Base): __tablename__ = "document_retrieval_feedback" id: Mapped[int] = mapped_column(primary_key=True) - chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + chat_message_id: Mapped[int | None] = mapped_column( + ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True + ) document_id: Mapped[str] = mapped_column(ForeignKey("document.id")) # How high up this document is in the results, 1 for first document_rank: Mapped[int] = mapped_column(Integer) @@ -797,7 +801,9 @@ class DocumentRetrievalFeedback(Base): ) chat_message: Mapped[ChatMessage] = relationship( - "ChatMessage", back_populates="document_feedbacks" + "ChatMessage", + back_populates="document_feedbacks", + foreign_keys=[chat_message_id], ) document: Mapped[Document] = relationship( "Document", back_populates="retrieval_feedbacks" @@ -808,14 +814,18 @@ class ChatMessageFeedback(Base): __tablename__ = "chat_feedback" id: Mapped[int] = mapped_column(Integer, primary_key=True) - chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + chat_message_id: Mapped[int | None] = mapped_column( + ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True + ) is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True) required_followup: Mapped[bool | None] = mapped_column(Boolean, nullable=True) feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True) predefined_feedback: Mapped[str | None] = mapped_column(String, nullable=True) chat_message: Mapped[ChatMessage] = relationship( - "ChatMessage", back_populates="chat_message_feedbacks" + "ChatMessage", + back_populates="chat_message_feedbacks", + foreign_keys=[chat_message_id], ) diff --git a/backend/danswer/db/pg_file_store.py b/backend/danswer/db/pg_file_store.py index 3325c2ffd..1333dcd6c 100644 --- a/backend/danswer/db/pg_file_store.py +++ b/backend/danswer/db/pg_file_store.py @@ -18,6 +18,25 @@ def get_pg_conn_from_session(db_session: Session) -> connection: return db_session.connection().connection.connection # type: ignore +def get_pgfilestore_by_file_name( + file_name: str, + db_session: Session, +) -> PGFileStore: + pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first() + + if not pgfilestore: + raise RuntimeError(f"File by name {file_name} does not exist or was deleted") + + return pgfilestore + + +def delete_pgfilestore_by_file_name( + file_name: str, + db_session: Session, +) -> None: + db_session.query(PGFileStore).filter_by(file_name=file_name).delete() + + def create_populate_lobj( content: IO, db_session: Session, @@ -73,6 +92,23 @@ def delete_lobj_by_id( pg_conn.lobject(lobj_oid).unlink() +def delete_lobj_by_name( + lobj_name: str, + db_session: Session, +) -> None: + try: + pgfilestore = get_pgfilestore_by_file_name(lobj_name, db_session) + except RuntimeError: + logger.info(f"no file with name {lobj_name} found") + return + + pg_conn = get_pg_conn_from_session(db_session) + pg_conn.lobject(pgfilestore.lobj_oid).unlink() + + delete_pgfilestore_by_file_name(lobj_name, db_session) + db_session.commit() + + def upsert_pgfilestore( file_name: str, display_name: str | None, @@ -112,22 +148,3 @@ def upsert_pgfilestore( db_session.commit() return pgfilestore - - -def get_pgfilestore_by_file_name( - file_name: str, - db_session: Session, -) -> PGFileStore: - pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first() - - if not pgfilestore: - raise RuntimeError(f"File by name {file_name} does not exist or was deleted") - - return pgfilestore - - -def delete_pgfilestore_by_file_name( - file_name: str, - db_session: Session, -) -> None: - db_session.query(PGFileStore).filter_by(file_name=file_name).delete() diff --git a/backend/danswer/server/settings/models.py b/backend/danswer/server/settings/models.py index 041e360d7..9afacf5ad 100644 --- a/backend/danswer/server/settings/models.py +++ b/backend/danswer/server/settings/models.py @@ -14,6 +14,7 @@ class Settings(BaseModel): chat_page_enabled: bool = True search_page_enabled: bool = True default_page: PageType = PageType.SEARCH + maximum_chat_retention_days: int | None = None def check_validity(self) -> None: chat_page_enabled = self.chat_page_enabled diff --git a/backend/ee/danswer/background/celery/celery_app.py b/backend/ee/danswer/background/celery/celery_app.py index 7b959f428..2dd3ecb47 100644 --- a/backend/ee/danswer/background/celery/celery_app.py +++ b/backend/ee/danswer/background/celery/celery_app.py @@ -3,16 +3,17 @@ from datetime import timedelta from sqlalchemy.orm import Session from danswer.background.celery.celery_app import celery_app +from danswer.background.task_utils import build_celery_task_wrapper from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.db.chat import delete_chat_sessions_older_than from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.tasks import check_live_task_not_timed_out -from danswer.db.tasks import get_latest_task -from danswer.db.tasks import mark_task_finished -from danswer.db.tasks import mark_task_start -from danswer.db.tasks import register_task +from danswer.server.settings.store import load_settings from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version -from ee.danswer.background.user_group_sync import name_user_group_sync_task +from ee.danswer.background.celery_utils import should_perform_chat_ttl_check +from ee.danswer.background.celery_utils import should_sync_user_groups +from ee.danswer.background.task_name_builders import name_chat_ttl_task +from ee.danswer.background.task_name_builders import name_user_group_sync_task from ee.danswer.db.user_group import fetch_user_groups from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report from ee.danswer.user_groups.sync import sync_user_groups @@ -23,30 +24,45 @@ logger = setup_logger() global_version.set_ee() +@build_celery_task_wrapper(name_user_group_sync_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def sync_user_group_task(user_group_id: int) -> None: - engine = get_sqlalchemy_engine() - with Session(engine) as db_session: - task_name = name_user_group_sync_task(user_group_id) - mark_task_start(task_name, db_session) - + with Session(get_sqlalchemy_engine()) as db_session: # actual sync logic - error_msg = None try: sync_user_groups(user_group_id=user_group_id, db_session=db_session) except Exception as e: - error_msg = str(e) - logger.exception(f"Failed to sync user group - {error_msg}") + logger.exception(f"Failed to sync user group - {e}") - # Need a new session so this can be committed (previous transaction may have - # been rolled back due to the exception) - with Session(engine) as db_session: - mark_task_finished(task_name, db_session, success=error_msg is None) + +@build_celery_task_wrapper(name_chat_ttl_task) +@celery_app.task(soft_time_limit=JOB_TIMEOUT) +def perform_ttl_management_task(retention_limit_days: int) -> None: + with Session(get_sqlalchemy_engine()) as db_session: + delete_chat_sessions_older_than(retention_limit_days, db_session) ##### # Periodic Tasks ##### + + +@celery_app.task( + name="check_ttl_management_task", + soft_time_limit=JOB_TIMEOUT, +) +def check_ttl_management_task() -> None: + """Runs periodically to check if any ttl tasks should be run and adds them + to the queue""" + settings = load_settings() + retention_limit_days = settings.maximum_chat_retention_days + with Session(get_sqlalchemy_engine()) as db_session: + if should_perform_chat_ttl_check(retention_limit_days, db_session): + perform_ttl_management_task.apply_async( + kwargs=dict(retention_limit_days=retention_limit_days), + ) + + @celery_app.task( name="check_for_user_groups_sync_task", soft_time_limit=JOB_TIMEOUT, @@ -58,23 +74,11 @@ def check_for_user_groups_sync_task() -> None: # check if any document sets are not synced user_groups = fetch_user_groups(db_session=db_session, only_current=False) for user_group in user_groups: - if not user_group.is_up_to_date: - task_name = name_user_group_sync_task(user_group.id) - latest_sync = get_latest_task(task_name, db_session) - - if latest_sync and check_live_task_not_timed_out( - latest_sync, db_session - ): - logger.info( - f"User Group '{user_group.id}' is already syncing. Skipping." - ) - continue - + if should_sync_user_groups(user_group, db_session): logger.info(f"User Group {user_group.id} is not synced. Syncing now!") - task = sync_user_group_task.apply_async( + sync_user_group_task.apply_async( kwargs=dict(user_group_id=user_group.id), ) - register_task(task.id, task_name, db_session) @celery_app.task( @@ -103,5 +107,9 @@ celery_app.conf.beat_schedule = { "task": "autogenerate_usage_report_task", "schedule": timedelta(days=30), # TODO: change this to config flag }, + "check-ttl-management": { + "task": "check_ttl_management_task", + "schedule": timedelta(hours=1), + }, **(celery_app.conf.beat_schedule or {}), } diff --git a/backend/ee/danswer/background/celery_utils.py b/backend/ee/danswer/background/celery_utils.py new file mode 100644 index 000000000..9ab436596 --- /dev/null +++ b/backend/ee/danswer/background/celery_utils.py @@ -0,0 +1,40 @@ +from sqlalchemy.orm import Session + +from danswer.db.models import UserGroup +from danswer.db.tasks import check_live_task_not_timed_out +from danswer.db.tasks import get_latest_task +from danswer.utils.logger import setup_logger +from ee.danswer.background.task_name_builders import name_chat_ttl_task +from ee.danswer.background.task_name_builders import name_user_group_sync_task + +logger = setup_logger() + + +def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool: + if user_group.is_up_to_date: + return False + task_name = name_user_group_sync_task(user_group.id) + latest_sync = get_latest_task(task_name, db_session) + + if latest_sync and check_live_task_not_timed_out(latest_sync, db_session): + logger.info("TTL check is already being performed. Skipping.") + return False + return True + + +def should_perform_chat_ttl_check( + retention_limit_days: int | None, db_session: Session +) -> bool: + # TODO: make this a check for None and add behavior for 0 day TTL + if not retention_limit_days: + return False + + task_name = name_chat_ttl_task(retention_limit_days) + latest_task = get_latest_task(task_name, db_session) + if not latest_task: + return True + + if latest_task and check_live_task_not_timed_out(latest_task, db_session): + logger.info("TTL check is already being performed. Skipping.") + return False + return True diff --git a/backend/ee/danswer/background/task_name_builders.py b/backend/ee/danswer/background/task_name_builders.py new file mode 100644 index 000000000..4f1046adb --- /dev/null +++ b/backend/ee/danswer/background/task_name_builders.py @@ -0,0 +1,6 @@ +def name_user_group_sync_task(user_group_id: int) -> str: + return f"user_group_sync_task__{user_group_id}" + + +def name_chat_ttl_task(retention_limit_days: int) -> str: + return f"chat_ttl_{retention_limit_days}_days" diff --git a/backend/ee/danswer/background/user_group_sync.py b/backend/ee/danswer/background/user_group_sync.py deleted file mode 100644 index ce824c471..000000000 --- a/backend/ee/danswer/background/user_group_sync.py +++ /dev/null @@ -1,2 +0,0 @@ -def name_user_group_sync_task(user_group_id: int) -> str: - return f"user_group_sync_task__{user_group_id}" diff --git a/web/src/app/admin/settings/SettingsForm.tsx b/web/src/app/admin/settings/SettingsForm.tsx index 782f06079..858a5c0ff 100644 --- a/web/src/app/admin/settings/SettingsForm.tsx +++ b/web/src/app/admin/settings/SettingsForm.tsx @@ -1,12 +1,16 @@ "use client"; import { Label, SubLabel } from "@/components/admin/connectors/Field"; +import { usePopup } from "@/components/admin/connectors/Popup"; import { Title } from "@tremor/react"; import { Settings } from "./interfaces"; import { useRouter } from "next/navigation"; import { DefaultDropdown, Option } from "@/components/Dropdown"; import { useContext } from "react"; import { SettingsContext } from "@/components/settings/SettingsProvider"; +import React, { useState, useEffect } from "react"; +import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; +import { Button } from "@tremor/react"; function Checkbox({ label, @@ -49,7 +53,7 @@ function Selector({ onSelect: (value: string | number | null) => void; }) { return ( -