mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-20 13:01:34 +02:00
Gating Notifications (#2868)
* functional notifications * typing * minor * ports * nit * verify functionality * pretty
This commit is contained in:
parent
786a46cbd0
commit
8b72264535
@ -136,6 +136,7 @@ DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FIL
|
|||||||
class NotificationType(str, Enum):
|
class NotificationType(str, Enum):
|
||||||
REINDEX = "reindex"
|
REINDEX = "reindex"
|
||||||
PERSONA_SHARED = "persona_shared"
|
PERSONA_SHARED = "persona_shared"
|
||||||
|
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||||
|
|
||||||
|
|
||||||
class BlobType(str, Enum):
|
class BlobType(str, Enum):
|
||||||
|
@ -4,6 +4,7 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
|
from danswer.auth.schemas import UserRole
|
||||||
from danswer.configs.constants import NotificationType
|
from danswer.configs.constants import NotificationType
|
||||||
from danswer.db.models import Notification
|
from danswer.db.models import Notification
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
@ -54,7 +55,9 @@ def get_notification_by_id(
|
|||||||
notif = db_session.get(Notification, notification_id)
|
notif = db_session.get(Notification, notification_id)
|
||||||
if not notif:
|
if not notif:
|
||||||
raise ValueError(f"No notification found with id {notification_id}")
|
raise ValueError(f"No notification found with id {notification_id}")
|
||||||
if notif.user_id != user_id:
|
if notif.user_id != user_id and not (
|
||||||
|
notif.user_id is None and user is not None and user.role == UserRole.ADMIN
|
||||||
|
):
|
||||||
raise PermissionError(
|
raise PermissionError(
|
||||||
f"User {user_id} is not authorized to access notification {notification_id}"
|
f"User {user_id} is not authorized to access notification {notification_id}"
|
||||||
)
|
)
|
||||||
|
@ -53,7 +53,7 @@ def fetch_settings(
|
|||||||
"""Settings and notifications are stuffed into this single endpoint to reduce number of
|
"""Settings and notifications are stuffed into this single endpoint to reduce number of
|
||||||
Postgres calls"""
|
Postgres calls"""
|
||||||
general_settings = load_settings()
|
general_settings = load_settings()
|
||||||
user_notifications = get_reindex_notification(user, db_session)
|
settings_notifications = get_settings_notifications(user, db_session)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kv_store = get_kv_store()
|
kv_store = get_kv_store()
|
||||||
@ -63,20 +63,29 @@ def fetch_settings(
|
|||||||
|
|
||||||
return UserSettings(
|
return UserSettings(
|
||||||
**general_settings.model_dump(),
|
**general_settings.model_dump(),
|
||||||
notifications=user_notifications,
|
notifications=settings_notifications,
|
||||||
needs_reindexing=needs_reindexing,
|
needs_reindexing=needs_reindexing,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_reindex_notification(
|
def get_settings_notifications(
|
||||||
user: User | None, db_session: Session
|
user: User | None, db_session: Session
|
||||||
) -> list[Notification]:
|
) -> list[Notification]:
|
||||||
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
|
"""Get notifications for settings page, including product gating and reindex notifications"""
|
||||||
|
# Check for product gating notification
|
||||||
|
product_notif = get_notifications(
|
||||||
|
user=None,
|
||||||
|
notif_type=NotificationType.TRIAL_ENDS_TWO_DAYS,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
notifications = [Notification.from_model(product_notif[0])] if product_notif else []
|
||||||
|
|
||||||
|
# Only show reindex notifications to admins
|
||||||
is_admin = is_user_admin(user)
|
is_admin = is_user_admin(user)
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
# Reindexing flag should only be shown to admins, basic users can't trigger it anyway
|
return notifications
|
||||||
return []
|
|
||||||
|
|
||||||
|
# Check if reindexing is needed
|
||||||
kv_store = get_kv_store()
|
kv_store = get_kv_store()
|
||||||
try:
|
try:
|
||||||
needs_index = cast(bool, kv_store.load(KV_REINDEX_KEY))
|
needs_index = cast(bool, kv_store.load(KV_REINDEX_KEY))
|
||||||
@ -84,12 +93,12 @@ def get_reindex_notification(
|
|||||||
dismiss_all_notifications(
|
dismiss_all_notifications(
|
||||||
notif_type=NotificationType.REINDEX, db_session=db_session
|
notif_type=NotificationType.REINDEX, db_session=db_session
|
||||||
)
|
)
|
||||||
return []
|
return notifications
|
||||||
except KvKeyNotFoundError:
|
except KvKeyNotFoundError:
|
||||||
# If something goes wrong and the flag is gone, better to not start a reindexing
|
# If something goes wrong and the flag is gone, better to not start a reindexing
|
||||||
# it's a heavyweight long running job and maybe this flag is cleaned up later
|
# it's a heavyweight long running job and maybe this flag is cleaned up later
|
||||||
logger.warning("Could not find reindex flag")
|
logger.warning("Could not find reindex flag")
|
||||||
return []
|
return notifications
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Need a transaction in order to prevent under-counting current notifications
|
# Need a transaction in order to prevent under-counting current notifications
|
||||||
@ -107,7 +116,9 @@ def get_reindex_notification(
|
|||||||
)
|
)
|
||||||
db_session.flush()
|
db_session.flush()
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
return [Notification.from_model(notif)]
|
|
||||||
|
notifications.append(Notification.from_model(notif))
|
||||||
|
return notifications
|
||||||
|
|
||||||
if len(reindex_notifs) > 1:
|
if len(reindex_notifs) > 1:
|
||||||
logger.error("User has multiple reindex notifications")
|
logger.error("User has multiple reindex notifications")
|
||||||
@ -118,8 +129,9 @@ def get_reindex_notification(
|
|||||||
)
|
)
|
||||||
|
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
return [Notification.from_model(reindex_notif)]
|
notifications.append(Notification.from_model(reindex_notif))
|
||||||
|
return notifications
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
logger.exception("Error while processing notifications")
|
logger.exception("Error while processing notifications")
|
||||||
db_session.rollback()
|
db_session.rollback()
|
||||||
return []
|
return notifications
|
||||||
|
@ -8,6 +8,7 @@ from danswer.auth.users import User
|
|||||||
from danswer.configs.app_configs import MULTI_TENANT
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.configs.app_configs import WEB_DOMAIN
|
from danswer.configs.app_configs import WEB_DOMAIN
|
||||||
from danswer.db.engine import get_session_with_tenant
|
from danswer.db.engine import get_session_with_tenant
|
||||||
|
from danswer.db.notification import create_notification
|
||||||
from danswer.server.settings.store import load_settings
|
from danswer.server.settings.store import load_settings
|
||||||
from danswer.server.settings.store import store_settings
|
from danswer.server.settings.store import store_settings
|
||||||
from danswer.setup import setup_danswer
|
from danswer.setup import setup_danswer
|
||||||
@ -87,12 +88,17 @@ def gate_product(
|
|||||||
1) User has ended free trial without adding payment method
|
1) User has ended free trial without adding payment method
|
||||||
2) User's card has declined
|
2) User's card has declined
|
||||||
"""
|
"""
|
||||||
token = current_tenant_id.set(current_tenant_id.get())
|
tenant_id = product_gating_request.tenant_id
|
||||||
|
token = current_tenant_id.set(tenant_id)
|
||||||
|
|
||||||
settings = load_settings()
|
settings = load_settings()
|
||||||
settings.product_gating = product_gating_request.product_gating
|
settings.product_gating = product_gating_request.product_gating
|
||||||
store_settings(settings)
|
store_settings(settings)
|
||||||
|
|
||||||
|
if product_gating_request.notification:
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
|
create_notification(None, product_gating_request.notification, db_session)
|
||||||
|
|
||||||
if token is not None:
|
if token is not None:
|
||||||
current_tenant_id.reset(token)
|
current_tenant_id.reset(token)
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from danswer.configs.constants import NotificationType
|
||||||
from danswer.server.settings.models import GatingType
|
from danswer.server.settings.models import GatingType
|
||||||
|
|
||||||
|
|
||||||
@ -15,6 +16,7 @@ class CreateTenantRequest(BaseModel):
|
|||||||
class ProductGatingRequest(BaseModel):
|
class ProductGatingRequest(BaseModel):
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
product_gating: GatingType
|
product_gating: GatingType
|
||||||
|
notification: NotificationType | None = None
|
||||||
|
|
||||||
|
|
||||||
class BillingInformation(BaseModel):
|
class BillingInformation(BaseModel):
|
||||||
|
@ -18,6 +18,7 @@ export interface Settings {
|
|||||||
export enum NotificationType {
|
export enum NotificationType {
|
||||||
PERSONA_SHARED = "persona_shared",
|
PERSONA_SHARED = "persona_shared",
|
||||||
REINDEX_NEEDED = "reindex_needed",
|
REINDEX_NEEDED = "reindex_needed",
|
||||||
|
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending",
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Notification {
|
export interface Notification {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { CLOUD_ENABLED } from "@/lib/constants";
|
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||||
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
|
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
|
||||||
import { NextRequest } from "next/server";
|
import { NextRequest } from "next/server";
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ export const POST = async (request: NextRequest) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Delete cookies only if cloud is enabled (jwt auth)
|
// Delete cookies only if cloud is enabled (jwt auth)
|
||||||
if (CLOUD_ENABLED) {
|
if (NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||||
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
|
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
|
||||||
const cookieOptions = {
|
const cookieOptions = {
|
||||||
path: "/",
|
path: "/",
|
||||||
|
@ -8,10 +8,8 @@ import {
|
|||||||
} from "@/lib/userSS";
|
} from "@/lib/userSS";
|
||||||
import { redirect } from "next/navigation";
|
import { redirect } from "next/navigation";
|
||||||
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
||||||
import { Card, Title, Text } from "@tremor/react";
|
import { Text } from "@tremor/react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { Logo } from "@/components/Logo";
|
|
||||||
import { CLOUD_ENABLED } from "@/lib/constants";
|
|
||||||
import { SignInButton } from "../login/SignInButton";
|
import { SignInButton } from "../login/SignInButton";
|
||||||
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||||
|
|
||||||
|
@ -174,20 +174,6 @@ export default async function RootLayout({
|
|||||||
process.env.THEME_IS_DARK?.toLowerCase() === "true" ? "dark" : ""
|
process.env.THEME_IS_DARK?.toLowerCase() === "true" ? "dark" : ""
|
||||||
}`}
|
}`}
|
||||||
>
|
>
|
||||||
{productGating === GatingType.PARTIAL && (
|
|
||||||
<div className="fixed top-0 left-0 right-0 z-50 bg-warning-100 text-warning-900 p-2 text-center">
|
|
||||||
<p className="text-sm font-medium">
|
|
||||||
Your account is pending payment!{" "}
|
|
||||||
<a
|
|
||||||
href="/admin/cloud-settings"
|
|
||||||
className="font-bold underline hover:text-warning-700 transition-colors"
|
|
||||||
>
|
|
||||||
Update your billing information
|
|
||||||
</a>{" "}
|
|
||||||
or access will be suspended soon.
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
<UserProvider>
|
<UserProvider>
|
||||||
<ProviderContextProvider>
|
<ProviderContextProvider>
|
||||||
<SettingsProvider settings={combinedSettings}>
|
<SettingsProvider settings={combinedSettings}>
|
||||||
|
@ -27,7 +27,6 @@ export async function Layout({ children }: { children: React.ReactNode }) {
|
|||||||
|
|
||||||
const authTypeMetadata = results[0] as AuthTypeMetadata | null;
|
const authTypeMetadata = results[0] as AuthTypeMetadata | null;
|
||||||
const user = results[1] as User | null;
|
const user = results[1] as User | null;
|
||||||
console.log("authTypeMetadata", authTypeMetadata);
|
|
||||||
const authDisabled = authTypeMetadata?.authType === "disabled";
|
const authDisabled = authTypeMetadata?.authType === "disabled";
|
||||||
const requiresVerification = authTypeMetadata?.requiresVerification;
|
const requiresVerification = authTypeMetadata?.requiresVerification;
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ export function AnnouncementBanner() {
|
|||||||
const handleDismiss = async (notificationId: number) => {
|
const handleDismiss = async (notificationId: number) => {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(
|
const response = await fetch(
|
||||||
`/api/settings/notifications/${notificationId}/dismiss`,
|
`/api/notifications/${notificationId}/dismiss`,
|
||||||
{
|
{
|
||||||
method: "POST",
|
method: "POST",
|
||||||
}
|
}
|
||||||
@ -61,12 +61,12 @@ export function AnnouncementBanner() {
|
|||||||
{localNotifications
|
{localNotifications
|
||||||
.filter((notification) => !notification.dismissed)
|
.filter((notification) => !notification.dismissed)
|
||||||
.map((notification) => {
|
.map((notification) => {
|
||||||
if (notification.notif_type == "reindex") {
|
return (
|
||||||
return (
|
<div
|
||||||
<div
|
key={notification.id}
|
||||||
key={notification.id}
|
className="absolute top-0 left-1/2 transform -translate-x-1/2 bg-blue-600 rounded-sm text-white px-4 pr-8 py-3 mx-auto"
|
||||||
className="absolute top-0 left-1/2 transform -translate-x-1/2 bg-blue-600 rounded-sm text-white px-4 pr-8 py-3 mx-auto"
|
>
|
||||||
>
|
{notification.notif_type == "reindex" ? (
|
||||||
<p className="text-center">
|
<p className="text-center">
|
||||||
Your index is out of date - we strongly recommend updating
|
Your index is out of date - we strongly recommend updating
|
||||||
your search settings.{" "}
|
your search settings.{" "}
|
||||||
@ -77,24 +77,29 @@ export function AnnouncementBanner() {
|
|||||||
Update here
|
Update here
|
||||||
</Link>
|
</Link>
|
||||||
</p>
|
</p>
|
||||||
<button
|
) : notification.notif_type == "two_day_trial_ending" ? (
|
||||||
onClick={() => handleDismiss(notification.id)}
|
<p className="text-center">
|
||||||
className="absolute top-0 right-0 mt-2 mr-2"
|
Your trial is ending soon - submit your billing information to
|
||||||
aria-label="Dismiss"
|
continue using Danswer.{" "}
|
||||||
>
|
<Link
|
||||||
<CustomTooltip
|
href="/admin/cloud-settings"
|
||||||
showTick
|
className="ml-2 underline cursor-pointer"
|
||||||
citation
|
|
||||||
delay={100}
|
|
||||||
content="Dismiss"
|
|
||||||
>
|
>
|
||||||
<XIcon className="h-5 w-5" />
|
Update here
|
||||||
</CustomTooltip>
|
</Link>
|
||||||
</button>
|
</p>
|
||||||
</div>
|
) : null}
|
||||||
);
|
<button
|
||||||
}
|
onClick={() => handleDismiss(notification.id)}
|
||||||
return null;
|
className="absolute top-0 right-0 mt-2 mr-2"
|
||||||
|
aria-label="Dismiss"
|
||||||
|
>
|
||||||
|
<CustomTooltip showTick citation delay={100} content="Dismiss">
|
||||||
|
<XIcon className="h-5 w-5" />
|
||||||
|
</CustomTooltip>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
})}
|
})}
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
@ -62,7 +62,7 @@ export const CUSTOM_ANALYTICS_ENABLED = process.env.CUSTOM_ANALYTICS_SECRET_KEY
|
|||||||
export const DISABLE_LLM_DOC_RELEVANCE =
|
export const DISABLE_LLM_DOC_RELEVANCE =
|
||||||
process.env.DISABLE_LLM_DOC_RELEVANCE?.toLowerCase() === "true";
|
process.env.DISABLE_LLM_DOC_RELEVANCE?.toLowerCase() === "true";
|
||||||
|
|
||||||
export const CLOUD_ENABLED =
|
export const NEXT_PUBLIC_CLOUD_ENABLED =
|
||||||
process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true";
|
process.env.NEXT_PUBLIC_CLOUD_ENABLED?.toLowerCase() === "true";
|
||||||
|
|
||||||
export const REGISTRATION_URL =
|
export const REGISTRATION_URL =
|
||||||
|
Loading…
x
Reference in New Issue
Block a user