mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-29 13:25:50 +02:00
Custom Refresh on Client Side (#2376)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
@@ -9,10 +9,12 @@ from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import current_user_with_expired_token
|
||||
from danswer.auth.users import get_user_manager
|
||||
from danswer.auth.users import UserManager
|
||||
from danswer.db.engine import get_session
|
||||
@@ -28,7 +30,6 @@ from ee.danswer.server.enterprise_settings.store import load_settings
|
||||
from ee.danswer.server.enterprise_settings.store import store_analytics_script
|
||||
from ee.danswer.server.enterprise_settings.store import store_settings
|
||||
from ee.danswer.server.enterprise_settings.store import upload_logo
|
||||
from shared_configs.configs import CUSTOM_REFRESH_URL
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/enterprise-settings")
|
||||
basic_router = APIRouter(prefix="/enterprise-settings")
|
||||
@@ -36,69 +37,37 @@ basic_router = APIRouter(prefix="/enterprise-settings")
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def mocked_refresh_token() -> dict:
|
||||
"""
|
||||
This function mocks the response from a token refresh endpoint.
|
||||
It generates a mock access token, refresh token, and user information
|
||||
with an expiration time set to 1 hour from now.
|
||||
This is useful for testing or development when the actual refresh endpoint is not available.
|
||||
"""
|
||||
mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000)
|
||||
data = {
|
||||
"access_token": "asdf Mock access token",
|
||||
"refresh_token": "asdf Mock refresh token",
|
||||
"session": {"exp": mock_exp},
|
||||
"userinfo": {
|
||||
"sub": "Mock email",
|
||||
"familyName": "Mock name",
|
||||
"givenName": "Mock name",
|
||||
"fullName": "Mock name",
|
||||
"userId": "Mock User ID",
|
||||
"email": "test_email@danswer.ai",
|
||||
},
|
||||
}
|
||||
return data
|
||||
class RefreshTokenData(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
session: dict = Field(..., description="Contains session information")
|
||||
userinfo: dict = Field(..., description="Contains user information")
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
if "exp" not in self.session:
|
||||
raise ValueError("'exp' must be set in the session dictionary")
|
||||
if "userId" not in self.userinfo or "email" not in self.userinfo:
|
||||
raise ValueError(
|
||||
"'userId' and 'email' must be set in the userinfo dictionary"
|
||||
)
|
||||
|
||||
|
||||
@basic_router.get("/refresh-token")
|
||||
@basic_router.post("/refresh-token")
|
||||
async def refresh_access_token(
|
||||
user: User = Depends(current_user),
|
||||
refresh_token: RefreshTokenData,
|
||||
user: User = Depends(current_user_with_expired_token),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
) -> None:
|
||||
# return
|
||||
if CUSTOM_REFRESH_URL is None:
|
||||
logger.error(
|
||||
"Custom refresh URL is not set and client is attempting to custom refresh"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Custom refresh URL is not set",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
logger.debug(f"Sending request to custom refresh URL for user {user.id}")
|
||||
access_token = user.oauth_accounts[0].access_token
|
||||
|
||||
response = await client.get(
|
||||
CUSTOM_REFRESH_URL,
|
||||
params={"info": "json", "access_token_refresh_interval": 3600},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# NOTE: Here is where we can mock the response
|
||||
# data = mocked_refresh_token()
|
||||
|
||||
logger.debug(f"Received response from Meechum auth URL for user {user.id}")
|
||||
|
||||
# Extract new tokens
|
||||
new_access_token = data["access_token"]
|
||||
new_refresh_token = data["refresh_token"]
|
||||
new_access_token = refresh_token.access_token
|
||||
new_refresh_token = refresh_token.refresh_token
|
||||
|
||||
new_expiry = datetime.fromtimestamp(
|
||||
data["session"]["exp"] / 1000, tz=timezone.utc
|
||||
refresh_token.session["exp"] / 1000, tz=timezone.utc
|
||||
)
|
||||
expires_at_timestamp = int(new_expiry.timestamp())
|
||||
|
||||
@@ -107,8 +76,8 @@ async def refresh_access_token(
|
||||
await user_manager.oauth_callback(
|
||||
oauth_name="custom",
|
||||
access_token=new_access_token,
|
||||
account_id=data["userinfo"]["userId"],
|
||||
account_email=data["userinfo"]["email"],
|
||||
account_id=refresh_token.userinfo["userId"],
|
||||
account_email=refresh_token.userinfo["email"],
|
||||
expires_at=expires_at_timestamp,
|
||||
refresh_token=new_refresh_token,
|
||||
associate_by_email=True,
|
||||
|
Reference in New Issue
Block a user