From 95897507e9eaeb1d5d7d9f53dbe44f05642ef6d5 Mon Sep 17 00:00:00 2001 From: Ava Chow Date: Mon, 22 Jul 2024 17:14:21 -0400 Subject: [PATCH] psbt: AddInput and AddOutput should take only PSBTInput and PSBTOutput --- src/psbt.cpp | 46 ++++++++++++++++++++++++++++---------- src/psbt.h | 4 ++-- src/rpc/rawtransaction.cpp | 8 +++---- src/test/fuzz/psbt.cpp | 8 +++---- 4 files changed, 44 insertions(+), 22 deletions(-) diff --git a/src/psbt.cpp b/src/psbt.cpp index 85c2f25a902..0885519bc2c 100644 --- a/src/psbt.cpp +++ b/src/psbt.cpp @@ -58,24 +58,46 @@ bool PartiallySignedTransaction::Merge(const PartiallySignedTransaction& psbt) return true; } -bool PartiallySignedTransaction::AddInput(const CTxIn& txin, PSBTInput& psbtin) +bool PartiallySignedTransaction::AddInput(const PSBTInput& psbtin) { - if (std::find(tx->vin.begin(), tx->vin.end(), txin) != tx->vin.end()) { + if (GetVersion() < 2) { + // This is a v0 psbt, so do the v0 AddInput + CTxIn txin(COutPoint(psbtin.prev_txid, psbtin.prev_out)); + if (std::find(tx->vin.begin(), tx->vin.end(), txin) != tx->vin.end()) { + // Prevent duplicate inputs + return false; + } + tx->vin.push_back(std::move(txin)); + inputs.push_back(psbtin); + inputs.back().partial_sigs.clear(); + inputs.back().final_script_sig.clear(); + inputs.back().final_script_witness.SetNull(); + return true; + } + + // Prevent duplicate inputs + if (std::find_if(inputs.begin(), inputs.end(), + [psbtin](const PSBTInput& psbt) { + return psbt.prev_txid == psbtin.prev_txid && psbt.prev_out == psbtin.prev_out; + } + ) != inputs.end()) { return false; } - tx->vin.push_back(txin); - psbtin.partial_sigs.clear(); - psbtin.final_script_sig.clear(); - psbtin.final_script_witness.SetNull(); - inputs.push_back(psbtin); - return true; + + return false; } -bool PartiallySignedTransaction::AddOutput(const CTxOut& txout, const PSBTOutput& psbtout) +bool PartiallySignedTransaction::AddOutput(const PSBTOutput& psbtout) { - tx->vout.push_back(txout); - outputs.push_back(psbtout); - return true; + if (GetVersion() < 2) { + // This is a v0 psbt, do the v0 AddOutput + CTxOut txout(psbtout.amount, psbtout.script); + tx->vout.push_back(txout); + outputs.push_back(psbtout); + return true; + } + + return false; } bool PartiallySignedTransaction::GetInputUTXO(CTxOut& utxo, int input_index) const diff --git a/src/psbt.h b/src/psbt.h index bb5ce49143f..70bdae96f41 100644 --- a/src/psbt.h +++ b/src/psbt.h @@ -1084,8 +1084,8 @@ public: /** Merge psbt into this. The two psbts must have the same underlying CTransaction (i.e. the * same actual Bitcoin transaction.) Returns true if the merge succeeded, false otherwise. */ [[nodiscard]] bool Merge(const PartiallySignedTransaction& psbt); - bool AddInput(const CTxIn& txin, PSBTInput& psbtin); - bool AddOutput(const CTxOut& txout, const PSBTOutput& psbtout); + bool AddInput(const PSBTInput& psbtin); + bool AddOutput(const PSBTOutput& psbtout); explicit PartiallySignedTransaction(const CMutableTransaction& tx); /** * Finds the UTXO for a given input index diff --git a/src/rpc/rawtransaction.cpp b/src/rpc/rawtransaction.cpp index b5a063859ee..d3076dbeaa9 100644 --- a/src/rpc/rawtransaction.cpp +++ b/src/rpc/rawtransaction.cpp @@ -1822,12 +1822,12 @@ static RPCMethod joinpsbts() // Merge for (auto& psbt : psbtxs) { for (unsigned int i = 0; i < psbt.tx->vin.size(); ++i) { - if (!merged_psbt.AddInput(psbt.tx->vin[i], psbt.inputs[i])) { + if (!merged_psbt.AddInput(psbt.inputs[i])) { throw JSONRPCError(RPC_INVALID_PARAMETER, strprintf("Input %s:%d exists in multiple PSBTs", psbt.tx->vin[i].prevout.hash.ToString(), psbt.tx->vin[i].prevout.n)); } } for (unsigned int i = 0; i < psbt.tx->vout.size(); ++i) { - merged_psbt.AddOutput(psbt.tx->vout[i], psbt.outputs[i]); + merged_psbt.AddOutput(psbt.outputs[i]); } for (auto& xpub_pair : psbt.m_xpubs) { if (!merged_psbt.m_xpubs.contains(xpub_pair.first)) { @@ -1851,10 +1851,10 @@ static RPCMethod joinpsbts() PartiallySignedTransaction shuffled_psbt(tx); for (int i : input_indices) { - shuffled_psbt.AddInput(merged_psbt.tx->vin[i], merged_psbt.inputs[i]); + shuffled_psbt.AddInput(merged_psbt.inputs[i]); } for (int i : output_indices) { - shuffled_psbt.AddOutput(merged_psbt.tx->vout[i], merged_psbt.outputs[i]); + shuffled_psbt.AddOutput(merged_psbt.outputs[i]); } shuffled_psbt.unknown.insert(merged_psbt.unknown.begin(), merged_psbt.unknown.end()); diff --git a/src/test/fuzz/psbt.cpp b/src/test/fuzz/psbt.cpp index dd7c61eb7da..d0abfc19a87 100644 --- a/src/test/fuzz/psbt.cpp +++ b/src/test/fuzz/psbt.cpp @@ -98,11 +98,11 @@ FUZZ_TARGET(psbt) if (comb_res) { psbt_mut = *comb_res; } - for (unsigned int i = 0; i < psbt_merge.tx->vin.size(); ++i) { - (void)psbt_mut.AddInput(psbt_merge.tx->vin[i], psbt_merge.inputs[i]); + for (const auto& psbt_in : psbt_merge.inputs) { + (void)psbt_mut.AddInput(psbt_in); } - for (unsigned int i = 0; i < psbt_merge.tx->vout.size(); ++i) { - Assert(psbt_mut.AddOutput(psbt_merge.tx->vout[i], psbt_merge.outputs[i])); + for (const auto& psbt_out : psbt_merge.outputs) { + Assert(psbt_mut.AddOutput(psbt_out)); } psbt_mut.unknown.insert(psbt_merge.unknown.begin(), psbt_merge.unknown.end()); }