Gating Notifications (#2868)

* functional notifications

* typing

* minor

* ports

* nit

* verify functionality

* pretty
This commit is contained in:
pablodanswer 2024-10-23 13:20:20 -07:00 committed by GitHub
parent 786a46cbd0
commit 8b72264535
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 71 additions and 58 deletions

View File

@ -136,6 +136,7 @@ DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FIL
class NotificationType(str, Enum): class NotificationType(str, Enum):
REINDEX = "reindex" REINDEX = "reindex"
PERSONA_SHARED = "persona_shared" PERSONA_SHARED = "persona_shared"
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
class BlobType(str, Enum): class BlobType(str, Enum):

View File

@ -4,6 +4,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.sql import func from sqlalchemy.sql import func
from danswer.auth.schemas import UserRole
from danswer.configs.constants import NotificationType from danswer.configs.constants import NotificationType
from danswer.db.models import Notification from danswer.db.models import Notification
from danswer.db.models import User from danswer.db.models import User
@ -54,7 +55,9 @@ def get_notification_by_id(
notif = db_session.get(Notification, notification_id) notif = db_session.get(Notification, notification_id)
if not notif: if not notif:
raise ValueError(f"No notification found with id {notification_id}") raise ValueError(f"No notification found with id {notification_id}")
if notif.user_id != user_id: if notif.user_id != user_id and not (
notif.user_id is None and user is not None and user.role == UserRole.ADMIN
):
raise PermissionError( raise PermissionError(
f"User {user_id} is not authorized to access notification {notification_id}" f"User {user_id} is not authorized to access notification {notification_id}"
) )

View File

@ -53,7 +53,7 @@ def fetch_settings(
"""Settings and notifications are stuffed into this single endpoint to reduce number of """Settings and notifications are stuffed into this single endpoint to reduce number of
Postgres calls""" Postgres calls"""
general_settings = load_settings() general_settings = load_settings()
user_notifications = get_reindex_notification(user, db_session) settings_notifications = get_settings_notifications(user, db_session)
try: try:
kv_store = get_kv_store() kv_store = get_kv_store()
@ -63,20 +63,29 @@ def fetch_settings(
return UserSettings( return UserSettings(
**general_settings.model_dump(), **general_settings.model_dump(),
notifications=user_notifications, notifications=settings_notifications,
needs_reindexing=needs_reindexing, needs_reindexing=needs_reindexing,
) )
def get_reindex_notification( def get_settings_notifications(
user: User | None, db_session: Session user: User | None, db_session: Session
) -> list[Notification]: ) -> list[Notification]:
"""Get notifications for the user, currently the logic is very specific to the reindexing flag""" """Get notifications for settings page, including product gating and reindex notifications"""
# Check for product gating notification
product_notif = get_notifications(
user=None,
notif_type=NotificationType.TRIAL_ENDS_TWO_DAYS,
db_session=db_session,
)
notifications = [Notification.from_model(product_notif[0])] if product_notif else []
# Only show reindex notifications to admins
is_admin = is_user_admin(user) is_admin = is_user_admin(user)
if not is_admin: if not is_admin:
# Reindexing flag should only be shown to admins, basic users can't trigger it anyway return notifications
return []
# Check if reindexing is needed
kv_store = get_kv_store() kv_store = get_kv_store()
try: try:
needs_index = cast(bool, kv_store.load(KV_REINDEX_KEY)) needs_index = cast(bool, kv_store.load(KV_REINDEX_KEY))
@ -84,12 +93,12 @@ def get_reindex_notification(
dismiss_all_notifications( dismiss_all_notifications(
notif_type=NotificationType.REINDEX, db_session=db_session notif_type=NotificationType.REINDEX, db_session=db_session
) )
return [] return notifications
except KvKeyNotFoundError: except KvKeyNotFoundError:
# If something goes wrong and the flag is gone, better to not start a reindexing # If something goes wrong and the flag is gone, better to not start a reindexing
# it's a heavyweight long running job and maybe this flag is cleaned up later # it's a heavyweight long running job and maybe this flag is cleaned up later
logger.warning("Could not find reindex flag") logger.warning("Could not find reindex flag")
return [] return notifications
try: try:
# Need a transaction in order to prevent under-counting current notifications # Need a transaction in order to prevent under-counting current notifications
@ -107,7 +116,9 @@ def get_reindex_notification(
) )
db_session.flush() db_session.flush()
db_session.commit() db_session.commit()
return [Notification.from_model(notif)]
notifications.append(Notification.from_model(notif))
return notifications
if len(reindex_notifs) > 1: if len(reindex_notifs) > 1:
logger.error("User has multiple reindex notifications") logger.error("User has multiple reindex notifications")
@ -118,8 +129,9 @@ def get_reindex_notification(
) )
db_session.commit() db_session.commit()
return [Notification.from_model(reindex_notif)] notifications.append(Notification.from_model(reindex_notif))
return notifications
except SQLAlchemyError: except SQLAlchemyError:
logger.exception("Error while processing notifications") logger.exception("Error while processing notifications")
db_session.rollback() db_session.rollback()
return [] return notifications

View File

@ -8,6 +8,7 @@ from danswer.auth.users import User
from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_session_with_tenant
from danswer.db.notification import create_notification
from danswer.server.settings.store import load_settings from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings from danswer.server.settings.store import store_settings
from danswer.setup import setup_danswer from danswer.setup import setup_danswer
@ -87,12 +88,17 @@ def gate_product(
1) User has ended free trial without adding payment method 1) User has ended free trial without adding payment method
2) User's card has declined 2) User's card has declined
""" """
token = current_tenant_id.set(current_tenant_id.get()) tenant_id = product_gating_request.tenant_id
token = current_tenant_id.set(tenant_id)
settings = load_settings() settings = load_settings()
settings.product_gating = product_gating_request.product_gating settings.product_gating = product_gating_request.product_gating
store_settings(settings) store_settings(settings)
if product_gating_request.notification:
with get_session_with_tenant(tenant_id) as db_session:
create_notification(None, product_gating_request.notification, db_session)
if token is not None: if token is not None:
current_tenant_id.reset(token) current_tenant_id.reset(token)

View File

@ -1,5 +1,6 @@
from pydantic import BaseModel from pydantic import BaseModel
from danswer.configs.constants import NotificationType
from danswer.server.settings.models import GatingType from danswer.server.settings.models import GatingType
@ -15,6 +16,7 @@ class CreateTenantRequest(BaseModel):
class ProductGatingRequest(BaseModel): class ProductGatingRequest(BaseModel):
tenant_id: str tenant_id: str
product_gating: GatingType product_gating: GatingType
notification: NotificationType | None = None
class BillingInformation(BaseModel): class BillingInformation(BaseModel):

View File

@ -18,6 +18,7 @@ export interface Settings {
export enum NotificationType { export enum NotificationType {
PERSONA_SHARED = "persona_shared", PERSONA_SHARED = "persona_shared",
REINDEX_NEEDED = "reindex_needed", REINDEX_NEEDED = "reindex_needed",
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending",
} }
export interface Notification { export interface Notification {

View File

@ -1,4 +1,4 @@
import { CLOUD_ENABLED } from "@/lib/constants"; import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS"; import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
import { NextRequest } from "next/server"; import { NextRequest } from "next/server";
@ -13,7 +13,7 @@ export const POST = async (request: NextRequest) => {
} }
// Delete cookies only if cloud is enabled (jwt auth) // Delete cookies only if cloud is enabled (jwt auth)
if (CLOUD_ENABLED) { if (NEXT_PUBLIC_CLOUD_ENABLED) {
const cookiesToDelete = ["fastapiusersauth", "tenant_details"]; const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
const cookieOptions = { const cookieOptions = {
path: "/", path: "/",

View File

@ -8,10 +8,8 @@ import {
} from "@/lib/userSS"; } from "@/lib/userSS";
import { redirect } from "next/navigation"; import { redirect } from "next/navigation";
import { EmailPasswordForm } from "../login/EmailPasswordForm"; import { EmailPasswordForm } from "../login/EmailPasswordForm";
import { Card, Title, Text } from "@tremor/react"; import { Text } from "@tremor/react";
import Link from "next/link"; import Link from "next/link";
import { Logo } from "@/components/Logo";
import { CLOUD_ENABLED } from "@/lib/constants";
import { SignInButton } from "../login/SignInButton"; import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; import AuthFlowContainer from "@/components/auth/AuthFlowContainer";

View File

@ -174,20 +174,6 @@ export default async function RootLayout({
process.env.THEME_IS_DARK?.toLowerCase() === "true" ? "dark" : "" process.env.THEME_IS_DARK?.toLowerCase() === "true" ? "dark" : ""
}`} }`}
> >
{productGating === GatingType.PARTIAL && (
<div className="fixed top-0 left-0 right-0 z-50 bg-warning-100 text-warning-900 p-2 text-center">
<p className="text-sm font-medium">
Your account is pending payment!{" "}
<a
href="/admin/cloud-settings"
className="font-bold underline hover:text-warning-700 transition-colors"
>
Update your billing information
</a>{" "}
or access will be suspended soon.
</p>
</div>
)}
<UserProvider> <UserProvider>
<ProviderContextProvider> <ProviderContextProvider>
<SettingsProvider settings={combinedSettings}> <SettingsProvider settings={combinedSettings}>

View File

@ -27,7 +27,6 @@ export async function Layout({ children }: { children: React.ReactNode }) {
const authTypeMetadata = results[0] as AuthTypeMetadata | null; const authTypeMetadata = results[0] as AuthTypeMetadata | null;
const user = results[1] as User | null; const user = results[1] as User | null;
console.log("authTypeMetadata", authTypeMetadata);
const authDisabled = authTypeMetadata?.authType === "disabled"; const authDisabled = authTypeMetadata?.authType === "disabled";
const requiresVerification = authTypeMetadata?.requiresVerification; const requiresVerification = authTypeMetadata?.requiresVerification;

View File

@ -32,7 +32,7 @@ export function AnnouncementBanner() {
const handleDismiss = async (notificationId: number) => { const handleDismiss = async (notificationId: number) => {
try { try {
const response = await fetch( const response = await fetch(
`/api/settings/notifications/${notificationId}/dismiss`, `/api/notifications/${notificationId}/dismiss`,
{ {
method: "POST", method: "POST",
} }
@ -61,12 +61,12 @@ export function AnnouncementBanner() {
{localNotifications {localNotifications
.filter((notification) => !notification.dismissed) .filter((notification) => !notification.dismissed)
.map((notification) => { .map((notification) => {
if (notification.notif_type == "reindex") { return (
return ( <div
<div key={notification.id}
key={notification.id} className="absolute top-0 left-1/2 transform -translate-x-1/2 bg-blue-600 rounded-sm text-white px-4 pr-8 py-3 mx-auto"
className="absolute top-0 left-1/2 transform -translate-x-1/2 bg-blue-600 rounded-sm text-white px-4 pr-8 py-3 mx-auto" >
> {notification.notif_type == "reindex" ? (
<p className="text-center"> <p className="text-center">
Your index is out of date - we strongly recommend updating Your index is out of date - we strongly recommend updating
your search settings.{" "} your search settings.{" "}
@ -77,24 +77,29 @@ export function AnnouncementBanner() {
Update here Update here
</Link> </Link>
</p> </p>
<button ) : notification.notif_type == "two_day_trial_ending" ? (
onClick={() => handleDismiss(notification.id)} <p className="text-center">
className="absolute top-0 right-0 mt-2 mr-2" Your trial is ending soon - submit your billing information to
aria-label="Dismiss" continue using Danswer.{" "}
> <Link
<CustomTooltip href="/admin/cloud-settings"
showTick className="ml-2 underline cursor-pointer"
citation
delay={100}
content="Dismiss"
> >
<XIcon className="h-5 w-5" /> Update here
</CustomTooltip> </Link>
</button> </p>
</div> ) : null}
); <button
} onClick={() => handleDismiss(notification.id)}
return null; className="absolute top-0 right-0 mt-2 mr-2"
aria-label="Dismiss"
>
<CustomTooltip showTick citation delay={100} content="Dismiss">
<XIcon className="h-5 w-5" />
</CustomTooltip>
</button>
</div>
);
})} })}
</> </>
); );

View File

@ -62,7 +62,7 @@ export const CUSTOM_ANALYTICS_ENABLED = process.env.CUSTOM_ANALYTICS_SECRET_KEY
export const DISABLE_LLM_DOC_RELEVANCE = export const DISABLE_LLM_DOC_RELEVANCE =
process.env.DISABLE_LLM_DOC_RELEVANCE?.toLowerCase() === "true"; process.env.DISABLE_LLM_DOC_RELEVANCE?.toLowerCase() === "true";
export const CLOUD_ENABLED = export const NEXT_PUBLIC_CLOUD_ENABLED =
process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true"; process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true";
export const REGISTRATION_URL = export const REGISTRATION_URL =