Improved logout flow (#4258)

* improved app provider modals

* improved logout flow

* k

* updates

* add docstring
This commit is contained in:
pablonyx 2025-03-12 12:19:39 -07:00 committed by GitHub
parent 2f8f0f01be
commit 0153ff6b51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 281 additions and 76 deletions

View File

@ -160,6 +160,20 @@ class RedisPool:
def get_replica_client(self, tenant_id: str) -> Redis:
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
def get_raw_client(self) -> Redis:
"""
Returns a Redis client with direct access to the primary connection pool,
without tenant prefixing.
"""
return redis.Redis(connection_pool=self._pool)
def get_raw_replica_client(self) -> Redis:
"""
Returns a Redis client with direct access to the replica connection pool,
without tenant prefixing.
"""
return redis.Redis(connection_pool=self._replica_pool)
@staticmethod
def create_pool(
host: str = REDIS_HOST,
@ -224,6 +238,15 @@ def get_redis_client(
# This argument will be deprecated in the future
tenant_id: str | None = None,
) -> Redis:
"""
Returns a Redis client with tenant-specific key prefixing.
This ensures proper data isolation between tenants by automatically
prefixing all Redis keys with the tenant ID.
Use this when working with tenant-specific data that should be
isolated from other tenants.
"""
if tenant_id is None:
tenant_id = get_current_tenant_id()
@ -235,6 +258,15 @@ def get_redis_replica_client(
# this argument will be deprecated in the future
tenant_id: str | None = None,
) -> Redis:
"""
Returns a Redis replica client with tenant-specific key prefixing.
Similar to get_redis_client(), but connects to a read replica when available.
This ensures proper data isolation between tenants by automatically
prefixing all Redis keys with the tenant ID.
Use this for read-heavy operations on tenant-specific data.
"""
if tenant_id is None:
tenant_id = get_current_tenant_id()
@ -242,13 +274,57 @@ def get_redis_replica_client(
def get_shared_redis_client() -> Redis:
"""
Returns a Redis client with a shared namespace prefix.
Unlike tenant-specific clients, this uses a common prefix for all keys,
creating a shared namespace accessible across all tenants.
Use this for data that should be shared across the application and
isn't specific to any individual tenant.
"""
return redis_pool.get_client(DEFAULT_REDIS_PREFIX)
def get_shared_redis_replica_client() -> Redis:
"""
Returns a Redis replica client with a shared namespace prefix.
Similar to get_shared_redis_client(), but connects to a read replica when available.
Uses a common prefix for all keys, creating a shared namespace.
Use this for read-heavy operations on data that should be shared
across the application.
"""
return redis_pool.get_replica_client(DEFAULT_REDIS_PREFIX)
def get_raw_redis_client() -> Redis:
"""
Returns a Redis client that doesn't apply tenant prefixing to keys.
Use this only when you need to access Redis directly without tenant isolation
or any key prefixing. Typically needed for integrating with external systems
or libraries that have inflexible key requirements.
Warning: Be careful with this client as it bypasses tenant isolation.
"""
return redis_pool.get_raw_client()
def get_raw_redis_replica_client() -> Redis:
"""
Returns a Redis replica client that doesn't apply tenant prefixing to keys.
Similar to get_raw_redis_client(), but connects to a read replica when available.
Use this for read-heavy operations that need direct Redis access without
tenant isolation or key prefixing.
Warning: Be careful with this client as it bypasses tenant isolation.
"""
return redis_pool.get_raw_replica_client()
SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,

View File

@ -1,6 +1,8 @@
import re
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
import jwt
from email_validator import EmailNotValidError
@ -31,9 +33,12 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import AuthBackend
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
from onyx.configs.constants import AuthType
@ -50,6 +55,7 @@ from onyx.db.users import get_total_filtered_users_count
from onyx.db.users import get_user_by_email
from onyx.db.users import validate_user_role_update
from onyx.key_value_store.factory import get_kv_store
from onyx.redis.redis_pool import get_raw_redis_client
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
@ -477,7 +483,7 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
return UserRoleResponse(role=user.role)
def get_current_token_expiration_jwt(
def get_current_auth_token_expiration_jwt(
user: User | None, request: Request
) -> datetime | None:
if user is None:
@ -506,6 +512,48 @@ def get_current_token_expiration_jwt(
return None
def get_current_auth_token_creation_redis(
user: User | None, request: Request
) -> datetime | None:
"""Calculate the token creation time from Redis TTL information.
This function retrieves the authentication token from cookies,
checks its TTL in Redis, and calculates when the token was created.
Despite the function name, it returns the token creation time, not the expiration time.
"""
if user is None:
return None
try:
# Get the token from the request
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:
logger.debug("No auth token cookie found")
return None
# Get the Redis client
redis = get_raw_redis_client()
redis_key = REDIS_AUTH_KEY_PREFIX + token
# Get the TTL of the token
ttl = cast(int, redis.ttl(redis_key))
if ttl <= 0:
logger.error("Token has expired or doesn't exist in Redis")
return None
# Calculate the creation time based on TTL and session expiry
# Current time minus (total session length minus remaining TTL)
current_time = datetime.now(timezone.utc)
token_creation_time = current_time - timedelta(
seconds=(SESSION_EXPIRE_TIME_SECONDS - ttl)
)
return token_creation_time
except Exception as e:
logger.error(f"Error retrieving token expiration from Redis: {e}")
return None
def get_current_token_creation(
user: User | None, db_session: Session
) -> datetime | None:
@ -533,6 +581,7 @@ def get_current_token_creation(
@router.get("/me")
def verify_user_logged_in(
request: Request,
user: User | None = Depends(optional_user),
db_session: Session = Depends(get_session),
) -> UserInfo:
@ -558,7 +607,9 @@ def verify_user_logged_in(
)
token_created_at = (
None if MULTI_TENANT else get_current_token_creation(user, db_session)
get_current_auth_token_creation_redis(user, request)
if AUTH_BACKEND == AuthBackend.REDIS
else get_current_token_creation(user, db_session)
)
team_name = fetch_ee_implementation_or_noop(

23
web/package-lock.json generated
View File

@ -45,6 +45,7 @@
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"cmdk": "^1.0.0",
"cookies-next": "^5.1.0",
"date-fns": "^3.6.0",
"favicon-fetch": "^1.0.0",
"formik": "^2.2.9",
@ -9849,6 +9850,28 @@
"resolved": "https://registry.npmjs.org/convert-source-map/-/convert-source-map-2.0.0.tgz",
"integrity": "sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg=="
},
"node_modules/cookie": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/cookie/-/cookie-1.0.2.tgz",
"integrity": "sha512-9Kr/j4O16ISv8zBBhJoi4bXOYNTkFLOqSL3UDB0njXxCXNezjeyVrJyGOWtgfs/q2km1gwBcfH8q1yEGoMYunA==",
"license": "MIT",
"engines": {
"node": ">=18"
}
},
"node_modules/cookies-next": {
"version": "5.1.0",
"resolved": "https://registry.npmjs.org/cookies-next/-/cookies-next-5.1.0.tgz",
"integrity": "sha512-9Ekne+q8hfziJtnT9c1yDUBqT0eDMGgPrfPl4bpR3xwQHLTd/8gbSf6+IEkP/pjGsDZt1TGbC6emYmFYRbIXwQ==",
"license": "MIT",
"dependencies": {
"cookie": "^1.0.1"
},
"peerDependencies": {
"next": ">=15.0.0",
"react": ">= 16.8.0"
}
},
"node_modules/core-js": {
"version": "3.38.1",
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.38.1.tgz",

View File

@ -48,6 +48,7 @@
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"cmdk": "^1.0.0",
"cookies-next": "^5.1.0",
"date-fns": "^3.6.0",
"favicon-fetch": "^1.0.0",
"formik": "^2.2.9",

View File

@ -134,7 +134,6 @@ import { UserSettingsModal } from "./modal/UserSettingsModal";
import { AgenticMessage } from "./message/AgenticMessage";
import AssistantModal from "../assistants/mine/AssistantModal";
import { useSidebarShortcut } from "@/lib/browserUtilities";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
import { ErrorBanner } from "./message/Resubmit";
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";

View File

@ -7,7 +7,6 @@ import { AssistantsProvider } from "./AssistantsContext";
import { Persona } from "@/app/admin/assistants/interfaces";
import { User } from "@/lib/types";
import { ModalProvider } from "./ModalContext";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
interface AppProviderProps {
children: React.ReactNode;
@ -18,8 +17,6 @@ interface AppProviderProps {
hasImageCompatibleModel: boolean;
}
//
export const AppProvider = ({
children,
user,

View File

@ -3,42 +3,97 @@
import { errorHandlingFetcher, RedirectError } from "@/lib/fetcher";
import useSWR from "swr";
import { Modal } from "../Modal";
import { useCallback, useEffect, useState } from "react";
import { useCallback, useEffect, useState, useRef } from "react";
import { getSecondsUntilExpiration } from "@/lib/time";
import { User } from "@/lib/types";
import { mockedRefreshToken, refreshToken } from "./refreshUtils";
import { refreshToken } from "./refreshUtils";
import { NEXT_PUBLIC_CUSTOM_REFRESH_URL } from "@/lib/constants";
import { Button } from "../ui/button";
import { logout } from "@/lib/user";
import { usePathname, useRouter } from "next/navigation";
export const HealthCheckBanner = () => {
const router = useRouter();
const { error } = useSWR("/api/health", errorHandlingFetcher);
const [expired, setExpired] = useState(false);
const [secondsUntilExpiration, setSecondsUntilExpiration] = useState<
number | null
>(null);
const { data: user, mutate: mutateUser } = useSWR<User>(
"/api/me",
errorHandlingFetcher
const [showLoggedOutModal, setShowLoggedOutModal] = useState(false);
const pathname = usePathname();
const expirationTimeoutRef = useRef<NodeJS.Timeout | null>(null);
const refreshIntervalRef = useRef<NodeJS.Timer | null>(null);
// Reduce revalidation frequency with dedicated SWR config
const {
data: user,
mutate: mutateUser,
error: userError,
} = useSWR<User>("/api/me", errorHandlingFetcher, {
revalidateOnFocus: false,
revalidateOnReconnect: false,
dedupingInterval: 30000, // 30 seconds
});
// Handle 403 errors from the /api/me endpoint
useEffect(() => {
if (userError && userError.status === 403) {
logout().then(() => {
if (!pathname.includes("/auth")) {
setShowLoggedOutModal(true);
}
});
}
}, [userError, pathname]);
// Function to handle the "Log in" button click
const handleLogin = () => {
setShowLoggedOutModal(false);
router.push("/auth/login");
};
// Function to set up expiration timeout
const setupExpirationTimeout = useCallback(
(secondsUntilExpiration: number) => {
// Clear any existing timeout
if (expirationTimeoutRef.current) {
clearTimeout(expirationTimeoutRef.current);
}
// Set timeout to show logout modal when session expires
const timeUntilExpire = (secondsUntilExpiration + 10) * 1000;
expirationTimeoutRef.current = setTimeout(() => {
setExpired(true);
if (!pathname.includes("/auth")) {
setShowLoggedOutModal(true);
}
}, timeUntilExpire);
},
[pathname]
);
const updateExpirationTime = useCallback(async () => {
const updatedUser = await mutateUser();
if (updatedUser) {
const seconds = getSecondsUntilExpiration(updatedUser);
setSecondsUntilExpiration(seconds);
console.debug(`Updated seconds until expiration:! ${seconds}`);
}
}, [mutateUser]);
// Clean up any timeouts/intervals when component unmounts
useEffect(() => {
updateExpirationTime();
}, [user, updateExpirationTime]);
return () => {
if (expirationTimeoutRef.current) {
clearTimeout(expirationTimeoutRef.current);
}
if (refreshIntervalRef.current) {
clearInterval(refreshIntervalRef.current);
}
};
}, []);
// Set up token refresh logic if custom refresh URL exists
useEffect(() => {
if (!user) return;
const secondsUntilExpiration = getSecondsUntilExpiration(user);
if (secondsUntilExpiration === null) return;
// Set up expiration timeout based on current user data
setupExpirationTimeout(secondsUntilExpiration);
if (NEXT_PUBLIC_CUSTOM_REFRESH_URL) {
const refreshUrl = NEXT_PUBLIC_CUSTOM_REFRESH_URL;
let refreshIntervalId: NodeJS.Timer;
let expireTimeoutId: NodeJS.Timeout;
const attemptTokenRefresh = async () => {
let retryCount = 0;
@ -46,9 +101,6 @@ export const HealthCheckBanner = () => {
while (retryCount < maxRetries) {
try {
// NOTE: This is a mocked refresh token for testing purposes.
// const refreshTokenData = mockedRefreshToken();
const refreshTokenData = await refreshToken(refreshUrl);
if (!refreshTokenData) {
throw new Error("Failed to refresh token");
@ -67,10 +119,25 @@ export const HealthCheckBanner = () => {
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
// Wait for backend to process the token
await new Promise((resolve) => setTimeout(resolve, 4000));
await mutateUser(undefined, { revalidate: true });
updateExpirationTime();
// Get updated user data
const updatedUser = await mutateUser();
if (updatedUser) {
// Reset expiration timeout with new expiration time
const newSecondsUntilExpiration =
getSecondsUntilExpiration(updatedUser);
if (newSecondsUntilExpiration !== null) {
setupExpirationTimeout(newSecondsUntilExpiration);
console.debug(
`Token refreshed, new expiration in ${newSecondsUntilExpiration} seconds`
);
}
}
break; // Success - exit the retry loop
} catch (error) {
console.error(
@ -93,64 +160,55 @@ export const HealthCheckBanner = () => {
}
};
const scheduleRefreshAndExpire = () => {
if (secondsUntilExpiration !== null) {
const refreshInterval = 60 * 15; // 15 mins
refreshIntervalId = setInterval(
attemptTokenRefresh,
refreshInterval * 1000
);
// Set up refresh interval
const refreshInterval = 60 * 15; // 15 mins
const timeUntilExpire = (secondsUntilExpiration + 10) * 1000;
expireTimeoutId = setTimeout(() => {
console.debug("Session expired. Setting expired state to true.");
setExpired(true);
}, timeUntilExpire);
// Clear any existing interval
if (refreshIntervalRef.current) {
clearInterval(refreshIntervalRef.current);
}
// if we're going to timeout before the next refresh, kick off a refresh now!
if (secondsUntilExpiration < refreshInterval) {
attemptTokenRefresh();
}
}
};
refreshIntervalRef.current = setInterval(
attemptTokenRefresh,
refreshInterval * 1000
);
scheduleRefreshAndExpire();
return () => {
clearInterval(refreshIntervalId);
clearTimeout(expireTimeoutId);
};
// If we're going to expire before the next refresh, kick off a refresh now
if (secondsUntilExpiration < refreshInterval) {
attemptTokenRefresh();
}
}
}, [secondsUntilExpiration, user, mutateUser, updateExpirationTime]);
}, [user, setupExpirationTimeout, mutateUser]);
if (!error && !expired) {
return null;
}
console.debug(
`Rendering HealthCheckBanner. Error: ${error}, Expired: ${expired}`
);
if (error instanceof RedirectError || expired) {
// Logged out modal
if (showLoggedOutModal) {
return (
<Modal
width="w-1/4"
width="w-1/3"
className="overflow-y-hidden flex flex-col"
title="You've been logged out"
title="You Have Been Logged Out"
>
<div className="flex flex-col gap-y-4">
<p className="text-sm">
Your session has expired. Please log in again to continue.
</p>
<a
href="/auth/login"
className="w-full mt-4 mx-auto rounded-md text-text-200 py-2 bg-background-900 text-center hover:bg-emphasis animtate duration-300 transition-bg"
>
Log in
</a>
<div className="flex flex-row gap-x-2 justify-end mt-4">
<Button onClick={handleLogin}>Log In</Button>
</div>
</div>
</Modal>
);
}
if (!error && !expired) {
return null;
}
if (error instanceof RedirectError || expired) {
if (!pathname.includes("/auth")) {
setShowLoggedOutModal(true);
}
return null;
} else {
return (
<div className="fixed top-0 left-0 z-[101] w-full text-xs mx-auto bg-gradient-to-r from-red-900 to-red-700 p-2 rounded-sm border-hidden text-text-200">