From 813445ab59d387fce73633bafb0d41190e8907e5 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 2 Dec 2024 11:14:31 -0800 Subject: [PATCH] Minor JWT Feature (#3290) * first pass * k * k * finalize * minor cleanup * k * address * minor typing updates --- backend/danswer/configs/app_configs.py | 4 -- backend/danswer/server/manage/users.py | 2 +- backend/ee/danswer/auth/users.py | 60 ++++++++++++++++++- backend/ee/danswer/configs/app_configs.py | 9 +++ .../docker_compose/docker-compose.dev.yml | 1 + 5 files changed, 69 insertions(+), 7 deletions(-) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index e06d3f5c4..1b630b953 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -493,10 +493,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get( # JWT configuration JWT_ALGORITHM = "HS256" -# Super Users -SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]')) -SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key") - ##### # API Key Configs diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 4ebda084c..12fe7ef29 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -34,7 +34,6 @@ from danswer.auth.users import optional_user from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import ENABLE_EMAIL_INVITES from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS -from danswer.configs.app_configs import SUPER_USERS from danswer.configs.app_configs import VALID_EMAIL_DOMAINS from danswer.configs.constants import AuthType from danswer.db.api_key import is_api_key_email_address @@ -64,6 +63,7 @@ from danswer.server.models import MinimalUserSnapshot from danswer.server.utils import send_user_email_invite from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop +from ee.danswer.configs.app_configs import SUPER_USERS from shared_configs.configs import MULTI_TENANT logger = setup_logger() diff --git a/backend/ee/danswer/auth/users.py b/backend/ee/danswer/auth/users.py index aab88efa8..3d44acc5e 100644 --- a/backend/ee/danswer/auth/users.py +++ b/backend/ee/danswer/auth/users.py @@ -1,23 +1,72 @@ +from functools import lru_cache + +import requests from fastapi import Depends from fastapi import HTTPException from fastapi import Request from fastapi import status +from jwt import decode as jwt_decode +from jwt import InvalidTokenError +from jwt import PyJWTError +from sqlalchemy import func +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from danswer.auth.users import current_admin_user from danswer.configs.app_configs import AUTH_TYPE -from danswer.configs.app_configs import SUPER_CLOUD_API_KEY -from danswer.configs.app_configs import SUPER_USERS from danswer.configs.constants import AuthType from danswer.db.models import User from danswer.utils.logger import setup_logger +from ee.danswer.configs.app_configs import JWT_PUBLIC_KEY_URL +from ee.danswer.configs.app_configs import SUPER_CLOUD_API_KEY +from ee.danswer.configs.app_configs import SUPER_USERS from ee.danswer.db.saml import get_saml_account from ee.danswer.server.seeding import get_seed_config from ee.danswer.utils.secrets import extract_hashed_cookie + logger = setup_logger() +@lru_cache() +def get_public_key() -> str | None: + if JWT_PUBLIC_KEY_URL is None: + logger.error("JWT_PUBLIC_KEY_URL is not set") + return None + + response = requests.get(JWT_PUBLIC_KEY_URL) + response.raise_for_status() + return response.text + + +async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None: + try: + public_key_pem = get_public_key() + if public_key_pem is None: + logger.error("Failed to retrieve public key") + return None + + payload = jwt_decode( + token, + public_key_pem, + algorithms=["RS256"], + audience=None, + ) + email = payload.get("email") + if email: + result = await async_db_session.execute( + select(User).where(func.lower(User.email) == func.lower(email)) + ) + return result.scalars().first() + except InvalidTokenError: + logger.error("Invalid JWT token") + get_public_key.cache_clear() + except PyJWTError as e: + logger.error(f"JWT decoding error: {str(e)}") + get_public_key.cache_clear() + return None + + def verify_auth_setting() -> None: # All the Auth flows are valid for EE version logger.notice(f"Using Auth Type: {AUTH_TYPE.value}") @@ -38,6 +87,13 @@ async def optional_user_( ) user = saml_account.user if saml_account else None + # If user is still None, check for JWT in Authorization header + if user is None and JWT_PUBLIC_KEY_URL is not None: + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + token = auth_header[len("Bearer ") :].strip() + user = await verify_jwt_token(token, async_db_session) + return user diff --git a/backend/ee/danswer/configs/app_configs.py b/backend/ee/danswer/configs/app_configs.py index 7e1ade5f3..f9547d078 100644 --- a/backend/ee/danswer/configs/app_configs.py +++ b/backend/ee/danswer/configs/app_configs.py @@ -1,3 +1,4 @@ +import json import os # Applicable for OIDC Auth @@ -19,3 +20,11 @@ STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE") OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY") ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY") COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY") + +# JWT Public Key URL +JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None) + + +# Super Users +SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]')) +SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key") diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 101080930..bcd73b729 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -130,6 +130,7 @@ services: restart: always environment: - ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-} + - JWT_PUBLIC_KEY_URL=${JWT_PUBLIC_KEY_URL:-} # used for JWT authentication of users via API # Gen AI Settings (Needed by DanswerBot) - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-}