add assistant notifications

This commit is contained in:
pablodanswer 2024-10-15 19:49:31 -07:00
parent 61424de531
commit fd8b11c6db
19 changed files with 241 additions and 35 deletions

View File

@ -0,0 +1,26 @@
"""add additional data to notifiations
Revision ID: 1b10e1fda030
Revises: 5d12a446f5c0
Create Date: 2024-10-15 19:26:44.071259
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1b10e1fda030"
down_revision = "5d12a446f5c0"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True)
)
def downgrade() -> None:
op.drop_column("notification", "additional_data")

View File

@ -134,7 +134,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int(

View File

@ -123,6 +123,7 @@ DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FIL
class NotificationType(str, Enum):
REINDEX = "reindex"
PERSONA_SHARED = "persona_shared"
class BlobType(str, Enum):

View File

@ -235,6 +235,9 @@ class Notification(Base):
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
user: Mapped[User] = relationship("User", back_populates="notifications")
additional_data: Mapped[dict | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
"""

View File

@ -1,3 +1,5 @@
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
@ -8,12 +10,12 @@ from danswer.db.models import User
def create_notification(
user: User | None,
user_id: UUID | None,
notif_type: NotificationType,
db_session: Session,
) -> Notification:
notification = Notification(
user_id=user.id if user else None,
user_id=user_id,
notif_type=notif_type,
dismissed=False,
last_shown=func.now(),

View File

@ -57,6 +57,7 @@ from danswer.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
from danswer.server.features.notifications.api import router as notification_router
from danswer.server.features.persona.api import admin_router as admin_persona_router
from danswer.server.features.persona.api import basic_router as persona_router
from danswer.server.features.prompt.api import basic_router as prompt_router
@ -243,6 +244,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, input_prompt_router)
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, notification_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
include_router_with_global_prefix_prepended(application, admin_tool_router)

View File

