mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 21:32:36 +01:00
misc typing
This commit is contained in:
parent
8821f399f0
commit
69638b4c4e
@ -1,9 +1,13 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi_users.manager import BaseUserManager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
@ -27,7 +31,7 @@ async def _test_expire_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: Any,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
expire_in_seconds: int = 10,
|
||||
) -> bool:
|
||||
"""
|
||||
@ -40,10 +44,10 @@ async def _test_expire_oauth_token(
|
||||
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
|
||||
)
|
||||
|
||||
updated_data = {"expires_at": new_expires_at}
|
||||
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
|
||||
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, oauth_account, updated_data
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
return True
|
||||
@ -56,7 +60,7 @@ async def refresh_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: Any,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Attempt to refresh an OAuth token that's about to expire or has expired.
|
||||
@ -110,7 +114,7 @@ async def refresh_oauth_token(
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
updated_data = {
|
||||
updated_data: Dict[str, Any] = {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
}
|
||||
@ -129,7 +133,7 @@ async def refresh_oauth_token(
|
||||
|
||||
# Update the OAuth account
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, oauth_account, updated_data
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
logger.info(f"Successfully refreshed OAuth token for {user.email}")
|
||||
@ -143,7 +147,7 @@ async def refresh_oauth_token(
|
||||
async def check_and_refresh_oauth_tokens(
|
||||
user: User,
|
||||
db_session: AsyncSession,
|
||||
user_manager: Any,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Check if any OAuth tokens are expired or about to expire and refresh them.
|
||||
@ -188,7 +192,7 @@ async def check_oauth_account_has_refresh_token(
|
||||
return bool(oauth_account.refresh_token)
|
||||
|
||||
|
||||
async def get_oauth_accounts_requiring_refresh_token(user: User) -> list[OAuthAccount]:
|
||||
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
|
||||
"""
|
||||
Returns a list of OAuth accounts for a user that are missing refresh tokens.
|
||||
These accounts will need re-authentication to get refresh tokens.
|
||||
|
@ -7,9 +7,9 @@ from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Generic
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
@ -691,15 +691,15 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
ID = TypeVar("ID")
|
||||
T = TypeVar("T", covariant=True)
|
||||
ID = TypeVar("ID", contravariant=True)
|
||||
|
||||
|
||||
# Protocol for strategies that support token refreshing without inheritance.
|
||||
class RefreshableStrategy(Protocol, Generic[T, ID]):
|
||||
class RefreshableStrategy(Protocol):
|
||||
"""Protocol for authentication strategies that support token refreshing."""
|
||||
|
||||
async def refresh_token(self, token: str, user: T) -> str:
|
||||
async def refresh_token(self, token: Optional[str], user: Any) -> str:
|
||||
"""
|
||||
Refresh an existing token by extending its lifetime.
|
||||
Returns either the same token with extended expiration or a new token.
|
||||
@ -792,13 +792,21 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
|
||||
"""Database strategy with token refreshing capabilities."""
|
||||
|
||||
async def refresh_token(self, token: str, user: User) -> str:
|
||||
def __init__(
|
||||
self,
|
||||
access_token_db: AccessTokenDatabase[AccessToken],
|
||||
lifetime_seconds: Optional[int] = None,
|
||||
):
|
||||
super().__init__(access_token_db, lifetime_seconds)
|
||||
self._access_token_db = access_token_db
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by updating its expiration time in the database."""
|
||||
if token is None:
|
||||
return await self.write_token(user)
|
||||
|
||||
# Find the token in database
|
||||
access_token = await self.access_token_db.get_by_token(token)
|
||||
access_token = await self._access_token_db.get_by_token(token)
|
||||
|
||||
if access_token is None:
|
||||
# Token not found, create new one
|
||||
@ -806,9 +814,9 @@ class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]
|
||||
|
||||
# Update expiration time
|
||||
new_expires = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=self.lifetime_seconds
|
||||
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
|
||||
)
|
||||
await self.access_token_db.update(access_token, {"expires": new_expires})
|
||||
await self._access_token_db.update(access_token, {"expires": new_expires})
|
||||
|
||||
return token
|
||||
|
||||
@ -917,7 +925,9 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
||||
# Check if user has OAuth accounts that need refreshing
|
||||
await check_and_refresh_oauth_tokens(
|
||||
user=user, db_session=db_session, user_manager=user_manager
|
||||
user=cast(User, user),
|
||||
db_session=db_session,
|
||||
user_manager=cast(Any, user_manager),
|
||||
)
|
||||
|
||||
# Check if strategy supports refreshing
|
||||
@ -927,7 +937,8 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
||||
if supports_refresh:
|
||||
try:
|
||||
new_token = await strategy.refresh_token(token, user)
|
||||
refresh_method = getattr(strategy, "refresh_token")
|
||||
new_token = await refresh_method(token, user)
|
||||
logger.info(
|
||||
f"Successfully refreshed session token for user {user.email}"
|
||||
)
|
||||
|
@ -5,6 +5,7 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.oauth_refresher import _test_expire_oauth_token
|
||||
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
|
||||
@ -16,8 +17,11 @@ from onyx.db.models import OAuthAccount
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_success(
|
||||
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
|
||||
):
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test successful OAuth token refresh."""
|
||||
# Mock HTTP client and response
|
||||
mock_response = MagicMock()
|
||||
@ -63,8 +67,11 @@ async def test_refresh_oauth_token_success(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_failure(
|
||||
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
|
||||
):
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> bool:
|
||||
"""Test OAuth token refresh failure due to HTTP error."""
|
||||
# Mock HTTP client with error response
|
||||
mock_response = MagicMock()
|
||||
@ -92,12 +99,16 @@ async def test_refresh_oauth_token_failure(
|
||||
assert result is False
|
||||
mock_client.post.assert_called_once()
|
||||
mock_user_manager.user_db.update_oauth_account.assert_not_called()
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_oauth_token_no_refresh_token(
|
||||
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
|
||||
):
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test OAuth token refresh when no refresh token is available."""
|
||||
# Set refresh token to None
|
||||
mock_oauth_account.refresh_token = None
|
||||
@ -114,8 +125,10 @@ async def test_refresh_oauth_token_no_refresh_token(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_and_refresh_oauth_tokens(
|
||||
mock_user, mock_user_manager, mock_db_session
|
||||
):
|
||||
mock_user: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test checking and refreshing multiple OAuth tokens."""
|
||||
# Create mock user with OAuth accounts
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
@ -161,7 +174,7 @@ async def test_check_and_refresh_oauth_tokens(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_oauth_accounts_requiring_refresh_token(mock_user):
|
||||
async def test_get_oauth_accounts_requiring_refresh_token(mock_user: MagicMock) -> None:
|
||||
"""Test identifying OAuth accounts that need refresh tokens."""
|
||||
# Create accounts with and without refresh tokens
|
||||
account_with_token = MagicMock(spec=OAuthAccount)
|
||||
@ -198,7 +211,9 @@ async def test_get_oauth_accounts_requiring_refresh_token(mock_user):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_account):
|
||||
async def test_check_oauth_account_has_refresh_token(
|
||||
mock_user: MagicMock, mock_oauth_account: MagicMock
|
||||
) -> None:
|
||||
"""Test checking if an OAuth account has a refresh token."""
|
||||
# Test with refresh token
|
||||
mock_oauth_account.refresh_token = "refresh_token"
|
||||
@ -224,8 +239,11 @@ async def test_check_oauth_account_has_refresh_token(mock_user, mock_oauth_accou
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_expire_oauth_token(
|
||||
mock_user, mock_oauth_account, mock_user_manager, mock_db_session
|
||||
):
|
||||
mock_user: MagicMock,
|
||||
mock_oauth_account: MagicMock,
|
||||
mock_user_manager: MagicMock,
|
||||
mock_db_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test the testing utility function for token expiration."""
|
||||
# Set up the mock account
|
||||
mock_oauth_account.oauth_name = "google"
|
||||
|
Loading…
x
Reference in New Issue
Block a user