From 69638b4c4e6091a03a92e66e8afb73d2153cf213 Mon Sep 17 00:00:00 2001 From: pablonyx Date: Sat, 15 Mar 2025 11:21:26 -0700 Subject: [PATCH] misc typing --- backend/onyx/auth/oauth_refresher.py | 20 +++++---- backend/onyx/auth/users.py | 33 ++++++++++----- .../unit/onyx/auth/test_oauth_refresher.py | 42 +++++++++++++------ 3 files changed, 64 insertions(+), 31 deletions(-) diff --git a/backend/onyx/auth/oauth_refresher.py b/backend/onyx/auth/oauth_refresher.py index 9e5e4b29d..127157d2a 100644 --- a/backend/onyx/auth/oauth_refresher.py +++ b/backend/onyx/auth/oauth_refresher.py @@ -1,9 +1,13 @@ from datetime import datetime from datetime import timezone from typing import Any +from typing import cast +from typing import Dict +from typing import List from typing import Optional import httpx +from fastapi_users.manager import BaseUserManager from sqlalchemy.ext.asyncio import AsyncSession from onyx.configs.app_configs import OAUTH_CLIENT_ID @@ -27,7 +31,7 @@ async def _test_expire_oauth_token( user: User, oauth_account: OAuthAccount, db_session: AsyncSession, - user_manager: Any, + user_manager: BaseUserManager[User, Any], expire_in_seconds: int = 10, ) -> bool: """ @@ -40,10 +44,10 @@ async def _test_expire_oauth_token( (datetime.now(timezone.utc).timestamp() + expire_in_seconds) ) - updated_data = {"expires_at": new_expires_at} + updated_data: Dict[str, Any] = {"expires_at": new_expires_at} await user_manager.user_db.update_oauth_account( - user, oauth_account, updated_data + user, cast(Any, oauth_account), updated_data ) return True @@ -56,7 +60,7 @@ async def refresh_oauth_token( user: User, oauth_account: OAuthAccount, db_session: AsyncSession, - user_manager: Any, + user_manager: BaseUserManager[User, Any], ) -> bool: """ Attempt to refresh an OAuth token that's about to expire or has expired. @@ -110,7 +114,7 @@ async def refresh_oauth_token( ) # Update the OAuth account - updated_data = { + updated_data: Dict[str, Any] = { "access_token": new_access_token, "refresh_token": new_refresh_token, } @@ -129,7 +133,7 @@ async def refresh_oauth_token( # Update the OAuth account await user_manager.user_db.update_oauth_account( - user, oauth_account, updated_data + user, cast(Any, oauth_account), updated_data ) logger.info(f"Successfully refreshed OAuth token for {user.email}") @@ -143,7 +147,7 @@ async def refresh_oauth_token( async def check_and_refresh_oauth_tokens( user: User, db_session: AsyncSession, - user_manager: Any, + user_manager: BaseUserManager[User, Any], ) -> None: """ Check if any OAuth tokens are expired or about to expire and refresh them. @@ -188,7 +192,7 @@ async def check_oauth_account_has_refresh_token( return bool(oauth_account.refresh_token) -async def get_oauth_accounts_requiring_refresh_token(user: User) -> list[OAuthAccount]: +async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]: """ Returns a list of OAuth accounts for a user that are missing refresh tokens. These accounts will need re-authentication to get refresh tokens. diff --git a/backend/onyx/auth/users.py b/backend/onyx/auth/users.py index 4aad1412e..316e76d5d 100644 --- a/backend/onyx/auth/users.py +++ b/backend/onyx/auth/users.py @@ -7,9 +7,9 @@ from collections.abc import AsyncGenerator from datetime import datetime from datetime import timedelta from datetime import timezone +from typing import Any from typing import cast from typing import Dict -from typing import Generic from typing import List from typing import Optional from typing import Protocol @@ -691,15 +691,15 @@ cookie_transport = CookieTransport( ) -T = TypeVar("T") -ID = TypeVar("ID") +T = TypeVar("T", covariant=True) +ID = TypeVar("ID", contravariant=True) # Protocol for strategies that support token refreshing without inheritance. -class RefreshableStrategy(Protocol, Generic[T, ID]): +class RefreshableStrategy(Protocol): """Protocol for authentication strategies that support token refreshing.""" - async def refresh_token(self, token: str, user: T) -> str: + async def refresh_token(self, token: Optional[str], user: Any) -> str: """ Refresh an existing token by extending its lifetime. Returns either the same token with extended expiration or a new token. @@ -792,13 +792,21 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]): class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]): """Database strategy with token refreshing capabilities.""" - async def refresh_token(self, token: str, user: User) -> str: + def __init__( + self, + access_token_db: AccessTokenDatabase[AccessToken], + lifetime_seconds: Optional[int] = None, + ): + super().__init__(access_token_db, lifetime_seconds) + self._access_token_db = access_token_db + + async def refresh_token(self, token: Optional[str], user: User) -> str: """Refresh a token by updating its expiration time in the database.""" if token is None: return await self.write_token(user) # Find the token in database - access_token = await self.access_token_db.get_by_token(token) + access_token = await self._access_token_db.get_by_token(token) if access_token is None: # Token not found, create new one @@ -806,9 +814,9 @@ class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken] # Update expiration time new_expires = datetime.now(timezone.utc) + timedelta( - seconds=self.lifetime_seconds + seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS) ) - await self.access_token_db.update(access_token, {"expires": new_expires}) + await self._access_token_db.update(access_token, {"expires": new_expires}) return token @@ -917,7 +925,9 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]): # Check if user has OAuth accounts that need refreshing await check_and_refresh_oauth_tokens( - user=user, db_session=db_session, user_manager=user_manager + user=cast(User, user), + db_session=db_session, + user_manager=cast(Any, user_manager), ) # Check if strategy supports refreshing @@ -927,7 +937,8 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]): if supports_refresh: try: - new_token = await strategy.refresh_token(token, user) + refresh_method = getattr(strategy, "refresh_token") + new_token = await refresh_method(token, user) logger.info( f"Successfully refreshed session token for user {user.email}" ) diff --git a/backend/tests/unit/onyx/auth/test_oauth_refresher.py b/backend/tests/unit/onyx/auth/test_oauth_refresher.py index 9f4d2a6ea..ad8479093 100644 --- a/backend/tests/unit/onyx/auth/test_oauth_refresher.py +++ b/backend/tests/unit/onyx/auth/test_oauth_refresher.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock from unittest.mock import patch import pytest +from sqlalchemy.ext.asyncio import AsyncSession from onyx.auth.oauth_refresher import _test_expire_oauth_token from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens @@ -16,8 +17,11 @@ from onyx.db.models import OAuthAccount @pytest.mark.asyncio async def test_refresh_oauth_token_success( - mock_user, mock_oauth_account, mock_user_manager, mock_db_session -): + mock_user: MagicMock, + mock_oauth_account: MagicMock, + mock_user_manager: MagicMock, + mock_db_session: AsyncSession, +) -> None: """Test successful OAuth token refresh.""" # Mock HTTP client and response mock_response = MagicMock() @@ -63,8 +67,11 @@ async def test_refresh_oauth_token_success( @pytest.mark.asyncio async def test_refresh_oauth_token_failure( - mock_user, mock_oauth_account, mock_user_manager, mock_db_session -): + mock_user: MagicMock, + mock_oauth_account: MagicMock, + mock_user_manager: MagicMock, + mock_db_session: AsyncSession, +) -> bool: """Test OAuth token refresh failure due to HTTP error.""" # Mock HTTP client with error response mock_response = MagicMock() @@ -92,12 +99,16 @@ async def test_refresh_oauth_token_failure( assert result is False mock_client.post.assert_called_once() mock_user_manager.user_db.update_oauth_account.assert_not_called() + return True @pytest.mark.asyncio async def test_refresh_oauth_token_no_refresh_token( - mock_user, mock_oauth_account, mock_user_manager, mock_db_session -): + mock_user: MagicMock, + mock_oauth_account: MagicMock, + mock_user_manager: MagicMock, + mock_db_session: AsyncSession, +) -> None: """Test OAuth token refresh when no refresh token is available.""" # Set refresh token to None mock_oauth_account.refresh_token = None @@ -114,8 +125,10 @@ async def test_refresh_oauth_token_no_refresh_token( @pytest.mark.asyncio async def test_check_and_refresh_oauth_tokens( - mock_user, mock_user_manager, mock_db_session -): + mock_user: MagicMock, + mock_user_manager: MagicMock, + mock_db_session: AsyncSession, +) -> None: """Test checking and refreshing multiple OAuth tokens.""" # Create mock user with OAuth accounts now_timestamp = datetime.now(timezone.utc).timestamp() @@ -161,7 +174,7 @@ async def test_check_and_refresh_oauth_tokens( @pytest.mark.asyncio -async def test_get_oauth_accounts_requiring_refresh_token(mock_user): +async def test_get_oauth_accounts_requiring_refresh_token(mock_user: MagicMock) -> None: """Test identifying OAuth accounts that need refresh tokens.""" # Create accounts with and without refresh tokens account_with_token = MagicMock(spec=OAuthAccount) @@ -198,7 +211,9 @@ async def test_get_oauth_accounts_requiring_refresh_token(mock_user): @pytest.mark.asyncio -async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_account): +async def test_check_oauth_account_has_refresh_token( + mock_user: MagicMock, mock_oauth_account: MagicMock +) -> None: """Test checking if an OAuth account has a refresh token.""" # Test with refresh token mock_oauth_account.refresh_token = "refresh_token" @@ -224,8 +239,11 @@ async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_accou @pytest.mark.asyncio async def test_test_expire_oauth_token( - mock_user, mock_oauth_account, mock_user_manager, mock_db_session -): + mock_user: MagicMock, + mock_oauth_account: MagicMock, + mock_user_manager: MagicMock, + mock_db_session: AsyncSession, +) -> None: """Test the testing utility function for token expiration.""" # Set up the mock account mock_oauth_account.oauth_name = "google"