From 990b084f11c038367035ad520482174900cdd29c Mon Sep 17 00:00:00 2001 From: Ava Chow Date: Mon, 22 Jul 2024 17:14:07 -0400 Subject: [PATCH] Have PSBTInput and PSBTOutput know the PSBT's version --- src/psbt.cpp | 4 ++-- src/psbt.h | 30 +++++++++++++++++++++++++----- src/rpc/rawtransaction.cpp | 8 ++++---- src/test/fuzz/deserialize.cpp | 4 ++-- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/psbt.cpp b/src/psbt.cpp index 3086ae7775e..416eae11d1f 100644 --- a/src/psbt.cpp +++ b/src/psbt.cpp @@ -15,8 +15,8 @@ using common::PSBTError; PartiallySignedTransaction::PartiallySignedTransaction(const CMutableTransaction& tx) : tx(tx) { - inputs.resize(tx.vin.size()); - outputs.resize(tx.vout.size()); + inputs.resize(tx.vin.size(), PSBTInput(GetVersion())); + outputs.resize(tx.vout.size(), PSBTOutput(GetVersion())); } bool PartiallySignedTransaction::IsNull() const diff --git a/src/psbt.h b/src/psbt.h index f6bf144ab8d..4272cd3a807 100644 --- a/src/psbt.h +++ b/src/psbt.h @@ -263,6 +263,9 @@ static inline void ExpectedKeySize(const std::string& key_name, const std::vecto /** A structure for PSBTs which contain per-input information */ class PSBTInput { +private: + uint32_t m_psbt_version; + public: CTransactionRef non_witness_utxo; CTxOut witness_utxo; @@ -300,7 +303,12 @@ public: void FillSignatureData(SignatureData& sigdata) const; void FromSignatureData(const SignatureData& sigdata); void Merge(const PSBTInput& input); - PSBTInput() = default; + uint32_t GetVersion() const { return m_psbt_version; } + explicit PSBTInput(uint32_t psbt_version) + : m_psbt_version(psbt_version) + { + assert(m_psbt_version == 0); + } template inline void Serialize(Stream& s) const { @@ -794,6 +802,9 @@ public: /** A structure for PSBTs which contains per output information */ class PSBTOutput { +private: + uint32_t m_psbt_version; + public: CScript redeem_script; CScript witness_script; @@ -809,7 +820,12 @@ public: void FillSignatureData(SignatureData& sigdata) const; void FromSignatureData(const SignatureData& sigdata); void Merge(const PSBTOutput& output); - PSBTOutput() = default; + uint32_t GetVersion() const { return m_psbt_version; } + explicit PSBTOutput(uint32_t psbt_version) + : m_psbt_version(psbt_version) + { + assert(m_psbt_version == 0); + } template inline void Serialize(Stream& s) const { @@ -1031,6 +1047,9 @@ public: /** A version of CTransaction with the PSBT format*/ class PartiallySignedTransaction { +private: + std::optional m_version; + public: std::optional tx; // We use a vector of CExtPubKey in the event that there happens to be the same KeyOriginInfos for different CExtPubKeys @@ -1039,7 +1058,6 @@ public: std::vector inputs; std::vector outputs; std::map, std::vector> unknown; - std::optional m_version; std::set m_proprietary; bool IsNull() const; @@ -1241,10 +1259,12 @@ public: throw std::ios_base::failure("No unsigned transaction was provided"); } + const uint32_t psbt_ver = GetVersion(); + // Read input data unsigned int i = 0; while (!s.empty() && i < tx->vin.size()) { - PSBTInput input; + PSBTInput input(psbt_ver); s >> input; inputs.push_back(input); @@ -1267,7 +1287,7 @@ public: // Read output data i = 0; while (!s.empty() && i < tx->vout.size()) { - PSBTOutput output; + PSBTOutput output(psbt_ver); s >> output; outputs.push_back(output); ++i; diff --git a/src/rpc/rawtransaction.cpp b/src/rpc/rawtransaction.cpp index 0f3cebb5dcf..67067110648 100644 --- a/src/rpc/rawtransaction.cpp +++ b/src/rpc/rawtransaction.cpp @@ -1652,10 +1652,10 @@ static RPCMethod createpsbt() PartiallySignedTransaction psbtx; psbtx.tx = rawTx; for (unsigned int i = 0; i < rawTx.vin.size(); ++i) { - psbtx.inputs.emplace_back(); + psbtx.inputs.emplace_back(0); } for (unsigned int i = 0; i < rawTx.vout.size(); ++i) { - psbtx.outputs.emplace_back(); + psbtx.outputs.emplace_back(0); } // Serialize the PSBT @@ -1720,10 +1720,10 @@ static RPCMethod converttopsbt() PartiallySignedTransaction psbtx; psbtx.tx = tx; for (unsigned int i = 0; i < tx.vin.size(); ++i) { - psbtx.inputs.emplace_back(); + psbtx.inputs.emplace_back(0); } for (unsigned int i = 0; i < tx.vout.size(); ++i) { - psbtx.outputs.emplace_back(); + psbtx.outputs.emplace_back(0); } // Serialize the PSBT diff --git a/src/test/fuzz/deserialize.cpp b/src/test/fuzz/deserialize.cpp index 1329f471c70..b3f35baa0fc 100644 --- a/src/test/fuzz/deserialize.cpp +++ b/src/test/fuzz/deserialize.cpp @@ -192,11 +192,11 @@ FUZZ_TARGET_DESERIALIZE(prefilled_transaction_deserialize, { DeserializeFromFuzzingInput(buffer, prefilled_transaction); }) FUZZ_TARGET_DESERIALIZE(psbt_input_deserialize, { - PSBTInput psbt_input; + PSBTInput psbt_input(0); DeserializeFromFuzzingInput(buffer, psbt_input); }) FUZZ_TARGET_DESERIALIZE(psbt_output_deserialize, { - PSBTOutput psbt_output; + PSBTOutput psbt_output(0); DeserializeFromFuzzingInput(buffer, psbt_output); }) FUZZ_TARGET_DESERIALIZE(block_deserialize, {