mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-22 22:44:21 +02:00
clean up + tests
This commit is contained in:
parent
185aa07526
commit
38afc8fa3a
@ -1195,10 +1195,9 @@ def get_oauth_router(
|
||||
|
||||
# For Google OAuth, add parameters to request refresh tokens
|
||||
if oauth_client.name == "google":
|
||||
if "?" in authorization_url:
|
||||
authorization_url += "&access_type=offline&prompt=consent"
|
||||
else:
|
||||
authorization_url += "?access_type=offline&prompt=consent"
|
||||
authorization_url = add_url_params(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
|
43
backend/tests/unit/onyx/auth/conftest.py
Normal file
43
backend/tests/unit/onyx/auth/conftest.py
Normal file
@ -0,0 +1,43 @@
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user() -> MagicMock:
|
||||
"""Creates a mock User instance for testing."""
|
||||
user = MagicMock(spec=User)
|
||||
user.email = "test@example.com"
|
||||
user.id = "test-user-id"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_account() -> MagicMock:
|
||||
"""Creates a mock OAuthAccount instance for testing."""
|
||||
oauth_account = MagicMock(spec=OAuthAccount)
|
||||
oauth_account.oauth_name = "google"
|
||||
oauth_account.refresh_token = "test-refresh-token"
|
||||
oauth_account.access_token = "test-access-token"
|
||||
oauth_account.expires_at = None
|
||||
return oauth_account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_manager() -> MagicMock:
|
||||
"""Creates a mock user manager for testing."""
|
||||
user_manager = MagicMock()
|
||||
user_manager.user_db = MagicMock()
|
||||
user_manager.user_db.update_oauth_account = AsyncMock()
|
||||
user_manager.user_db.update = AsyncMock()
|
||||
return user_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
"""Creates a mock database session for testing."""
|
||||
return MagicMock()
|
255
backend/tests/unit/onyx/auth/test_oauth_refresher.py
Normal file
255
backend/tests/unit/onyx/auth/test_oauth_refresher.py
Normal file
@ -0,0 +1,255 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.oauth_refresher import _test_expire_oauth_token
|
||||
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
|
||||
from onyx.auth.oauth_refresher import check_oauth_account_has_refresh_token
|
||||
from onyx.auth.oauth_refresher import get_oauth_accounts_requiring_refresh_token
|
||||
from onyx.auth.oauth_refresher import refresh_oauth_token
|
||||
from onyx.db.models import OAuthAccount
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_success(
|
||||
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
|
||||
):
|
||||
"""Test successful OAuth token refresh."""
|
||||
# Mock HTTP client and response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new_token",
|
||||
"refresh_token": "new_refresh_token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
# Create async mock for the client post method
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Use fixture values but ensure refresh token exists
|
||||
mock_oauth_account.oauth_name = (
|
||||
"google" # Ensure it's google to match the refresh endpoint
|
||||
)
|
||||
mock_oauth_account.refresh_token = "old_refresh_token"
|
||||
|
||||
# Patch at the module level where it's actually being used
|
||||
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
|
||||
# Configure the context manager
|
||||
client_instance = mock_client
|
||||
client_class_mock.return_value.__aenter__.return_value = client_instance
|
||||
|
||||
# Call the function under test
|
||||
result = await refresh_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is True
|
||||
mock_client.post.assert_called_once()
|
||||
mock_user_manager.user_db.update_oauth_account.assert_called_once()
|
||||
|
||||
# Verify token data was updated correctly
|
||||
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
|
||||
assert update_data["access_token"] == "new_token"
|
||||
assert update_data["refresh_token"] == "new_refresh_token"
|
||||
assert "expires_at" in update_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_failure(
|
||||
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
|
||||
):
|
||||
"""Test OAuth token refresh failure due to HTTP error."""
|
||||
# Mock HTTP client with error response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400 # Simulate error
|
||||
|
||||
# Create async mock for the client post method
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
# Ensure refresh token exists and provider is supported
|
||||
mock_oauth_account.oauth_name = "google"
|
||||
mock_oauth_account.refresh_token = "old_refresh_token"
|
||||
|
||||
# Patch at the module level where it's actually being used
|
||||
with patch("onyx.auth.oauth_refresher.httpx.AsyncClient") as client_class_mock:
|
||||
# Configure the context manager
|
||||
client_class_mock.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Call the function under test
|
||||
result = await refresh_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is False
|
||||
mock_client.post.assert_called_once()
|
||||
mock_user_manager.user_db.update_oauth_account.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_no_refresh_token(
|
||||
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
|
||||
):
|
||||
"""Test OAuth token refresh when no refresh token is available."""
|
||||
# Set refresh token to None
|
||||
mock_oauth_account.refresh_token = None
|
||||
mock_oauth_account.oauth_name = "google"
|
||||
|
||||
# No need to mock httpx since it shouldn't be called
|
||||
result = await refresh_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_and_refresh_oauth_tokens(
|
||||
mock_user, mock_user_manager, mock_db_session
|
||||
):
|
||||
"""Test checking and refreshing multiple OAuth tokens."""
|
||||
# Create mock user with OAuth accounts
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# Create an account that needs refreshing (expiring soon)
|
||||
expiring_account = MagicMock(spec=OAuthAccount)
|
||||
expiring_account.oauth_name = "google"
|
||||
expiring_account.refresh_token = "refresh_token_1"
|
||||
expiring_account.expires_at = now_timestamp + 60 # Expires in 1 minute
|
||||
|
||||
# Create an account that doesn't need refreshing (expires later)
|
||||
valid_account = MagicMock(spec=OAuthAccount)
|
||||
valid_account.oauth_name = "google"
|
||||
valid_account.refresh_token = "refresh_token_2"
|
||||
valid_account.expires_at = now_timestamp + 3600 # Expires in 1 hour
|
||||
|
||||
# Create an account without a refresh token
|
||||
no_refresh_account = MagicMock(spec=OAuthAccount)
|
||||
no_refresh_account.oauth_name = "google"
|
||||
no_refresh_account.refresh_token = None
|
||||
no_refresh_account.expires_at = (
|
||||
now_timestamp + 60
|
||||
) # Expiring soon but no refresh token
|
||||
|
||||
# Set oauth_accounts on the mock user
|
||||
mock_user.oauth_accounts = [expiring_account, valid_account, no_refresh_account]
|
||||
|
||||
# Mock refresh_oauth_token function
|
||||
with patch(
|
||||
"onyx.auth.oauth_refresher.refresh_oauth_token", AsyncMock(return_value=True)
|
||||
) as mock_refresh:
|
||||
# Call the function under test
|
||||
await check_and_refresh_oauth_tokens(
|
||||
mock_user, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert mock_refresh.call_count == 1 # Should only refresh the expiring account
|
||||
# Check it was called with the expiring account
|
||||
mock_refresh.assert_called_once_with(
|
||||
mock_user, expiring_account, mock_db_session, mock_user_manager
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_oauth_accounts_requiring_refresh_token(mock_user):
|
||||
"""Test identifying OAuth accounts that need refresh tokens."""
|
||||
# Create accounts with and without refresh tokens
|
||||
account_with_token = MagicMock(spec=OAuthAccount)
|
||||
account_with_token.oauth_name = "google"
|
||||
account_with_token.refresh_token = "refresh_token"
|
||||
|
||||
account_without_token = MagicMock(spec=OAuthAccount)
|
||||
account_without_token.oauth_name = "google"
|
||||
account_without_token.refresh_token = None
|
||||
|
||||
second_account_without_token = MagicMock(spec=OAuthAccount)
|
||||
second_account_without_token.oauth_name = "github"
|
||||
second_account_without_token.refresh_token = (
|
||||
"" # Empty string should also be treated as missing
|
||||
)
|
||||
|
||||
# Set accounts on user
|
||||
mock_user.oauth_accounts = [
|
||||
account_with_token,
|
||||
account_without_token,
|
||||
second_account_without_token,
|
||||
]
|
||||
|
||||
# Call the function under test
|
||||
accounts_needing_refresh = await get_oauth_accounts_requiring_refresh_token(
|
||||
mock_user
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert len(accounts_needing_refresh) == 2
|
||||
assert account_without_token in accounts_needing_refresh
|
||||
assert second_account_without_token in accounts_needing_refresh
|
||||
assert account_with_token not in accounts_needing_refresh
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_account):
|
||||
"""Test checking if an OAuth account has a refresh token."""
|
||||
# Test with refresh token
|
||||
mock_oauth_account.refresh_token = "refresh_token"
|
||||
has_token = await check_oauth_account_has_refresh_token(
|
||||
mock_user, mock_oauth_account
|
||||
)
|
||||
assert has_token is True
|
||||
|
||||
# Test with None refresh token
|
||||
mock_oauth_account.refresh_token = None
|
||||
has_token = await check_oauth_account_has_refresh_token(
|
||||
mock_user, mock_oauth_account
|
||||
)
|
||||
assert has_token is False
|
||||
|
||||
# Test with empty string refresh token
|
||||
mock_oauth_account.refresh_token = ""
|
||||
has_token = await check_oauth_account_has_refresh_token(
|
||||
mock_user, mock_oauth_account
|
||||
)
|
||||
assert has_token is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_expire_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
|
||||
):
|
||||
"""Test the testing utility function for token expiration."""
|
||||
# Set up the mock account
|
||||
mock_oauth_account.oauth_name = "google"
|
||||
mock_oauth_account.refresh_token = "test_refresh_token"
|
||||
mock_oauth_account.access_token = "test_access_token"
|
||||
|
||||
# Call the function under test
|
||||
result = await _test_expire_oauth_token(
|
||||
mock_user,
|
||||
mock_oauth_account,
|
||||
mock_db_session,
|
||||
mock_user_manager,
|
||||
expire_in_seconds=10,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert result is True
|
||||
mock_user_manager.user_db.update_oauth_account.assert_called_once()
|
||||
|
||||
# Verify the expiration time was set correctly
|
||||
update_data = mock_user_manager.user_db.update_oauth_account.call_args[0][2]
|
||||
assert "expires_at" in update_data
|
||||
|
||||
# Now should be within 10-11 seconds of the set expiration
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
assert update_data["expires_at"] - now >= 9 # Allow 1 second for test execution
|
||||
assert update_data["expires_at"] - now <= 11 # Allow 1 second for test execution
|
@ -19,98 +19,26 @@ const DEFAULT_AUTH_ERROR_MSG =
|
||||
|
||||
const DEFAULT_ERROR_MSG = "An error occurred while fetching the data.";
|
||||
|
||||
// Keep track of token refresh attempts to prevent infinite loops
|
||||
let isRefreshing = false;
|
||||
let refreshPromise: Promise<boolean> | null = null;
|
||||
export const errorHandlingFetcher = async <T>(url: string): Promise<T> => {
|
||||
const res = await fetch(url);
|
||||
|
||||
// Function to refresh the auth token
|
||||
const refreshAuthToken = async (): Promise<boolean> => {
|
||||
if (isRefreshing) {
|
||||
// If already refreshing, return the existing promise
|
||||
console.debug(
|
||||
"Token refresh already in progress, reusing existing promise"
|
||||
if (res.status === 403) {
|
||||
const redirect = new RedirectError(
|
||||
DEFAULT_AUTH_ERROR_MSG,
|
||||
res.status,
|
||||
await res.json()
|
||||
);
|
||||
return refreshPromise || Promise.resolve(false);
|
||||
throw redirect;
|
||||
}
|
||||
|
||||
console.debug("Starting token refresh due to 401 response");
|
||||
isRefreshing = true;
|
||||
refreshPromise = new Promise<boolean>(async (resolve) => {
|
||||
try {
|
||||
console.debug("Calling /api/auth/refresh endpoint");
|
||||
const response = await fetch("/api/auth/refresh", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
if (!res.ok) {
|
||||
const error = new FetchError(
|
||||
DEFAULT_ERROR_MSG,
|
||||
res.status,
|
||||
await res.json()
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
|
||||
const success = response.ok;
|
||||
if (success) {
|
||||
console.debug("Token refresh succeeded");
|
||||
} else {
|
||||
console.warn(`Token refresh failed with status: ${response.status}`);
|
||||
}
|
||||
resolve(success);
|
||||
} catch (error) {
|
||||
console.error("Error during token refresh:", error);
|
||||
resolve(false);
|
||||
} finally {
|
||||
console.debug("Token refresh attempt completed");
|
||||
isRefreshing = false;
|
||||
refreshPromise = null;
|
||||
}
|
||||
});
|
||||
|
||||
return refreshPromise;
|
||||
};
|
||||
|
||||
export const errorHandlingFetcher = async <T>(url: string): Promise<T> => {
|
||||
const performFetch = async (retried = false): Promise<T> => {
|
||||
const res = await fetch(url);
|
||||
|
||||
// If unauthorized and not already retried, attempt to refresh token
|
||||
if (res.status === 401 && !retried) {
|
||||
console.debug(
|
||||
`401 Unauthorized received for ${url}, attempting token refresh`
|
||||
);
|
||||
// Try to refresh the token
|
||||
const refreshSucceeded = await refreshAuthToken();
|
||||
|
||||
if (refreshSucceeded) {
|
||||
console.debug(`Token refresh succeeded, retrying request to ${url}`);
|
||||
// If token refresh succeeded, retry the original request
|
||||
return performFetch(true); // Retry with retried flag set to true
|
||||
}
|
||||
|
||||
console.debug(`Token refresh failed, cannot retry request to ${url}`);
|
||||
// If refresh failed, proceed with error handling as usual
|
||||
const error = new FetchError(
|
||||
DEFAULT_AUTH_ERROR_MSG,
|
||||
res.status,
|
||||
await res.json().catch(() => ({}))
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
|
||||
if (res.status === 403) {
|
||||
const redirect = new RedirectError(
|
||||
DEFAULT_AUTH_ERROR_MSG,
|
||||
res.status,
|
||||
await res.json().catch(() => ({}))
|
||||
);
|
||||
throw redirect;
|
||||
}
|
||||
|
||||
if (!res.ok) {
|
||||
const error = new FetchError(
|
||||
DEFAULT_ERROR_MSG,
|
||||
res.status,
|
||||
await res.json().catch(() => ({}))
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
|
||||
return res.json();
|
||||
};
|
||||
|
||||
return performFetch();
|
||||
return res.json();
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user