mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
JWT -> Redis (#3574)
* functional v1 * functional logout * minor clean up * quick clean up * update configuration * ni * nit * finalize * update login page * delete unused import * quick nit * updates * clean up * ni * k * k
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
import secrets
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
@@ -29,10 +31,8 @@ from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from fastapi_users.authentication import RedisStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
@@ -59,6 +59,8 @@ from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
@@ -73,7 +75,6 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
from onyx.db.auth import get_access_token_db
|
||||
from onyx.db.auth import get_default_admin_user_emails
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
@@ -81,10 +82,10 @@ from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
@@ -581,49 +582,70 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
"""
|
||||
A custom strategy that fetches the actual async Redis connection inside each method.
|
||||
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
|
||||
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||
):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
)
|
||||
)(email=user.email)
|
||||
|
||||
data = {
|
||||
token_data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
return data
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
data = await self._create_token_data(user)
|
||||
return generate_jwt(
|
||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||
token = secrets.token_urlsafe()
|
||||
await redis.set(
|
||||
f"{self.key_prefix}{token}",
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
)
|
||||
return token
|
||||
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID]
|
||||
) -> Optional[User]:
|
||||
redis = await get_async_redis_connection()
|
||||
token_data_str = await redis.get(f"{self.key_prefix}{token}")
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
def get_jwt_strategy() -> TenantAwareJWTStrategy:
|
||||
return TenantAwareJWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
try:
|
||||
token_data = json.loads(token_data_str)
|
||||
user_id = token_data["sub"]
|
||||
parsed_id = user_manager.parse_id(user_id)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
async def destroy_token(self, token: str, user: User) -> None:
|
||||
"""Properly delete the token from async redis."""
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy
|
||||
) # type: ignore
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
Reference in New Issue
Block a user