Updated refreshing (#2327)

* clean up + add environment variables

* remove log

* update

* update api settings

* somewhat cleaner refresh functionality

* fully functional

* update settings

* validated

* remove random logs

* remove unneeded paramter + log

* move to ee + remove comments

* Cleanup unused

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
This commit is contained in:
pablodanswer
2024-09-05 21:36:55 -07:00
committed by GitHub
parent 2bd3833c55
commit 69c0419146
11 changed files with 208 additions and 23 deletions

View File

@@ -1,14 +1,24 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from fastapi import status
from fastapi import UploadFile
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import get_user_manager
from danswer.auth.users import UserManager
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
from ee.danswer.server.enterprise_settings.store import _LOGO_FILENAME
@@ -18,10 +28,117 @@ from ee.danswer.server.enterprise_settings.store import load_settings
from ee.danswer.server.enterprise_settings.store import store_analytics_script
from ee.danswer.server.enterprise_settings.store import store_settings
from ee.danswer.server.enterprise_settings.store import upload_logo
from shared_configs.configs import CUSTOM_REFRESH_URL
admin_router = APIRouter(prefix="/admin/enterprise-settings")
basic_router = APIRouter(prefix="/enterprise-settings")
logger = setup_logger()
def mocked_refresh_token() -> dict:
"""
This function mocks the response from a token refresh endpoint.
It generates a mock access token, refresh token, and user information
with an expiration time set to 1 hour from now.
This is useful for testing or development when the actual refresh endpoint is not available.
"""
mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000)
data = {
"access_token": "asdf Mock access token",
"refresh_token": "asdf Mock refresh token",
"session": {"exp": mock_exp},
"userinfo": {
"sub": "Mock email",
"familyName": "Mock name",
"givenName": "Mock name",
"fullName": "Mock name",
"userId": "Mock User ID",
"email": "test_email@danswer.ai",
},
}
return data
@basic_router.get("/refresh-token")
async def refresh_access_token(
user: User = Depends(current_user),
user_manager: UserManager = Depends(get_user_manager),
) -> None:
# return
if CUSTOM_REFRESH_URL is None:
logger.error(
"Custom refresh URL is not set and client is attempting to custom refresh"
)
raise HTTPException(
status_code=500,
detail="Custom refresh URL is not set",
)
try:
async with httpx.AsyncClient() as client:
logger.debug(f"Sending request to custom refresh URL for user {user.id}")
access_token = user.oauth_accounts[0].access_token
response = await client.get(
CUSTOM_REFRESH_URL,
params={"info": "json", "access_token_refresh_interval": 3600},
headers={"Authorization": f"Bearer {access_token}"},
)
response.raise_for_status()
data = response.json()
# NOTE: Here is where we can mock the response
# data = mocked_refresh_token()
logger.debug(f"Received response from Meechum auth URL for user {user.id}")
# Extract new tokens
new_access_token = data["access_token"]
new_refresh_token = data["refresh_token"]
new_expiry = datetime.fromtimestamp(
data["session"]["exp"] / 1000, tz=timezone.utc
)
expires_at_timestamp = int(new_expiry.timestamp())
logger.debug(f"Access token has been refreshed for user {user.id}")
await user_manager.oauth_callback(
oauth_name="custom",
access_token=new_access_token,
account_id=data["userinfo"]["userId"],
account_email=data["userinfo"]["email"],
expires_at=expires_at_timestamp,
refresh_token=new_refresh_token,
associate_by_email=True,
)
logger.info(f"Successfully refreshed tokens for user {user.id}")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
logger.warning(f"Full authentication required for user {user.id}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Full authentication required",
)
logger.error(
f"HTTP error occurred while refreshing token for user {user.id}: {str(e)}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to refresh token",
)
except Exception as e:
logger.error(
f"Unexpected error occurred while refreshing token for user {user.id}: {str(e)}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred",
)
@admin_router.put("")
def put_settings(