mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 13:05:49 +02:00
JWT -> Redis (#3574)
* functional v1 * functional logout * minor clean up * quick clean up * update configuration * ni * nit * finalize * update login page * delete unused import * quick nit * updates * clean up * ni * k * k
This commit is contained in:
@@ -2,15 +2,14 @@ import logging
|
|||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import jwt
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
|
|
||||||
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
||||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
|
||||||
from onyx.db.engine import is_valid_schema_name
|
from onyx.db.engine import is_valid_schema_name
|
||||||
|
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||||
from shared_configs.configs import MULTI_TENANT
|
from shared_configs.configs import MULTI_TENANT
|
||||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||||
@@ -22,11 +21,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
|||||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||||
) -> Response:
|
) -> Response:
|
||||||
try:
|
try:
|
||||||
tenant_id = (
|
if MULTI_TENANT:
|
||||||
_get_tenant_id_from_request(request, logger)
|
tenant_id = await _get_tenant_id_from_request(request, logger)
|
||||||
if MULTI_TENANT
|
else:
|
||||||
else POSTGRES_DEFAULT_SCHEMA
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
)
|
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
@@ -35,27 +34,36 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
|
async def _get_tenant_id_from_request(
|
||||||
# First check for API key
|
request: Request, logger: logging.LoggerAdapter
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Attempt to extract tenant_id from:
|
||||||
|
1) The API key header
|
||||||
|
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||||
|
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||||
|
"""
|
||||||
|
# Check for API key
|
||||||
tenant_id = extract_tenant_from_api_key_header(request)
|
tenant_id = extract_tenant_from_api_key_header(request)
|
||||||
if tenant_id is not None:
|
if tenant_id:
|
||||||
return tenant_id
|
return tenant_id
|
||||||
|
|
||||||
# Check for cookie-based auth
|
|
||||||
token = request.cookies.get("fastapiusersauth")
|
|
||||||
if not token:
|
|
||||||
return POSTGRES_DEFAULT_SCHEMA
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
# Look up token data in Redis
|
||||||
token,
|
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||||
USER_AUTH_SECRET,
|
|
||||||
audience=["fastapi-users:auth"],
|
|
||||||
algorithms=["HS256"],
|
|
||||||
)
|
|
||||||
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
|
||||||
|
|
||||||
# Since payload.get() can return None, ensure we have a string
|
if not token_data:
|
||||||
|
logger.debug(
|
||||||
|
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||||
|
)
|
||||||
|
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||||
|
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||||
|
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||||
|
return POSTGRES_DEFAULT_SCHEMA
|
||||||
|
|
||||||
|
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||||
|
|
||||||
|
# Since token_data.get() can return None, ensure we have a string
|
||||||
tenant_id = (
|
tenant_id = (
|
||||||
str(tenant_id_from_payload)
|
str(tenant_id_from_payload)
|
||||||
if tenant_id_from_payload is not None
|
if tenant_id_from_payload is not None
|
||||||
@@ -67,9 +75,6 @@ def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter)
|
|||||||
|
|
||||||
return tenant_id
|
return tenant_id
|
||||||
|
|
||||||
except jwt.InvalidTokenError:
|
|
||||||
return POSTGRES_DEFAULT_SCHEMA
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
|
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
@@ -19,7 +19,7 @@ from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
|||||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||||
from onyx.auth.users import auth_backend
|
from onyx.auth.users import auth_backend
|
||||||
from onyx.auth.users import current_admin_user
|
from onyx.auth.users import current_admin_user
|
||||||
from onyx.auth.users import get_jwt_strategy
|
from onyx.auth.users import get_redis_strategy
|
||||||
from onyx.auth.users import User
|
from onyx.auth.users import User
|
||||||
from onyx.configs.app_configs import WEB_DOMAIN
|
from onyx.configs.app_configs import WEB_DOMAIN
|
||||||
from onyx.db.auth import get_user_count
|
from onyx.db.auth import get_user_count
|
||||||
@@ -112,7 +112,7 @@ async def impersonate_user(
|
|||||||
)
|
)
|
||||||
if user_to_impersonate is None:
|
if user_to_impersonate is None:
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
token = await get_jwt_strategy().write_token(user_to_impersonate)
|
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||||
|
|
||||||
response = await auth_backend.transport.get_login_response(token)
|
response = await auth_backend.transport.get_login_response(token)
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
import json
|
||||||
|
import secrets
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -29,10 +31,8 @@ from fastapi_users import schemas
|
|||||||
from fastapi_users import UUIDIDMixin
|
from fastapi_users import UUIDIDMixin
|
||||||
from fastapi_users.authentication import AuthenticationBackend
|
from fastapi_users.authentication import AuthenticationBackend
|
||||||
from fastapi_users.authentication import CookieTransport
|
from fastapi_users.authentication import CookieTransport
|
||||||
from fastapi_users.authentication import JWTStrategy
|
from fastapi_users.authentication import RedisStrategy
|
||||||
from fastapi_users.authentication import Strategy
|
from fastapi_users.authentication import Strategy
|
||||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
|
||||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
|
||||||
from fastapi_users.exceptions import UserAlreadyExists
|
from fastapi_users.exceptions import UserAlreadyExists
|
||||||
from fastapi_users.jwt import decode_jwt
|
from fastapi_users.jwt import decode_jwt
|
||||||
from fastapi_users.jwt import generate_jwt
|
from fastapi_users.jwt import generate_jwt
|
||||||
@@ -59,6 +59,8 @@ from onyx.auth.schemas import UserUpdate
|
|||||||
from onyx.configs.app_configs import AUTH_TYPE
|
from onyx.configs.app_configs import AUTH_TYPE
|
||||||
from onyx.configs.app_configs import DISABLE_AUTH
|
from onyx.configs.app_configs import DISABLE_AUTH
|
||||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||||
|
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
|
||||||
|
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||||
@@ -73,7 +75,6 @@ from onyx.configs.constants import OnyxRedisLocks
|
|||||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||||
from onyx.db.api_key import fetch_user_for_api_key
|
from onyx.db.api_key import fetch_user_for_api_key
|
||||||
from onyx.db.auth import get_access_token_db
|
|
||||||
from onyx.db.auth import get_default_admin_user_emails
|
from onyx.db.auth import get_default_admin_user_emails
|
||||||
from onyx.db.auth import get_user_count
|
from onyx.db.auth import get_user_count
|
||||||
from onyx.db.auth import get_user_db
|
from onyx.db.auth import get_user_db
|
||||||
@@ -81,10 +82,10 @@ from onyx.db.auth import SQLAlchemyUserAdminDB
|
|||||||
from onyx.db.engine import get_async_session
|
from onyx.db.engine import get_async_session
|
||||||
from onyx.db.engine import get_async_session_with_tenant
|
from onyx.db.engine import get_async_session_with_tenant
|
||||||
from onyx.db.engine import get_session_with_tenant
|
from onyx.db.engine import get_session_with_tenant
|
||||||
from onyx.db.models import AccessToken
|
|
||||||
from onyx.db.models import OAuthAccount
|
from onyx.db.models import OAuthAccount
|
||||||
from onyx.db.models import User
|
from onyx.db.models import User
|
||||||
from onyx.db.users import get_user_by_email
|
from onyx.db.users import get_user_by_email
|
||||||
|
from onyx.redis.redis_pool import get_async_redis_connection
|
||||||
from onyx.redis.redis_pool import get_redis_client
|
from onyx.redis.redis_pool import get_redis_client
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from onyx.utils.telemetry import create_milestone_and_report
|
from onyx.utils.telemetry import create_milestone_and_report
|
||||||
@@ -581,49 +582,70 @@ cookie_transport = CookieTransport(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# This strategy is used to add tenant_id to the JWT token
|
def get_redis_strategy() -> RedisStrategy:
|
||||||
class TenantAwareJWTStrategy(JWTStrategy):
|
return TenantAwareRedisStrategy()
|
||||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
|
||||||
|
|
||||||
|
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||||
|
"""
|
||||||
|
A custom strategy that fetches the actual async Redis connection inside each method.
|
||||||
|
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
|
||||||
|
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||||
|
):
|
||||||
|
self.lifetime_seconds = lifetime_seconds
|
||||||
|
self.key_prefix = key_prefix
|
||||||
|
|
||||||
|
async def write_token(self, user: User) -> str:
|
||||||
|
redis = await get_async_redis_connection()
|
||||||
|
|
||||||
tenant_id = await fetch_ee_implementation_or_noop(
|
tenant_id = await fetch_ee_implementation_or_noop(
|
||||||
"onyx.server.tenants.provisioning",
|
"onyx.server.tenants.provisioning",
|
||||||
"get_or_provision_tenant",
|
"get_or_provision_tenant",
|
||||||
async_return_default_schema,
|
async_return_default_schema,
|
||||||
)(
|
)(email=user.email)
|
||||||
email=user.email,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {
|
token_data = {
|
||||||
"sub": str(user.id),
|
"sub": str(user.id),
|
||||||
"aud": self.token_audience,
|
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
}
|
}
|
||||||
return data
|
token = secrets.token_urlsafe()
|
||||||
|
await redis.set(
|
||||||
async def write_token(self, user: User) -> str:
|
f"{self.key_prefix}{token}",
|
||||||
data = await self._create_token_data(user)
|
json.dumps(token_data),
|
||||||
return generate_jwt(
|
ex=self.lifetime_seconds,
|
||||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
|
||||||
)
|
)
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def read_token(
|
||||||
|
self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID]
|
||||||
|
) -> Optional[User]:
|
||||||
|
redis = await get_async_redis_connection()
|
||||||
|
token_data_str = await redis.get(f"{self.key_prefix}{token}")
|
||||||
|
if not token_data_str:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_jwt_strategy() -> TenantAwareJWTStrategy:
|
try:
|
||||||
return TenantAwareJWTStrategy(
|
token_data = json.loads(token_data_str)
|
||||||
secret=USER_AUTH_SECRET,
|
user_id = token_data["sub"]
|
||||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
parsed_id = user_manager.parse_id(user_id)
|
||||||
)
|
return await user_manager.get(parsed_id)
|
||||||
|
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def destroy_token(self, token: str, user: User) -> None:
|
||||||
def get_database_strategy(
|
"""Properly delete the token from async redis."""
|
||||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
redis = await get_async_redis_connection()
|
||||||
) -> DatabaseStrategy:
|
await redis.delete(f"{self.key_prefix}{token}")
|
||||||
return DatabaseStrategy(
|
|
||||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
auth_backend = AuthenticationBackend(
|
auth_backend = AuthenticationBackend(
|
||||||
name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy
|
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
|
|
||||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||||
|
@@ -54,6 +54,10 @@ MASK_CREDENTIAL_PREFIX = (
|
|||||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
|
||||||
|
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 3600
|
||||||
|
)
|
||||||
|
|
||||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||||
) # 7 days
|
) # 7 days
|
||||||
@@ -188,9 +192,11 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
|||||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||||
|
|
||||||
|
|
||||||
|
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||||
|
|
||||||
|
|
||||||
# Rate limiting for auth endpoints
|
# Rate limiting for auth endpoints
|
||||||
|
|
||||||
|
|
||||||
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
||||||
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
||||||
if _rate_limit_window_seconds_str is not None:
|
if _rate_limit_window_seconds_str is not None:
|
||||||
@@ -570,7 +576,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
|||||||
# JWT configuration
|
# JWT configuration
|
||||||
JWT_ALGORITHM = "HS256"
|
JWT_ALGORITHM = "HS256"
|
||||||
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# API Key Configs
|
# API Key Configs
|
||||||
#####
|
#####
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import ssl
|
import ssl
|
||||||
@@ -14,7 +15,6 @@ from typing import ContextManager
|
|||||||
|
|
||||||
import asyncpg # type: ignore
|
import asyncpg # type: ignore
|
||||||
import boto3
|
import boto3
|
||||||
import jwt
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlalchemy import event
|
from sqlalchemy import event
|
||||||
@@ -40,9 +40,9 @@ from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
|
|||||||
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||||
from onyx.configs.app_configs import POSTGRES_PORT
|
from onyx.configs.app_configs import POSTGRES_PORT
|
||||||
from onyx.configs.app_configs import POSTGRES_USER
|
from onyx.configs.app_configs import POSTGRES_USER
|
||||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
|
||||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||||
from onyx.configs.constants import SSL_CERT_FILE
|
from onyx.configs.constants import SSL_CERT_FILE
|
||||||
|
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||||
from onyx.server.utils import BasicAuthenticationError
|
from onyx.server.utils import BasicAuthenticationError
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from shared_configs.configs import MULTI_TENANT
|
from shared_configs.configs import MULTI_TENANT
|
||||||
@@ -322,31 +322,33 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
|||||||
return _ASYNC_ENGINE
|
return _ASYNC_ENGINE
|
||||||
|
|
||||||
|
|
||||||
def get_current_tenant_id(request: Request) -> str:
|
async def get_current_tenant_id(request: Request) -> str:
|
||||||
if not MULTI_TENANT:
|
if not MULTI_TENANT:
|
||||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
return tenant_id
|
return tenant_id
|
||||||
|
|
||||||
token = request.cookies.get("fastapiusersauth")
|
|
||||||
if not token:
|
|
||||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
|
||||||
return current_value
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
# Look up token data in Redis
|
||||||
token,
|
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||||
USER_AUTH_SECRET,
|
|
||||||
audience=["fastapi-users:auth"],
|
if not token_data:
|
||||||
algorithms=["HS256"],
|
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||||
)
|
logger.debug(
|
||||||
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
f"Token data not found or expired in Redis, defaulting to {current_value}"
|
||||||
|
)
|
||||||
|
return current_value
|
||||||
|
|
||||||
|
tenant_id = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||||
|
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||||
|
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
return tenant_id
|
return tenant_id
|
||||||
except jwt.InvalidTokenError:
|
except json.JSONDecodeError:
|
||||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
logger.error("Error decoding token data from Redis")
|
||||||
|
return POSTGRES_DEFAULT_SCHEMA
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Internal server error")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
@@ -1,14 +1,17 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import json
|
||||||
import threading
|
import threading
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
|
from fastapi import Request
|
||||||
from redis import asyncio as aioredis
|
from redis import asyncio as aioredis
|
||||||
from redis.client import Redis
|
from redis.client import Redis
|
||||||
|
|
||||||
|
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||||
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||||
from onyx.configs.app_configs import REDIS_HOST
|
from onyx.configs.app_configs import REDIS_HOST
|
||||||
@@ -228,3 +231,31 @@ async def get_async_redis_connection() -> aioredis.Redis:
|
|||||||
|
|
||||||
# Return the established connection (or pool) for all future operations
|
# Return the established connection (or pool) for all future operations
|
||||||
return _async_redis_connection
|
return _async_redis_connection
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||||
|
token = request.cookies.get("fastapiusersauth")
|
||||||
|
if not token:
|
||||||
|
logger.debug("No auth token cookie found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
redis = await get_async_redis_connection()
|
||||||
|
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||||
|
token_data_str = await redis.get(redis_key)
|
||||||
|
|
||||||
|
if not token_data_str:
|
||||||
|
logger.debug(f"Token key {redis_key} not found or expired in Redis")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return json.loads(token_data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error("Error decoding token data from Redis")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||||
|
)
|
||||||
|
@@ -11,6 +11,7 @@ import React, { useContext, useState, useEffect } from "react";
|
|||||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||||
import { Modal } from "@/components/Modal";
|
import { Modal } from "@/components/Modal";
|
||||||
|
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||||
|
|
||||||
export function Checkbox({
|
export function Checkbox({
|
||||||
label,
|
label,
|
||||||
@@ -218,14 +219,19 @@ export function SettingsForm() {
|
|||||||
handleToggleSettingsField("auto_scroll", e.target.checked)
|
handleToggleSettingsField("auto_scroll", e.target.checked)
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<Checkbox
|
{!NEXT_PUBLIC_CLOUD_ENABLED && (
|
||||||
label="Anonymous Users"
|
<Checkbox
|
||||||
sublabel="If set, users will not be required to sign in to use Danswer."
|
label="Anonymous Users"
|
||||||
checked={settings.anonymous_user_enabled}
|
sublabel="If set, users will not be required to sign in to use Onyx."
|
||||||
onChange={(e) =>
|
checked={settings.anonymous_user_enabled}
|
||||||
handleToggleSettingsField("anonymous_user_enabled", e.target.checked)
|
onChange={(e) =>
|
||||||
}
|
handleToggleSettingsField(
|
||||||
/>
|
"anonymous_user_enabled",
|
||||||
|
e.target.checked
|
||||||
|
)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
{showConfirmModal && (
|
{showConfirmModal && (
|
||||||
<Modal
|
<Modal
|
||||||
width="max-w-3xl w-full"
|
width="max-w-3xl w-full"
|
||||||
|
@@ -104,18 +104,10 @@ const Page = async (props: {
|
|||||||
<span className="px-4 text-gray-500">or</span>
|
<span className="px-4 text-gray-500">or</span>
|
||||||
<div className="flex-grow border-t border-gray-300"></div>
|
<div className="flex-grow border-t border-gray-300"></div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<EmailPasswordForm shouldVerify={true} nextUrl={nextUrl} />
|
<EmailPasswordForm shouldVerify={true} nextUrl={nextUrl} />
|
||||||
|
|
||||||
<div className="flex mt-4 justify-between">
|
<div className="flex mt-4 justify-between">
|
||||||
<Link
|
|
||||||
href={`/auth/signup${
|
|
||||||
searchParams?.next ? `?next=${searchParams.next}` : ""
|
|
||||||
}`}
|
|
||||||
className="text-link font-medium"
|
|
||||||
>
|
|
||||||
Create an account
|
|
||||||
</Link>
|
|
||||||
|
|
||||||
{NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED && (
|
{NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED && (
|
||||||
<Link
|
<Link
|
||||||
href="/auth/forgot-password"
|
href="/auth/forgot-password"
|
||||||
|
@@ -16,7 +16,7 @@ export const getCurrentUser = async (): Promise<User | null> => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const logout = async (): Promise<Response> => {
|
export const logout = async (): Promise<Response> => {
|
||||||
const response = await fetch("/auth/logout", {
|
const response = await fetch("/api/auth/logout", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
credentials: "include",
|
credentials: "include",
|
||||||
});
|
});
|
||||||
|
Reference in New Issue
Block a user