auth bearer fix

x-api-key now says unauthorized
This commit is contained in:
Tiago vasconcelos
2021-10-15 16:21:05 +01:00
parent 9096ed38b9
commit 43653cb84d

View File

@@ -1,5 +1,6 @@
from functools import wraps from functools import wraps
from http import HTTPStatus from http import HTTPStatus
from base64 import b64decode
from fastapi.security import api_key from fastapi.security import api_key
from pydantic.types import UUID4 from pydantic.types import UUID4
@@ -12,6 +13,7 @@ from fastapi.exceptions import HTTPException
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.params import Security from fastapi.params import Security
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from starlette.requests import Request from starlette.requests import Request
@@ -47,13 +49,13 @@ class KeyChecker(SecurityBase):
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Invalid key or expired key.") raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Invalid key or expired key.")
except KeyError: except KeyError:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, raise HTTPException(status_code=HTTPStatus.BAD_REQUEST,
detail="`X-API-KEY` header missing.") detail="`X-API-KEY` header missing.")
class WalletInvoiceKeyChecker(KeyChecker): class WalletInvoiceKeyChecker(KeyChecker):
""" """
WalletInvoiceKeyChecker will ensure that the provided invoice WalletInvoiceKeyChecker will ensure that the provided invoice
wallet key is correct and populate g().wallet with the wallet wallet key is correct and populate g().wallet with the wallet
for the key in `X-API-key`. for the key in `X-API-key`.
The checker will raise an HTTPException when the key is wrong in some ways. The checker will raise an HTTPException when the key is wrong in some ways.
@@ -65,7 +67,7 @@ class WalletInvoiceKeyChecker(KeyChecker):
class WalletAdminKeyChecker(KeyChecker): class WalletAdminKeyChecker(KeyChecker):
""" """
WalletAdminKeyChecker will ensure that the provided admin WalletAdminKeyChecker will ensure that the provided admin
wallet key is correct and populate g().wallet with the wallet wallet key is correct and populate g().wallet with the wallet
for the key in `X-API-key`. for the key in `X-API-key`.
The checker will raise an HTTPException when the key is wrong in some ways. The checker will raise an HTTPException when the key is wrong in some ways.
@@ -85,14 +87,19 @@ class WalletTypeInfo():
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's") api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's")
api_key_query = APIKeyQuery(name="api-key", auto_error=False, description="Admin or Invoice key for wallet API's") api_key_query = APIKeyQuery(name="api-key", auto_error=False, description="Admin or Invoice key for wallet API's")
async def get_key_type(r: Request, oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
api_key_header: str = Security(api_key_header), async def get_key_type(r: Request,
token: str = Security(oauth2_scheme),
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query)) -> WalletTypeInfo: api_key_query: str = Security(api_key_query)) -> WalletTypeInfo:
# 0: admin # 0: admin
# 1: invoice # 1: invoice
# 2: invalid # 2: invalid
# print("TOKEN", b64decode(token).decode("utf-8").split(":"))
key_type, key = b64decode(token).decode("utf-8").split(":")
try: try:
checker = WalletAdminKeyChecker(api_key=api_key_query) checker = WalletAdminKeyChecker(api_key=key if token else api_key_query)
await checker.__call__(r) await checker.__call__(r)
return WalletTypeInfo(0, checker.wallet) return WalletTypeInfo(0, checker.wallet)
except HTTPException as e: except HTTPException as e:
@@ -104,7 +111,7 @@ async def get_key_type(r: Request,
raise raise
try: try:
checker = WalletInvoiceKeyChecker() checker = WalletInvoiceKeyChecker(api_key=key if token else None)
await checker.__call__(r) await checker.__call__(r)
return WalletTypeInfo(1, checker.wallet) return WalletTypeInfo(1, checker.wallet)
except HTTPException as e: except HTTPException as e:
@@ -121,7 +128,7 @@ def api_validate_post_request(*, schema: dict):
async def wrapped_view(**kwargs): async def wrapped_view(**kwargs):
if "application/json" not in request.headers["Content-Type"]: if "application/json" not in request.headers["Content-Type"]:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=jsonify({"message": "Content-Type must be `application/json`."}) detail=jsonify({"message": "Content-Type must be `application/json`."})
) )
@@ -131,10 +138,10 @@ def api_validate_post_request(*, schema: dict):
if not v.validate(g().data): if not v.validate(g().data):
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=jsonify({"message": f"Errors in request data: {v.errors}"}) detail=jsonify({"message": f"Errors in request data: {v.errors}"})
) )
return await view(**kwargs) return await view(**kwargs)
@@ -144,7 +151,7 @@ def api_validate_post_request(*, schema: dict):
async def check_user_exists(usr: UUID4) -> User: async def check_user_exists(usr: UUID4) -> User:
g().user = await get_user(usr.hex) g().user = await get_user(usr.hex)
if not g().user: if not g().user:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, status_code=HTTPStatus.NOT_FOUND,