From a4caf66a3510163f75f1c3c0f3142f2455578122 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sat, 10 Aug 2024 11:39:21 -0700 Subject: [PATCH] User Notification Backend (#2104) --- .../versions/213fd978c6d8_notifications.py | 44 +++++++++++ backend/danswer/configs/constants.py | 4 + backend/danswer/db/models.py | 36 +++++++-- backend/danswer/db/notification.py | 76 +++++++++++++++++++ backend/danswer/server/settings/api.py | 75 ++++++++++++++++-- backend/danswer/server/settings/models.py | 24 +++++- 6 files changed, 243 insertions(+), 16 deletions(-) create mode 100644 backend/alembic/versions/213fd978c6d8_notifications.py create mode 100644 backend/danswer/db/notification.py diff --git a/backend/alembic/versions/213fd978c6d8_notifications.py b/backend/alembic/versions/213fd978c6d8_notifications.py new file mode 100644 index 000000000..563556ea5 --- /dev/null +++ b/backend/alembic/versions/213fd978c6d8_notifications.py @@ -0,0 +1,44 @@ +"""notifications + +Revision ID: 213fd978c6d8 +Revises: 5fc1f54cc252 +Create Date: 2024-08-10 11:13:36.070790 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "213fd978c6d8" +down_revision = "5fc1f54cc252" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.create_table( + "notification", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "notif_type", + sa.String(), + nullable=False, + ), + sa.Column( + "user_id", + sa.UUID(), + nullable=True, + ), + sa.Column("dismissed", sa.Boolean(), nullable=False), + sa.Column("last_shown", sa.DateTime(timezone=True), nullable=False), + sa.Column("first_shown", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade() -> None: + op.drop_table("notification") diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 4cdc726a0..98ef2bb03 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -132,6 +132,10 @@ class DocumentSource(str, Enum): NOT_APPLICABLE = "not_applicable" +class NotificationType(str, Enum): + REINDEX = "reindex" + + class BlobType(str, Enum): R2 = "r2" S3 = "s3" diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 6b5303951..0fa09c534 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -38,6 +38,7 @@ from danswer.configs.constants import DEFAULT_BOOST from danswer.configs.constants import DocumentSource from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType +from danswer.configs.constants import NotificationType from danswer.configs.constants import SearchFeedbackType from danswer.configs.constants import TokenRateLimitScope from danswer.connectors.models import InputType @@ -144,6 +145,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base): personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user") # Custom tools created by this user custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user") + # Notifications for the UI + notifications: Mapped[list["Notification"]] = relationship( + "Notification", back_populates="user" + ) class InputPrompt(Base): @@ -155,7 +160,7 @@ class InputPrompt(Base): active: Mapped[bool] = mapped_column(Boolean) user: Mapped[User | None] = relationship("User", back_populates="input_prompts") is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id")) + user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) class InputPrompt__User(Base): @@ -164,7 +169,7 @@ class InputPrompt__User(Base): input_prompt_id: Mapped[int] = mapped_column( ForeignKey("inputprompt.id"), primary_key=True ) - user_id: Mapped[UUID] = mapped_column( + user_id: Mapped[UUID | None] = mapped_column( ForeignKey("inputprompt.id"), primary_key=True ) @@ -189,6 +194,21 @@ class ApiKey(Base): ) +class Notification(Base): + __tablename__ = "notification" + + id: Mapped[int] = mapped_column(primary_key=True) + notif_type: Mapped[NotificationType] = mapped_column( + Enum(NotificationType, native_enum=False) + ) + user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + dismissed: Mapped[bool] = mapped_column(Boolean, default=False) + last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) + first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) + + user: Mapped[User] = relationship("User", back_populates="notifications") + + """ Association Tables NOTE: must be at the top since they are referenced by other tables @@ -215,7 +235,9 @@ class Persona__User(Base): __tablename__ = "persona__user" persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) - user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id"), primary_key=True, nullable=True + ) class DocumentSet__User(Base): @@ -224,7 +246,9 @@ class DocumentSet__User(Base): document_set_id: Mapped[int] = mapped_column( ForeignKey("document_set.id"), primary_key=True ) - user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id"), primary_key=True, nullable=True + ) class DocumentSet__ConnectorCredentialPair(Base): @@ -1370,7 +1394,9 @@ class User__UserGroup(Base): user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id"), primary_key=True ) - user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id"), primary_key=True, nullable=True + ) class UserGroup__ConnectorCredentialPair(Base): diff --git a/backend/danswer/db/notification.py b/backend/danswer/db/notification.py new file mode 100644 index 000000000..61586208c --- /dev/null +++ b/backend/danswer/db/notification.py @@ -0,0 +1,76 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session +from sqlalchemy.sql import func + +from danswer.configs.constants import NotificationType +from danswer.db.models import Notification +from danswer.db.models import User + + +def create_notification( + user: User | None, + notif_type: NotificationType, + db_session: Session, +) -> Notification: + notification = Notification( + user_id=user.id if user else None, + notif_type=notif_type, + dismissed=False, + last_shown=func.now(), + first_shown=func.now(), + ) + db_session.add(notification) + db_session.commit() + return notification + + +def get_notification_by_id( + notification_id: int, user: User | None, db_session: Session +) -> Notification: + user_id = user.id if user else None + notif = db_session.get(Notification, notification_id) + if not notif: + raise ValueError(f"No notification found with id {notification_id}") + if notif.user_id != user_id: + raise PermissionError( + f"User {user_id} is not authorized to access notification {notification_id}" + ) + return notif + + +def get_notifications( + user: User | None, + db_session: Session, + notif_type: NotificationType | None = None, + include_dismissed: bool = True, +) -> list[Notification]: + query = select(Notification).where( + Notification.user_id == user.id if user else Notification.user_id.is_(None) + ) + if not include_dismissed: + query = query.where(Notification.dismissed.is_(False)) + if notif_type: + query = query.where(Notification.notif_type == notif_type) + return list(db_session.execute(query).scalars().all()) + + +def dismiss_all_notifications( + notif_type: NotificationType, + db_session: Session, +) -> None: + db_session.query(Notification).filter(Notification.notif_type == notif_type).update( + {"dismissed": True} + ) + db_session.commit() + + +def dismiss_notification(notification: Notification, db_session: Session) -> None: + notification.dismissed = True + db_session.commit() + + +def update_notification_last_shown( + notification: Notification, db_session: Session +) -> None: + notification.last_shown = func.now() + db_session.commit() diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index 25cf1b6c6..f75dc38a9 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -3,12 +3,21 @@ from typing import cast from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException +from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.auth.users import is_user_admin from danswer.configs.constants import KV_REINDEX_KEY +from danswer.configs.constants import NotificationType +from danswer.db.engine import get_session from danswer.db.models import User +from danswer.db.notification import create_notification +from danswer.db.notification import dismiss_all_notifications +from danswer.db.notification import dismiss_notification +from danswer.db.notification import get_notification_by_id +from danswer.db.notification import get_notifications +from danswer.db.notification import update_notification_last_shown from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.server.settings.models import Notification @@ -38,23 +47,73 @@ def put_settings( @basic_router.get("") -def fetch_settings(user: User | None = Depends(current_user)) -> UserSettings: +def fetch_settings( + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> UserSettings: + """Settings and notifications are stuffed into this single endpoint to reduce number of + Postgres calls""" general_settings = load_settings() - user_notifications = get_user_notifications(user) - return UserSettings(**general_settings.dict(), **user_notifications.dict()) + user_notifications = get_user_notifications(user, db_session) + return UserSettings(**general_settings.dict(), notifications=user_notifications) -def get_user_notifications(user: User | None) -> Notification: - """Get any notification names, currently the only one is the reindexing flag""" +@basic_router.post("/notifications/{notification_id}/dismiss") +def dismiss_notification_endpoint( + notification_id: int, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + try: + notification = get_notification_by_id(notification_id, user, db_session) + except PermissionError: + raise HTTPException( + status_code=403, detail="Not authorized to dismiss this notification" + ) + except ValueError: + raise HTTPException(status_code=404, detail="Notification not found") + + dismiss_notification(notification, db_session) + + +def get_user_notifications( + user: User | None, db_session: Session +) -> list[Notification]: + """Get notifications for the user, currently the logic is very specific to the reindexing flag""" is_admin = is_user_admin(user) if not is_admin: - return Notification(notif_name=None) + # Reindexing flag should only be shown to admins, basic users can't trigger it anyway + return [] + kv_store = get_dynamic_config_store() try: need_index = cast(bool, kv_store.load(KV_REINDEX_KEY)) - return Notification(notif_name=KV_REINDEX_KEY if need_index else None) + if not need_index: + dismiss_all_notifications( + notif_type=NotificationType.REINDEX, db_session=db_session + ) + return [] except ConfigNotFoundError: # If something goes wrong and the flag is gone, better to not start a reindexing # it's a heavyweight long running job and maybe this flag is cleaned up later logger.warning("Could not find reindex flag") - return Notification(notif_name=None) + return [] + + reindex_notifs = get_notifications( + user=user, notif_type=NotificationType.REINDEX, db_session=db_session + ) + + if not reindex_notifs: + notif = create_notification( + user=user, notif_type=NotificationType.REINDEX, db_session=db_session + ) + return [Notification.from_model(notif)] + + if len(reindex_notifs) > 1: + logger.error("User has multiple reindex notifications") + + reindex_notif = reindex_notifs[0] + + update_notification_last_shown(notification=reindex_notif, db_session=db_session) + + return [Notification.from_model(reindex_notif)] diff --git a/backend/danswer/server/settings/models.py b/backend/danswer/server/settings/models.py index 1547c469b..5a574fb14 100644 --- a/backend/danswer/server/settings/models.py +++ b/backend/danswer/server/settings/models.py @@ -1,7 +1,11 @@ +from datetime import datetime from enum import Enum from pydantic import BaseModel +from danswer.configs.constants import NotificationType +from danswer.db.models import Notification as NotificationDBModel + class PageType(str, Enum): CHAT = "chat" @@ -9,7 +13,21 @@ class PageType(str, Enum): class Notification(BaseModel): - notif_name: str | None + id: int + notif_type: NotificationType + dismissed: bool + last_shown: datetime + first_shown: datetime + + @classmethod + def from_model(cls, notif: NotificationDBModel) -> "Notification": + return cls( + id=notif.id, + notif_type=notif.notif_type, + dismissed=notif.dismissed, + last_shown=notif.last_shown, + first_shown=notif.first_shown, + ) class Settings(BaseModel): @@ -41,5 +59,5 @@ class Settings(BaseModel): ) -class UserSettings(Notification, Settings): - """User-specific settings combining Notification and general Settings""" +class UserSettings(Settings): + notifications: list[Notification]