mirror of
https://github.com/lnbits/lnbits.git
synced 2025-08-02 23:12:34 +02:00
auth bearer fix
x-api-key now says unauthorized
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from base64 import b64decode
|
||||||
|
|
||||||
from fastapi.security import api_key
|
from fastapi.security import api_key
|
||||||
from pydantic.types import UUID4
|
from pydantic.types import UUID4
|
||||||
@@ -12,6 +13,7 @@ from fastapi.exceptions import HTTPException
|
|||||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||||
from fastapi.params import Security
|
from fastapi.params import Security
|
||||||
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
|
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from fastapi.security.base import SecurityBase
|
from fastapi.security.base import SecurityBase
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
@@ -47,13 +49,13 @@ class KeyChecker(SecurityBase):
|
|||||||
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Invalid key or expired key.")
|
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Invalid key or expired key.")
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST,
|
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST,
|
||||||
detail="`X-API-KEY` header missing.")
|
detail="`X-API-KEY` header missing.")
|
||||||
|
|
||||||
class WalletInvoiceKeyChecker(KeyChecker):
|
class WalletInvoiceKeyChecker(KeyChecker):
|
||||||
"""
|
"""
|
||||||
WalletInvoiceKeyChecker will ensure that the provided invoice
|
WalletInvoiceKeyChecker will ensure that the provided invoice
|
||||||
wallet key is correct and populate g().wallet with the wallet
|
wallet key is correct and populate g().wallet with the wallet
|
||||||
for the key in `X-API-key`.
|
for the key in `X-API-key`.
|
||||||
|
|
||||||
The checker will raise an HTTPException when the key is wrong in some ways.
|
The checker will raise an HTTPException when the key is wrong in some ways.
|
||||||
@@ -65,7 +67,7 @@ class WalletInvoiceKeyChecker(KeyChecker):
|
|||||||
class WalletAdminKeyChecker(KeyChecker):
|
class WalletAdminKeyChecker(KeyChecker):
|
||||||
"""
|
"""
|
||||||
WalletAdminKeyChecker will ensure that the provided admin
|
WalletAdminKeyChecker will ensure that the provided admin
|
||||||
wallet key is correct and populate g().wallet with the wallet
|
wallet key is correct and populate g().wallet with the wallet
|
||||||
for the key in `X-API-key`.
|
for the key in `X-API-key`.
|
||||||
|
|
||||||
The checker will raise an HTTPException when the key is wrong in some ways.
|
The checker will raise an HTTPException when the key is wrong in some ways.
|
||||||
@@ -85,14 +87,19 @@ class WalletTypeInfo():
|
|||||||
|
|
||||||
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's")
|
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's")
|
||||||
api_key_query = APIKeyQuery(name="api-key", auto_error=False, description="Admin or Invoice key for wallet API's")
|
api_key_query = APIKeyQuery(name="api-key", auto_error=False, description="Admin or Invoice key for wallet API's")
|
||||||
async def get_key_type(r: Request,
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
api_key_header: str = Security(api_key_header),
|
async def get_key_type(r: Request,
|
||||||
|
token: str = Security(oauth2_scheme),
|
||||||
|
api_key_header: str = Security(api_key_header),
|
||||||
api_key_query: str = Security(api_key_query)) -> WalletTypeInfo:
|
api_key_query: str = Security(api_key_query)) -> WalletTypeInfo:
|
||||||
# 0: admin
|
# 0: admin
|
||||||
# 1: invoice
|
# 1: invoice
|
||||||
# 2: invalid
|
# 2: invalid
|
||||||
|
# print("TOKEN", b64decode(token).decode("utf-8").split(":"))
|
||||||
|
|
||||||
|
key_type, key = b64decode(token).decode("utf-8").split(":")
|
||||||
try:
|
try:
|
||||||
checker = WalletAdminKeyChecker(api_key=api_key_query)
|
checker = WalletAdminKeyChecker(api_key=key if token else api_key_query)
|
||||||
await checker.__call__(r)
|
await checker.__call__(r)
|
||||||
return WalletTypeInfo(0, checker.wallet)
|
return WalletTypeInfo(0, checker.wallet)
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
@@ -104,7 +111,7 @@ async def get_key_type(r: Request,
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
try:
|
try:
|
||||||
checker = WalletInvoiceKeyChecker()
|
checker = WalletInvoiceKeyChecker(api_key=key if token else None)
|
||||||
await checker.__call__(r)
|
await checker.__call__(r)
|
||||||
return WalletTypeInfo(1, checker.wallet)
|
return WalletTypeInfo(1, checker.wallet)
|
||||||
except HTTPException as e:
|
except HTTPException as e:
|
||||||
@@ -121,7 +128,7 @@ def api_validate_post_request(*, schema: dict):
|
|||||||
async def wrapped_view(**kwargs):
|
async def wrapped_view(**kwargs):
|
||||||
if "application/json" not in request.headers["Content-Type"]:
|
if "application/json" not in request.headers["Content-Type"]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
detail=jsonify({"message": "Content-Type must be `application/json`."})
|
detail=jsonify({"message": "Content-Type must be `application/json`."})
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -131,10 +138,10 @@ def api_validate_post_request(*, schema: dict):
|
|||||||
|
|
||||||
if not v.validate(g().data):
|
if not v.validate(g().data):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
detail=jsonify({"message": f"Errors in request data: {v.errors}"})
|
detail=jsonify({"message": f"Errors in request data: {v.errors}"})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
return await view(**kwargs)
|
return await view(**kwargs)
|
||||||
|
|
||||||
@@ -144,7 +151,7 @@ def api_validate_post_request(*, schema: dict):
|
|||||||
|
|
||||||
|
|
||||||
async def check_user_exists(usr: UUID4) -> User:
|
async def check_user_exists(usr: UUID4) -> User:
|
||||||
g().user = await get_user(usr.hex)
|
g().user = await get_user(usr.hex)
|
||||||
if not g().user:
|
if not g().user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.NOT_FOUND,
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
|
Reference in New Issue
Block a user