added admin key required

This commit is contained in:
Tiago vasconcelos
2021-10-18 16:06:06 +01:00
parent cfd37ec31e
commit 4739a0811d

View File

@@ -1,24 +1,17 @@
from functools import wraps
from http import HTTPStatus from http import HTTPStatus
from base64 import b64decode
from fastapi.security import api_key
from pydantic.types import UUID4
from lnbits.core.models import User, Wallet
from typing import List, Union
from uuid import UUID
from cerberus import Validator # type: ignore from cerberus import Validator # type: ignore
from fastapi import status
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.params import Security from fastapi.params import Security
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
from fastapi.security import OAuth2PasswordBearer
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
from fastapi import status from pydantic.types import UUID4
from starlette.requests import Request from starlette.requests import Request
from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.core.models import User, Wallet
from lnbits.requestvars import g from lnbits.requestvars import g
from lnbits.settings import LNBITS_ALLOWED_USERS from lnbits.settings import LNBITS_ALLOWED_USERS
@@ -160,71 +153,23 @@ async def get_key_type(
raise raise
# api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's") async def require_admin_key(
# api_key_query = APIKeyQuery(name="api-key", auto_error=False, description="Admin or Invoice key for wallet API's") r: Request,
# oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") api_key_header: str = Security(api_key_header),
# async def get_key_type(r: Request, api_key_query: str = Security(api_key_query),
# token: str = Security(oauth2_scheme), ):
# api_key_header: str = Security(api_key_header), token = api_key_header if api_key_header else api_key_query
# api_key_query: str = Security(api_key_query)) -> WalletTypeInfo:
# # 0: admin
# # 1: invoice
# # 2: invalid
# # print("TOKEN", b64decode(token).decode("utf-8").split(":"))
#
# key_type, key = b64decode(token).decode("utf-8").split(":")
# try:
# checker = WalletAdminKeyChecker(api_key=key if token else api_key_query)
# await checker.__call__(r)
# return WalletTypeInfo(0, checker.wallet)
# except HTTPException as e:
# if e.status_code == HTTPStatus.BAD_REQUEST:
# raise
# if e.status_code == HTTPStatus.UNAUTHORIZED:
# pass
# except:
# raise
#
# try:
# checker = WalletInvoiceKeyChecker(api_key=key if token else None)
# await checker.__call__(r)
# return WalletTypeInfo(1, checker.wallet)
# except HTTPException as e:
# if e.status_code == HTTPStatus.BAD_REQUEST:
# raise
# if e.status_code == HTTPStatus.UNAUTHORIZED:
# return WalletTypeInfo(2, None)
# except:
# raise
wallet = await get_key_type(r, token)
def api_validate_post_request(*, schema: dict): if wallet.wallet_type != 0:
def wrap(view): # If wallet type is not admin then return the unauthorized status
@wraps(view) # This also covers when the user passes an invalid key type
async def wrapped_view(**kwargs): raise HTTPException(
if "application/json" not in request.headers["Content-Type"]: status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin key required."
raise HTTPException( )
status_code=HTTPStatus.BAD_REQUEST, else:
detail=jsonify( return wallet
{"message": "Content-Type must be `application/json`."}
),
)
v = Validator(schema)
data = await request.get_json()
g().data = {key: data[key] for key in schema.keys() if key in data}
if not v.validate(g().data):
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=jsonify({"message": f"Errors in request data: {v.errors}"}),
)
return await view(**kwargs)
return wrapped_view
return wrap
async def check_user_exists(usr: UUID4) -> User: async def check_user_exists(usr: UUID4) -> User: