Merge pull request #1303 from lnbits/fix/mypy-watchonly

fix mypy issues for watchonly
This commit is contained in:
calle
2023-01-09 14:30:05 +01:00
committed by GitHub
6 changed files with 44 additions and 44 deletions

View File

@@ -41,8 +41,9 @@ async def create_watch_wallet(user: str, w: WalletAccount) -> WalletAccount:
w.meta, w.meta,
), ),
) )
wallet = await get_watch_wallet(wallet_id)
return await get_watch_wallet(wallet_id) assert wallet
return wallet
async def get_watch_wallet(wallet_id: str) -> Optional[WalletAccount]: async def get_watch_wallet(wallet_id: str) -> Optional[WalletAccount]:
@@ -121,11 +122,11 @@ async def create_fresh_addresses(
change_address=False, change_address=False,
) -> List[Address]: ) -> List[Address]:
if start_address_index > end_address_index: if start_address_index > end_address_index:
return None return []
wallet = await get_watch_wallet(wallet_id) wallet = await get_watch_wallet(wallet_id)
if not wallet: if not wallet:
return None return []
branch_index = 1 if change_address else 0 branch_index = 1 if change_address else 0
@@ -150,7 +151,7 @@ async def create_fresh_addresses(
# return fresh addresses # return fresh addresses
rows = await db.fetchall( rows = await db.fetchall(
""" """
SELECT * FROM watchonly.addresses SELECT * FROM watchonly.addresses
WHERE wallet = ? AND branch_index = ? AND address_index >= ? AND address_index < ? WHERE wallet = ? AND branch_index = ? AND address_index >= ? AND address_index < ?
ORDER BY branch_index, address_index ORDER BY branch_index, address_index
""", """,
@@ -172,7 +173,7 @@ async def get_address_at_index(
) -> Optional[Address]: ) -> Optional[Address]:
row = await db.fetchone( row = await db.fetchone(
""" """
SELECT * FROM watchonly.addresses SELECT * FROM watchonly.addresses
WHERE wallet = ? AND branch_index = ? AND address_index = ? WHERE wallet = ? AND branch_index = ? AND address_index = ?
""", """,
( (

View File

@@ -1,6 +1,6 @@
from embit.descriptor import Descriptor, Key # type: ignore from embit.descriptor import Descriptor, Key
from embit.descriptor.arguments import AllowedDerivation # type: ignore from embit.descriptor.arguments import AllowedDerivation
from embit.networks import NETWORKS # type: ignore from embit.networks import NETWORKS
def detect_network(k): def detect_network(k):

View File

@@ -1,7 +1,7 @@
from sqlite3 import Row from sqlite3 import Row
from typing import List, Optional from typing import List, Optional
from fastapi.param_functions import Query from fastapi import Query
from pydantic import BaseModel from pydantic import BaseModel
@@ -35,7 +35,7 @@ class Address(BaseModel):
amount: int = 0 amount: int = 0
branch_index: int = 0 branch_index: int = 0
address_index: int address_index: int
note: str = None note: Optional[str] = None
has_activity: bool = False has_activity: bool = False
@classmethod @classmethod
@@ -57,9 +57,9 @@ class TransactionInput(BaseModel):
class TransactionOutput(BaseModel): class TransactionOutput(BaseModel):
amount: int amount: int
address: str address: str
branch_index: int = None branch_index: Optional[int] = None
address_index: int = None address_index: Optional[int] = None
wallet: str = None wallet: Optional[str] = None
class MasterPublicKey(BaseModel): class MasterPublicKey(BaseModel):

View File

@@ -1,6 +1,5 @@
from fastapi.params import Depends from fastapi import Depends, Request
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from starlette.requests import Request
from starlette.responses import HTMLResponse from starlette.responses import HTMLResponse
from lnbits.core.models import User from lnbits.core.models import User

View File

@@ -1,5 +1,6 @@
import json import json
from http import HTTPStatus from http import HTTPStatus
from typing import List
import httpx import httpx
from embit import finalizer, script from embit import finalizer, script
@@ -7,9 +8,7 @@ from embit.ec import PublicKey
from embit.networks import NETWORKS from embit.networks import NETWORKS
from embit.psbt import PSBT, DerivationPath from embit.psbt import PSBT, DerivationPath
from embit.transaction import Transaction, TransactionInput, TransactionOutput from embit.transaction import Transaction, TransactionInput, TransactionOutput
from fastapi import Query, Request from fastapi import Depends, HTTPException, Query, Request
from fastapi.params import Depends
from starlette.exceptions import HTTPException
from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key
from lnbits.extensions.watchonly import watchonly_ext from lnbits.extensions.watchonly import watchonly_ext
@@ -57,10 +56,8 @@ async def api_wallets_retrieve(
return [] return []
@watchonly_ext.get("/api/v1/wallet/{wallet_id}") @watchonly_ext.get("/api/v1/wallet/{wallet_id}", dependencies=[Depends(get_key_type)])
async def api_wallet_retrieve( async def api_wallet_retrieve(wallet_id: str):
wallet_id, wallet: WalletTypeInfo = Depends(get_key_type)
):
w_wallet = await get_watch_wallet(wallet_id) w_wallet = await get_watch_wallet(wallet_id)
if not w_wallet: if not w_wallet:
@@ -126,8 +123,10 @@ async def api_wallet_create_or_update(
return wallet.dict() return wallet.dict()
@watchonly_ext.delete("/api/v1/wallet/{wallet_id}") @watchonly_ext.delete(
async def api_wallet_delete(wallet_id, w: WalletTypeInfo = Depends(require_admin_key)): "/api/v1/wallet/{wallet_id}", dependencies=[Depends(require_admin_key)]
)
async def api_wallet_delete(wallet_id: str):
wallet = await get_watch_wallet(wallet_id) wallet = await get_watch_wallet(wallet_id)
if not wallet: if not wallet:
@@ -144,16 +143,15 @@ async def api_wallet_delete(wallet_id, w: WalletTypeInfo = Depends(require_admin
#############################ADDRESSES########################## #############################ADDRESSES##########################
@watchonly_ext.get("/api/v1/address/{wallet_id}") @watchonly_ext.get("/api/v1/address/{wallet_id}", dependencies=[Depends(get_key_type)])
async def api_fresh_address(wallet_id, w: WalletTypeInfo = Depends(get_key_type)): async def api_fresh_address(wallet_id: str):
address = await get_fresh_address(wallet_id) address = await get_fresh_address(wallet_id)
assert address
return address.dict() return address.dict()
@watchonly_ext.put("/api/v1/address/{id}") @watchonly_ext.put("/api/v1/address/{id}", dependencies=[Depends(require_admin_key)])
async def api_update_address( async def api_update_address(id: str, req: Request):
id: str, req: Request, w: WalletTypeInfo = Depends(require_admin_key)
):
body = await req.json() body = await req.json()
params = {} params = {}
# amout is only updated if the address has history # amout is only updated if the address has history
@@ -162,9 +160,10 @@ async def api_update_address(
params["has_activity"] = True params["has_activity"] = True
if "note" in body: if "note" in body:
params["note"] = str(body["note"]) params["note"] = body["note"]
address = await update_address(**params, id=id) address = await update_address(**params, id=id)
assert address
wallet = ( wallet = (
await get_watch_wallet(address.wallet) await get_watch_wallet(address.wallet)
@@ -189,6 +188,7 @@ async def api_get_addresses(wallet_id, w: WalletTypeInfo = Depends(get_key_type)
addresses = await get_addresses(wallet_id) addresses = await get_addresses(wallet_id)
config = await get_config(w.wallet.user) config = await get_config(w.wallet.user)
assert config
if not addresses: if not addresses:
await create_fresh_addresses(wallet_id, 0, config.receive_gap_limit) await create_fresh_addresses(wallet_id, 0, config.receive_gap_limit)
@@ -229,10 +229,8 @@ async def api_get_addresses(wallet_id, w: WalletTypeInfo = Depends(get_key_type)
#############################PSBT########################## #############################PSBT##########################
@watchonly_ext.post("/api/v1/psbt") @watchonly_ext.post("/api/v1/psbt", dependencies=[Depends(require_admin_key)])
async def api_psbt_create( async def api_psbt_create(data: CreatePsbt):
data: CreatePsbt, w: WalletTypeInfo = Depends(require_admin_key)
):
try: try:
vin = [ vin = [
TransactionInput(bytes.fromhex(inp.tx_id), inp.vout) for inp in data.inputs TransactionInput(bytes.fromhex(inp.tx_id), inp.vout) for inp in data.inputs
@@ -246,7 +244,7 @@ async def api_psbt_create(
for _, masterpub in enumerate(data.masterpubs): for _, masterpub in enumerate(data.masterpubs):
descriptors[masterpub.id] = parse_key(masterpub.public_key) descriptors[masterpub.id] = parse_key(masterpub.public_key)
inputs_extra = [] inputs_extra: List[dict] = []
for i, inp in enumerate(data.inputs): for i, inp in enumerate(data.inputs):
bip32_derivations = {} bip32_derivations = {}
@@ -266,14 +264,15 @@ async def api_psbt_create(
tx = Transaction(vin=vin, vout=vout) tx = Transaction(vin=vin, vout=vout)
psbt = PSBT(tx) psbt = PSBT(tx)
for i, inp in enumerate(inputs_extra): for i, inp_extra in enumerate(inputs_extra):
psbt.inputs[i].bip32_derivations = inp["bip32_derivations"] psbt.inputs[i].bip32_derivations = inp_extra["bip32_derivations"]
psbt.inputs[i].non_witness_utxo = inp.get("non_witness_utxo", None) psbt.inputs[i].non_witness_utxo = inp_extra.get("non_witness_utxo", None)
outputs_extra = [] outputs_extra = []
bip32_derivations = {} bip32_derivations = {}
for i, out in enumerate(data.outputs): for i, out in enumerate(data.outputs):
if out.branch_index == 1: if out.branch_index == 1:
assert out.wallet
descriptor = descriptors[out.wallet][0] descriptor = descriptors[out.wallet][0]
d = descriptor.derive(out.address_index, out.branch_index) d = descriptor.derive(out.address_index, out.branch_index)
for k in d.keys: for k in d.keys:
@@ -282,8 +281,8 @@ async def api_psbt_create(
) )
outputs_extra.append({"bip32_derivations": bip32_derivations}) outputs_extra.append({"bip32_derivations": bip32_derivations})
for i, out in enumerate(outputs_extra): for i, out_extra in enumerate(outputs_extra):
psbt.outputs[i].bip32_derivations = out["bip32_derivations"] psbt.outputs[i].bip32_derivations = out_extra["bip32_derivations"]
return psbt.to_string() return psbt.to_string()
@@ -360,7 +359,8 @@ async def api_tx_broadcast(
else config.mempool_endpoint + "/testnet" else config.mempool_endpoint + "/testnet"
) )
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
r = await client.post(endpoint + "/api/tx", data=data.tx_hex) r = await client.post(endpoint + "/api/tx", content=data.tx_hex)
r.raise_for_status()
tx_id = r.text tx_id = r.text
return tx_id return tx_id
except Exception as e: except Exception as e:
@@ -375,6 +375,7 @@ async def api_update_config(
data: Config, w: WalletTypeInfo = Depends(require_admin_key) data: Config, w: WalletTypeInfo = Depends(require_admin_key)
): ):
config = await update_config(data, user=w.wallet.user) config = await update_config(data, user=w.wallet.user)
assert config
return config.dict() return config.dict()

View File

@@ -93,7 +93,6 @@ exclude = """(?x)(
| ^lnbits/extensions/boltz. | ^lnbits/extensions/boltz.
| ^lnbits/extensions/livestream. | ^lnbits/extensions/livestream.
| ^lnbits/extensions/lnurldevice. | ^lnbits/extensions/lnurldevice.
| ^lnbits/extensions/watchonly.
| ^lnbits/wallets/lnd_grpc_files. | ^lnbits/wallets/lnd_grpc_files.
)""" )"""