misc typing

This commit is contained in:
pablonyx 2025-03-15 11:21:26 -07:00
parent 8821f399f0
commit 69638b4c4e
3 changed files with 64 additions and 31 deletions

View File

@ -1,9 +1,13 @@
from datetime import datetime from datetime import datetime
from datetime import timezone from datetime import timezone
from typing import Any from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional from typing import Optional
import httpx import httpx
from fastapi_users.manager import BaseUserManager
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import OAUTH_CLIENT_ID from onyx.configs.app_configs import OAUTH_CLIENT_ID
@ -27,7 +31,7 @@ async def _test_expire_oauth_token(
user: User, user: User,
oauth_account: OAuthAccount, oauth_account: OAuthAccount,
db_session: AsyncSession, db_session: AsyncSession,
user_manager: Any, user_manager: BaseUserManager[User, Any],
expire_in_seconds: int = 10, expire_in_seconds: int = 10,
) -> bool: ) -> bool:
""" """
@ -40,10 +44,10 @@ async def _test_expire_oauth_token(
(datetime.now(timezone.utc).timestamp() + expire_in_seconds) (datetime.now(timezone.utc).timestamp() + expire_in_seconds)
) )
updated_data = {"expires_at": new_expires_at} updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
await user_manager.user_db.update_oauth_account( await user_manager.user_db.update_oauth_account(
user, oauth_account, updated_data user, cast(Any, oauth_account), updated_data
) )
return True return True
@ -56,7 +60,7 @@ async def refresh_oauth_token(
user: User, user: User,
oauth_account: OAuthAccount, oauth_account: OAuthAccount,
db_session: AsyncSession, db_session: AsyncSession,
user_manager: Any, user_manager: BaseUserManager[User, Any],
) -> bool: ) -> bool:
""" """
Attempt to refresh an OAuth token that's about to expire or has expired. Attempt to refresh an OAuth token that's about to expire or has expired.
@ -110,7 +114,7 @@ async def refresh_oauth_token(
) )
# Update the OAuth account # Update the OAuth account
updated_data = { updated_data: Dict[str, Any] = {
"access_token": new_access_token, "access_token": new_access_token,
"refresh_token": new_refresh_token, "refresh_token": new_refresh_token,
} }
@ -129,7 +133,7 @@ async def refresh_oauth_token(
# Update the OAuth account # Update the OAuth account
await user_manager.user_db.update_oauth_account( await user_manager.user_db.update_oauth_account(
user, oauth_account, updated_data user, cast(Any, oauth_account), updated_data
) )
logger.info(f"Successfully refreshed OAuth token for {user.email}") logger.info(f"Successfully refreshed OAuth token for {user.email}")
@ -143,7 +147,7 @@ async def refresh_oauth_token(
async def check_and_refresh_oauth_tokens( async def check_and_refresh_oauth_tokens(
user: User, user: User,
db_session: AsyncSession, db_session: AsyncSession,
user_manager: Any, user_manager: BaseUserManager[User, Any],
) -> None: ) -> None:
""" """
Check if any OAuth tokens are expired or about to expire and refresh them. Check if any OAuth tokens are expired or about to expire and refresh them.
@ -188,7 +192,7 @@ async def check_oauth_account_has_refresh_token(
return bool(oauth_account.refresh_token) return bool(oauth_account.refresh_token)
async def get_oauth_accounts_requiring_refresh_token(user: User) -> list[OAuthAccount]: async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
""" """
Returns a list of OAuth accounts for a user that are missing refresh tokens. Returns a list of OAuth accounts for a user that are missing refresh tokens.
These accounts will need re-authentication to get refresh tokens. These accounts will need re-authentication to get refresh tokens.

View File

@ -7,9 +7,9 @@ from collections.abc import AsyncGenerator
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
from datetime import timezone from datetime import timezone
from typing import Any
from typing import cast from typing import cast
from typing import Dict from typing import Dict
from typing import Generic
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Protocol from typing import Protocol
@ -691,15 +691,15 @@ cookie_transport = CookieTransport(
) )
T = TypeVar("T") T = TypeVar("T", covariant=True)
ID = TypeVar("ID") ID = TypeVar("ID", contravariant=True)
# Protocol for strategies that support token refreshing without inheritance. # Protocol for strategies that support token refreshing without inheritance.
class RefreshableStrategy(Protocol, Generic[T, ID]): class RefreshableStrategy(Protocol):
"""Protocol for authentication strategies that support token refreshing.""" """Protocol for authentication strategies that support token refreshing."""
async def refresh_token(self, token: str, user: T) -> str: async def refresh_token(self, token: Optional[str], user: Any) -> str:
""" """
Refresh an existing token by extending its lifetime. Refresh an existing token by extending its lifetime.
Returns either the same token with extended expiration or a new token. Returns either the same token with extended expiration or a new token.
@ -792,13 +792,21 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]): class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
"""Database strategy with token refreshing capabilities.""" """Database strategy with token refreshing capabilities."""
async def refresh_token(self, token: str, user: User) -> str: def __init__(
self,
access_token_db: AccessTokenDatabase[AccessToken],
lifetime_seconds: Optional[int] = None,
):
super().__init__(access_token_db, lifetime_seconds)
self._access_token_db = access_token_db
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by updating its expiration time in the database.""" """Refresh a token by updating its expiration time in the database."""
if token is None: if token is None:
return await self.write_token(user) return await self.write_token(user)
# Find the token in database # Find the token in database
access_token = await self.access_token_db.get_by_token(token) access_token = await self._access_token_db.get_by_token(token)
if access_token is None: if access_token is None:
# Token not found, create new one # Token not found, create new one
@ -806,9 +814,9 @@ class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]
# Update expiration time # Update expiration time
new_expires = datetime.now(timezone.utc) + timedelta( new_expires = datetime.now(timezone.utc) + timedelta(
seconds=self.lifetime_seconds seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
) )
await self.access_token_db.update(access_token, {"expires": new_expires}) await self._access_token_db.update(access_token, {"expires": new_expires})
return token return token
@ -917,7 +925,9 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
# Check if user has OAuth accounts that need refreshing # Check if user has OAuth accounts that need refreshing
await check_and_refresh_oauth_tokens( await check_and_refresh_oauth_tokens(
user=user, db_session=db_session, user_manager=user_manager user=cast(User, user),
db_session=db_session,
user_manager=cast(Any, user_manager),
) )
# Check if strategy supports refreshing # Check if strategy supports refreshing
@ -927,7 +937,8 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
if supports_refresh: if supports_refresh:
try: try:
new_token = await strategy.refresh_token(token, user) refresh_method = getattr(strategy, "refresh_token")
new_token = await refresh_method(token, user)
logger.info( logger.info(
f"Successfully refreshed session token for user {user.email}" f"Successfully refreshed session token for user {user.email}"
) )

View File

@ -5,6 +5,7 @@ from unittest.mock import MagicMock
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.oauth_refresher import _test_expire_oauth_token from onyx.auth.oauth_refresher import _test_expire_oauth_token
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
@ -16,8 +17,11 @@ from onyx.db.models import OAuthAccount
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_oauth_token_success( async def test_refresh_oauth_token_success(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session mock_user: MagicMock,
): mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test successful OAuth token refresh.""" """Test successful OAuth token refresh."""
# Mock HTTP client and response # Mock HTTP client and response
mock_response = MagicMock() mock_response = MagicMock()
@ -63,8 +67,11 @@ async def test_refresh_oauth_token_success(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_oauth_token_failure( async def test_refresh_oauth_token_failure(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session mock_user: MagicMock,
): mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> bool:
"""Test OAuth token refresh failure due to HTTP error.""" """Test OAuth token refresh failure due to HTTP error."""
# Mock HTTP client with error response # Mock HTTP client with error response
mock_response = MagicMock() mock_response = MagicMock()
@ -92,12 +99,16 @@ async def test_refresh_oauth_token_failure(
assert result is False assert result is False
mock_client.post.assert_called_once() mock_client.post.assert_called_once()
mock_user_manager.user_db.update_oauth_account.assert_not_called() mock_user_manager.user_db.update_oauth_account.assert_not_called()
return True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_oauth_token_no_refresh_token( async def test_refresh_oauth_token_no_refresh_token(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session mock_user: MagicMock,
): mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test OAuth token refresh when no refresh token is available.""" """Test OAuth token refresh when no refresh token is available."""
# Set refresh token to None # Set refresh token to None
mock_oauth_account.refresh_token = None mock_oauth_account.refresh_token = None
@ -114,8 +125,10 @@ async def test_refresh_oauth_token_no_refresh_token(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_and_refresh_oauth_tokens( async def test_check_and_refresh_oauth_tokens(
mock_user, mock_user_manager, mock_db_session mock_user: MagicMock,
): mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test checking and refreshing multiple OAuth tokens.""" """Test checking and refreshing multiple OAuth tokens."""
# Create mock user with OAuth accounts # Create mock user with OAuth accounts
now_timestamp = datetime.now(timezone.utc).timestamp() now_timestamp = datetime.now(timezone.utc).timestamp()
@ -161,7 +174,7 @@ async def test_check_and_refresh_oauth_tokens(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_oauth_accounts_requiring_refresh_token(mock_user): async def test_get_oauth_accounts_requiring_refresh_token(mock_user: MagicMock) -> None:
"""Test identifying OAuth accounts that need refresh tokens.""" """Test identifying OAuth accounts that need refresh tokens."""
# Create accounts with and without refresh tokens # Create accounts with and without refresh tokens
account_with_token = MagicMock(spec=OAuthAccount) account_with_token = MagicMock(spec=OAuthAccount)
@ -198,7 +211,9 @@ async def test_get_oauth_accounts_requiring_refresh_token(mock_user):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_account): async def test_check_oauth_account_has_refresh_token(
mock_user: MagicMock, mock_oauth_account: MagicMock
) -> None:
"""Test checking if an OAuth account has a refresh token.""" """Test checking if an OAuth account has a refresh token."""
# Test with refresh token # Test with refresh token
mock_oauth_account.refresh_token = "refresh_token" mock_oauth_account.refresh_token = "refresh_token"
@ -224,8 +239,11 @@ async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_accou
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_test_expire_oauth_token( async def test_test_expire_oauth_token(
mock_user, mock_oauth_account, mock_user_manager, mock_db_session mock_user: MagicMock,
): mock_oauth_account: MagicMock,
mock_user_manager: MagicMock,
mock_db_session: AsyncSession,
) -> None:
"""Test the testing utility function for token expiration.""" """Test the testing utility function for token expiration."""
# Set up the mock account # Set up the mock account
mock_oauth_account.oauth_name = "google" mock_oauth_account.oauth_name = "google"