From 38afc8fa3a5badaf5c62402e76ccd5208e42a615 Mon Sep 17 00:00:00 2001 From: pablonyx Date: Thu, 13 Mar 2025 16:07:37 -0700 Subject: [PATCH] clean up + tests --- backend/onyx/auth/users.py | 7 +- backend/tests/unit/onyx/auth/conftest.py | 43 +++ .../unit/onyx/auth/test_oauth_refresher.py | 255 ++++++++++++++++++ web/src/lib/fetcher.ts | 106 ++------ 4 files changed, 318 insertions(+), 93 deletions(-) create mode 100644 backend/tests/unit/onyx/auth/conftest.py create mode 100644 backend/tests/unit/onyx/auth/test_oauth_refresher.py diff --git a/backend/onyx/auth/users.py b/backend/onyx/auth/users.py index 02028553a..547adc78c 100644 --- a/backend/onyx/auth/users.py +++ b/backend/onyx/auth/users.py @@ -1195,10 +1195,9 @@ def get_oauth_router( # For Google OAuth, add parameters to request refresh tokens if oauth_client.name == "google": - if "?" in authorization_url: - authorization_url += "&access_type=offline&prompt=consent" - else: - authorization_url += "?access_type=offline&prompt=consent" + authorization_url = add_url_params( + authorization_url, {"access_type": "offline", "prompt": "consent"} + ) return OAuth2AuthorizeResponse(authorization_url=authorization_url) diff --git a/backend/tests/unit/onyx/auth/conftest.py b/backend/tests/unit/onyx/auth/conftest.py new file mode 100644 index 000000000..52b323c6e --- /dev/null +++ b/backend/tests/unit/onyx/auth/conftest.py @@ -0,0 +1,43 @@ +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +import pytest + +from onyx.db.models import OAuthAccount +from onyx.db.models import User + + +@pytest.fixture +def mock_user() -> MagicMock: + """Creates a mock User instance for testing.""" + user = MagicMock(spec=User) + user.email = "test@example.com" + user.id = "test-user-id" + return user + + +@pytest.fixture +def mock_oauth_account() -> MagicMock: + """Creates a mock OAuthAccount instance for testing.""" + oauth_account = MagicMock(spec=OAuthAccount) + oauth_account.oauth_name = "google" + oauth_account.refresh_token = "test-refresh-token" + oauth_account.access_token = "test-access-token" + oauth_account.expires_at = None + return oauth_account + + +@pytest.fixture +def mock_user_manager() -> MagicMock: + """Creates a mock user manager for testing.""" + user_manager = MagicMock() + user_manager.user_db = MagicMock() + user_manager.user_db.update_oauth_account = AsyncMock() + user_manager.user_db.update = AsyncMock() + return user_manager + + +@pytest.fixture +def mock_db_session() -> MagicMock: + """Creates a mock database session for testing.""" + return MagicMock() diff --git a/backend/tests/unit/onyx/auth/test_oauth_refresher.py b/backend/tests/unit/onyx/auth/test_oauth_refresher.py new file mode 100644 index 000000000..9f4d2a6ea --- /dev/null +++ b/backend/tests/unit/onyx/auth/test_oauth_refresher.py @@ -0,0 +1,255 @@ +from datetime import datetime +from datetime import timezone +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from onyx.auth.oauth_refresher import _test_expire_oauth_token +from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens +from onyx.auth.oauth_refresher import check_oauth_account_has_refresh_token +from onyx.auth.oauth_refresher import get_oauth_accounts_requiring_refresh_token +from onyx.auth.oauth_refresher import refresh_oauth_token +from onyx.db.models import OAuthAccount + + +@pytest.mark.asyncio +async def test_refresh_oauth_token_success( + mock_user, mock_oauth_account, mock_user_manager, mock_db_session +): + """Test successful OAuth token refresh.""" + # Mock HTTP client and response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "new_token", + "refresh_token": "new_refresh_token", + "expires_in": 3600, + } + + # Create async mock for the client post method + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + # Use fixture values but ensure refresh token exists + mock_oauth_account.oauth_name = ( + "google" # Ensure it's google to match the refresh endpoint + ) + mock_oauth_account.refresh_token = "old_refresh_token" + + # Patch at the module level where it's actually being used + with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock: + # Configure the context manager + client_instance = mock_client + client_class_mock.return_value.__aenter__.return_value = client_instance + + # Call the function under test + result = await refresh_oauth_token( + mock_user, mock_oauth_account, mock_db_session, mock_user_manager + ) + + # Assertions + assert result is True + mock_client.post.assert_called_once() + mock_user_manager.user_db.update_oauth_account.assert_called_once() + + # Verify token data was updated correctly + update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2] + assert update_data["access_token"] == "new_token" + assert update_data["refresh_token"] == "new_refresh_token" + assert "expires_at" in update_data + + +@pytest.mark.asyncio +async def test_refresh_oauth_token_failure( + mock_user, mock_oauth_account, mock_user_manager, mock_db_session +): + """Test OAuth token refresh failure due to HTTP error.""" + # Mock HTTP client with error response + mock_response = MagicMock() + mock_response.status_code = 400 # Simulate error + + # Create async mock for the client post method + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + # Ensure refresh token exists and provider is supported + mock_oauth_account.oauth_name = "google" + mock_oauth_account.refresh_token = "old_refresh_token" + + # Patch at the module level where it's actually being used + with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock: + # Configure the context manager + client_class_mock.return_value.__aenter__.return_value = mock_client + + # Call the function under test + result = await refresh_oauth_token( + mock_user, mock_oauth_account, mock_db_session, mock_user_manager + ) + + # Assertions + assert result is False + mock_client.post.assert_called_once() + mock_user_manager.user_db.update_oauth_account.assert_not_called() + + +@pytest.mark.asyncio +async def test_refresh_oauth_token_no_refresh_token( + mock_user, mock_oauth_account, mock_user_manager, mock_db_session +): + """Test OAuth token refresh when no refresh token is available.""" + # Set refresh token to None + mock_oauth_account.refresh_token = None + mock_oauth_account.oauth_name = "google" + + # No need to mock httpx since it shouldn't be called + result = await refresh_oauth_token( + mock_user, mock_oauth_account, mock_db_session, mock_user_manager + ) + + # Assertions + assert result is False + + +@pytest.mark.asyncio +async def test_check_and_refresh_oauth_tokens( + mock_user, mock_user_manager, mock_db_session +): + """Test checking and refreshing multiple OAuth tokens.""" + # Create mock user with OAuth accounts + now_timestamp = datetime.now(timezone.utc).timestamp() + + # Create an account that needs refreshing (expiring soon) + expiring_account = MagicMock(spec=OAuthAccount) + expiring_account.oauth_name = "google" + expiring_account.refresh_token = "refresh_token_1" + expiring_account.expires_at = now_timestamp + 60 # Expires in 1 minute + + # Create an account that doesn't need refreshing (expires later) + valid_account = MagicMock(spec=OAuthAccount) + valid_account.oauth_name = "google" + valid_account.refresh_token = "refresh_token_2" + valid_account.expires_at = now_timestamp + 3600 # Expires in 1 hour + + # Create an account without a refresh token + no_refresh_account = MagicMock(spec=OAuthAccount) + no_refresh_account.oauth_name = "google" + no_refresh_account.refresh_token = None + no_refresh_account.expires_at = ( + now_timestamp + 60 + ) # Expiring soon but no refresh token + + # Set oauth_accounts on the mock user + mock_user.oauth_accounts = [expiring_account, valid_account, no_refresh_account] + + # Mock refresh_oauth_token function + with patch( + "onyx.auth.oauth_refresher.refresh_oauth_token", AsyncMock(return_value=True) + ) as mock_refresh: + # Call the function under test + await check_and_refresh_oauth_tokens( + mock_user, mock_db_session, mock_user_manager + ) + + # Assertions + assert mock_refresh.call_count == 1 # Should only refresh the expiring account + # Check it was called with the expiring account + mock_refresh.assert_called_once_with( + mock_user, expiring_account, mock_db_session, mock_user_manager + ) + + +@pytest.mark.asyncio +async def test_get_oauth_accounts_requiring_refresh_token(mock_user): + """Test identifying OAuth accounts that need refresh tokens.""" + # Create accounts with and without refresh tokens + account_with_token = MagicMock(spec=OAuthAccount) + account_with_token.oauth_name = "google" + account_with_token.refresh_token = "refresh_token" + + account_without_token = MagicMock(spec=OAuthAccount) + account_without_token.oauth_name = "google" + account_without_token.refresh_token = None + + second_account_without_token = MagicMock(spec=OAuthAccount) + second_account_without_token.oauth_name = "github" + second_account_without_token.refresh_token = ( + "" # Empty string should also be treated as missing + ) + + # Set accounts on user + mock_user.oauth_accounts = [ + account_with_token, + account_without_token, + second_account_without_token, + ] + + # Call the function under test + accounts_needing_refresh = await get_oauth_accounts_requiring_refresh_token( + mock_user + ) + + # Assertions + assert len(accounts_needing_refresh) == 2 + assert account_without_token in accounts_needing_refresh + assert second_account_without_token in accounts_needing_refresh + assert account_with_token not in accounts_needing_refresh + + +@pytest.mark.asyncio +async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_account): + """Test checking if an OAuth account has a refresh token.""" + # Test with refresh token + mock_oauth_account.refresh_token = "refresh_token" + has_token = await check_oauth_account_has_refresh_token( + mock_user, mock_oauth_account + ) + assert has_token is True + + # Test with None refresh token + mock_oauth_account.refresh_token = None + has_token = await check_oauth_account_has_refresh_token( + mock_user, mock_oauth_account + ) + assert has_token is False + + # Test with empty string refresh token + mock_oauth_account.refresh_token = "" + has_token = await check_oauth_account_has_refresh_token( + mock_user, mock_oauth_account + ) + assert has_token is False + + +@pytest.mark.asyncio +async def test_test_expire_oauth_token( + mock_user, mock_oauth_account, mock_user_manager, mock_db_session +): + """Test the testing utility function for token expiration.""" + # Set up the mock account + mock_oauth_account.oauth_name = "google" + mock_oauth_account.refresh_token = "test_refresh_token" + mock_oauth_account.access_token = "test_access_token" + + # Call the function under test + result = await _test_expire_oauth_token( + mock_user, + mock_oauth_account, + mock_db_session, + mock_user_manager, + expire_in_seconds=10, + ) + + # Assertions + assert result is True + mock_user_manager.user_db.update_oauth_account.assert_called_once() + + # Verify the expiration time was set correctly + update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2] + assert "expires_at" in update_data + + # Now should be within 10-11 seconds of the set expiration + now = datetime.now(timezone.utc).timestamp() + assert update_data["expires_at"] - now >= 9 # Allow 1 second for test execution + assert update_data["expires_at"] - now <= 11 # Allow 1 second for test execution diff --git a/web/src/lib/fetcher.ts b/web/src/lib/fetcher.ts index 7a6ea88f3..c7f0d204c 100644 --- a/web/src/lib/fetcher.ts +++ b/web/src/lib/fetcher.ts @@ -19,98 +19,26 @@ const DEFAULT_AUTH_ERROR_MSG = const DEFAULT_ERROR_MSG = "An error occurred while fetching the data."; -// Keep track of token refresh attempts to prevent infinite loops -let isRefreshing = false; -let refreshPromise: Promise | null = null; +export const errorHandlingFetcher = async (url: string): Promise => { + const res = await fetch(url); -// Function to refresh the auth token -const refreshAuthToken = async (): Promise => { - if (isRefreshing) { - // If already refreshing, return the existing promise - console.debug( - "Token refresh already in progress, reusing existing promise" + if (res.status === 403) { + const redirect = new RedirectError( + DEFAULT_AUTH_ERROR_MSG, + res.status, + await res.json() ); - return refreshPromise || Promise.resolve(false); + throw redirect; } - console.debug("Starting token refresh due to 401 response"); - isRefreshing = true; - refreshPromise = new Promise(async (resolve) => { - try { - console.debug("Calling /api/auth/refresh endpoint"); - const response = await fetch("/api/auth/refresh", { - method: "POST", - credentials: "include", - }); + if (!res.ok) { + const error = new FetchError( + DEFAULT_ERROR_MSG, + res.status, + await res.json() + ); + throw error; + } - const success = response.ok; - if (success) { - console.debug("Token refresh succeeded"); - } else { - console.warn(`Token refresh failed with status: ${response.status}`); - } - resolve(success); - } catch (error) { - console.error("Error during token refresh:", error); - resolve(false); - } finally { - console.debug("Token refresh attempt completed"); - isRefreshing = false; - refreshPromise = null; - } - }); - - return refreshPromise; -}; - -export const errorHandlingFetcher = async (url: string): Promise => { - const performFetch = async (retried = false): Promise => { - const res = await fetch(url); - - // If unauthorized and not already retried, attempt to refresh token - if (res.status === 401 && !retried) { - console.debug( - `401 Unauthorized received for ${url}, attempting token refresh` - ); - // Try to refresh the token - const refreshSucceeded = await refreshAuthToken(); - - if (refreshSucceeded) { - console.debug(`Token refresh succeeded, retrying request to ${url}`); - // If token refresh succeeded, retry the original request - return performFetch(true); // Retry with retried flag set to true - } - - console.debug(`Token refresh failed, cannot retry request to ${url}`); - // If refresh failed, proceed with error handling as usual - const error = new FetchError( - DEFAULT_AUTH_ERROR_MSG, - res.status, - await res.json().catch(() => ({})) - ); - throw error; - } - - if (res.status === 403) { - const redirect = new RedirectError( - DEFAULT_AUTH_ERROR_MSG, - res.status, - await res.json().catch(() => ({})) - ); - throw redirect; - } - - if (!res.ok) { - const error = new FetchError( - DEFAULT_ERROR_MSG, - res.status, - await res.json().catch(() => ({})) - ); - throw error; - } - - return res.json(); - }; - - return performFetch(); + return res.json(); };