tests: Test that PSBT_OUT_TAP_TREE is combined correctly

Github-Pull: #25858
Rebased-From: 22c051ca70bae73e0430b05fb9d879591df27699
This commit is contained in:
Andrew Chow 2022-10-06 15:32:33 -04:00 committed by fanquake
parent 4abd2ab18e
commit a9419eff0c
No known key found for this signature in database
GPG Key ID: 2EEB9F5CC09526C1
2 changed files with 29 additions and 2 deletions

View File

@ -27,6 +27,7 @@ from test_framework.psbt import (
PSBT_IN_SHA256,
PSBT_IN_HASH160,
PSBT_IN_HASH256,
PSBT_OUT_TAP_TREE,
)
from test_framework.test_framework import BitcoinTestFramework
from test_framework.util import (
@ -779,9 +780,18 @@ class PSBTTest(BitcoinTestFramework):
self.generate(self.nodes[0], 1)
self.nodes[0].importdescriptors([{"desc": descsum_create("tr({})".format(privkey)), "timestamp":"now"}])
psbt = watchonly.sendall([wallet.getnewaddress()])["psbt"]
psbt = watchonly.sendall([wallet.getnewaddress(), addr])["psbt"]
psbt = self.nodes[0].walletprocesspsbt(psbt)["psbt"]
self.nodes[0].sendrawtransaction(self.nodes[0].finalizepsbt(psbt)["hex"])
txid = self.nodes[0].sendrawtransaction(self.nodes[0].finalizepsbt(psbt)["hex"])
vout = find_vout_for_address(self.nodes[0], txid, addr)
# Make sure tap tree is in psbt
parsed_psbt = PSBT.from_base64(psbt)
assert_greater_than(len(parsed_psbt.o[vout].map[PSBT_OUT_TAP_TREE]), 0)
assert "taproot_tree" in self.nodes[0].decodepsbt(psbt)["outputs"][vout]
parsed_psbt.make_blank()
comb_psbt = self.nodes[0].combinepsbt([psbt, parsed_psbt.to_base64()])
assert_equal(comb_psbt, psbt)
self.log.info("Test that walletprocesspsbt both updates and signs a non-updated psbt containing Taproot inputs")
addr = self.nodes[0].getnewaddress("", "bech32m")
@ -793,6 +803,14 @@ class PSBTTest(BitcoinTestFramework):
self.nodes[0].sendrawtransaction(rawtx)
self.generate(self.nodes[0], 1)
# Make sure tap tree is not in psbt
parsed_psbt = PSBT.from_base64(psbt)
assert PSBT_OUT_TAP_TREE not in parsed_psbt.o[0].map
assert "taproot_tree" not in self.nodes[0].decodepsbt(psbt)["outputs"][0]
parsed_psbt.make_blank()
comb_psbt = self.nodes[0].combinepsbt([psbt, parsed_psbt.to_base64()])
assert_equal(comb_psbt, psbt)
self.log.info("Test decoding PSBT with per-input preimage types")
# note that the decodepsbt RPC doesn't check whether preimages and hashes match
hash_ripemd160, preimage_ripemd160 = random_bytes(20), random_bytes(50)

View File

@ -123,6 +123,15 @@ class PSBT:
psbt = [x.serialize() for x in [self.g] + self.i + self.o]
return b"psbt\xff" + b"".join(psbt)
def make_blank(self):
"""
Remove all fields except for PSBT_GLOBAL_UNSIGNED_TX
"""
for m in self.i + self.o:
m.map.clear()
self.g = PSBTMap(map={0: self.g.map[0]})
def to_base64(self):
return base64.b64encode(self.serialize()).decode("utf8")