User Notification Backend (#2104)

This commit is contained in:
Yuhong Sun 2024-08-10 11:39:21 -07:00 committed by GitHub
parent 0a8d44b44c
commit a4caf66a35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 243 additions and 16 deletions

View File

@ -0,0 +1,44 @@
"""notifications
Revision ID: 213fd978c6d8
Revises: 5fc1f54cc252
Create Date: 2024-08-10 11:13:36.070790
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "213fd978c6d8"
down_revision = "5fc1f54cc252"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"notification",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"notif_type",
sa.String(),
nullable=False,
),
sa.Column(
"user_id",
sa.UUID(),
nullable=True,
),
sa.Column("dismissed", sa.Boolean(), nullable=False),
sa.Column("last_shown", sa.DateTime(timezone=True), nullable=False),
sa.Column("first_shown", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("notification")

View File

@ -132,6 +132,10 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable"
class NotificationType(str, Enum):
REINDEX = "reindex"
class BlobType(str, Enum):
R2 = "r2"
S3 = "s3"

View File

@ -38,6 +38,7 @@ from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.configs.constants import NotificationType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.constants import TokenRateLimitScope
from danswer.connectors.models import InputType
@ -144,6 +145,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
# Custom tools created by this user
custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user")
# Notifications for the UI
notifications: Mapped[list["Notification"]] = relationship(
"Notification", back_populates="user"
)
class InputPrompt(Base):
@ -155,7 +160,7 @@ class InputPrompt(Base):
active: Mapped[bool] = mapped_column(Boolean)
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
class InputPrompt__User(Base):
@ -164,7 +169,7 @@ class InputPrompt__User(Base):
input_prompt_id: Mapped[int] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
user_id: Mapped[UUID] = mapped_column(
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
@ -189,6 +194,21 @@ class ApiKey(Base):
)
class Notification(Base):
__tablename__ = "notification"
id: Mapped[int] = mapped_column(primary_key=True)
notif_type: Mapped[NotificationType] = mapped_column(
Enum(NotificationType, native_enum=False)
)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
dismissed: Mapped[bool] = mapped_column(Boolean, default=False)
last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
user: Mapped[User] = relationship("User", back_populates="notifications")
"""
Association Tables
NOTE: must be at the top since they are referenced by other tables
@ -215,7 +235,9 @@ class Persona__User(Base):
__tablename__ = "persona__user"
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
)
class DocumentSet__User(Base):
@ -224,7 +246,9 @@ class DocumentSet__User(Base):
document_set_id: Mapped[int] = mapped_column(
ForeignKey("document_set.id"), primary_key=True
)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
)
class DocumentSet__ConnectorCredentialPair(Base):
@ -1370,7 +1394,9 @@ class User__UserGroup(Base):
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
)
class UserGroup__ConnectorCredentialPair(Base):

View File

@ -0,0 +1,76 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from danswer.configs.constants import NotificationType
from danswer.db.models import Notification
from danswer.db.models import User
def create_notification(
user: User | None,
notif_type: NotificationType,
db_session: Session,
) -> Notification:
notification = Notification(
user_id=user.id if user else None,
notif_type=notif_type,
dismissed=False,
last_shown=func.now(),
first_shown=func.now(),
)
db_session.add(notification)
db_session.commit()
return notification
def get_notification_by_id(
notification_id: int, user: User | None, db_session: Session
) -> Notification:
user_id = user.id if user else None
notif = db_session.get(Notification, notification_id)
if not notif:
raise ValueError(f"No notification found with id {notification_id}")
if notif.user_id != user_id:
raise PermissionError(
f"User {user_id} is not authorized to access notification {notification_id}"
)
return notif
def get_notifications(
user: User | None,
db_session: Session,
notif_type: NotificationType | None = None,
include_dismissed: bool = True,
) -> list[Notification]:
query = select(Notification).where(
Notification.user_id == user.id if user else Notification.user_id.is_(None)
)
if not include_dismissed:
query = query.where(Notification.dismissed.is_(False))
if notif_type:
query = query.where(Notification.notif_type == notif_type)
return list(db_session.execute(query).scalars().all())
def dismiss_all_notifications(
notif_type: NotificationType,
db_session: Session,
) -> None:
db_session.query(Notification).filter(Notification.notif_type == notif_type).update(
{"dismissed": True}
)
db_session.commit()
def dismiss_notification(notification: Notification, db_session: Session) -> None:
notification.dismissed = True
db_session.commit()
def update_notification_last_shown(
notification: Notification, db_session: Session
) -> None:
notification.last_shown = func.now()
db_session.commit()

