fix mypy watchonly issues

This commit is contained in:
dni ⚡ 2023-01-05 13:40:44 +01:00
parent d5009a7d0a
commit 6ea5069835
5 changed files with 43 additions and 42 deletions

@ -41,8 +41,9 @@ async def create_watch_wallet(user: str, w: WalletAccount) -> WalletAccount:
w.meta,
),
)
return await get_watch_wallet(wallet_id)
wallet = await get_watch_wallet(wallet_id)
assert wallet
return wallet
async def get_watch_wallet(wallet_id: str) -> Optional[WalletAccount]:
@ -121,11 +122,11 @@ async def create_fresh_addresses(
change_address=False,
) -> List[Address]:
if start_address_index > end_address_index:
return None
return []
wallet = await get_watch_wallet(wallet_id)
if not wallet:
return None
return []
branch_index = 1 if change_address else 0
@ -150,7 +151,7 @@ async def create_fresh_addresses(
# return fresh addresses
rows = await db.fetchall(
"""
SELECT * FROM watchonly.addresses
SELECT * FROM watchonly.addresses
WHERE wallet = ? AND branch_index = ? AND address_index >= ? AND address_index < ?
ORDER BY branch_index, address_index
""",
@ -172,7 +173,7 @@ async def get_address_at_index(
) -> Optional[Address]:
row = await db.fetchone(
"""
SELECT * FROM watchonly.addresses
SELECT * FROM watchonly.addresses
WHERE wallet = ? AND branch_index = ? AND address_index = ?
""",
(

@ -1,7 +1,7 @@
from sqlite3 import Row
from typing import List, Optional
from fastapi.param_functions import Query
from fastapi import Query
from pydantic import BaseModel
@ -35,7 +35,7 @@ class Address(BaseModel):
amount: int = 0
branch_index: int = 0
address_index: int
note: str = None
note: Optional[str] = None
has_activity: bool = False
@classmethod
@ -57,9 +57,9 @@ class TransactionInput(BaseModel):
class TransactionOutput(BaseModel):
amount: int
address: str
branch_index: int = None
address_index: int = None
wallet: str = None
branch_index: Optional[int] = None
address_index: Optional[int] = None
wallet: Optional[str] = None
class MasterPublicKey(BaseModel):

@ -1,6 +1,5 @@
from fastapi.params import Depends
from fastapi import Depends, Request
from fastapi.templating import Jinja2Templates
from starlette.requests import Request
from starlette.responses import HTMLResponse
from lnbits.core.models import User

@ -1,5 +1,6 @@
import json
from http import HTTPStatus
from typing import List
import httpx
from embit import finalizer, script
@ -7,9 +8,7 @@ from embit.ec import PublicKey
from embit.networks import NETWORKS
from embit.psbt import PSBT, DerivationPath
from embit.transaction import Transaction, TransactionInput, TransactionOutput
from fastapi import Query, Request
from fastapi.params import Depends
from starlette.exceptions import HTTPException
from fastapi import Depends, HTTPException, Query, Request
from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key
from lnbits.extensions.watchonly import watchonly_ext
@ -57,10 +56,8 @@ async def api_wallets_retrieve(
return []
@watchonly_ext.get("/api/v1/wallet/{wallet_id}")
async def api_wallet_retrieve(
wallet_id, wallet: WalletTypeInfo = Depends(get_key_type)
):
@watchonly_ext.get("/api/v1/wallet/{wallet_id}", dependencies=[Depends(get_key_type)])
async def api_wallet_retrieve(wallet_id: str):
w_wallet = await get_watch_wallet(wallet_id)
if not w_wallet:
@ -76,7 +73,8 @@ async def api_wallet_create_or_update(
data: CreateWallet, w: WalletTypeInfo = Depends(require_admin_key)
):
try:
(descriptor, network) = parse_key(data.masterpub)
# TODO: talk to motorina about this
(descriptor, network) = parse_key(data.masterpub) # type: ignore
if data.network != network["name"]:
raise ValueError(
"Account network error. This account is for '{}'".format(
@ -126,8 +124,10 @@ async def api_wallet_create_or_update(
return wallet.dict()
@watchonly_ext.delete("/api/v1/wallet/{wallet_id}")
async def api_wallet_delete(wallet_id, w: WalletTypeInfo = Depends(require_admin_key)):
@watchonly_ext.delete(
"/api/v1/wallet/{wallet_id}", dependencies=[Depends(require_admin_key)]
)
async def api_wallet_delete(wallet_id: str):
wallet = await get_watch_wallet(wallet_id)
if not wallet:
@ -144,16 +144,15 @@ async def api_wallet_delete(wallet_id, w: WalletTypeInfo = Depends(require_admin
#############################ADDRESSES##########################
@watchonly_ext.get("/api/v1/address/{wallet_id}")
async def api_fresh_address(wallet_id, w: WalletTypeInfo = Depends(get_key_type)):
@watchonly_ext.get("/api/v1/address/{wallet_id}", dependencies=[Depends(get_key_type)])
async def api_fresh_address(wallet_id: str):
address = await get_fresh_address(wallet_id)
assert address
return address.dict()
@watchonly_ext.put("/api/v1/address/{id}")
async def api_update_address(
id: str, req: Request, w: WalletTypeInfo = Depends(require_admin_key)
):
@watchonly_ext.put("/api/v1/address/{id}", dependencies=[Depends(require_admin_key)])
async def api_update_address(id: str, req: Request):
body = await req.json()
params = {}
# amout is only updated if the address has history
@ -162,9 +161,10 @@ async def api_update_address(
params["has_activity"] = True
if "note" in body:
params["note"] = str(body["note"])
params["note"] = body["note"]
address = await update_address(**params, id=id)
assert address
wallet = (
await get_watch_wallet(address.wallet)
@ -189,6 +189,7 @@ async def api_get_addresses(wallet_id, w: WalletTypeInfo = Depends(get_key_type)
addresses = await get_addresses(wallet_id)
config = await get_config(w.wallet.user)
assert config
if not addresses:
await create_fresh_addresses(wallet_id, 0, config.receive_gap_limit)
@ -229,10 +230,8 @@ async def api_get_addresses(wallet_id, w: WalletTypeInfo = Depends(get_key_type)
#############################PSBT##########################
@watchonly_ext.post("/api/v1/psbt")
async def api_psbt_create(
data: CreatePsbt, w: WalletTypeInfo = Depends(require_admin_key)
):
@watchonly_ext.post("/api/v1/psbt", dependencies=[Depends(require_admin_key)])
async def api_psbt_create(data: CreatePsbt):
try:
vin = [
TransactionInput(bytes.fromhex(inp.tx_id), inp.vout) for inp in data.inputs
@ -246,7 +245,7 @@ async def api_psbt_create(
for _, masterpub in enumerate(data.masterpubs):
descriptors[masterpub.id] = parse_key(masterpub.public_key)
inputs_extra = []
inputs_extra: List[dict] = []
for i, inp in enumerate(data.inputs):
bip32_derivations = {}
@ -266,14 +265,15 @@ async def api_psbt_create(
tx = Transaction(vin=vin, vout=vout)
psbt = PSBT(tx)
for i, inp in enumerate(inputs_extra):
psbt.inputs[i].bip32_derivations = inp["bip32_derivations"]
psbt.inputs[i].non_witness_utxo = inp.get("non_witness_utxo", None)
for i, inp_extra in enumerate(inputs_extra):
psbt.inputs[i].bip32_derivations = inp_extra["bip32_derivations"]
psbt.inputs[i].non_witness_utxo = inp_extra.get("non_witness_utxo", None)
outputs_extra = []
bip32_derivations = {}
for i, out in enumerate(data.outputs):
if out.branch_index == 1:
assert out.wallet
descriptor = descriptors[out.wallet][0]
d = descriptor.derive(out.address_index, out.branch_index)
for k in d.keys:
@ -282,8 +282,8 @@ async def api_psbt_create(
)
outputs_extra.append({"bip32_derivations": bip32_derivations})
for i, out in enumerate(outputs_extra):
psbt.outputs[i].bip32_derivations = out["bip32_derivations"]
for i, out_extra in enumerate(outputs_extra):
psbt.outputs[i].bip32_derivations = out_extra["bip32_derivations"]
return psbt.to_string()
@ -360,7 +360,8 @@ async def api_tx_broadcast(
else config.mempool_endpoint + "/testnet"
)
async with httpx.AsyncClient() as client:
r = await client.post(endpoint + "/api/tx", data=data.tx_hex)
r = await client.post(endpoint + "/api/tx", content=data.tx_hex)
r.raise_for_status()
tx_id = r.text
return tx_id
except Exception as e:
@ -375,6 +376,7 @@ async def api_update_config(
data: Config, w: WalletTypeInfo = Depends(require_admin_key)
):
config = await update_config(data, user=w.wallet.user)
assert config
return config.dict()

@ -93,7 +93,6 @@ exclude = """(?x)(
| ^lnbits/extensions/boltz.
| ^lnbits/extensions/livestream.
| ^lnbits/extensions/lnurldevice.
| ^lnbits/extensions/watchonly.
| ^lnbits/wallets/lnd_grpc_files.
)"""