misc typing

This commit is contained in:
pablonyx 2025-03-15 11:21:26 -07:00
parent 8821f399f0
commit 69638b4c4e
3 changed files with 64 additions and 31 deletions

View File

@ -1,9 +1,13 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
import httpx
from fastapi_users.manager import BaseUserManager
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import OAUTH_CLIENT_ID
@ -27,7 +31,7 @@ async def _test_expire_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: Any,
user_manager: BaseUserManager[User, Any],
expire_in_seconds: int = 10,
) -> bool:
"""
@ -40,10 +44,10 @@ async def _test_expire_oauth_token(
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
)
updated_data = {"expires_at": new_expires_at}
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
await user_manager.user_db.update_oauth_account(
user, oauth_account, updated_data
user, cast(Any, oauth_account), updated_data
)
return True
@ -56,7 +60,7 @@ async def refresh_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: Any,
user_manager: BaseUserManager[User, Any],
) -> bool:
"""
Attempt to refresh an OAuth token that's about to expire or has expired.
@ -110,7 +114,7 @@ async def refresh_oauth_token(
)
# Update the OAuth account
updated_data = {
updated_data: Dict[str, Any] = {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}
@ -129,7 +133,7 @@ async def refresh_oauth_token(
# Update the OAuth account
await user_manager.user_db.update_oauth_account(
user, oauth_account, updated_data
user, cast(Any, oauth_account), updated_data
)
logger.info(f"Successfully refreshed OAuth token for {user.email}")
@ -143,7 +147,7 @@ async def refresh_oauth_token(
async def check_and_refresh_oauth_tokens(
user: User,
db_session: AsyncSession,
user_manager: Any,
user_manager: BaseUserManager[User, Any],
) -> None:
"""
Check if any OAuth tokens are expired or about to expire and refresh them.
@ -188,7 +192,7 @@ async def check_oauth_account_has_refresh_token(
return bool(oauth_account.refresh_token)
async def get_oauth_accounts_requiring_refresh_token(user: User) -> list[OAuthAccount]:
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
"""
Returns a list of OAuth accounts for a user that are missing refresh tokens.
These accounts will need re-authentication to get refresh tokens.

View File

@ -7,9 +7,9 @@ from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import Generic
from typing import List
from typing import Optional
from typing import Protocol
@ -691,15 +691,15 @@ cookie_transport = CookieTransport(
)
T = TypeVar("T")
ID = TypeVar("ID")
T = TypeVar("T", covariant=True)
ID = TypeVar("ID", contravariant=True)
# Protocol for strategies that support token refreshing without inheritance.
class RefreshableStrategy(Protocol, Generic[T, ID]):
class RefreshableStrategy(Protocol):
"""Protocol for authentication strategies that support token refreshing."""
async def refresh_token(self, token: str, user: T) -> str:
async def refresh_token(self, token: Optional[str], user: Any) -> str:
"""
Refresh an existing token by extending its lifetime.
Returns either the same token with extended expiration or a new token.
@ -792,13 +792,21 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
"""Database strategy with token refreshing capabilities."""
async def refresh_token(self, token: str, user: User) -> str:
def __init__(
self,
access_token_db: AccessTokenDatabase[AccessToken],
lifetime_seconds: Optional[int] = None,
):
super().__init__(access_token_db, lifetime_seconds)
self._access_token_db = access_token_db
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by updating its expiration time in the database."""
if token is None:
return await self.write_token(user)
# Find the token in database
access_token = await self.access_token_db.get_by_token(token)
access_token = await self._access_token_db.get_by_token(token)
if access_token is None:
# Token not found, create new one
@ -806,9 +814,9 @@ class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]
# Update expiration time
new_expires = datetime.now(timezone.utc) + timedelta(
seconds=self.lifetime_seconds
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
)
await self.access_token_db.update(access_token, {"expires": new_expires})
await self._access_token_db.update(access_token, {"expires": new_expires})
return token
@ -917,7 +925,9 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
# Check if user has OAuth accounts that need refreshing
await check_and_refresh_oauth_tokens(
user=user, db_session=db_session, user_manager=user_manager
user=cast(User, user),
db_session=db_session,
user_manager=cast(Any, user_manager),
)
# Check if strategy supports refreshing
@ -927,7 +937,8 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
if supports_refresh:
try:
new_token = await strategy.refresh_token(token, user)
refresh_method = getattr(strategy, "refresh_token")
new_token = await refresh_method(token, user)
logger.info(
f"Successfully refreshed session token for user {user.email}"
)

View File

@ -5,6 +5,7 @@ from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.oauth_refresher import _test_expire_oauth_token
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
@ -16,8 +17,11 @@ from onyx.db.models import OAuthAccount
@pytest.mark.asyncio
async def test_refresh_oauth_token_success(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
):
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test successful OAuth token refresh."""
# Mock HTTP client and response
mock_response = MagicMock()
@ -63,8 +67,11 @@ async def test_refresh_oauth_token_success(
@pytest.mark.asyncio
async def test_refresh_oauth_token_failure(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
):
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> bool:
"""Test OAuth token refresh failure due to HTTP error."""
# Mock HTTP client with error response
mock_response = MagicMock()
@ -92,12 +99,16 @@ async def test_refresh_oauth_token_failure(
assert result is False
mock_client.post.assert_called_once()
mock_user_manager.user_db.update_oauth_account.assert_not_called()
return True
@pytest.mark.asyncio
async def test_refresh_oauth_token_no_refresh_token(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
):
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test OAuth token refresh when no refresh token is available."""
# Set refresh token to None
mock_oauth_account.refresh_token = None
@ -114,8 +125,10 @@ async def test_refresh_oauth_token_no_refresh_token(
@pytest.mark.asyncio
async def test_check_and_refresh_oauth_tokens(
mock_user, mock_user_manager, mock_db_session
):
mock_user: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test checking and refreshing multiple OAuth tokens."""
# Create mock user with OAuth accounts
now_timestamp = datetime.now(timezone.utc).timestamp()
@ -161,7 +174,7 @@ async def test_check_and_refresh_oauth_tokens(
@pytest.mark.asyncio
async def test_get_oauth_accounts_requiring_refresh_token(mock_user):
async def test_get_oauth_accounts_requiring_refresh_token(mock_user: MagicMock) -> None:
"""Test identifying OAuth accounts that need refresh tokens."""
# Create accounts with and without refresh tokens
account_with_token = MagicMock(spec=OAuthAccount)
@ -198,7 +211,9 @@ async def test_get_oauth_accounts_requiring_refresh_token(mock_user):
@pytest.mark.asyncio
async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_account):
async def test_check_oauth_account_has_refresh_token(
mock_user: MagicMock, mock_oauth_account: MagicMock
) -> None:
"""Test checking if an OAuth account has a refresh token."""
# Test with refresh token
mock_oauth_account.refresh_token = "refresh_token"
@ -224,8 +239,11 @@ async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_accou
@pytest.mark.asyncio
async def test_test_expire_oauth_token(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
):
mock_user: MagicMock,
mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test the testing utility function for token expiration."""
# Set up the mock account
mock_oauth_account.oauth_name = "google"