mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
User Notification Backend (#2104)
This commit is contained in:
parent
0a8d44b44c
commit
a4caf66a35
44
backend/alembic/versions/213fd978c6d8_notifications.py
Normal file
44
backend/alembic/versions/213fd978c6d8_notifications.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""notifications
|
||||
|
||||
Revision ID: 213fd978c6d8
|
||||
Revises: 5fc1f54cc252
|
||||
Create Date: 2024-08-10 11:13:36.070790
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "213fd978c6d8"
|
||||
down_revision = "5fc1f54cc252"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"notification",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"notif_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.UUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("dismissed", sa.Boolean(), nullable=False),
|
||||
sa.Column("last_shown", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("first_shown", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("notification")
|
@ -132,6 +132,10 @@ class DocumentSource(str, Enum):
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
R2 = "r2"
|
||||
S3 = "s3"
|
||||
|
@ -38,6 +38,7 @@ from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.constants import TokenRateLimitScope
|
||||
from danswer.connectors.models import InputType
|
||||
@ -144,6 +145,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||
# Custom tools created by this user
|
||||
custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user")
|
||||
# Notifications for the UI
|
||||
notifications: Mapped[list["Notification"]] = relationship(
|
||||
"Notification", back_populates="user"
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt(Base):
|
||||
@ -155,7 +160,7 @@ class InputPrompt(Base):
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
|
||||
|
||||
class InputPrompt__User(Base):
|
||||
@ -164,7 +169,7 @@ class InputPrompt__User(Base):
|
||||
input_prompt_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
|
||||
@ -189,6 +194,21 @@ class ApiKey(Base):
|
||||
)
|
||||
|
||||
|
||||
class Notification(Base):
|
||||
__tablename__ = "notification"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
notif_type: Mapped[NotificationType] = mapped_column(
|
||||
Enum(NotificationType, native_enum=False)
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
dismissed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="notifications")
|
||||
|
||||
|
||||
"""
|
||||
Association Tables
|
||||
NOTE: must be at the top since they are referenced by other tables
|
||||
@ -215,7 +235,9 @@ class Persona__User(Base):
|
||||
__tablename__ = "persona__user"
|
||||
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
class DocumentSet__User(Base):
|
||||
@ -224,7 +246,9 @@ class DocumentSet__User(Base):
|
||||
document_set_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("document_set.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
class DocumentSet__ConnectorCredentialPair(Base):
|
||||
@ -1370,7 +1394,9 @@ class User__UserGroup(Base):
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
class UserGroup__ConnectorCredentialPair(Base):
|
||||
|
76
backend/danswer/db/notification.py
Normal file
76
backend/danswer/db/notification.py
Normal file
@ -0,0 +1,76 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.models import Notification
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
def create_notification(
|
||||
user: User | None,
|
||||
notif_type: NotificationType,
|
||||
db_session: Session,
|
||||
) -> Notification:
|
||||
notification = Notification(
|
||||
user_id=user.id if user else None,
|
||||
notif_type=notif_type,
|
||||
dismissed=False,
|
||||
last_shown=func.now(),
|
||||
first_shown=func.now(),
|
||||
)
|
||||
db_session.add(notification)
|
||||
db_session.commit()
|
||||
return notification
|
||||
|
||||
|
||||
def get_notification_by_id(
|
||||
notification_id: int, user: User | None, db_session: Session
|
||||
) -> Notification:
|
||||
user_id = user.id if user else None
|
||||
notif = db_session.get(Notification, notification_id)
|
||||
if not notif:
|
||||
raise ValueError(f"No notification found with id {notification_id}")
|
||||
if notif.user_id != user_id:
|
||||
raise PermissionError(
|
||||
f"User {user_id} is not authorized to access notification {notification_id}"
|
||||
)
|
||||
return notif
|
||||
|
||||
|
||||
def get_notifications(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
notif_type: NotificationType | None = None,
|
||||
include_dismissed: bool = True,
|
||||
) -> list[Notification]:
|
||||
query = select(Notification).where(
|
||||
Notification.user_id == user.id if user else Notification.user_id.is_(None)
|
||||
)
|
||||
if not include_dismissed:
|
||||
query = query.where(Notification.dismissed.is_(False))
|
||||
if notif_type:
|
||||
query = query.where(Notification.notif_type == notif_type)
|
||||
return list(db_session.execute(query).scalars().all())
|
||||
|
||||
|
||||
def dismiss_all_notifications(
|
||||
notif_type: NotificationType,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.query(Notification).filter(Notification.notif_type == notif_type).update(
|
||||
{"dismissed": True}
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def dismiss_notification(notification: Notification, db_session: Session) -> None:
|
||||
notification.dismissed = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_notification_last_shown(
|
||||
notification: Notification, db_session: Session
|
||||
) -> None:
|
||||
notification.last_shown = func.now()
|
||||
db_session.commit()
|
@ -3,12 +3,21 @@ from typing import cast
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.notification import create_notification
|
||||
from danswer.db.notification import dismiss_all_notifications
|
||||
from danswer.db.notification import dismiss_notification
|
||||
from danswer.db.notification import get_notification_by_id
|
||||
from danswer.db.notification import get_notifications
|
||||
from danswer.db.notification import update_notification_last_shown
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.server.settings.models import Notification
|
||||
@ -38,23 +47,73 @@ def put_settings(
|
||||
|
||||
|
||||
@basic_router.get("")
|
||||
def fetch_settings(user: User | None = Depends(current_user)) -> UserSettings:
|
||||
def fetch_settings(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserSettings:
|
||||
"""Settings and notifications are stuffed into this single endpoint to reduce number of
|
||||
Postgres calls"""
|
||||
general_settings = load_settings()
|
||||
user_notifications = get_user_notifications(user)
|
||||
return UserSettings(**general_settings.dict(), **user_notifications.dict())
|
||||
user_notifications = get_user_notifications(user, db_session)
|
||||
return UserSettings(**general_settings.dict(), notifications=user_notifications)
|
||||
|
||||
|
||||
def get_user_notifications(user: User | None) -> Notification:
|
||||
"""Get any notification names, currently the only one is the reindexing flag"""
|
||||
@basic_router.post("/notifications/{notification_id}/dismiss")
|
||||
def dismiss_notification_endpoint(
|
||||
notification_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
notification = get_notification_by_id(notification_id, user, db_session)
|
||||
except PermissionError:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not authorized to dismiss this notification"
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Notification not found")
|
||||
|
||||
dismiss_notification(notification, db_session)
|
||||
|
||||
|
||||
def get_user_notifications(
|
||||
user: User | None, db_session: Session
|
||||
) -> list[Notification]:
|
||||
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
|
||||
is_admin = is_user_admin(user)
|
||||
if not is_admin:
|
||||
return Notification(notif_name=None)
|
||||
# Reindexing flag should only be shown to admins, basic users can't trigger it anyway
|
||||
return []
|
||||
|
||||
kv_store = get_dynamic_config_store()
|
||||
try:
|
||||
need_index = cast(bool, kv_store.load(KV_REINDEX_KEY))
|
||||
return Notification(notif_name=KV_REINDEX_KEY if need_index else None)
|
||||
if not need_index:
|
||||
dismiss_all_notifications(
|
||||
notif_type=NotificationType.REINDEX, db_session=db_session
|
||||
)
|
||||
return []
|
||||
except ConfigNotFoundError:
|
||||
# If something goes wrong and the flag is gone, better to not start a reindexing
|
||||
# it's a heavyweight long running job and maybe this flag is cleaned up later
|
||||
logger.warning("Could not find reindex flag")
|
||||
return Notification(notif_name=None)
|
||||
return []
|
||||
|
||||
reindex_notifs = get_notifications(
|
||||
user=user, notif_type=NotificationType.REINDEX, db_session=db_session
|
||||
)
|
||||
|
||||
if not reindex_notifs:
|
||||
notif = create_notification(
|
||||
user=user, notif_type=NotificationType.REINDEX, db_session=db_session
|
||||
)
|
||||
return [Notification.from_model(notif)]
|
||||
|
||||
if len(reindex_notifs) > 1:
|
||||
logger.error("User has multiple reindex notifications")
|
||||
|
||||
reindex_notif = reindex_notifs[0]
|
||||
|
||||
update_notification_last_shown(notification=reindex_notif, db_session=db_session)
|
||||
|
||||
return [Notification.from_model(reindex_notif)]
|
||||
|
@ -1,7 +1,11 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.models import Notification as NotificationDBModel
|
||||
|
||||
|
||||
class PageType(str, Enum):
|
||||
CHAT = "chat"
|
||||
@ -9,7 +13,21 @@ class PageType(str, Enum):
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
notif_name: str | None
|
||||
id: int
|
||||
notif_type: NotificationType
|
||||
dismissed: bool
|
||||
last_shown: datetime
|
||||
first_shown: datetime
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, notif: NotificationDBModel) -> "Notification":
|
||||
return cls(
|
||||
id=notif.id,
|
||||
notif_type=notif.notif_type,
|
||||
dismissed=notif.dismissed,
|
||||
last_shown=notif.last_shown,
|
||||
first_shown=notif.first_shown,
|
||||
)
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
@ -41,5 +59,5 @@ class Settings(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class UserSettings(Notification, Settings):
|
||||
"""User-specific settings combining Notification and general Settings"""
|
||||
class UserSettings(Settings):
|
||||
notifications: list[Notification]
|
||||
|
Loading…
x
Reference in New Issue
Block a user