improved logout flow

This commit is contained in:
pablonyx 2025-03-11 13:37:32 -07:00
parent b0619ce198
commit 8ea0de3f6a
8 changed files with 298 additions and 59 deletions

View File

@ -160,6 +160,20 @@ class RedisPool:
def get_replica_client(self, tenant_id: str) -> Redis:
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
def get_raw_client(self) -> Redis:
"""
Returns a Redis client with direct access to the primary connection pool,
without tenant prefixing.
"""
return redis.Redis(connection_pool=self._pool)
def get_raw_replica_client(self) -> Redis:
"""
Returns a Redis client with direct access to the replica connection pool,
without tenant prefixing.
"""
return redis.Redis(connection_pool=self._replica_pool)
@staticmethod
def create_pool(
host: str = REDIS_HOST,
@ -249,6 +263,23 @@ def get_shared_redis_replica_client() -> Redis:
return redis_pool.get_replica_client(DEFAULT_REDIS_PREFIX)
# New functions to get Redis clients without tenant prefixing
def get_raw_redis_client() -> Redis:
"""
Returns a Redis client that doesn't apply tenant prefixing to keys.
Use this only when you need to access Redis directly without tenant isolation.
"""
return redis_pool.get_raw_client()
def get_raw_redis_replica_client() -> Redis:
"""
Returns a Redis replica client that doesn't apply tenant prefixing to keys.
Use this only when you need to access Redis replicas directly without tenant isolation.
"""
return redis_pool.get_raw_replica_client()
SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,
@ -307,6 +338,34 @@ async def get_async_redis_connection() -> aioredis.Redis:
return _async_redis_connection
def retrieve_auth_token_data_from_redis_sync(request: Request) -> dict | None:
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:
logger.debug("No auth token cookie found")
return None
try:
redis = get_raw_redis_client()
redis_key = REDIS_AUTH_KEY_PREFIX + token
token_data_str = redis.get(redis_key)
if not token_data_str:
logger.debug(f"Token key {redis_key} not found or expired in Redis")
return None
return json.loads(token_data_str)
except json.JSONDecodeError:
logger.error("Error decoding token data from Redis")
return None
except Exception as e:
logger.error(
f"Unexpected error in retrieve_auth_token_data_from_redis_sync: {str(e)}"
)
raise ValueError(
f"Unexpected error in retrieve_auth_token_data_from_redis_sync: {str(e)}"
)
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:

View File

@ -1,5 +1,6 @@
import re
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import jwt
@ -31,9 +32,12 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import AuthBackend
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
from onyx.configs.constants import AuthType
@ -50,6 +54,7 @@ from onyx.db.users import get_total_filtered_users_count
from onyx.db.users import get_user_by_email
from onyx.db.users import validate_user_role_update
from onyx.key_value_store.factory import get_kv_store
from onyx.redis.redis_pool import get_raw_redis_client
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
@ -506,6 +511,44 @@ def get_current_token_expiration_jwt(
return None
def get_current_token_expiration_redis(
user: User | None, request: Request
) -> datetime | None:
if user is None:
return None
try:
print("retrieving token data from Redis")
# Get the token from the request
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:
logger.debug("No auth token cookie found")
return None
# Get the Redis client
redis = get_raw_redis_client()
redis_key = REDIS_AUTH_KEY_PREFIX + token
# Get the TTL of the token
ttl = redis.ttl(redis_key)
if ttl <= 0:
logger.error("Token has expired or doesn't exist in Redis")
return None
# Calculate the creation time based on TTL and session expiry
# Current time minus (total session length minus remaining TTL)
current_time = datetime.now()
token_creation_time = current_time - timedelta(
seconds=(SESSION_EXPIRE_TIME_SECONDS - ttl)
)
print(f"Calculated token creation time: {token_creation_time}")
return token_creation_time
except Exception as e:
logger.error(f"Error retrieving token expiration from Redis: {e}")
return None
def get_current_token_creation(
user: User | None, db_session: Session
) -> datetime | None:
@ -533,6 +576,7 @@ def get_current_token_creation(
@router.get("/me")
def verify_user_logged_in(
request: Request,
user: User | None = Depends(optional_user),
db_session: Session = Depends(get_session),
) -> UserInfo:
@ -558,9 +602,13 @@ def verify_user_logged_in(
)
token_created_at = (
None if MULTI_TENANT else get_current_token_creation(user, db_session)
get_current_token_expiration_redis(user, request)
if AUTH_BACKEND == AuthBackend.REDIS
else get_current_token_creation(user, db_session)
)
print(f"token_created_at: {token_created_at}")
team_name = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(user.email)

View File

@ -135,7 +135,6 @@ import { UserSettingsModal } from "./modal/UserSettingsModal";
import { AgenticMessage } from "./message/AgenticMessage";
import AssistantModal from "../assistants/mine/AssistantModal";
import { useSidebarShortcut } from "@/lib/browserUtilities";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
import { ErrorBanner } from "./message/Resubmit";

View File

@ -30,6 +30,8 @@ import { ThemeProvider } from "next-themes";
import CloudError from "@/components/errorPages/CloudErrorPage";
import Error from "@/components/errorPages/ErrorPage";
import AccessRestrictedPage from "@/components/errorPages/AccessRestrictedPage";
import { cookies } from "next/headers";
import { TokenPayload } from "@/components/auth/AuthMonitor";
const inter = Inter({
subsets: ["latin"],

View File

@ -2,18 +2,21 @@
import React, { useEffect, useState } from "react";
import { useRouter } from "next/navigation";
import { getCookie } from "cookies-next";
// Time constants (in milliseconds)
const WARNING_THRESHOLD = 5 * 60 * 1000; // 5 minutes
const CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const REFRESH_THRESHOLD = 10 * 60 * 1000; // Try to refresh when 10 minutes remain
export interface TokenPayload {
exp: number;
token: string;
}
interface AuthMonitorProps {
children: React.ReactNode;
authToken: TokenPayload | null; // Add authToken as an optional prop
}
export function AuthMonitor({ children }: AuthMonitorProps) {
export function AuthMonitor({ children, authToken }: AuthMonitorProps) {
const router = useRouter();
const [showWarning, setShowWarning] = useState(false);
const [timeRemaining, setTimeRemaining] = useState<number | null>(null);
@ -22,12 +25,16 @@ export function AuthMonitor({ children }: AuthMonitorProps) {
// Function to parse JWT and get expiration time
const getTokenExpiration = (): number | null => {
try {
const authCookie = getCookie("fastapi-users-auth") as string | undefined;
if (!authCookie) return null;
// Only use the authToken prop provided by the server
if (!authToken) {
console.log("No authToken prop provided");
return null;
}
// JWT token has 3 parts separated by dots
const payload = JSON.parse(atob(authCookie.split(".")[1]));
return payload.exp * 1000; // Convert from seconds to milliseconds
console.log("Using authToken from server, exp:", authToken.exp);
// Return the expiration time in milliseconds
return authToken.exp * 1000; // Convert from seconds to milliseconds
} catch (error) {
console.error("Error parsing auth token:", error);
return null;
@ -37,6 +44,7 @@ export function AuthMonitor({ children }: AuthMonitorProps) {
// Attempt to refresh the token
const refreshToken = async (): Promise<boolean> => {
try {
console.log("Attempting to refresh token");
setIsRefreshing(true);
// Call your refresh token endpoint here
@ -49,7 +57,11 @@ export function AuthMonitor({ children }: AuthMonitorProps) {
console.log("Session refreshed successfully");
return true;
} else {
console.error("Failed to refresh session:", await response.text());
const errorText = await response.text();
console.error(
`Failed to refresh session: ${response.status} ${response.statusText}`,
errorText
);
return false;
}
} catch (error) {
@ -57,52 +69,72 @@ export function AuthMonitor({ children }: AuthMonitorProps) {
return false;
} finally {
setIsRefreshing(false);
console.log("Token refresh attempt completed");
}
};
// Check token expiration and handle status
const checkTokenExpiration = async () => {
console.log("Checking token expiration");
const expiresAt = getTokenExpiration();
if (!expiresAt) {
// No valid token found, redirect to login
router.push("/login");
console.log("No valid token found, redirecting to login");
router.push("/auth/login");
return;
}
console.log("Token found, checking expiration");
const remaining = expiresAt - Date.now();
console.log(`Token expires in ${remaining}ms (${remaining / 1000}s)`);
setTimeRemaining(remaining);
if (remaining <= 0) {
// Token expired, redirect to login
console.log("Token expired, redirecting to login");
setShowWarning(false);
router.push("/login");
router.push("/auth/login");
} else if (remaining < WARNING_THRESHOLD) {
// Show warning when less than 5 minutes remaining
console.log(
`Token expiring soon (${remaining / 1000}s remaining), showing warning`
);
setShowWarning(true);
} else if (remaining < REFRESH_THRESHOLD && !isRefreshing) {
// Try refreshing token when less than 10 minutes remaining
console.log(
`Token refresh threshold reached (${
remaining / 1000
}s remaining), attempting refresh`
);
const refreshed = await refreshToken();
if (refreshed) {
console.log("Token refreshed successfully, rechecking expiration");
// Re-check expiration after successful refresh
checkTokenExpiration();
}
} else {
console.log("Token is valid and not near expiration");
setShowWarning(false);
}
};
useEffect(() => {
console.log("AuthMonitor mounted, initializing token check");
// Check immediately on mount
checkTokenExpiration();
// Set up interval for periodic checking
const interval = setInterval(() => {
console.log("Running scheduled token check");
checkTokenExpiration();
}, CHECK_INTERVAL);
// Clean up interval on unmount
return () => clearInterval(interval);
return () => {
console.log("AuthMonitor unmounting, clearing interval");
clearInterval(interval);
};
}, []);
// Format time remaining for display
@ -111,8 +143,9 @@ export function AuthMonitor({ children }: AuthMonitorProps) {
const minutes = Math.floor(timeRemaining / 60000);
const seconds = Math.floor((timeRemaining % 60000) / 1000);
return `${minutes}:${seconds.toString().padStart(2, "0")}`;
const formattedTime = `${minutes}:${seconds.toString().padStart(2, "0")}`;
console.log(`Formatted time remaining: ${formattedTime}`);
return formattedTime;
};
return (

View File

@ -7,8 +7,6 @@ import { AssistantsProvider } from "./AssistantsContext";
import { Persona } from "@/app/admin/assistants/interfaces";
import { User } from "@/lib/types";
import { ModalProvider } from "./ModalContext";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { AuthMonitor } from "../auth/AuthMonitor";
interface AppProviderProps {
children: React.ReactNode;
@ -19,8 +17,6 @@ interface AppProviderProps {
hasImageCompatibleModel: boolean;
}
//
export const AppProvider = ({
children,
user,
@ -30,20 +26,18 @@ export const AppProvider = ({
hasImageCompatibleModel,
}: AppProviderProps) => {
return (
<AuthMonitor>
<SettingsProvider settings={settings}>
<UserProvider settings={settings} user={user}>
<ProviderContextProvider>
<AssistantsProvider
initialAssistants={assistants}
hasAnyConnectors={hasAnyConnectors}
hasImageCompatibleModel={hasImageCompatibleModel}
>
<ModalProvider user={user}>{children}</ModalProvider>
</AssistantsProvider>
</ProviderContextProvider>
</UserProvider>
</SettingsProvider>
</AuthMonitor>
<SettingsProvider settings={settings}>
<UserProvider settings={settings} user={user}>
<ProviderContextProvider>
<AssistantsProvider
initialAssistants={assistants}
hasAnyConnectors={hasAnyConnectors}
hasImageCompatibleModel={hasImageCompatibleModel}
>
<ModalProvider user={user}>{children}</ModalProvider>
</AssistantsProvider>
</ProviderContextProvider>
</UserProvider>
</SettingsProvider>
);
};

View File

@ -8,17 +8,41 @@ import { getSecondsUntilExpiration } from "@/lib/time";
import { User } from "@/lib/types";
import { mockedRefreshToken, refreshToken } from "./refreshUtils";
import { NEXT_PUBLIC_CUSTOM_REFRESH_URL } from "@/lib/constants";
import { Button } from "../ui/button";
import { logout } from "@/lib/user";
import { usePathname, useRouter } from "next/navigation";
import Cookies from "js-cookie";
import { SUPPRESS_EXPIRATION_WARNING_COOKIE_NAME } from "../resizable/constants";
export const HealthCheckBanner = () => {
const router = useRouter();
const { error } = useSWR("/api/health", errorHandlingFetcher);
const [expired, setExpired] = useState(false);
const [secondsUntilExpiration, setSecondsUntilExpiration] = useState<
number | null
>(null);
const { data: user, mutate: mutateUser } = useSWR<User>(
"/api/me",
errorHandlingFetcher
);
const [showExpirationWarning, setShowExpirationWarning] = useState(false);
const pathname = usePathname();
const {
data: user,
mutate: mutateUser,
error: userError,
} = useSWR<User>("/api/me", errorHandlingFetcher);
// Handle 403 errors from the /api/me endpoint
useEffect(() => {
if (userError && userError.status === 403) {
console.log("Received 403 from /api/me, logging out user");
logout().then(() => {
if (!pathname.includes("/auth")) {
router.push("/auth/login");
}
});
}
}, [userError, router]);
const updateExpirationTime = useCallback(async () => {
const updatedUser = await mutateUser();
@ -120,9 +144,101 @@ export const HealthCheckBanner = () => {
clearInterval(refreshIntervalId);
clearTimeout(expireTimeoutId);
};
} else {
let warningTimeoutId: NodeJS.Timeout;
let expireTimeoutId: NodeJS.Timeout;
const scheduleWarningAndExpire = () => {
if (secondsUntilExpiration !== null) {
const warningThreshold = 5 * 6000; // 5 minutes
// Check if there's a cookie to suppress the warning
const suppressWarning = Cookies.get(
SUPPRESS_EXPIRATION_WARNING_COOKIE_NAME
);
if (suppressWarning) {
console.debug("Suppressing expiration warning due to cookie");
setShowExpirationWarning(false);
} else if (secondsUntilExpiration <= warningThreshold) {
setShowExpirationWarning(true);
} else {
const timeUntilWarning =
(secondsUntilExpiration - warningThreshold) * 1000;
warningTimeoutId = setTimeout(() => {
// Check again for cookie when timeout fires
if (!Cookies.get(SUPPRESS_EXPIRATION_WARNING_COOKIE_NAME)) {
console.debug("Session about to expire. Showing warning.");
setShowExpirationWarning(true);
}
}, timeUntilWarning);
}
const timeUntilExpire = (secondsUntilExpiration + 10) * 1000;
expireTimeoutId = setTimeout(() => {
console.debug("Session expired. Setting expired state to true.");
setShowExpirationWarning(false);
setExpired(true);
// Remove the cookie when session actually expires
Cookies.remove(SUPPRESS_EXPIRATION_WARNING_COOKIE_NAME);
}, timeUntilExpire);
}
};
scheduleWarningAndExpire();
return () => {
clearTimeout(warningTimeoutId);
clearTimeout(expireTimeoutId);
};
}
}, [secondsUntilExpiration, user, mutateUser, updateExpirationTime]);
// Function to handle the "Continue Session" button
const handleContinueSession = () => {
// Set a cookie that will expire when the session expires
if (secondsUntilExpiration) {
// Calculate expiry in days (js-cookie uses days for expiration)
const expiryDays = secondsUntilExpiration / (60 * 60 * 24);
Cookies.set(SUPPRESS_EXPIRATION_WARNING_COOKIE_NAME, "true", {
expires: expiryDays,
path: "/",
});
console.debug(`Set cookie to suppress warnings for ${expiryDays} days`);
setShowExpirationWarning(false);
}
};
if (showExpirationWarning) {
return (
<Modal
width="w-1/3"
className="overflow-y-hidden flex flex-col"
title="Your Session Is About To Expire"
>
<div className="flex flex-col gap-y-4">
<p className="text-sm">
Your session will expire soon (in {secondsUntilExpiration} seconds).
Would you like to continue your session or log out?
</p>
<div className="flex flex-row gap-x-2 justify-end mt-4">
<Button onClick={handleContinueSession}>Continue Session</Button>
<Button
onClick={async () => {
await logout();
router.push("/auth/login");
}}
variant="outline"
>
Log Out
</Button>
</div>
</div>
</Modal>
);
}
if (!error && !expired) {
return null;
}
@ -132,25 +248,11 @@ export const HealthCheckBanner = () => {
);
if (error instanceof RedirectError || expired) {
return (
<Modal
width="w-1/4"
className="overflow-y-hidden flex flex-col"
title="You've been logged out"
>
<div className="flex flex-col gap-y-4">
<p className="text-sm">
Your session has expired. Please log in again to continue.
</p>
<a
href="/auth/login"
className="w-full mt-4 mx-auto rounded-md text-text-200 py-2 bg-background-900 text-center hover:bg-emphasis animtate duration-300 transition-bg"
>
Log in
</a>
</div>
</Modal>
);
if (!pathname.includes("/auth")) {
alert(pathname);
router.push("/auth/login");
}
return null;
} else {
return (
<div className="fixed top-0 left-0 z-[101] w-full text-xs mx-auto bg-gradient-to-r from-red-900 to-red-700 p-2 rounded-sm border-hidden text-text-200">

View File

@ -1,3 +1,5 @@
export const DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME = "documentSidebarWidth";
export const SIDEBAR_TOGGLED_COOKIE_NAME = "sidebarIsToggled";
export const PRO_SEARCH_TOGGLED_COOKIE_NAME = "proSearchIsToggled";
export const SUPPRESS_EXPIRATION_WARNING_COOKIE_NAME =
"suppress_expiration_warning";