Custom Refresh on Client Side (#2376)

This commit is contained in:
pablodanswer 2024-09-13 00:04:03 -07:00 committed by GitHub
parent 6dd91414be
commit 31ca6857fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 122 additions and 69 deletions

View File

@ -268,6 +268,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
user.is_verified = is_verified_by_default
user.has_web_login = True
return user
async def on_after_register(
@ -414,6 +415,7 @@ async def optional_user(
async def double_check_user(
user: User | None,
optional: bool = DISABLE_AUTH,
include_expired: bool = False,
) -> User | None:
if optional:
return None
@ -430,7 +432,11 @@ async def double_check_user(
detail="Access denied. User is not verified.",
)
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
if (
user.oidc_expiry
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
@ -439,6 +445,12 @@ async def double_check_user(
return user
async def current_user_with_expired_token(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user, include_expired=True)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:

View File

@ -7,6 +7,7 @@ from starlette.routing import BaseRoute
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import current_user_with_expired_token
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.server.danswer_api.ingestion import api_key_dep
@ -96,6 +97,7 @@ def check_router_auth(
or depends_fn == current_admin_user
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
):
found_auth = True
break

View File

@ -1,6 +1,6 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
import httpx
from fastapi import APIRouter
@ -9,10 +9,12 @@ from fastapi import HTTPException
from fastapi import Response
from fastapi import status
from fastapi import UploadFile
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import current_user_with_expired_token
from danswer.auth.users import get_user_manager
from danswer.auth.users import UserManager
from danswer.db.engine import get_session
@ -28,7 +30,6 @@ from ee.danswer.server.enterprise_settings.store import load_settings
from ee.danswer.server.enterprise_settings.store import store_analytics_script
from ee.danswer.server.enterprise_settings.store import store_settings
from ee.danswer.server.enterprise_settings.store import upload_logo
from shared_configs.configs import CUSTOM_REFRESH_URL
admin_router = APIRouter(prefix="/admin/enterprise-settings")
basic_router = APIRouter(prefix="/enterprise-settings")
@ -36,69 +37,37 @@ basic_router = APIRouter(prefix="/enterprise-settings")
logger = setup_logger()
def mocked_refresh_token() -> dict:
"""
This function mocks the response from a token refresh endpoint.
It generates a mock access token, refresh token, and user information
with an expiration time set to 1 hour from now.
This is useful for testing or development when the actual refresh endpoint is not available.
"""
mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000)
data = {
"access_token": "asdf Mock access token",
"refresh_token": "asdf Mock refresh token",
"session": {"exp": mock_exp},
"userinfo": {
"sub": "Mock email",
"familyName": "Mock name",
"givenName": "Mock name",
"fullName": "Mock name",
"userId": "Mock User ID",
"email": "test_email@danswer.ai",
},
}
return data
class RefreshTokenData(BaseModel):
access_token: str
refresh_token: str
session: dict = Field(..., description="Contains session information")
userinfo: dict = Field(..., description="Contains user information")
def __init__(self, **data: Any) -> None:
super().__init__(**data)
if "exp" not in self.session:
raise ValueError("'exp' must be set in the session dictionary")
if "userId" not in self.userinfo or "email" not in self.userinfo:
raise ValueError(
"'userId' and 'email' must be set in the userinfo dictionary"
)
@basic_router.get("/refresh-token")
@basic_router.post("/refresh-token")
async def refresh_access_token(
user: User = Depends(current_user),
refresh_token: RefreshTokenData,
user: User = Depends(current_user_with_expired_token),
user_manager: UserManager = Depends(get_user_manager),
) -> None:
# return
if CUSTOM_REFRESH_URL is None:
logger.error(
"Custom refresh URL is not set and client is attempting to custom refresh"
)
raise HTTPException(
status_code=500,
detail="Custom refresh URL is not set",
)
try:
async with httpx.AsyncClient() as client:
logger.debug(f"Sending request to custom refresh URL for user {user.id}")
access_token = user.oauth_accounts[0].access_token
response = await client.get(
CUSTOM_REFRESH_URL,
params={"info": "json", "access_token_refresh_interval": 3600},
headers={"Authorization": f"Bearer {access_token}"},
)
response.raise_for_status()
data = response.json()
# NOTE: Here is where we can mock the response
# data = mocked_refresh_token()
logger.debug(f"Received response from Meechum auth URL for user {user.id}")
# Extract new tokens
new_access_token = data["access_token"]
new_refresh_token = data["refresh_token"]
new_access_token = refresh_token.access_token
new_refresh_token = refresh_token.refresh_token
new_expiry = datetime.fromtimestamp(
data["session"]["exp"] / 1000, tz=timezone.utc
refresh_token.session["exp"] / 1000, tz=timezone.utc
)
expires_at_timestamp = int(new_expiry.timestamp())
@ -107,8 +76,8 @@ async def refresh_access_token(
await user_manager.oauth_callback(
oauth_name="custom",
access_token=new_access_token,
account_id=data["userinfo"]["userId"],
account_email=data["userinfo"]["email"],
account_id=refresh_token.userinfo["userId"],
account_email=refresh_token.userinfo["email"],
expires_at=expires_at_timestamp,
refresh_token=new_refresh_token,
associate_by_email=True,

View File

@ -73,5 +73,3 @@ PRESERVED_SEARCH_FIELDS = [
"passage_prefix",
"query_prefix",
]
CUSTOM_REFRESH_URL = os.environ.get("CUSTOM_REFRESH_URL") or "/settings/refresh-token"

View File

@ -180,7 +180,6 @@ services:
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
- DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-}
- DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-}
- CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-}
# Logging
# Leave this on pretty please? Nothing sensitive is collected!
# https://docs.danswer.dev/more/telemetry
@ -229,7 +228,7 @@ services:
# Enterprise Edition only
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-}
- NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL:-}
inference_model_server:
image: danswer/danswer-model-server:${IMAGE_TAG:-latest}

View File

@ -58,6 +58,9 @@ ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_T
ARG NEXT_PUBLIC_DISABLE_LOGOUT
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL
ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL}
RUN npx next build
@ -116,6 +119,9 @@ ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN}
ARG NEXT_PUBLIC_DISABLE_LOGOUT
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL
ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL}
# Note: Don't expose ports here, Compose will handle that for us if necessary.
# If you want to run this without compose, specify the ports to
# expose via cli