View File

@ -3,12 +3,21 @@ from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import is_user_admin
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.notification import dismiss_all_notifications
from danswer.db.notification import dismiss_notification
from danswer.db.notification import get_notification_by_id
from danswer.db.notification import get_notifications
from danswer.db.notification import update_notification_last_shown
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.settings.models import Notification
@ -38,23 +47,73 @@ def put_settings(
@basic_router.get("")
def fetch_settings(user: User | None = Depends(current_user)) -> UserSettings:
def fetch_settings(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> UserSettings:
"""Settings and notifications are stuffed into this single endpoint to reduce number of
Postgres calls"""
general_settings = load_settings()
user_notifications = get_user_notifications(user)
return UserSettings(**general_settings.dict(), **user_notifications.dict())
user_notifications = get_user_notifications(user, db_session)
return UserSettings(**general_settings.dict(), notifications=user_notifications)
def get_user_notifications(user: User | None) -> Notification:
"""Get any notification names, currently the only one is the reindexing flag"""
@basic_router.post("/notifications/{notification_id}/dismiss")
def dismiss_notification_endpoint(
notification_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
notification = get_notification_by_id(notification_id, user, db_session)
except PermissionError:
raise HTTPException(
status_code=403, detail="Not authorized to dismiss this notification"
)
except ValueError:
raise HTTPException(status_code=404, detail="Notification not found")
dismiss_notification(notification, db_session)
def get_user_notifications(
user: User | None, db_session: Session
) -> list[Notification]:
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
is_admin = is_user_admin(user)
if not is_admin:
return Notification(notif_name=None)
# Reindexing flag should only be shown to admins, basic users can't trigger it anyway
return []
kv_store = get_dynamic_config_store()
try:
need_index = cast(bool, kv_store.load(KV_REINDEX_KEY))
return Notification(notif_name=KV_REINDEX_KEY if need_index else None)
if not need_index:
dismiss_all_notifications(
notif_type=NotificationType.REINDEX, db_session=db_session
)
return []
except ConfigNotFoundError:
# If something goes wrong and the flag is gone, better to not start a reindexing
# it's a heavyweight long running job and maybe this flag is cleaned up later
logger.warning("Could not find reindex flag")
return Notification(notif_name=None)
return []
reindex_notifs = get_notifications(
user=user, notif_type=NotificationType.REINDEX, db_session=db_session
)
if not reindex_notifs:
notif = create_notification(
user=user, notif_type=NotificationType.REINDEX, db_session=db_session
)
return [Notification.from_model(notif)]
if len(reindex_notifs) > 1:
logger.error("User has multiple reindex notifications")
reindex_notif = reindex_notifs[0]
update_notification_last_shown(notification=reindex_notif, db_session=db_session)
return [Notification.from_model(reindex_notif)]

View File

@ -1,7 +1,11 @@
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from danswer.configs.constants import NotificationType
from danswer.db.models import Notification as NotificationDBModel
class PageType(str, Enum):
CHAT = "chat"
@ -9,7 +13,21 @@ class PageType(str, Enum):
class Notification(BaseModel):
notif_name: str | None
id: int
notif_type: NotificationType
dismissed: bool
last_shown: datetime
first_shown: datetime
@classmethod
def from_model(cls, notif: NotificationDBModel) -> "Notification":
return cls(
id=notif.id,
notif_type=notif.notif_type,
dismissed=notif.dismissed,
last_shown=notif.last_shown,
first_shown=notif.first_shown,
)
class Settings(BaseModel):
@ -41,5 +59,5 @@ class Settings(BaseModel):
)
class UserSettings(Notification, Settings):
"""User-specific settings combining Notification and general Settings"""
class UserSettings(Settings):
notifications: list[Notification]