mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-30 09:40:50 +02:00
Custom Refresh on Client Side (#2376)
This commit is contained in:
@ -268,6 +268,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||||||
)
|
)
|
||||||
user.is_verified = is_verified_by_default
|
user.is_verified = is_verified_by_default
|
||||||
user.has_web_login = True
|
user.has_web_login = True
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def on_after_register(
|
async def on_after_register(
|
||||||
@ -414,6 +415,7 @@ async def optional_user(
|
|||||||
async def double_check_user(
|
async def double_check_user(
|
||||||
user: User | None,
|
user: User | None,
|
||||||
optional: bool = DISABLE_AUTH,
|
optional: bool = DISABLE_AUTH,
|
||||||
|
include_expired: bool = False,
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
if optional:
|
if optional:
|
||||||
return None
|
return None
|
||||||
@ -430,7 +432,11 @@ async def double_check_user(
|
|||||||
detail="Access denied. User is not verified.",
|
detail="Access denied. User is not verified.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
if (
|
||||||
|
user.oidc_expiry
|
||||||
|
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||||
|
and not include_expired
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Access denied. User's OIDC token has expired.",
|
detail="Access denied. User's OIDC token has expired.",
|
||||||
@ -439,6 +445,12 @@ async def double_check_user(
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def current_user_with_expired_token(
|
||||||
|
user: User | None = Depends(optional_user),
|
||||||
|
) -> User | None:
|
||||||
|
return await double_check_user(user, include_expired=True)
|
||||||
|
|
||||||
|
|
||||||
async def current_user(
|
async def current_user(
|
||||||
user: User | None = Depends(optional_user),
|
user: User | None = Depends(optional_user),
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
|
@ -7,6 +7,7 @@ from starlette.routing import BaseRoute
|
|||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.auth.users import current_curator_or_admin_user
|
from danswer.auth.users import current_curator_or_admin_user
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
|
from danswer.auth.users import current_user_with_expired_token
|
||||||
from danswer.configs.app_configs import APP_API_PREFIX
|
from danswer.configs.app_configs import APP_API_PREFIX
|
||||||
from danswer.server.danswer_api.ingestion import api_key_dep
|
from danswer.server.danswer_api.ingestion import api_key_dep
|
||||||
|
|
||||||
@ -96,6 +97,7 @@ def check_router_auth(
|
|||||||
or depends_fn == current_admin_user
|
or depends_fn == current_admin_user
|
||||||
or depends_fn == current_curator_or_admin_user
|
or depends_fn == current_curator_or_admin_user
|
||||||
or depends_fn == api_key_dep
|
or depends_fn == api_key_dep
|
||||||
|
or depends_fn == current_user_with_expired_token
|
||||||
):
|
):
|
||||||
found_auth = True
|
found_auth = True
|
||||||
break
|
break
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
@ -9,10 +9,12 @@ from fastapi import HTTPException
|
|||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
from fastapi import status
|
from fastapi import status
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pydantic import Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user_with_expired_token
|
||||||
from danswer.auth.users import get_user_manager
|
from danswer.auth.users import get_user_manager
|
||||||
from danswer.auth.users import UserManager
|
from danswer.auth.users import UserManager
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
@ -28,7 +30,6 @@ from ee.danswer.server.enterprise_settings.store import load_settings
|
|||||||
from ee.danswer.server.enterprise_settings.store import store_analytics_script
|
from ee.danswer.server.enterprise_settings.store import store_analytics_script
|
||||||
from ee.danswer.server.enterprise_settings.store import store_settings
|
from ee.danswer.server.enterprise_settings.store import store_settings
|
||||||
from ee.danswer.server.enterprise_settings.store import upload_logo
|
from ee.danswer.server.enterprise_settings.store import upload_logo
|
||||||
from shared_configs.configs import CUSTOM_REFRESH_URL
|
|
||||||
|
|
||||||
admin_router = APIRouter(prefix="/admin/enterprise-settings")
|
admin_router = APIRouter(prefix="/admin/enterprise-settings")
|
||||||
basic_router = APIRouter(prefix="/enterprise-settings")
|
basic_router = APIRouter(prefix="/enterprise-settings")
|
||||||
@ -36,69 +37,37 @@ basic_router = APIRouter(prefix="/enterprise-settings")
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def mocked_refresh_token() -> dict:
|
class RefreshTokenData(BaseModel):
|
||||||
"""
|
access_token: str
|
||||||
This function mocks the response from a token refresh endpoint.
|
refresh_token: str
|
||||||
It generates a mock access token, refresh token, and user information
|
session: dict = Field(..., description="Contains session information")
|
||||||
with an expiration time set to 1 hour from now.
|
userinfo: dict = Field(..., description="Contains user information")
|
||||||
This is useful for testing or development when the actual refresh endpoint is not available.
|
|
||||||
"""
|
def __init__(self, **data: Any) -> None:
|
||||||
mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000)
|
super().__init__(**data)
|
||||||
data = {
|
if "exp" not in self.session:
|
||||||
"access_token": "asdf Mock access token",
|
raise ValueError("'exp' must be set in the session dictionary")
|
||||||
"refresh_token": "asdf Mock refresh token",
|
if "userId" not in self.userinfo or "email" not in self.userinfo:
|
||||||
"session": {"exp": mock_exp},
|
raise ValueError(
|
||||||
"userinfo": {
|
"'userId' and 'email' must be set in the userinfo dictionary"
|
||||||
"sub": "Mock email",
|
)
|
||||||
"familyName": "Mock name",
|
|
||||||
"givenName": "Mock name",
|
|
||||||
"fullName": "Mock name",
|
|
||||||
"userId": "Mock User ID",
|
|
||||||
"email": "test_email@danswer.ai",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@basic_router.get("/refresh-token")
|
@basic_router.post("/refresh-token")
|
||||||
async def refresh_access_token(
|
async def refresh_access_token(
|
||||||
user: User = Depends(current_user),
|
refresh_token: RefreshTokenData,
|
||||||
|
user: User = Depends(current_user_with_expired_token),
|
||||||
user_manager: UserManager = Depends(get_user_manager),
|
user_manager: UserManager = Depends(get_user_manager),
|
||||||
) -> None:
|
) -> None:
|
||||||
# return
|
|
||||||
if CUSTOM_REFRESH_URL is None:
|
|
||||||
logger.error(
|
|
||||||
"Custom refresh URL is not set and client is attempting to custom refresh"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Custom refresh URL is not set",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
logger.debug(f"Sending request to custom refresh URL for user {user.id}")
|
|
||||||
access_token = user.oauth_accounts[0].access_token
|
|
||||||
|
|
||||||
response = await client.get(
|
|
||||||
CUSTOM_REFRESH_URL,
|
|
||||||
params={"info": "json", "access_token_refresh_interval": 3600},
|
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
# NOTE: Here is where we can mock the response
|
|
||||||
# data = mocked_refresh_token()
|
|
||||||
|
|
||||||
logger.debug(f"Received response from Meechum auth URL for user {user.id}")
|
logger.debug(f"Received response from Meechum auth URL for user {user.id}")
|
||||||
|
|
||||||
# Extract new tokens
|
# Extract new tokens
|
||||||
new_access_token = data["access_token"]
|
new_access_token = refresh_token.access_token
|
||||||
new_refresh_token = data["refresh_token"]
|
new_refresh_token = refresh_token.refresh_token
|
||||||
|
|
||||||
new_expiry = datetime.fromtimestamp(
|
new_expiry = datetime.fromtimestamp(
|
||||||
data["session"]["exp"] / 1000, tz=timezone.utc
|
refresh_token.session["exp"] / 1000, tz=timezone.utc
|
||||||
)
|
)
|
||||||
expires_at_timestamp = int(new_expiry.timestamp())
|
expires_at_timestamp = int(new_expiry.timestamp())
|
||||||
|
|
||||||
@ -107,8 +76,8 @@ async def refresh_access_token(
|
|||||||
await user_manager.oauth_callback(
|
await user_manager.oauth_callback(
|
||||||
oauth_name="custom",
|
oauth_name="custom",
|
||||||
access_token=new_access_token,
|
access_token=new_access_token,
|
||||||
account_id=data["userinfo"]["userId"],
|
account_id=refresh_token.userinfo["userId"],
|
||||||
account_email=data["userinfo"]["email"],
|
account_email=refresh_token.userinfo["email"],
|
||||||
expires_at=expires_at_timestamp,
|
expires_at=expires_at_timestamp,
|
||||||
refresh_token=new_refresh_token,
|
refresh_token=new_refresh_token,
|
||||||
associate_by_email=True,
|
associate_by_email=True,
|
||||||
|
@ -73,5 +73,3 @@ PRESERVED_SEARCH_FIELDS = [
|
|||||||
"passage_prefix",
|
"passage_prefix",
|
||||||
"query_prefix",
|
"query_prefix",
|
||||||
]
|
]
|
||||||
|
|
||||||
CUSTOM_REFRESH_URL = os.environ.get("CUSTOM_REFRESH_URL") or "/settings/refresh-token"
|
|
||||||
|
@ -180,7 +180,6 @@ services:
|
|||||||
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
|
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
|
||||||
- DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-}
|
- DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-}
|
||||||
- DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-}
|
- DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-}
|
||||||
- CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-}
|
|
||||||
# Logging
|
# Logging
|
||||||
# Leave this on pretty please? Nothing sensitive is collected!
|
# Leave this on pretty please? Nothing sensitive is collected!
|
||||||
# https://docs.danswer.dev/more/telemetry
|
# https://docs.danswer.dev/more/telemetry
|
||||||
@ -229,7 +228,7 @@ services:
|
|||||||
|
|
||||||
# Enterprise Edition only
|
# Enterprise Edition only
|
||||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||||
- CUSTOM_REFRESH_URL=${CUSTOM_REFRESH_URL:-}
|
- NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL:-}
|
||||||
|
|
||||||
inference_model_server:
|
inference_model_server:
|
||||||
image: danswer/danswer-model-server:${IMAGE_TAG:-latest}
|
image: danswer/danswer-model-server:${IMAGE_TAG:-latest}
|
||||||
|
@ -58,6 +58,9 @@ ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_T
|
|||||||
ARG NEXT_PUBLIC_DISABLE_LOGOUT
|
ARG NEXT_PUBLIC_DISABLE_LOGOUT
|
||||||
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
|
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
|
||||||
|
|
||||||
|
ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL
|
||||||
|
ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL}
|
||||||
|
|
||||||
|
|
||||||
RUN npx next build
|
RUN npx next build
|
||||||
|
|
||||||
@ -116,6 +119,9 @@ ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN}
|
|||||||
ARG NEXT_PUBLIC_DISABLE_LOGOUT
|
ARG NEXT_PUBLIC_DISABLE_LOGOUT
|
||||||
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
|
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
|
||||||
|
|
||||||
|
ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL
|
||||||
|
ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL}
|
||||||
|
|
||||||
# Note: Don't expose ports here, Compose will handle that for us if necessary.
|
# Note: Don't expose ports here, Compose will handle that for us if necessary.
|
||||||
# If you want to run this without compose, specify the ports to
|
# If you want to run this without compose, specify the ports to
|
||||||
# expose via cli
|
# expose via cli
|
||||||
|
@ -6,6 +6,8 @@ import { Modal } from "../Modal";
|
|||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { getSecondsUntilExpiration } from "@/lib/time";
|
import { getSecondsUntilExpiration } from "@/lib/time";
|
||||||
import { User } from "@/lib/types";
|
import { User } from "@/lib/types";
|
||||||
|
import { mockedRefreshToken, refreshToken } from "./refreshUtils";
|
||||||
|
import { CUSTOM_REFRESH_URL } from "@/lib/constants";
|
||||||
|
|
||||||
export const HealthCheckBanner = () => {
|
export const HealthCheckBanner = () => {
|
||||||
const { error } = useSWR("/api/health", errorHandlingFetcher);
|
const { error } = useSWR("/api/health", errorHandlingFetcher);
|
||||||
@ -33,27 +35,32 @@ export const HealthCheckBanner = () => {
|
|||||||
}, [user]);
|
}, [user]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (true) {
|
if (CUSTOM_REFRESH_URL) {
|
||||||
|
const refreshUrl = CUSTOM_REFRESH_URL;
|
||||||
let refreshTimeoutId: NodeJS.Timeout;
|
let refreshTimeoutId: NodeJS.Timeout;
|
||||||
let expireTimeoutId: NodeJS.Timeout;
|
let expireTimeoutId: NodeJS.Timeout;
|
||||||
|
|
||||||
const refreshToken = async () => {
|
const attemptTokenRefresh = async () => {
|
||||||
try {
|
try {
|
||||||
|
// NOTE: This is a mocked refresh token for testing purposes.
|
||||||
|
// const refreshTokenData = mockedRefreshToken();
|
||||||
|
|
||||||
|
const refreshTokenData = await refreshToken(refreshUrl);
|
||||||
|
|
||||||
const response = await fetch(
|
const response = await fetch(
|
||||||
"/api/enterprise-settings/refresh-token",
|
"/api/enterprise-settings/refresh-token",
|
||||||
{
|
{
|
||||||
method: "GET",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
|
body: JSON.stringify(refreshTokenData),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(`HTTP error! status: ${response.status}`);
|
throw new Error(`HTTP error! status: ${response.status}`);
|
||||||
}
|
}
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 4000));
|
||||||
console.debug("Token refresh successful");
|
|
||||||
// Force revalidation of user data
|
|
||||||
|
|
||||||
await mutateUser(undefined, { revalidate: true });
|
await mutateUser(undefined, { revalidate: true });
|
||||||
updateExpirationTime();
|
updateExpirationTime();
|
||||||
@ -65,7 +72,7 @@ export const HealthCheckBanner = () => {
|
|||||||
const scheduleRefreshAndExpire = () => {
|
const scheduleRefreshAndExpire = () => {
|
||||||
if (secondsUntilExpiration !== null) {
|
if (secondsUntilExpiration !== null) {
|
||||||
const timeUntilRefresh = (secondsUntilExpiration + 0.5) * 1000;
|
const timeUntilRefresh = (secondsUntilExpiration + 0.5) * 1000;
|
||||||
refreshTimeoutId = setTimeout(refreshToken, timeUntilRefresh);
|
refreshTimeoutId = setTimeout(attemptTokenRefresh, timeUntilRefresh);
|
||||||
|
|
||||||
const timeUntilExpire = (secondsUntilExpiration + 10) * 1000;
|
const timeUntilExpire = (secondsUntilExpiration + 10) * 1000;
|
||||||
expireTimeoutId = setTimeout(() => {
|
expireTimeoutId = setTimeout(() => {
|
||||||
|
59
web/src/components/health/refreshUtils.ts
Normal file
59
web/src/components/health/refreshUtils.ts
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import { User } from "@/lib/types";
|
||||||
|
|
||||||
|
export interface CustomRefreshTokenResponse {
|
||||||
|
access_token: string;
|
||||||
|
refresh_token: string;
|
||||||
|
session: {
|
||||||
|
exp: number;
|
||||||
|
};
|
||||||
|
userinfo: {
|
||||||
|
sub: string;
|
||||||
|
familyName: string;
|
||||||
|
givenName: string;
|
||||||
|
fullName: string;
|
||||||
|
userId: string;
|
||||||
|
email: string;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function mockedRefreshToken(): CustomRefreshTokenResponse {
|
||||||
|
/**
|
||||||
|
* This function mocks the response from a token refresh endpoint.
|
||||||
|
* It generates a mock access token, refresh token, and user information
|
||||||
|
* with an expiration time set to 1 hour from now.
|
||||||
|
* This is useful for testing or development when the actual refresh endpoint is not available.
|
||||||
|
*/
|
||||||
|
const mockExp = Date.now() + 3600000; // 1 hour from now in milliseconds
|
||||||
|
const data: CustomRefreshTokenResponse = {
|
||||||
|
access_token: "asdf Mock access token",
|
||||||
|
refresh_token: "asdf Mock refresh token",
|
||||||
|
session: { exp: mockExp },
|
||||||
|
userinfo: {
|
||||||
|
sub: "Mock email",
|
||||||
|
familyName: "Mock name",
|
||||||
|
givenName: "Mock name",
|
||||||
|
fullName: "Mock name",
|
||||||
|
userId: "Mock User ID",
|
||||||
|
email: "email@danswer.ai",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function refreshToken(
|
||||||
|
customRefreshUrl: string
|
||||||
|
): Promise<CustomRefreshTokenResponse> {
|
||||||
|
try {
|
||||||
|
console.debug("Sending request to custom refresh URL");
|
||||||
|
const url = new URL(customRefreshUrl);
|
||||||
|
url.searchParams.append("info", "json");
|
||||||
|
url.searchParams.append("access_token_refresh_interval", "3600");
|
||||||
|
|
||||||
|
const response = await fetch(url.toString());
|
||||||
|
|
||||||
|
return await response.json();
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error refreshing token:", error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
@ -36,6 +36,7 @@ export const NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN =
|
|||||||
export const TOGGLED_CONNECTORS_COOKIE_NAME = "toggled_connectors";
|
export const TOGGLED_CONNECTORS_COOKIE_NAME = "toggled_connectors";
|
||||||
|
|
||||||
/* Enterprise-only settings */
|
/* Enterprise-only settings */
|
||||||
|
export const CUSTOM_REFRESH_URL = process.env.NEXT_PUBLIC_CUSTOM_REFRESH_URL;
|
||||||
|
|
||||||
// NOTE: this should ONLY be used on the server-side. If used client side,
|
// NOTE: this should ONLY be used on the server-side. If used client side,
|
||||||
// it will not be accurate (will always be false).
|
// it will not be accurate (will always be false).
|
||||||
|
Reference in New Issue
Block a user