Add assistant notifications + update assistant context (#2816)

* add assistant notifications

* nit

* update context

* validated

* ensure context passed properly

* validated + cleaned

* nit: naming

* k

* k

* final validation + new ui

* nit + video

* nit

* nit

* nit

* k

* fix typos
This commit is contained in:
pablodanswer
2024-10-18 18:21:11 -07:00
committed by GitHub
parent 6913efef90
commit 8b220d2dba
37 changed files with 820 additions and 344 deletions

View File

@@ -0,0 +1,26 @@
"""add additional data to notifications
Revision ID: 1b10e1fda030
Revises: 6756efa39ada
Create Date: 2024-10-15 19:26:44.071259
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1b10e1fda030"
down_revision = "6756efa39ada"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True)
)
def downgrade() -> None:
op.drop_column("notification", "additional_data")

View File

@@ -135,6 +135,7 @@ DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FIL
class NotificationType(str, Enum):
REINDEX = "reindex"
PERSONA_SHARED = "persona_shared"
class BlobType(str, Enum):

View File

@@ -235,6 +235,9 @@ class Notification(Base):
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
user: Mapped[User] = relationship("User", back_populates="notifications")
additional_data: Mapped[dict | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
"""

View File

@@ -1,3 +1,5 @@
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
@@ -8,16 +10,37 @@ from danswer.db.models import User
def create_notification(
user: User | None,
user_id: UUID | None,
notif_type: NotificationType,
db_session: Session,
additional_data: dict | None = None,
) -> Notification:
# Check if an undismissed notification of the same type and data exists
existing_notification = (
db_session.query(Notification)
.filter_by(
user_id=user_id,
notif_type=notif_type,
dismissed=False,
)
.filter(Notification.additional_data == additional_data)
.first()
)
if existing_notification:
# Update the last_shown timestamp
existing_notification.last_shown = func.now()
db_session.commit()
return existing_notification
# Create a new notification if none exists
notification = Notification(
user_id=user.id if user else None,
user_id=user_id,
notif_type=notif_type,
dismissed=False,
last_shown=func.now(),
first_shown=func.now(),
additional_data=additional_data,
)
db_session.add(notification)
db_session.commit()

View File

@@ -57,6 +57,7 @@ from danswer.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
from danswer.server.features.notifications.api import router as notification_router
from danswer.server.features.persona.api import admin_router as admin_persona_router
from danswer.server.features.persona.api import basic_router as persona_router
from danswer.server.features.prompt.api import basic_router as prompt_router
@@ -246,6 +247,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, input_prompt_router)
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, notification_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
include_router_with_global_prefix_prepended(application, admin_tool_router)

View File

@@ -0,0 +1,47 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import dismiss_notification
from danswer.db.notification import get_notification_by_id
from danswer.db.notification import get_notifications
from danswer.server.settings.models import Notification as NotificationModel
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/notifications")
@router.get("")
def get_notifications_api(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[NotificationModel]:
notifications = [
NotificationModel.from_model(notif)
for notif in get_notifications(user, db_session, include_dismissed=False)
]
return notifications
@router.post("/{notification_id}/dismiss")
def dismiss_notification_endpoint(
notification_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
notification = get_notification_by_id(notification_id, user, db_session)
except PermissionError:
raise HTTPException(
status_code=403, detail="Not authorized to dismiss this notification"
)
except ValueError:
raise HTTPException(status_code=404, detail="Notification not found")
dismiss_notification(notification, db_session)

View File

@@ -13,8 +13,10 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.persona import create_update_persona
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas
@@ -28,6 +30,7 @@ from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSharedNotificationData
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse
from danswer.server.models import DisplayPriorityRequest
@@ -183,11 +186,12 @@ class PersonaShareRequest(BaseModel):
user_ids: list[UUID]
# We notify each user when a user is shared with them
@basic_router.patch("/{persona_id}/share")
def share_persona(
persona_id: int,
persona_share_request: PersonaShareRequest,
user: User | None = Depends(current_user),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_shared_users(
@@ -197,6 +201,18 @@ def share_persona(
db_session=db_session,
)
for user_id in persona_share_request.user_ids:
# Don't notify the user that they have access to their own persona
if user_id != user.id:
create_notification(
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
@basic_router.delete("/{persona_id}")
def delete_persona(
@@ -216,23 +232,31 @@ def list_personas(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),
) -> list[PersonaSnapshot]:
return [
PersonaSnapshot.from_model(persona)
for persona in get_personas(
user=user,
include_deleted=include_deleted,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)
# If the persona has an image generation tool and it's not available, don't include it
personas = get_personas(
user=user,
include_deleted=include_deleted,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)
if persona_ids:
personas = [p for p in personas if p.id in persona_ids]
# Filter out personas with unavailable tools
personas = [
p
for p in personas
if not (
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in persona.tools)
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in p.tools)
and not is_image_generation_available(db_session=db_session)
)
]
return [PersonaSnapshot.from_model(p) for p in personas]
@basic_router.get("/{persona_id}")
def get_persona(

View File

@@ -120,3 +120,7 @@ class PersonaSnapshot(BaseModel):
class PromptTemplateResponse(BaseModel):
final_prompt_template: str
class PersonaSharedNotificationData(BaseModel):
persona_id: int

View File

@@ -15,8 +15,6 @@ from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.notification import dismiss_all_notifications
from danswer.db.notification import dismiss_notification
from danswer.db.notification import get_notification_by_id
from danswer.db.notification import get_notifications
from danswer.db.notification import update_notification_last_shown
from danswer.key_value_store.factory import get_kv_store
@@ -55,7 +53,7 @@ def fetch_settings(
"""Settings and notifications are stuffed into this single endpoint to reduce number of
Postgres calls"""
general_settings = load_settings()
user_notifications = get_user_notifications(user, db_session)
user_notifications = get_reindex_notification(user, db_session)
try:
kv_store = get_kv_store()
@@ -70,25 +68,7 @@ def fetch_settings(
)
@basic_router.post("/notifications/{notification_id}/dismiss")
def dismiss_notification_endpoint(
notification_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
notification = get_notification_by_id(notification_id, user, db_session)
except PermissionError:
raise HTTPException(
status_code=403, detail="Not authorized to dismiss this notification"
)
except ValueError:
raise HTTPException(status_code=404, detail="Notification not found")
dismiss_notification(notification, db_session)
def get_user_notifications(
def get_reindex_notification(
user: User | None, db_session: Session
) -> list[Notification]:
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
@@ -121,7 +101,7 @@ def get_user_notifications(
if not reindex_notifs:
notif = create_notification(
user=user,
user_id=user.id if user else None,
notif_type=NotificationType.REINDEX,
db_session=db_session,
)

View File

@@ -24,6 +24,7 @@ class Notification(BaseModel):
dismissed: bool
last_shown: datetime
first_shown: datetime
additional_data: dict | None = None
@classmethod
def from_model(cls, notif: NotificationDBModel) -> "Notification":
@@ -33,6 +34,7 @@ class Notification(BaseModel):
dismissed=notif.dismissed,
last_shown=notif.last_shown,
first_shown=notif.first_shown,
additional_data=notif.additional_data,
)