mirror of
https://github.com/bitcoin/bitcoin.git
synced 2026-05-12 15:03:18 +02:00
Update test_framework/psbt.py for PSBTv2
This commit is contained in:
@@ -66,8 +66,8 @@ def signet_txs(block, challenge):
|
||||
def decode_challenge_psbt(b64psbt):
|
||||
psbt = PSBT.from_base64(b64psbt)
|
||||
|
||||
assert len(psbt.tx.vin) == 1
|
||||
assert len(psbt.tx.vout) == 1
|
||||
assert len(psbt.i) == 1
|
||||
assert len(psbt.o) == 1
|
||||
assert PSBT_SIGNET_BLOCK in psbt.g.map
|
||||
return psbt
|
||||
|
||||
|
||||
@@ -4,11 +4,15 @@
|
||||
# file COPYING or http://www.opensource.org/licenses/mit-license.php.
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from .util import assert_equal
|
||||
from .messages import (
|
||||
CTransaction,
|
||||
deser_string,
|
||||
deser_compact_size,
|
||||
from_binary,
|
||||
ser_compact_size,
|
||||
)
|
||||
@@ -108,37 +112,81 @@ class PSBT:
|
||||
self.g = g if g is not None else PSBTMap()
|
||||
self.i = i if i is not None else []
|
||||
self.o = o if o is not None else []
|
||||
self.tx = None
|
||||
self.version = None
|
||||
|
||||
def deserialize(self, f):
|
||||
assert_equal(f.read(5), b"psbt\xff")
|
||||
self.g = from_binary(PSBTMap, f)
|
||||
assert PSBT_GLOBAL_UNSIGNED_TX in self.g.map
|
||||
self.tx = from_binary(CTransaction, self.g.map[PSBT_GLOBAL_UNSIGNED_TX])
|
||||
self.i = [from_binary(PSBTMap, f) for _ in self.tx.vin]
|
||||
self.o = [from_binary(PSBTMap, f) for _ in self.tx.vout]
|
||||
|
||||
self.version = 0
|
||||
if PSBT_GLOBAL_VERSION in self.g.map:
|
||||
self.version = struct.unpack("<I", self.g.map[PSBT_GLOBAL_VERSION])[0]
|
||||
assert self.version in [0, 2]
|
||||
if self.version == 2:
|
||||
assert PSBT_GLOBAL_INPUT_COUNT in self.g.map
|
||||
assert PSBT_GLOBAL_OUTPUT_COUNT in self.g.map
|
||||
in_count = deser_compact_size(BytesIO(self.g.map[PSBT_GLOBAL_INPUT_COUNT]))
|
||||
out_count = deser_compact_size(BytesIO(self.g.map[PSBT_GLOBAL_OUTPUT_COUNT]))
|
||||
else:
|
||||
assert PSBT_GLOBAL_UNSIGNED_TX in self.g.map
|
||||
tx = from_binary(CTransaction, self.g.map[PSBT_GLOBAL_UNSIGNED_TX])
|
||||
in_count = len(tx.vin)
|
||||
out_count = len(tx.vout)
|
||||
|
||||
self.i = [from_binary(PSBTMap, f) for _ in range(in_count)]
|
||||
self.o = [from_binary(PSBTMap, f) for _ in range(out_count)]
|
||||
return self
|
||||
|
||||
def serialize(self):
|
||||
assert isinstance(self.g, PSBTMap)
|
||||
assert isinstance(self.i, list) and all(isinstance(x, PSBTMap) for x in self.i)
|
||||
assert isinstance(self.o, list) and all(isinstance(x, PSBTMap) for x in self.o)
|
||||
assert PSBT_GLOBAL_UNSIGNED_TX in self.g.map
|
||||
tx = from_binary(CTransaction, self.g.map[PSBT_GLOBAL_UNSIGNED_TX])
|
||||
assert_equal(len(tx.vin), len(self.i))
|
||||
assert_equal(len(tx.vout), len(self.o))
|
||||
if self.version is not None and self.version == 2:
|
||||
self.g.map[PSBT_GLOBAL_INPUT_COUNT] = ser_compact_size(len(self.i))
|
||||
self.g.map[PSBT_GLOBAL_OUTPUT_COUNT] = ser_compact_size(len(self.o))
|
||||
if self.version is None or (self.version is not None and self.version == 0):
|
||||
assert PSBT_GLOBAL_UNSIGNED_TX in self.g.map
|
||||
tx = from_binary(CTransaction, self.g.map[PSBT_GLOBAL_UNSIGNED_TX])
|
||||
assert_equal(len(tx.vin), len(self.i))
|
||||
assert_equal(len(tx.vout), len(self.o))
|
||||
|
||||
psbt = [x.serialize() for x in [self.g] + self.i + self.o]
|
||||
return b"psbt\xff" + b"".join(psbt)
|
||||
|
||||
def make_blank(self):
|
||||
"""
|
||||
Remove all fields except for PSBT_GLOBAL_UNSIGNED_TX
|
||||
Remove all fields except for required fields depending on version
|
||||
"""
|
||||
for m in self.i + self.o:
|
||||
m.map.clear()
|
||||
if self.version == 0:
|
||||
for m in self.i + self.o:
|
||||
m.map.clear()
|
||||
|
||||
self.g = PSBTMap(map={PSBT_GLOBAL_UNSIGNED_TX: self.g.map[PSBT_GLOBAL_UNSIGNED_TX]})
|
||||
self.g = PSBTMap(map={PSBT_GLOBAL_UNSIGNED_TX: self.g.map[PSBT_GLOBAL_UNSIGNED_TX]})
|
||||
elif self.version == 2:
|
||||
self.g = PSBTMap(map={
|
||||
PSBT_GLOBAL_TX_VERSION: self.g.map[PSBT_GLOBAL_TX_VERSION],
|
||||
PSBT_GLOBAL_INPUT_COUNT: self.g.map[PSBT_GLOBAL_INPUT_COUNT],
|
||||
PSBT_GLOBAL_OUTPUT_COUNT: self.g.map[PSBT_GLOBAL_OUTPUT_COUNT],
|
||||
PSBT_GLOBAL_VERSION: self.g.map[PSBT_GLOBAL_VERSION],
|
||||
})
|
||||
|
||||
new_i = []
|
||||
for m in self.i:
|
||||
new_i.append(PSBTMap(map={
|
||||
PSBT_IN_PREVIOUS_TXID: m.map[PSBT_IN_PREVIOUS_TXID],
|
||||
PSBT_IN_OUTPUT_INDEX: m.map[PSBT_IN_OUTPUT_INDEX],
|
||||
}))
|
||||
self.i = new_i
|
||||
|
||||
new_o = []
|
||||
for m in self.o:
|
||||
new_o.append(PSBTMap(map={
|
||||
PSBT_OUT_SCRIPT: m.map[PSBT_OUT_SCRIPT],
|
||||
PSBT_OUT_AMOUNT: m.map[PSBT_OUT_AMOUNT],
|
||||
}))
|
||||
self.o = new_o
|
||||
else:
|
||||
assert False
|
||||
|
||||
def to_base64(self):
|
||||
return base64.b64encode(self.serialize()).decode("utf8")
|
||||
|
||||
Reference in New Issue
Block a user