Updated refreshing (#2327)

* clean up + add environment variables

* remove log

* update

* update api settings

* somewhat cleaner refresh functionality

* fully functional

* update settings

* validated

* remove random logs

* remove unneeded paramter + log

* move to ee + remove comments

* Cleanup unused

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
This commit is contained in:
pablodanswer
2024-09-05 21:36:55 -07:00
committed by GitHub
parent 2bd3833c55
commit 69c0419146
11 changed files with 208 additions and 23 deletions

View File

@@ -66,7 +66,7 @@ def fetch_settings(
return UserSettings(
**general_settings.model_dump(),
notifications=user_notifications,
needs_reindexing=needs_reindexing
needs_reindexing=needs_reindexing,
)

View File

@@ -1,14 +1,24 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from fastapi import status
from fastapi import UploadFile
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import get_user_manager
from danswer.auth.users import UserManager
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
from ee.danswer.server.enterprise_settings.store import _LOGO_FILENAME
@@ -18,10 +28,117 @@ from ee.danswer.server.enterprise_settings.store import load_settings
from ee.danswer.server.enterprise_settings.store import store_analytics_script
from ee.danswer.server.enterprise_settings.store import store_settings
from ee.danswer.server.enterprise_settings.store import upload_logo
from shared_configs.configs import CUSTOM_REFRESH_URL
admin_router = APIRouter(prefix="/admin/enterprise-settings")
basic_router = APIRouter(prefix="/enterprise-settings")
logger = setup_logger()
def mocked_refresh_token() -> dict:
"""
This function mocks the response from a token refresh endpoint.
It generates a mock access token, refresh token, and user information
with an expiration time set to 1 hour from now.
This is useful for testing or development when the actual refresh endpoint is not available.
"""
mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000)
data = {
"access_token": "asdf Mock access token",
"refresh_token": "asdf Mock refresh token",
"session": {"exp": mock_exp},
"userinfo": {
"sub": "Mock email",
"familyName": "Mock name",
"givenName": "Mock name",
"fullName": "Mock name",
"userId": "Mock User ID",
"email": "test_email@danswer.ai",
},
}
return data
@basic_router.get("/refresh-token")
async def refresh_access_token(
user: User = Depends(current_user),
user_manager: UserManager = Depends(get_user_manager),
) -> None:
# return
if CUSTOM_REFRESH_URL is None:
logger.error(
"Custom refresh URL is not set and client is attempting to custom refresh"
)
raise HTTPException(
status_code=500,
detail="Custom refresh URL is not set",
)
try:
async with httpx.AsyncClient() as client:
logger.debug(f"Sending request to custom refresh URL for user {user.id}")
access_token = user.oauth_accounts[0].access_token
response = await client.get(
CUSTOM_REFRESH_URL,
params={"info": "json", "access_token_refresh_interval": 3600},
headers={"Authorization": f"Bearer {access_token}"},
)
response.raise_for_status()
data = response.json()
# NOTE: Here is where we can mock the response
# data = mocked_refresh_token()
logger.debug(f"Received response from Meechum auth URL for user {user.id}")
# Extract new tokens
new_access_token = data["access_token"]
new_refresh_token = data["refresh_token"]
new_expiry = datetime.fromtimestamp(
data["session"]["exp"] / 1000, tz=timezone.utc
)
expires_at_timestamp = int(new_expiry.timestamp())
logger.debug(f"Access token has been refreshed for user {user.id}")
await user_manager.oauth_callback(
oauth_name="custom",
access_token=new_access_token,
account_id=data["userinfo"]["userId"],
account_email=data["userinfo"]["email"],
expires_at=expires_at_timestamp,
refresh_token=new_refresh_token,
associate_by_email=True,
)
logger.info(f"Successfully refreshed tokens for user {user.id}")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
logger.warning(f"Full authentication required for user {user.id}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Full authentication required",
)
logger.error(
f"HTTP error occurred while refreshing token for user {user.id}: {str(e)}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to refresh token",
)
except Exception as e:
logger.error(
f"Unexpected error occurred while refreshing token for user {user.id}: {str(e)}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred",
)
@admin_router.put("")
def put_settings(

View File

@@ -70,3 +70,5 @@ PRESERVED_SEARCH_FIELDS = [
"passage_prefix",
"query_prefix",
]
CUSTOM_REFRESH_URL = os.environ.get("CUSTOM_REFRESH_URL") or "/settings/refresh-token"

View File

@@ -172,6 +172,7 @@ services:
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
- DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-}
- DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-}
- CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-}
# Logging
# Leave this on pretty please? Nothing sensitive is collected!
# https://docs.danswer.dev/more/telemetry
@@ -220,6 +221,7 @@ services:
# Enterprise Edition only
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-}
inference_model_server:
image: danswer/danswer-model-server:${IMAGE_TAG:-latest}

View File

@@ -58,6 +58,7 @@ ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_T
ARG NEXT_PUBLIC_DISABLE_LOGOUT
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
RUN npx next build
# Step 2. Production image, copy all the files and run next

View File

