clean up + tests

This commit is contained in:
pablonyx 2025-03-13 16:07:37 -07:00
parent 185aa07526
commit 38afc8fa3a
4 changed files with 318 additions and 93 deletions

View File

@ -1195,10 +1195,9 @@ def get_oauth_router(
# For Google OAuth, add parameters to request refresh tokens
if oauth_client.name == "google":
if "?" in authorization_url:
authorization_url += "&access_type=offline&prompt=consent"
else:
authorization_url += "?access_type=offline&prompt=consent"
authorization_url = add_url_params(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)

View File

@ -0,0 +1,43 @@
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
import pytest
from onyx.db.models import OAuthAccount
from onyx.db.models import User
@pytest.fixture
def mock_user() -> MagicMock:
"""Creates a mock User instance for testing."""
user = MagicMock(spec=User)
user.email = "test@example.com"
user.id = "test-user-id"
return user
@pytest.fixture
def mock_oauth_account() -> MagicMock:
"""Creates a mock OAuthAccount instance for testing."""
oauth_account = MagicMock(spec=OAuthAccount)
oauth_account.oauth_name = "google"
oauth_account.refresh_token = "test-refresh-token"
oauth_account.access_token = "test-access-token"
oauth_account.expires_at = None
return oauth_account
@pytest.fixture
def mock_user_manager() -> MagicMock:
"""Creates a mock user manager for testing."""
user_manager = MagicMock()
user_manager.user_db = MagicMock()
user_manager.user_db.update_oauth_account = AsyncMock()
user_manager.user_db.update = AsyncMock()
return user_manager
@pytest.fixture
def mock_db_session() -> MagicMock:
"""Creates a mock database session for testing."""
return MagicMock()

View File

@ -0,0 +1,255 @@
from datetime import datetime
from datetime import timezone
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.auth.oauth_refresher import _test_expire_oauth_token
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
from onyx.auth.oauth_refresher import check_oauth_account_has_refresh_token
from onyx.auth.oauth_refresher import get_oauth_accounts_requiring_refresh_token
from onyx.auth.oauth_refresher import refresh_oauth_token
from onyx.db.models import OAuthAccount
@pytest.mark.asyncio
async def test_refresh_oauth_token_success(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
):
"""Test successful OAuth token refresh."""
# Mock HTTP client and response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"access_token": "new_token",
"refresh_token": "new_refresh_token",
"expires_in": 3600,
}
# Create async mock for the client post method
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
# Use fixture values but ensure refresh token exists
mock_oauth_account.oauth_name = (
"google" # Ensure it's google to match the refresh endpoint
)
mock_oauth_account.refresh_token = "old_refresh_token"
# Patch at the module level where it's actually being used
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
# Configure the context manager
client_instance = mock_client
client_class_mock.return_value.__aenter__.return_value = client_instance
# Call the function under test
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is True
mock_client.post.assert_called_once()
mock_user_manager.user_db.update_oauth_account.assert_called_once()
# Verify token data was updated correctly
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
assert update_data["access_token"] == "new_token"
assert update_data["refresh_token"] == "new_refresh_token"
assert "expires_at" in update_data
@pytest.mark.asyncio
async def test_refresh_oauth_token_failure(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
):
"""Test OAuth token refresh failure due to HTTP error."""
# Mock HTTP client with error response
mock_response = MagicMock()
mock_response.status_code = 400 # Simulate error
# Create async mock for the client post method
mock_client = AsyncMock()
mock_client.post.return_value = mock_response
# Ensure refresh token exists and provider is supported
mock_oauth_account.oauth_name = "google"
mock_oauth_account.refresh_token = "old_refresh_token"
# Patch at the module level where it's actually being used
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
# Configure the context manager
client_class_mock.return_value.__aenter__.return_value = mock_client
# Call the function under test
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is False
mock_client.post.assert_called_once()
mock_user_manager.user_db.update_oauth_account.assert_not_called()
@pytest.mark.asyncio
async def test_refresh_oauth_token_no_refresh_token(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
):
"""Test OAuth token refresh when no refresh token is available."""
# Set refresh token to None
mock_oauth_account.refresh_token = None
mock_oauth_account.oauth_name = "google"
# No need to mock httpx since it shouldn't be called
result = await refresh_oauth_token(
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
)
# Assertions
assert result is False
@pytest.mark.asyncio
async def test_check_and_refresh_oauth_tokens(
mock_user, mock_user_manager, mock_db_session
):
"""Test checking and refreshing multiple OAuth tokens."""
# Create mock user with OAuth accounts
now_timestamp = datetime.now(timezone.utc).timestamp()
# Create an account that needs refreshing (expiring soon)
expiring_account = MagicMock(spec=OAuthAccount)
expiring_account.oauth_name = "google"
expiring_account.refresh_token = "refresh_token_1"
expiring_account.expires_at = now_timestamp + 60 # Expires in 1 minute
# Create an account that doesn't need refreshing (expires later)
valid_account = MagicMock(spec=OAuthAccount)
valid_account.oauth_name = "google"
valid_account.refresh_token = "refresh_token_2"
valid_account.expires_at = now_timestamp + 3600 # Expires in 1 hour
# Create an account without a refresh token
no_refresh_account = MagicMock(spec=OAuthAccount)
no_refresh_account.oauth_name = "google"
no_refresh_account.refresh_token = None
no_refresh_account.expires_at = (
now_timestamp + 60
) # Expiring soon but no refresh token
# Set oauth_accounts on the mock user
mock_user.oauth_accounts = [expiring_account, valid_account, no_refresh_account]
# Mock refresh_oauth_token function
with patch(
"onyx.auth.oauth_refresher.refresh_oauth_token", AsyncMock(return_value=True)
) as mock_refresh:
# Call the function under test
await check_and_refresh_oauth_tokens(
mock_user, mock_db_session, mock_user_manager
)
# Assertions
assert mock_refresh.call_count == 1 # Should only refresh the expiring account
# Check it was called with the expiring account
mock_refresh.assert_called_once_with(
mock_user, expiring_account, mock_db_session, mock_user_manager
)
@pytest.mark.asyncio
async def test_get_oauth_accounts_requiring_refresh_token(mock_user):
"""Test identifying OAuth accounts that need refresh tokens."""
# Create accounts with and without refresh tokens
account_with_token = MagicMock(spec=OAuthAccount)
account_with_token.oauth_name = "google"
account_with_token.refresh_token = "refresh_token"
account_without_token = MagicMock(spec=OAuthAccount)
account_without_token.oauth_name = "google"
account_without_token.refresh_token = None
second_account_without_token = MagicMock(spec=OAuthAccount)
second_account_without_token.oauth_name = "github"
second_account_without_token.refresh_token = (
"" # Empty string should also be treated as missing
)
# Set accounts on user
mock_user.oauth_accounts = [
account_with_token,
account_without_token,
second_account_without_token,
]
# Call the function under test
accounts_needing_refresh = await get_oauth_accounts_requiring_refresh_token(
mock_user
)
# Assertions
assert len(accounts_needing_refresh) == 2
assert account_without_token in accounts_needing_refresh
assert second_account_without_token in accounts_needing_refresh
assert account_with_token not in accounts_needing_refresh
@pytest.mark.asyncio
async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_account):
"""Test checking if an OAuth account has a refresh token."""
# Test with refresh token
mock_oauth_account.refresh_token = "refresh_token"
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is True
# Test with None refresh token
mock_oauth_account.refresh_token = None
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is False
# Test with empty string refresh token
mock_oauth_account.refresh_token = ""
has_token = await check_oauth_account_has_refresh_token(
mock_user, mock_oauth_account
)
assert has_token is False
@pytest.mark.asyncio
async def test_test_expire_oauth_token(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
):
"""Test the testing utility function for token expiration."""
# Set up the mock account
mock_oauth_account.oauth_name = "google"
mock_oauth_account.refresh_token = "test_refresh_token"
mock_oauth_account.access_token = "test_access_token"
# Call the function under test
result = await _test_expire_oauth_token(
mock_user,
mock_oauth_account,
mock_db_session,
mock_user_manager,
expire_in_seconds=10,
)
# Assertions
assert result is True
mock_user_manager.user_db.update_oauth_account.assert_called_once()
# Verify the expiration time was set correctly
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
assert "expires_at" in update_data
# Now should be within 10-11 seconds of the set expiration
now = datetime.now(timezone.utc).timestamp()
assert update_data["expires_at"] - now >= 9 # Allow 1 second for test execution
assert update_data["expires_at"] - now <= 11 # Allow 1 second for test execution

