Minor JWT Feature (#3290)

* first pass

* k

* k

* finalize

* minor cleanup

* k

* address

* minor typing updates
This commit is contained in:
pablodanswer
2024-12-02 11:14:31 -08:00
committed by GitHub
parent af814823c8
commit 813445ab59
5 changed files with 69 additions and 7 deletions

View File

@ -493,10 +493,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration # JWT configuration
JWT_ALGORITHM = "HS256" JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
##### #####
# API Key Configs # API Key Configs

View File

@ -34,7 +34,6 @@ from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType from danswer.configs.constants import AuthType
from danswer.db.api_key import is_api_key_email_address from danswer.db.api_key import is_api_key_email_address
@ -64,6 +63,7 @@ from danswer.server.models import MinimalUserSnapshot
from danswer.server.utils import send_user_email_invite from danswer.server.utils import send_user_email_invite
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.configs.app_configs import SUPER_USERS
from shared_configs.configs import MULTI_TENANT from shared_configs.configs import MULTI_TENANT
logger = setup_logger() logger = setup_logger()

View File

@ -1,23 +1,72 @@
from functools import lru_cache
import requests
from fastapi import Depends from fastapi import Depends
from fastapi import HTTPException from fastapi import HTTPException
from fastapi import Request from fastapi import Request
from fastapi import status from fastapi import status
from jwt import decode as jwt_decode
from jwt import InvalidTokenError
from jwt import PyJWTError
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.constants import AuthType from danswer.configs.constants import AuthType
from danswer.db.models import User from danswer.db.models import User
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import JWT_PUBLIC_KEY_URL
from ee.danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from ee.danswer.configs.app_configs import SUPER_USERS
from ee.danswer.db.saml import get_saml_account from ee.danswer.db.saml import get_saml_account
from ee.danswer.server.seeding import get_seed_config from ee.danswer.server.seeding import get_seed_config
from ee.danswer.utils.secrets import extract_hashed_cookie from ee.danswer.utils.secrets import extract_hashed_cookie
logger = setup_logger() logger = setup_logger()
@lru_cache()
def get_public_key() -> str | None:
if JWT_PUBLIC_KEY_URL is None:
logger.error("JWT_PUBLIC_KEY_URL is not set")
return None
response = requests.get(JWT_PUBLIC_KEY_URL)
response.raise_for_status()
return response.text
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
try:
public_key_pem = get_public_key()
if public_key_pem is None:
logger.error("Failed to retrieve public key")
return None
payload = jwt_decode(
token,
public_key_pem,
algorithms=["RS256"],
audience=None,
)
email = payload.get("email")
if email:
result = await async_db_session.execute(
select(User).where(func.lower(User.email) == func.lower(email))
)
return result.scalars().first()
except InvalidTokenError:
logger.error("Invalid JWT token")
get_public_key.cache_clear()
except PyJWTError as e:
logger.error(f"JWT decoding error: {str(e)}")
get_public_key.cache_clear()
return None
def verify_auth_setting() -> None: def verify_auth_setting() -> None:
# All the Auth flows are valid for EE version # All the Auth flows are valid for EE version
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}") logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
@ -38,6 +87,13 @@ async def optional_user_(
) )
user = saml_account.user if saml_account else None user = saml_account.user if saml_account else None
# If user is still None, check for JWT in Authorization header
if user is None and JWT_PUBLIC_KEY_URL is not None:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[len("Bearer ") :].strip()
user = await verify_jwt_token(token, async_db_session)
return user return user

View File

@ -1,3 +1,4 @@
import json
import os import os
# Applicable for OIDC Auth # Applicable for OIDC Auth
@ -19,3 +20,11 @@ STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY") OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY") ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY") COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
# JWT Public Key URL
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")

View File

@ -130,6 +130,7 @@ services:
restart: always restart: always
environment: environment:
- ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-} - ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-}
- JWT_PUBLIC_KEY_URL=${JWT_PUBLIC_KEY_URL:-} # used for JWT authentication of users via API
# Gen AI Settings (Needed by DanswerBot) # Gen AI Settings (Needed by DanswerBot)
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-} - QA_TIMEOUT=${QA_TIMEOUT:-}