mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-09 20:55:06 +02:00
Add assistant notifications + update assistant context (#2816)
* add assistant notifications * nit * update context * validated * ensure context passed properly * validated + cleaned * nit: naming * k * k * final validation + new ui * nit + video * nit * nit * nit * k * fix typos
This commit is contained in:
@@ -0,0 +1,26 @@
|
||||
"""add additional data to notifications
|
||||
|
||||
Revision ID: 1b10e1fda030
|
||||
Revises: 6756efa39ada
|
||||
Create Date: 2024-10-15 19:26:44.071259
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1b10e1fda030"
|
||||
down_revision = "6756efa39ada"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("notification", "additional_data")
|
@@ -135,6 +135,7 @@ DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FIL
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
PERSONA_SHARED = "persona_shared"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
|
@@ -235,6 +235,9 @@ class Notification(Base):
|
||||
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="notifications")
|
||||
additional_data: Mapped[dict | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import func
|
||||
@@ -8,16 +10,37 @@ from danswer.db.models import User
|
||||
|
||||
|
||||
def create_notification(
|
||||
user: User | None,
|
||||
user_id: UUID | None,
|
||||
notif_type: NotificationType,
|
||||
db_session: Session,
|
||||
additional_data: dict | None = None,
|
||||
) -> Notification:
|
||||
# Check if an undismissed notification of the same type and data exists
|
||||
existing_notification = (
|
||||
db_session.query(Notification)
|
||||
.filter_by(
|
||||
user_id=user_id,
|
||||
notif_type=notif_type,
|
||||
dismissed=False,
|
||||
)
|
||||
.filter(Notification.additional_data == additional_data)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_notification:
|
||||
# Update the last_shown timestamp
|
||||
existing_notification.last_shown = func.now()
|
||||
db_session.commit()
|
||||
return existing_notification
|
||||
|
||||
# Create a new notification if none exists
|
||||
notification = Notification(
|
||||
user_id=user.id if user else None,
|
||||
user_id=user_id,
|
||||
notif_type=notif_type,
|
||||
dismissed=False,
|
||||
last_shown=func.now(),
|
||||
first_shown=func.now(),
|
||||
additional_data=additional_data,
|
||||
)
|
||||
db_session.add(notification)
|
||||
db_session.commit()
|
||||
|
@@ -57,6 +57,7 @@ from danswer.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
|
||||
from danswer.server.features.notifications.api import router as notification_router
|
||||
from danswer.server.features.persona.api import admin_router as admin_persona_router
|
||||
from danswer.server.features.persona.api import basic_router as persona_router
|
||||
from danswer.server.features.prompt.api import basic_router as prompt_router
|
||||
@@ -246,6 +247,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, notification_router)
|
||||
include_router_with_global_prefix_prepended(application, prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, tool_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_tool_router)
|
||||
|
47
backend/danswer/server/features/notifications/api.py
Normal file
47
backend/danswer/server/features/notifications/api.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.notification import dismiss_notification
|
||||
from danswer.db.notification import get_notification_by_id
|
||||
from danswer.db.notification import get_notifications
|
||||
from danswer.server.settings.models import Notification as NotificationModel
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/notifications")
|
||||
|
||||
|
||||
@router.get("")
|
||||
def get_notifications_api(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[NotificationModel]:
|
||||
notifications = [
|
||||
NotificationModel.from_model(notif)
|
||||
for notif in get_notifications(user, db_session, include_dismissed=False)
|
||||
]
|
||||
return notifications
|
||||
|
||||
|
||||
@router.post("/{notification_id}/dismiss")
|
||||
def dismiss_notification_endpoint(
|
||||
notification_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
notification = get_notification_by_id(notification_id, user, db_session)
|
||||
except PermissionError:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not authorized to dismiss this notification"
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Notification not found")
|
||||
|
||||
dismiss_notification(notification, db_session)
|
@@ -13,8 +13,10 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.notification import create_notification
|
||||
from danswer.db.persona import create_update_persona
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.persona import get_personas
|
||||
@@ -28,6 +30,7 @@ from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.llm.answering.prompts.utils import build_dummy_prompt
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import PersonaSharedNotificationData
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.features.persona.models import PromptTemplateResponse
|
||||
from danswer.server.models import DisplayPriorityRequest
|
||||
@@ -183,11 +186,12 @@ class PersonaShareRequest(BaseModel):
|
||||
user_ids: list[UUID]
|
||||
|
||||
|
||||
# We notify each user when a user is shared with them
|
||||
@basic_router.patch("/{persona_id}/share")
|
||||
def share_persona(
|
||||
persona_id: int,
|
||||
persona_share_request: PersonaShareRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_persona_shared_users(
|
||||
@@ -197,6 +201,18 @@ def share_persona(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
for user_id in persona_share_request.user_ids:
|
||||
# Don't notify the user that they have access to their own persona
|
||||
if user_id != user.id:
|
||||
create_notification(
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@basic_router.delete("/{persona_id}")
|
||||
def delete_persona(
|
||||
@@ -216,23 +232,31 @@ def list_personas(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
include_deleted: bool = False,
|
||||
persona_ids: list[int] = Query(None),
|
||||
) -> list[PersonaSnapshot]:
|
||||
return [
|
||||
PersonaSnapshot.from_model(persona)
|
||||
for persona in get_personas(
|
||||
user=user,
|
||||
include_deleted=include_deleted,
|
||||
db_session=db_session,
|
||||
get_editable=False,
|
||||
joinedload_all=True,
|
||||
)
|
||||
# If the persona has an image generation tool and it's not available, don't include it
|
||||
personas = get_personas(
|
||||
user=user,
|
||||
include_deleted=include_deleted,
|
||||
db_session=db_session,
|
||||
get_editable=False,
|
||||
joinedload_all=True,
|
||||
)
|
||||
|
||||
if persona_ids:
|
||||
personas = [p for p in personas if p.id in persona_ids]
|
||||
|
||||
# Filter out personas with unavailable tools
|
||||
personas = [
|
||||
p
|
||||
for p in personas
|
||||
if not (
|
||||
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in persona.tools)
|
||||
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in p.tools)
|
||||
and not is_image_generation_available(db_session=db_session)
|
||||
)
|
||||
]
|
||||
|
||||
return [PersonaSnapshot.from_model(p) for p in personas]
|
||||
|
||||
|
||||
@basic_router.get("/{persona_id}")
|
||||
def get_persona(
|
||||
|
@@ -120,3 +120,7 @@ class PersonaSnapshot(BaseModel):
|
||||
|
||||
class PromptTemplateResponse(BaseModel):
|
||||
final_prompt_template: str
|
||||
|
||||
|
||||
class PersonaSharedNotificationData(BaseModel):
|
||||
persona_id: int
|
||||
|
@@ -15,8 +15,6 @@ from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.notification import create_notification
|
||||
from danswer.db.notification import dismiss_all_notifications
|
||||
from danswer.db.notification import dismiss_notification
|
||||
from danswer.db.notification import get_notification_by_id
|
||||
from danswer.db.notification import get_notifications
|
||||
from danswer.db.notification import update_notification_last_shown
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
@@ -55,7 +53,7 @@ def fetch_settings(
|
||||
"""Settings and notifications are stuffed into this single endpoint to reduce number of
|
||||
Postgres calls"""
|
||||
general_settings = load_settings()
|
||||
user_notifications = get_user_notifications(user, db_session)
|
||||
user_notifications = get_reindex_notification(user, db_session)
|
||||
|
||||
try:
|
||||
kv_store = get_kv_store()
|
||||
@@ -70,25 +68,7 @@ def fetch_settings(
|
||||
)
|
||||
|
||||
|
||||
@basic_router.post("/notifications/{notification_id}/dismiss")
|
||||
def dismiss_notification_endpoint(
|
||||
notification_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
notification = get_notification_by_id(notification_id, user, db_session)
|
||||
except PermissionError:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not authorized to dismiss this notification"
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Notification not found")
|
||||
|
||||
dismiss_notification(notification, db_session)
|
||||
|
||||
|
||||
def get_user_notifications(
|
||||
def get_reindex_notification(
|
||||
user: User | None, db_session: Session
|
||||
) -> list[Notification]:
|
||||
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
|
||||
@@ -121,7 +101,7 @@ def get_user_notifications(
|
||||
|
||||
if not reindex_notifs:
|
||||
notif = create_notification(
|
||||
user=user,
|
||||
user_id=user.id if user else None,
|
||||
notif_type=NotificationType.REINDEX,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
@@ -24,6 +24,7 @@ class Notification(BaseModel):
|
||||
dismissed: bool
|
||||
last_shown: datetime
|
||||
first_shown: datetime
|
||||
additional_data: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, notif: NotificationDBModel) -> "Notification":
|
||||
@@ -33,6 +34,7 @@ class Notification(BaseModel):
|
||||
dismissed=notif.dismissed,
|
||||
last_shown=notif.last_shown,
|
||||
first_shown=notif.first_shown,
|
||||
additional_data=notif.additional_data,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user