View File

@ -6,6 +6,8 @@ import { Modal } from "../Modal";
import { useEffect, useState } from "react";
import { getSecondsUntilExpiration } from "@/lib/time";
import { User } from "@/lib/types";
import { mockedRefreshToken, refreshToken } from "./refreshUtils";
import { CUSTOM_REFRESH_URL } from "@/lib/constants";
export const HealthCheckBanner = () => {
const { error } = useSWR("/api/health", errorHandlingFetcher);
@ -33,27 +35,32 @@ export const HealthCheckBanner = () => {
}, [user]);
useEffect(() => {
if (true) {
if (CUSTOM_REFRESH_URL) {
const refreshUrl = CUSTOM_REFRESH_URL;
let refreshTimeoutId: NodeJS.Timeout;
let expireTimeoutId: NodeJS.Timeout;
const refreshToken = async () => {
const attemptTokenRefresh = async () => {
try {
// NOTE: This is a mocked refresh token for testing purposes.
// const refreshTokenData = mockedRefreshToken();
const refreshTokenData = await refreshToken(refreshUrl);
const response = await fetch(
"/api/enterprise-settings/refresh-token",
{
method: "GET",
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(refreshTokenData),
}
);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
console.debug("Token refresh successful");
// Force revalidation of user data
await new Promise((resolve) => setTimeout(resolve, 4000));
await mutateUser(undefined, { revalidate: true });
updateExpirationTime();
@ -65,7 +72,7 @@ export const HealthCheckBanner = () => {
const scheduleRefreshAndExpire = () => {
if (secondsUntilExpiration !== null) {
const timeUntilRefresh = (secondsUntilExpiration + 0.5) * 1000;
refreshTimeoutId = setTimeout(refreshToken, timeUntilRefresh);
refreshTimeoutId = setTimeout(attemptTokenRefresh, timeUntilRefresh);
const timeUntilExpire = (secondsUntilExpiration + 10) * 1000;
expireTimeoutId = setTimeout(() => {

View File

@ -0,0 +1,59 @@
import { User } from "@/lib/types";
export interface CustomRefreshTokenResponse {
access_token: string;
refresh_token: string;
session: {
exp: number;
};
userinfo: {
sub: string;
familyName: string;
givenName: string;
fullName: string;
userId: string;
email: string;
};
}
export function mockedRefreshToken(): CustomRefreshTokenResponse {
/**
* This function mocks the response from a token refresh endpoint.
* It generates a mock access token, refresh token, and user information
* with an expiration time set to 1 hour from now.
* This is useful for testing or development when the actual refresh endpoint is not available.
*/
const mockExp = Date.now() + 3600000; // 1 hour from now in milliseconds
const data: CustomRefreshTokenResponse = {
access_token: "asdf Mock access token",
refresh_token: "asdf Mock refresh token",
session: { exp: mockExp },
userinfo: {
sub: "Mock email",
familyName: "Mock name",
givenName: "Mock name",
fullName: "Mock name",
userId: "Mock User ID",
email: "email@danswer.ai",
},
};
return data;
}
export async function refreshToken(
customRefreshUrl: string
): Promise<CustomRefreshTokenResponse> {
try {
console.debug("Sending request to custom refresh URL");
const url = new URL(customRefreshUrl);
url.searchParams.append("info", "json");
url.searchParams.append("access_token_refresh_interval", "3600");
const response = await fetch(url.toString());
return await response.json();
} catch (error) {
console.error("Error refreshing token:", error);
throw error;
}
}

View File

@ -36,6 +36,7 @@ export const NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN =
export const TOGGLED_CONNECTORS_COOKIE_NAME = "toggled_connectors";
/* Enterprise-only settings */
export const CUSTOM_REFRESH_URL = process.env.NEXT_PUBLIC_CUSTOM_REFRESH_URL;
// NOTE: this should ONLY be used on the server-side. If used client side,
// it will not be accurate (will always be false).