From 31ca6857fb3b254bbd9e98cf8a210d2cd5ba6f7c Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Fri, 13 Sep 2024 00:04:03 -0700 Subject: [PATCH] Custom Refresh on Client Side (#2376) --- backend/danswer/auth/users.py | 14 +++- backend/danswer/server/auth_check.py | 2 + .../danswer/server/enterprise_settings/api.py | 83 ++++++------------- backend/shared_configs/configs.py | 2 - .../docker_compose/docker-compose.dev.yml | 3 +- web/Dockerfile | 6 ++ web/src/components/health/healthcheck.tsx | 21 +++-- web/src/components/health/refreshUtils.ts | 59 +++++++++++++ web/src/lib/constants.ts | 1 + 9 files changed, 122 insertions(+), 69 deletions(-) create mode 100644 web/src/components/health/refreshUtils.ts diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 44a801e84..1776217d3 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -268,6 +268,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): ) user.is_verified = is_verified_by_default user.has_web_login = True + return user async def on_after_register( @@ -414,6 +415,7 @@ async def optional_user( async def double_check_user( user: User | None, optional: bool = DISABLE_AUTH, + include_expired: bool = False, ) -> User | None: if optional: return None @@ -430,7 +432,11 @@ async def double_check_user( detail="Access denied. User is not verified.", ) - if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc): + if ( + user.oidc_expiry + and user.oidc_expiry < datetime.now(timezone.utc) + and not include_expired + ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Access denied. User's OIDC token has expired.", @@ -439,6 +445,12 @@ async def double_check_user( return user +async def current_user_with_expired_token( + user: User | None = Depends(optional_user), +) -> User | None: + return await double_check_user(user, include_expired=True) + + async def current_user( user: User | None = Depends(optional_user), ) -> User | None: diff --git a/backend/danswer/server/auth_check.py b/backend/danswer/server/auth_check.py index 12258eba2..8a35a560a 100644 --- a/backend/danswer/server/auth_check.py +++ b/backend/danswer/server/auth_check.py @@ -7,6 +7,7 @@ from starlette.routing import BaseRoute from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user +from danswer.auth.users import current_user_with_expired_token from danswer.configs.app_configs import APP_API_PREFIX from danswer.server.danswer_api.ingestion import api_key_dep @@ -96,6 +97,7 @@ def check_router_auth( or depends_fn == current_admin_user or depends_fn == current_curator_or_admin_user or depends_fn == api_key_dep + or depends_fn == current_user_with_expired_token ): found_auth = True break diff --git a/backend/ee/danswer/server/enterprise_settings/api.py b/backend/ee/danswer/server/enterprise_settings/api.py index 8590fd6c5..385adcf68 100644 --- a/backend/ee/danswer/server/enterprise_settings/api.py +++ b/backend/ee/danswer/server/enterprise_settings/api.py @@ -1,6 +1,6 @@ from datetime import datetime -from datetime import timedelta from datetime import timezone +from typing import Any import httpx from fastapi import APIRouter @@ -9,10 +9,12 @@ from fastapi import HTTPException from fastapi import Response from fastapi import status from fastapi import UploadFile +from pydantic import BaseModel +from pydantic import Field from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user -from danswer.auth.users import current_user +from danswer.auth.users import current_user_with_expired_token from danswer.auth.users import get_user_manager from danswer.auth.users import UserManager from danswer.db.engine import get_session @@ -28,7 +30,6 @@ from ee.danswer.server.enterprise_settings.store import load_settings from ee.danswer.server.enterprise_settings.store import store_analytics_script from ee.danswer.server.enterprise_settings.store import store_settings from ee.danswer.server.enterprise_settings.store import upload_logo -from shared_configs.configs import CUSTOM_REFRESH_URL admin_router = APIRouter(prefix="/admin/enterprise-settings") basic_router = APIRouter(prefix="/enterprise-settings") @@ -36,69 +37,37 @@ basic_router = APIRouter(prefix="/enterprise-settings") logger = setup_logger() -def mocked_refresh_token() -> dict: - """ - This function mocks the response from a token refresh endpoint. - It generates a mock access token, refresh token, and user information - with an expiration time set to 1 hour from now. - This is useful for testing or development when the actual refresh endpoint is not available. - """ - mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000) - data = { - "access_token": "asdf Mock access token", - "refresh_token": "asdf Mock refresh token", - "session": {"exp": mock_exp}, - "userinfo": { - "sub": "Mock email", - "familyName": "Mock name", - "givenName": "Mock name", - "fullName": "Mock name", - "userId": "Mock User ID", - "email": "test_email@danswer.ai", - }, - } - return data +class RefreshTokenData(BaseModel): + access_token: str + refresh_token: str + session: dict = Field(..., description="Contains session information") + userinfo: dict = Field(..., description="Contains user information") + + def __init__(self, **data: Any) -> None: + super().__init__(**data) + if "exp" not in self.session: + raise ValueError("'exp' must be set in the session dictionary") + if "userId" not in self.userinfo or "email" not in self.userinfo: + raise ValueError( + "'userId' and 'email' must be set in the userinfo dictionary" + ) -@basic_router.get("/refresh-token") +@basic_router.post("/refresh-token") async def refresh_access_token( - user: User = Depends(current_user), + refresh_token: RefreshTokenData, + user: User = Depends(current_user_with_expired_token), user_manager: UserManager = Depends(get_user_manager), ) -> None: - # return - if CUSTOM_REFRESH_URL is None: - logger.error( - "Custom refresh URL is not set and client is attempting to custom refresh" - ) - raise HTTPException( - status_code=500, - detail="Custom refresh URL is not set", - ) - try: - async with httpx.AsyncClient() as client: - logger.debug(f"Sending request to custom refresh URL for user {user.id}") - access_token = user.oauth_accounts[0].access_token - - response = await client.get( - CUSTOM_REFRESH_URL, - params={"info": "json", "access_token_refresh_interval": 3600}, - headers={"Authorization": f"Bearer {access_token}"}, - ) - response.raise_for_status() - data = response.json() - - # NOTE: Here is where we can mock the response - # data = mocked_refresh_token() - logger.debug(f"Received response from Meechum auth URL for user {user.id}") # Extract new tokens - new_access_token = data["access_token"] - new_refresh_token = data["refresh_token"] + new_access_token = refresh_token.access_token + new_refresh_token = refresh_token.refresh_token new_expiry = datetime.fromtimestamp( - data["session"]["exp"] / 1000, tz=timezone.utc + refresh_token.session["exp"] / 1000, tz=timezone.utc ) expires_at_timestamp = int(new_expiry.timestamp()) @@ -107,8 +76,8 @@ async def refresh_access_token( await user_manager.oauth_callback( oauth_name="custom", access_token=new_access_token, - account_id=data["userinfo"]["userId"], - account_email=data["userinfo"]["email"], + account_id=refresh_token.userinfo["userId"], + account_email=refresh_token.userinfo["email"], expires_at=expires_at_timestamp, refresh_token=new_refresh_token, associate_by_email=True, diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index fe9332270..23d3c7f89 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -73,5 +73,3 @@ PRESERVED_SEARCH_FIELDS = [ "passage_prefix", "query_prefix", ] - -CUSTOM_REFRESH_URL = os.environ.get("CUSTOM_REFRESH_URL") or "/settings/refresh-token" diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 66ece7ed0..d61419607 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -180,7 +180,6 @@ services: - NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-} - DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-} - DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-} - - CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-} # Logging # Leave this on pretty please? Nothing sensitive is collected! # https://docs.danswer.dev/more/telemetry @@ -229,7 +228,7 @@ services: # Enterprise Edition only - ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false} - - CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-} + - NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL:-} inference_model_server: image: danswer/danswer-model-server:${IMAGE_TAG:-latest} diff --git a/web/Dockerfile b/web/Dockerfile index 710cf653f..6ea85752b 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -58,6 +58,9 @@ ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_T ARG NEXT_PUBLIC_DISABLE_LOGOUT ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT} +ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL +ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL} + RUN npx next build @@ -116,6 +119,9 @@ ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN} ARG NEXT_PUBLIC_DISABLE_LOGOUT ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT} +ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL +ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL} + # Note: Don't expose ports here, Compose will handle that for us if necessary. # If you want to run this without compose, specify the ports to # expose via cli diff --git a/web/src/components/health/healthcheck.tsx b/web/src/components/health/healthcheck.tsx index 2cba8be82..037082ca0 100644 --- a/web/src/components/health/healthcheck.tsx +++ b/web/src/components/health/healthcheck.tsx @@ -6,6 +6,8 @@ import { Modal } from "../Modal"; import { useEffect, useState } from "react"; import { getSecondsUntilExpiration } from "@/lib/time"; import { User } from "@/lib/types"; +import { mockedRefreshToken, refreshToken } from "./refreshUtils"; +import { CUSTOM_REFRESH_URL } from "@/lib/constants"; export const HealthCheckBanner = () => { const { error } = useSWR("/api/health", errorHandlingFetcher); @@ -33,27 +35,32 @@ export const HealthCheckBanner = () => { }, [user]); useEffect(() => { - if (true) { + if (CUSTOM_REFRESH_URL) { + const refreshUrl = CUSTOM_REFRESH_URL; let refreshTimeoutId: NodeJS.Timeout; let expireTimeoutId: NodeJS.Timeout; - const refreshToken = async () => { + const attemptTokenRefresh = async () => { try { + // NOTE: This is a mocked refresh token for testing purposes. + // const refreshTokenData = mockedRefreshToken(); + + const refreshTokenData = await refreshToken(refreshUrl); + const response = await fetch( "/api/enterprise-settings/refresh-token", { - method: "GET", + method: "POST", headers: { "Content-Type": "application/json", }, + body: JSON.stringify(refreshTokenData), } ); if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); } - - console.debug("Token refresh successful"); - // Force revalidation of user data + await new Promise((resolve) => setTimeout(resolve, 4000)); await mutateUser(undefined, { revalidate: true }); updateExpirationTime(); @@ -65,7 +72,7 @@ export const HealthCheckBanner = () => { const scheduleRefreshAndExpire = () => { if (secondsUntilExpiration !== null) { const timeUntilRefresh = (secondsUntilExpiration + 0.5) * 1000; - refreshTimeoutId = setTimeout(refreshToken, timeUntilRefresh); + refreshTimeoutId = setTimeout(attemptTokenRefresh, timeUntilRefresh); const timeUntilExpire = (secondsUntilExpiration + 10) * 1000; expireTimeoutId = setTimeout(() => { diff --git a/web/src/components/health/refreshUtils.ts b/web/src/components/health/refreshUtils.ts new file mode 100644 index 000000000..f478b5e4a --- /dev/null +++ b/web/src/components/health/refreshUtils.ts @@ -0,0 +1,59 @@ +import { User } from "@/lib/types"; + +export interface CustomRefreshTokenResponse { + access_token: string; + refresh_token: string; + session: { + exp: number; + }; + userinfo: { + sub: string; + familyName: string; + givenName: string; + fullName: string; + userId: string; + email: string; + }; +} + +export function mockedRefreshToken(): CustomRefreshTokenResponse { + /** + * This function mocks the response from a token refresh endpoint. + * It generates a mock access token, refresh token, and user information + * with an expiration time set to 1 hour from now. + * This is useful for testing or development when the actual refresh endpoint is not available. + */ + const mockExp = Date.now() + 3600000; // 1 hour from now in milliseconds + const data: CustomRefreshTokenResponse = { + access_token: "asdf Mock access token", + refresh_token: "asdf Mock refresh token", + session: { exp: mockExp }, + userinfo: { + sub: "Mock email", + familyName: "Mock name", + givenName: "Mock name", + fullName: "Mock name", + userId: "Mock User ID", + email: "email@danswer.ai", + }, + }; + return data; +} + +export async function refreshToken( + customRefreshUrl: string +): Promise { + try { + console.debug("Sending request to custom refresh URL"); + const url = new URL(customRefreshUrl); + url.searchParams.append("info", "json"); + url.searchParams.append("access_token_refresh_interval", "3600"); + + const response = await fetch(url.toString()); + + return await response.json(); + } catch (error) { + console.error("Error refreshing token:", error); + throw error; + } +} diff --git a/web/src/lib/constants.ts b/web/src/lib/constants.ts index a694f157e..974695a83 100644 --- a/web/src/lib/constants.ts +++ b/web/src/lib/constants.ts @@ -36,6 +36,7 @@ export const NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN = export const TOGGLED_CONNECTORS_COOKIE_NAME = "toggled_connectors"; /* Enterprise-only settings */ +export const CUSTOM_REFRESH_URL = process.env.NEXT_PUBLIC_CUSTOM_REFRESH_URL; // NOTE: this should ONLY be used on the server-side. If used client side, // it will not be accurate (will always be false).