View File

@ -19,98 +19,26 @@ const DEFAULT_AUTH_ERROR_MSG =
const DEFAULT_ERROR_MSG = "An error occurred while fetching the data.";
// Keep track of token refresh attempts to prevent infinite loops
let isRefreshing = false;
let refreshPromise: Promise<boolean> | null = null;
export const errorHandlingFetcher = async <T>(url: string): Promise<T> => {
const res = await fetch(url);
// Function to refresh the auth token
const refreshAuthToken = async (): Promise<boolean> => {
if (isRefreshing) {
// If already refreshing, return the existing promise
console.debug(
"Token refresh already in progress, reusing existing promise"
if (res.status === 403) {
const redirect = new RedirectError(
DEFAULT_AUTH_ERROR_MSG,
res.status,
await res.json()
);
return refreshPromise || Promise.resolve(false);
throw redirect;
}
console.debug("Starting token refresh due to 401 response");
isRefreshing = true;
refreshPromise = new Promise<boolean>(async (resolve) => {
try {
console.debug("Calling /api/auth/refresh endpoint");
const response = await fetch("/api/auth/refresh", {
method: "POST",
credentials: "include",
});
if (!res.ok) {
const error = new FetchError(
DEFAULT_ERROR_MSG,
res.status,
await res.json()
);
throw error;
}
const success = response.ok;
if (success) {
console.debug("Token refresh succeeded");
} else {
console.warn(`Token refresh failed with status: ${response.status}`);
}
resolve(success);
} catch (error) {
console.error("Error during token refresh:", error);
resolve(false);
} finally {
console.debug("Token refresh attempt completed");
isRefreshing = false;
refreshPromise = null;
}
});
return refreshPromise;
};
export const errorHandlingFetcher = async <T>(url: string): Promise<T> => {
const performFetch = async (retried = false): Promise<T> => {
const res = await fetch(url);
// If unauthorized and not already retried, attempt to refresh token
if (res.status === 401 && !retried) {
console.debug(
`401 Unauthorized received for ${url}, attempting token refresh`
);
// Try to refresh the token
const refreshSucceeded = await refreshAuthToken();
if (refreshSucceeded) {
console.debug(`Token refresh succeeded, retrying request to ${url}`);
// If token refresh succeeded, retry the original request
return performFetch(true); // Retry with retried flag set to true
}
console.debug(`Token refresh failed, cannot retry request to ${url}`);
// If refresh failed, proceed with error handling as usual
const error = new FetchError(
DEFAULT_AUTH_ERROR_MSG,
res.status,
await res.json().catch(() => ({}))
);
throw error;
}
if (res.status === 403) {
const redirect = new RedirectError(
DEFAULT_AUTH_ERROR_MSG,
res.status,
await res.json().catch(() => ({}))
);
throw redirect;
}
if (!res.ok) {
const error = new FetchError(
DEFAULT_ERROR_MSG,
res.status,
await res.json().catch(() => ({}))
);
throw error;
}
return res.json();
};
return performFetch();
return res.json();
};