@ -0,0 +1,46 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import dismiss_notification
from danswer.db.notification import get_notification_by_id
from danswer.db.notification import get_notifications
from danswer.server.settings.models import Notification as NotificationModel
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/notifications")
@router.get("/")
def get_notifications_api(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[NotificationModel]:
return [
NotificationModel.from_model(notif)
for notif in get_notifications(user, db_session, include_dismissed=False)
]
@router.post("/{notification_id}/dismiss")
def dismiss_notification_endpoint(
notification_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
notification = get_notification_by_id(notification_id, user, db_session)
except PermissionError:
raise HTTPException(
status_code=403, detail="Not authorized to dismiss this notification"
)
except ValueError:
raise HTTPException(status_code=404, detail="Notification not found")
dismiss_notification(notification, db_session)

View File

@ -13,8 +13,10 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.persona import create_update_persona
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas
@ -28,6 +30,7 @@ from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSharedNotificationData
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse
from danswer.server.models import DisplayPriorityRequest
@ -183,6 +186,7 @@ class PersonaShareRequest(BaseModel):
user_ids: list[UUID]
# We notify each user when a user is shared with them
@basic_router.patch("/{persona_id}/share")
def share_persona(
persona_id: int,
@ -197,6 +201,16 @@ def share_persona(
db_session=db_session,
)
for user_id in persona_share_request.user_ids:
create_notification(
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
@basic_router.delete("/{persona_id}")
def delete_persona(

View File

@ -120,3 +120,7 @@ class PersonaSnapshot(BaseModel):
class PromptTemplateResponse(BaseModel):
final_prompt_template: str
class PersonaSharedNotificationData(BaseModel):
persona_id: int

View File

@ -15,8 +15,6 @@ from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.notification import dismiss_all_notifications
from danswer.db.notification import dismiss_notification
from danswer.db.notification import get_notification_by_id
from danswer.db.notification import get_notifications
from danswer.db.notification import update_notification_last_shown
from danswer.key_value_store.factory import get_kv_store
@ -55,7 +53,7 @@ def fetch_settings(
"""Settings and notifications are stuffed into this single endpoint to reduce number of
Postgres calls"""
general_settings = load_settings()
user_notifications = get_user_notifications(user, db_session)
user_notifications = get_reindex_notification(user, db_session)
try:
kv_store = get_kv_store()
@ -70,25 +68,7 @@ def fetch_settings(
)
@basic_router.post("/notifications/{notification_id}/dismiss")
def dismiss_notification_endpoint(
notification_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
notification = get_notification_by_id(notification_id, user, db_session)
except PermissionError:
raise HTTPException(
status_code=403, detail="Not authorized to dismiss this notification"
)
except ValueError:
raise HTTPException(status_code=404, detail="Notification not found")
dismiss_notification(notification, db_session)
def get_user_notifications(
def get_reindex_notification(
user: User | None, db_session: Session
) -> list[Notification]:
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
@ -121,7 +101,7 @@ def get_user_notifications(
if not reindex_notifs:
notif = create_notification(
user=user,
user_id=user.id if user else None,
notif_type=NotificationType.REINDEX,
db_session=db_session,
)

View File

@ -24,6 +24,7 @@ class Notification(BaseModel):
dismissed: bool
last_shown: datetime
first_shown: datetime
additional_data: dict | None = None
@classmethod
def from_model(cls, notif: NotificationDBModel) -> "Notification":
@ -33,6 +34,7 @@ class Notification(BaseModel):
dismissed=notif.dismissed,
last_shown=notif.last_shown,
first_shown=notif.first_shown,
additional_data=notif.additional_data,
)

View File

@ -312,7 +312,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@ -312,7 +312,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@ -157,7 +157,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432"
- "5433"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@ -59,6 +59,15 @@ export interface ToolCallFinalResult {
tool_result: Record<string, any>;
}
export interface Notification {
id: string;
title: string;
message: string;
time_created: string;
dismissed: boolean;
additional_data?: Record<string, any>;
}
export interface ChatSession {
id: string;
name: string;

View File

@ -9,11 +9,7 @@ import { checkUserIsNoAuthUser, logout } from "@/lib/user";
import { Popover } from "./popover/Popover";
import { LOGOUT_DISABLED } from "@/lib/constants";
import { SettingsContext } from "./settings/SettingsProvider";
import {
AssistantsIconSkeleton,
LightSettingsIcon,
UsersIcon,
} from "./icons/icons";
import { LightSettingsIcon } from "./icons/icons";
import { pageType } from "@/app/chat/sessionSidebar/types";
import { NavigationItem } from "@/app/admin/settings/interfaces";
import DynamicFaIcon, { preloadIcons } from "./icons/DynamicFaIcon";

View File

@ -5,12 +5,15 @@ import { FiShare2 } from "react-icons/fi";
import { SetStateAction, useContext, useEffect } from "react";
import { NewChatIcon } from "../icons/icons";
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
import { ChatSession } from "@/app/chat/interfaces";
import { ChatSession, Notification } from "@/app/chat/interfaces";
import Link from "next/link";
import { pageType } from "@/app/chat/sessionSidebar/types";
import { useRouter } from "next/navigation";
import { ChatBanner } from "@/app/chat/ChatBanner";
import LogoType from "../header/LogoType";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { NotificationCard } from "./Notification";
export default function FunctionalHeader({
user,
@ -54,6 +57,18 @@ export default function FunctionalHeader({
}, [page, currentChatSession]);
const router = useRouter();
const {
data: notifications,
error,
mutate: refreshNotifications,
} = useSWR<Notification[]>("/api/notifications", errorHandlingFetcher);
useEffect(() => {
if (error) {
console.error("Failed to fetch notificat ions:", error);
}
}, [error]);
const handleNewChat = () => {
reset();
const newChatUrl =
@ -108,6 +123,10 @@ export default function FunctionalHeader({
</div>
)}
<NotificationCard
notifications={notifications}
refreshNotifications={refreshNotifications}
/>
<div className="mobile:hidden flex my-auto">
<UserDropdown user={user} />
</div>

View File

@ -0,0 +1,93 @@
import React, { useState } from "react";
import { Notification } from "../../app/chat/interfaces";
export const NotificationCard = ({
notifications,
refreshNotifications,
}: {
notifications?: Notification[];
refreshNotifications: () => void;
}) => {
const [showDropdown, setShowDropdown] = useState(false);
const dismissNotification = async (notificationId: string) => {
try {
await fetch(`/api/notifications/${notificationId}/dismiss`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
} catch (error) {
console.error("Error dismissing notification:", error);
}
};
const handleDismiss = async (notificationId: string) => {
try {
await dismissNotification(notificationId);
refreshNotifications();
} catch (error) {
console.error("Error dismissing notification:", error);
}
};
const handleAccept = async (notification: Notification) => {
// Handle accept logic based on notification.additional_data
// For example, accept a shared persona
// await acceptSharedPersona(notification.additional_data.persona_id);
// Then dismiss the notification
await handleDismiss(notification.id);
};
if (!notifications) {
return null;
}
return (
<div className="relative">
<div
onClick={() => setShowDropdown(!showDropdown)}
className="cursor-pointer"
>
<svg className="w-6 h-6">
{/* Bell icon SVG */}
<path d="..." />
</svg>
{notifications.length > 0 && (
<span className="absolute top-0 right-0 h-2 w-2 bg-orange-500 rounded-full"></span>
)}
</div>
{showDropdown && (
<div className="absolute right-0 mt-2 py-2 w-80 bg-white rounded-md shadow-xl z-20">
{notifications.length > 0 ? (
notifications.map((notification) => (
<div key={notification.id} className="px-4 py-2 border-b">
<p className="font-semibold">{notification.title}</p>
<p className="text-sm text-gray-600">{notification.message}</p>
<div className="flex justify-end mt-2">
<button
onClick={() => handleAccept(notification)}
className="text-sm text-blue-500 mr-4"
>
Accept
</button>
<button
onClick={() => handleDismiss(notification.id)}
className="text-sm text-red-500"
>
Dismiss
</button>
</div>
</div>
))
) : (
<div className="px-4 py-2 text-center text-gray-600">
No new notifications
</div>
)}
</div>
)}
</div>
);
};

View File

@ -155,3 +155,12 @@ export const processCookies = (cookies: ReadonlyRequestCookies): string => {
.map((cookie) => `${cookie.name}=${cookie.value}`)
.join("; ");
};
export const getNotificationsSS = async (): Promise<Notification[]> => {
const response = await fetch(buildUrl("/notifications"));
if (!response.ok) {
throw new Error("Failed to fetch notifications");
}
const notifications = await response.json();
return notifications;
};