Merge branch 'FastAPI' of https://github.com/arcbtc/lnbits into FastAPI

This commit is contained in:
Tiago vasconcelos
2021-09-29 10:43:19 +01:00
9 changed files with 73 additions and 64 deletions

View File

@@ -96,9 +96,10 @@ def register_routes(app: FastAPI) -> None:
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
ext_route = getattr(ext_module, f"{ext.code}_ext") ext_route = getattr(ext_module, f"{ext.code}_ext")
ext_statics = getattr(ext_module, f"{ext.code}_static_files") if hasattr(ext_module, f"{ext.code}_static_files"):
for s in ext_statics: ext_statics = getattr(ext_module, f"{ext.code}_static_files")
app.mount(s["path"], s["app"], s["name"]) for s in ext_statics:
app.mount(s["path"], s["app"], s["name"])
app.include_router(ext_route) app.include_router(ext_route)
except Exception as e: except Exception as e:

View File

@@ -96,6 +96,8 @@ async def get_key_type(r: Request,
await checker.__call__(r) await checker.__call__(r)
return WalletTypeInfo(0, checker.wallet) return WalletTypeInfo(0, checker.wallet)
except HTTPException as e: except HTTPException as e:
if e.status_code == HTTPStatus.BAD_REQUEST:
raise
if e.status_code == HTTPStatus.UNAUTHORIZED: if e.status_code == HTTPStatus.UNAUTHORIZED:
pass pass
except: except:
@@ -106,6 +108,8 @@ async def get_key_type(r: Request,
await checker.__call__(r) await checker.__call__(r)
return WalletTypeInfo(1, checker.wallet) return WalletTypeInfo(1, checker.wallet)
except HTTPException as e: except HTTPException as e:
if e.status_code == HTTPStatus.BAD_REQUEST:
raise
if e.status_code == HTTPStatus.UNAUTHORIZED: if e.status_code == HTTPStatus.UNAUTHORIZED:
return WalletTypeInfo(2, None) return WalletTypeInfo(2, None)
except: except:

View File

@@ -1,6 +1,4 @@
from fastapi import APIRouter, FastAPI from fastapi import APIRouter
from fastapi.staticfiles import StaticFiles
from starlette.routing import Mount
from lnbits.db import Database from lnbits.db import Database
from lnbits.helpers import template_renderer from lnbits.helpers import template_renderer
@@ -23,7 +21,9 @@ def lnticket_renderer():
from .views_api import * # noqa from .views_api import * # noqa
from .views import * # noqa from .views import * # noqa
from .tasks import register_listeners
@lntickets_ext.on_event("startup") @lnticket_ext.on_event("startup")
def _do_it(): def _do_it():
# FIXME: isn't called yet
register_listeners() register_listeners()

View File

@@ -1,9 +1,10 @@
from lnbits.core.models import Wallet
from typing import List, Optional, Union from typing import List, Optional, Union
from lnbits.helpers import urlsafe_short_hash from lnbits.helpers import urlsafe_short_hash
from . import db from . import db
from .models import Tickets, Forms from .models import CreateFormData, Tickets, Forms
import httpx import httpx
@@ -103,13 +104,8 @@ async def delete_ticket(ticket_id: str) -> None:
async def create_form( async def create_form(
*, data: CreateFormData,
wallet: str, wallet: Wallet,
name: str,
webhook: Optional[str] = None,
description: str,
amount: int,
flatrate: int,
) -> Forms: ) -> Forms:
form_id = urlsafe_short_hash() form_id = urlsafe_short_hash()
await db.execute( await db.execute(
@@ -117,7 +113,7 @@ async def create_form(
INSERT INTO lnticket.form2 (id, wallet, name, webhook, description, flatrate, amount, amountmade) INSERT INTO lnticket.form2 (id, wallet, name, webhook, description, flatrate, amount, amountmade)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", """,
(form_id, wallet, name, webhook, description, flatrate, amount, 0), (form_id, wallet.id, wallet.name, data.webhook, data.description, data.flatrate, data.amount, 0),
) )
form = await get_form(form_id) form = await get_form(form_id)

View File

@@ -1,10 +1,26 @@
from typing import Optional
from fastapi.param_functions import Query
from pydantic import BaseModel from pydantic import BaseModel
class CreateFormData(BaseModel):
name: str = Query(...)
webhook: str = Query(None)
description: str = Query(..., min_length=0)
amount: int = Query(..., ge=0)
flatrate: int = Query(...)
class CreateTicketData(BaseModel):
form: str = Query(...)
name: str = Query(...)
email: str = Query("")
ltext: str = Query(...)
sats: int = Query(..., ge=0)
class Forms(BaseModel): class Forms(BaseModel):
id: str id: str
wallet: str wallet: str
name: str name: str
webhook: str webhook: Optional[str]
description: str description: str
amount: int amount: int
flatrate: int flatrate: int

View File

@@ -1,23 +1,21 @@
import json import asyncio
import trio # type: ignore
from lnbits.core.models import Payment from lnbits.core.models import Payment
from lnbits.core.crud import create_payment from lnbits.tasks import register_invoice_listener
from lnbits.core import db as core_db
from lnbits.tasks import register_invoice_listener, internal_invoice_paid
from lnbits.helpers import urlsafe_short_hash
from .crud import get_ticket, set_ticket_paid from .crud import get_ticket, set_ticket_paid
async def register_listeners(): async def register_listeners():
invoice_paid_chan_send, invoice_paid_chan_recv = trio.open_memory_channel(2) send_queue = asyncio.Queue()
register_invoice_listener(invoice_paid_chan_send) recv_queue = asyncio.Queue()
await wait_for_paid_invoices(invoice_paid_chan_recv) register_invoice_listener(send_queue)
await wait_for_paid_invoices(recv_queue)
async def wait_for_paid_invoices(invoice_paid_chan: trio.MemoryReceiveChannel): async def wait_for_paid_invoices(invoice_paid_queue: asyncio.Queue):
async for payment in invoice_paid_chan: while True:
payment = await invoice_paid_queue.get()
await on_invoice_paid(payment) await on_invoice_paid(payment)

View File

@@ -337,7 +337,7 @@
LNbits.api LNbits.api
.request( .request(
'GET', 'GET',
'/lnticket/api/v1/tickets?all_wallets', '/lnticket/api/v1/tickets?all_wallets=true',
this.g.user.wallets[0].inkey this.g.user.wallets[0].inkey
) )
.then(function (response) { .then(function (response) {
@@ -382,7 +382,7 @@
LNbits.api LNbits.api
.request( .request(
'GET', 'GET',
'/lnticket/api/v1/forms?all_wallets', '/lnticket/api/v1/forms?all_wallets=true',
this.g.user.wallets[0].inkey this.g.user.wallets[0].inkey
) )
.then(function (response) { .then(function (response) {

View File

@@ -1,5 +1,12 @@
<<<<<<< HEAD
=======
from fastapi.param_functions import Depends
from starlette.exceptions import HTTPException
from starlette.responses import HTMLResponse
from lnbits.core.models import User
>>>>>>> f827d2ce181d97368161d46ab8de2e9f061b9872
from lnbits.core.crud import get_wallet from lnbits.core.crud import get_wallet
from lnbits.decorators import check_user_exists, validate_uuids from lnbits.decorators import check_user_exists
from http import HTTPStatus from http import HTTPStatus
from . import lnticket_ext, lnticket_renderer from . import lnticket_ext, lnticket_renderer
@@ -11,7 +18,9 @@ from fastapi.templating import Jinja2Templates
templates = Jinja2Templates(directory="templates") templates = Jinja2Templates(directory="templates")
@lnticket_ext.get("/", response_class=HTMLResponse) @lnticket_ext.get("/", response_class=HTMLResponse)
@validate_uuids(["usr"], required=True) # not needed as we automatically get the user with the given ID
# If no user with this ID is found, an error is raised
# @validate_uuids(["usr"], required=True)
# @check_user_exists() # @check_user_exists()
async def index(request: Request, user: User = Depends(check_user_exists)): async def index(request: Request, user: User = Depends(check_user_exists)):
return lnticket_renderer().TemplateResponse("lnticket/index.html", {"request": request,"user": user.dict()}) return lnticket_renderer().TemplateResponse("lnticket/index.html", {"request": request,"user": user.dict()})

View File

@@ -1,9 +1,10 @@
from lnbits.extensions.lnticket.models import CreateFormData, CreateTicketData
import re import re
from http import HTTPStatus from http import HTTPStatus
from typing import List
from fastapi import FastAPI, Query from fastapi import Query
from fastapi.params import Depends from fastapi.params import Depends
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel from pydantic import BaseModel
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
@@ -12,7 +13,7 @@ from starlette.responses import HTMLResponse, JSONResponse # type: ignore
from lnbits.core.crud import get_user, get_wallet from lnbits.core.crud import get_user, get_wallet
from lnbits.core.services import create_invoice, check_invoice_status from lnbits.core.services import create_invoice, check_invoice_status
from lnbits.decorators import api_check_wallet_key, api_validate_post_request from lnbits.decorators import WalletTypeInfo, get_key_type
from . import lnticket_ext from . import lnticket_ext
from .crud import ( from .crud import (
@@ -32,25 +33,18 @@ from .crud import (
# FORMS # FORMS
@lnticket_ext.get("/api/v1/forms", status_code=HTTPStatus.OK) @lnticket_ext.get("/api/v1/forms")
# @api_check_wallet_key("invoice") async def api_forms_get(r: Request, all_wallets: bool = Query(False), wallet: WalletTypeInfo = Depends(get_key_type)):
async def api_forms(r: Request, wallet: WalletTypeInfo = Depends(get_key_type)):
wallet_ids = [wallet.wallet.id] wallet_ids = [wallet.wallet.id]
<<<<<<< HEAD
if "all_wallets" in r.path_parameters: if "all_wallets" in r.path_parameters:
=======
if all_wallets:
>>>>>>> f827d2ce181d97368161d46ab8de2e9f061b9872
wallet_ids = (await get_user(wallet.wallet.user)).wallet_ids wallet_ids = (await get_user(wallet.wallet.user)).wallet_ids
return ( return [form.dict() for form in await get_forms(wallet_ids)]
[form._asdict() for form in await get_forms(wallet_ids)],
)
class CreateData(BaseModel):
wallet: str = Query(...)
name: str = Query(...)
webhook: str = Query(None)
description: str = Query(..., min_length=0)
amount: int = Query(..., ge=0)
flatrate: int = Query(...)
@lnticket_ext.post("/api/v1/forms", status_code=HTTPStatus.CREATED) @lnticket_ext.post("/api/v1/forms", status_code=HTTPStatus.CREATED)
@lnticket_ext.put("/api/v1/forms/{form_id}") @lnticket_ext.put("/api/v1/forms/{form_id}")
@@ -65,7 +59,7 @@ class CreateData(BaseModel):
# "flatrate": {"type": "integer", "required": True}, # "flatrate": {"type": "integer", "required": True},
# } # }
# ) # )
async def api_form_create(data: CreateData, form_id=None, wallet: WalletTypeInfo = Depends(get_key_type)): async def api_form_create(data: CreateFormData, form_id=None, wallet: WalletTypeInfo = Depends(get_key_type)):
if form_id: if form_id:
form = await get_form(form_id) form = await get_form(form_id)
@@ -85,8 +79,8 @@ async def api_form_create(data: CreateData, form_id=None, wallet: WalletTypeInfo
form = await update_form(form_id, **data) form = await update_form(form_id, **data)
else: else:
form = await create_form(**data) form = await create_form(data, wallet.wallet)
return form._asdict() return form.dict()
@lnticket_ext.delete("/api/v1/forms/{form_id}") @lnticket_ext.delete("/api/v1/forms/{form_id}")
@@ -117,24 +111,15 @@ async def api_form_delete(form_id, wallet: WalletTypeInfo = Depends(get_key_type
#########tickets########## #########tickets##########
@lnticket_ext.get("/api/v1/tickets", status_code=HTTPStatus.OK) @lnticket_ext.get("/api/v1/tickets")
# @api_check_wallet_key("invoice") # @api_check_wallet_key("invoice")
async def api_tickets(all_wallets: bool = Query(None), wallet: WalletTypeInfo = Depends(get_key_type)): async def api_tickets(all_wallets: bool = Query(False), wallet: WalletTypeInfo = Depends(get_key_type)):
wallet_ids = [wallet.wallet.id] wallet_ids = [wallet.wallet.id]
if all_wallets: if all_wallets:
wallet_ids = (await get_user(wallet.wallet.user)).wallet_ids wallet_ids = (await get_user(wallet.wallet.user)).wallet_ids
return ( return [form.dict() for form in await get_tickets(wallet_ids)]
[form._asdict() for form in await get_tickets(wallet_ids)]
)
class CreateTicketData(BaseModel):
form: str = Query(...)
name: str = Query(...)
email: str = Query("")
ltext: str = Query(...)
sats: int = Query(..., ge=0)
@lnticket_ext.post("/api/v1/tickets/{form_id}", status_code=HTTPStatus.CREATED) @lnticket_ext.post("/api/v1/tickets/{form_id}", status_code=HTTPStatus.CREATED)
# @api_validate_post_request( # @api_validate_post_request(