@@ -91,7 +91,6 @@ import FunctionalHeader from "@/components/chat_search/Header";
import { useSidebarVisibility } from "@/components/chat_search/hooks";
import { SIDEBAR_TOGGLED_COOKIE_NAME } from "@/components/resizable/constants";
import FixedLogo from "./shared_chat_search/FixedLogo";
import { getSecondsUntilExpiration } from "@/lib/time";
import { SetDefaultModelModal } from "./modal/SetDefaultModelModal";
import { DeleteEntityModal } from "../../components/modals/DeleteEntityModal";
import { MinimalMarkdown } from "@/components/chat_search/MinimalMarkdown";
@@ -1559,7 +1558,6 @@ export function ChatPage({
setDocumentSelection((documentSelection) => !documentSelection);
setShowDocSidebar(false);
};
const secondsUntilExpiration = getSecondsUntilExpiration(user);
interface RegenerationRequest {
messageId: number;
@@ -1579,7 +1577,7 @@ export function ChatPage({
return (
<>
<HealthCheckBanner secondsUntilExpiration={secondsUntilExpiration} />
<HealthCheckBanner />
{/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit.
Only used in the EE version of the app. */}
{popup}

View File

@@ -3,7 +3,6 @@ import "./globals.css";
import {
fetchEnterpriseSettingsSS,
fetchSettingsSS,
SettingsError,
} from "@/components/settings/lib";
import {
CUSTOM_ANALYTICS_ENABLED,
@@ -11,7 +10,7 @@ import {
} from "@/lib/constants";
import { SettingsProvider } from "@/components/settings/SettingsProvider";
import { Metadata } from "next";
import { buildClientUrl } from "@/lib/utilsSS";
import { buildClientUrl, fetchSS } from "@/lib/utilsSS";
import { Inter } from "next/font/google";
import Head from "next/head";
import { EnterpriseSettings } from "./admin/settings/interfaces";

View File

@@ -3,7 +3,6 @@ import {
getAuthTypeMetadataSS,
getCurrentUserSS,
} from "@/lib/userSS";
import { getSecondsUntilExpiration } from "@/lib/time";
import { redirect } from "next/navigation";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
@@ -179,11 +178,10 @@ export default async function Home() {
const agenticSearchEnabled = agenticSearchToggle
? agenticSearchToggle.value.toLocaleLowerCase() == "true" || false
: false;
const secondsUntilExpiration = getSecondsUntilExpiration(user);
return (
<>
<HealthCheckBanner secondsUntilExpiration={secondsUntilExpiration} />
<HealthCheckBanner />
{shouldShowWelcomeModal && <WelcomeModal user={user} />}
<InstantSSRAutoRefresh />

View File

@@ -3,29 +3,95 @@
import { errorHandlingFetcher, RedirectError } from "@/lib/fetcher";
import useSWR from "swr";
import { Modal } from "../Modal";
import { useState } from "react";
import { useEffect, useState } from "react";
import { getSecondsUntilExpiration } from "@/lib/time";
import { User } from "@/lib/types";
export const HealthCheckBanner = ({
secondsUntilExpiration,
}: {
secondsUntilExpiration?: number | null;
}) => {
export const HealthCheckBanner = () => {
const { error } = useSWR("/api/health", errorHandlingFetcher);
const [expired, setExpired] = useState(false);
if (secondsUntilExpiration !== null && secondsUntilExpiration !== undefined) {
setTimeout(
() => {
setExpired(true);
},
secondsUntilExpiration * 1000 - 200
const [secondsUntilExpiration, setSecondsUntilExpiration] = useState<
number | null
>(null);
const { data: user, mutate: mutateUser } = useSWR<User>(
"/api/me",
errorHandlingFetcher
);
const updateExpirationTime = async () => {
const updatedUser = await mutateUser();
if (updatedUser) {
const seconds = getSecondsUntilExpiration(updatedUser);
setSecondsUntilExpiration(seconds);
console.debug(`Updated seconds until expiration:! ${seconds}`);
}
};
useEffect(() => {
updateExpirationTime();
}, [user]);
useEffect(() => {
if (true) {
let refreshTimeoutId: NodeJS.Timeout;
let expireTimeoutId: NodeJS.Timeout;
const refreshToken = async () => {
try {
const response = await fetch(
"/api/enterprise-settings/refresh-token",
{
method: "GET",
headers: {
"Content-Type": "application/json",
},
}
);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
console.debug("Token refresh successful");
// Force revalidation of user data
await mutateUser(undefined, { revalidate: true });
updateExpirationTime();
} catch (error) {
console.error("Error refreshing token:", error);
}
};
const scheduleRefreshAndExpire = () => {
if (secondsUntilExpiration !== null) {
const timeUntilRefresh = (secondsUntilExpiration + 0.5) * 1000;
refreshTimeoutId = setTimeout(refreshToken, timeUntilRefresh);
const timeUntilExpire = (secondsUntilExpiration + 10) * 1000;
expireTimeoutId = setTimeout(() => {
console.debug("Session expired. Setting expired state to true.");
setExpired(true);
}, timeUntilExpire);
}
};
scheduleRefreshAndExpire();
return () => {
clearTimeout(refreshTimeoutId);
clearTimeout(expireTimeoutId);
};
}
}, [secondsUntilExpiration, user]);
if (!error && !expired) {
return null;
}
console.debug(
`Rendering HealthCheckBanner. Error: ${error}, Expired: ${expired}`
);
if (error instanceof RedirectError || expired) {
return (
<Modal

View File

@@ -101,6 +101,7 @@ export function getSecondsUntilExpiration(
if (!userInfo) {
return null;
}
const { oidc_expiry, current_token_created_at, current_token_expiry_length } =
userInfo;

View File

@@ -10,6 +10,7 @@ const eePaths = [
"/admin/whitelabeling",
"/admin/performance/custom-analytics",
];
const eePathsForMatcher = eePaths.map((path) => `${path}/:path*`);
export async function middleware(request: NextRequest) {