diff --git a/src/psbt.cpp b/src/psbt.cpp index d03ee404ff8..f01d915f46d 100644 --- a/src/psbt.cpp +++ b/src/psbt.cpp @@ -44,12 +44,19 @@ bool PartiallySignedTransaction::Merge(const PartiallySignedTransaction& psbt) if (!this_id || !psbt_id || this_id != psbt_id) { return false; } + if (GetVersion() != psbt.GetVersion()) { + return false; + } for (unsigned int i = 0; i < inputs.size(); ++i) { - inputs[i].Merge(psbt.inputs[i]); + if (!inputs[i].Merge(psbt.inputs[i])) { + return false; + } } for (unsigned int i = 0; i < outputs.size(); ++i) { - outputs[i].Merge(psbt.outputs[i]); + if (!outputs[i].Merge(psbt.outputs[i])) { + return false; + } } for (auto& xpub_pair : psbt.m_xpubs) { if (!m_xpubs.contains(xpub_pair.first)) { @@ -58,6 +65,20 @@ bool PartiallySignedTransaction::Merge(const PartiallySignedTransaction& psbt) m_xpubs[xpub_pair.first].insert(xpub_pair.second.begin(), xpub_pair.second.end()); } } + if (fallback_locktime == std::nullopt && psbt.fallback_locktime != std::nullopt) fallback_locktime = psbt.fallback_locktime; + + // Set m_tx_modifiable only if either PSBT had it set + if (m_tx_modifiable.has_value() || psbt.m_tx_modifiable.has_value()) { + // In general, we AND the modifiable flags + std::bitset<8> this_modifiable = m_tx_modifiable.value_or(0); + std::bitset<8> psbt_modifiable = psbt.m_tx_modifiable.value_or(0); + std::bitset<8> final_modifiable = this_modifiable & psbt_modifiable; + // SIGHASH_SINGLE Modifiable (bit 2) needs to be bitwise OR'd + final_modifiable.set(2, this_modifiable[2] || psbt_modifiable[2]); + + m_tx_modifiable = final_modifiable; + } + unknown.insert(psbt.unknown.begin(), psbt.unknown.end()); return true; @@ -402,7 +423,7 @@ void PSBTInput::FromSignatureData(const SignatureData& sigdata) } } -void PSBTInput::Merge(const PSBTInput& input) +bool PSBTInput::Merge(const PSBTInput& input) { if (!non_witness_utxo && input.non_witness_utxo) non_witness_utxo = input.non_witness_utxo; if (witness_utxo.IsNull() && !input.witness_utxo.IsNull()) { @@ -434,6 +455,11 @@ void PSBTInput::Merge(const PSBTInput& input) for (const auto& [agg_key_lh, psigs] : input.m_musig2_partial_sigs) { m_musig2_partial_sigs[agg_key_lh].insert(psigs.begin(), psigs.end()); } + if (sequence == std::nullopt && input.sequence != std::nullopt) sequence = input.sequence; + if (time_locktime == std::nullopt && input.time_locktime != std::nullopt) time_locktime = input.time_locktime; + if (height_locktime == std::nullopt && input.height_locktime != std::nullopt) height_locktime = input.height_locktime; + + return true; } bool PSBTInput::HasSignatures() const @@ -505,7 +531,7 @@ bool PSBTOutput::IsNull() const return redeem_script.empty() && witness_script.empty() && hd_keypaths.empty() && unknown.empty(); } -void PSBTOutput::Merge(const PSBTOutput& output) +bool PSBTOutput::Merge(const PSBTOutput& output) { hd_keypaths.insert(output.hd_keypaths.begin(), output.hd_keypaths.end()); unknown.insert(output.unknown.begin(), output.unknown.end()); @@ -516,6 +542,8 @@ void PSBTOutput::Merge(const PSBTOutput& output) if (m_tap_internal_key.IsNull() && !output.m_tap_internal_key.IsNull()) m_tap_internal_key = output.m_tap_internal_key; if (m_tap_tree.empty() && !output.m_tap_tree.empty()) m_tap_tree = output.m_tap_tree; m_musig2_participants.insert(output.m_musig2_participants.begin(), output.m_musig2_participants.end()); + + return true; } bool PSBTInputSigned(const PSBTInput& input) diff --git a/src/psbt.h b/src/psbt.h index 083cfa53979..c5bbff891f6 100644 --- a/src/psbt.h +++ b/src/psbt.h @@ -323,7 +323,7 @@ public: bool IsNull() const; void FillSignatureData(SignatureData& sigdata) const; void FromSignatureData(const SignatureData& sigdata); - void Merge(const PSBTInput& input); + [[nodiscard]] bool Merge(const PSBTInput& input); uint32_t GetVersion() const { return m_psbt_version; } COutPoint GetOutPoint() const; /** @@ -956,7 +956,7 @@ public: bool IsNull() const; void FillSignatureData(SignatureData& sigdata) const; void FromSignatureData(const SignatureData& sigdata); - void Merge(const PSBTOutput& output); + [[nodiscard]] bool Merge(const PSBTOutput& output); uint32_t GetVersion() const { return m_psbt_version; } explicit PSBTOutput(uint32_t psbt_version, CAmount amount, const CScript& script)