From 47bec7af457c7da97bb457ca6c5029527e5f8be3 Mon Sep 17 00:00:00 2001 From: Ava Chow Date: Wed, 8 Jan 2025 19:58:08 -0500 Subject: [PATCH] psbt: Add sighash types to PSBT when not DEFAULT or ALL When an atypical sighash type is specified by the user, add it to the PSBT so that further signing can enforce sighash type matching. --- src/psbt.cpp | 10 ++++++- test/functional/rpc_psbt.py | 52 +++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/src/psbt.cpp b/src/psbt.cpp index a43243e4b57..f87849409d9 100644 --- a/src/psbt.cpp +++ b/src/psbt.cpp @@ -423,6 +423,13 @@ PSBTError SignPSBTInput(const SigningProvider& provider, PartiallySignedTransact if (input.sighash_type && input.sighash_type != sighash) { return PSBTError::SIGHASH_MISMATCH; } + // Set the PSBT sighash field when sighash is not DEFAULT or ALL + // DEFAULT is allowed for non-taproot inputs since DEFAULT may be passed for them (e.g. the psbt being signed also has taproot inputs) + // Note that signing already aliases DEFAULT to ALL for non-taproot inputs. + if (utxo.scriptPubKey.IsPayToTaproot() ? sighash != SIGHASH_DEFAULT : + (sighash != SIGHASH_DEFAULT && sighash != SIGHASH_ALL)) { + input.sighash_type = sighash; + } Assert(sighash.has_value()); // Check all existing signatures use the sighash type @@ -522,7 +529,8 @@ bool FinalizePSBT(PartiallySignedTransaction& psbtx) bool complete = true; const PrecomputedTransactionData txdata = PrecomputePSBTData(psbtx); for (unsigned int i = 0; i < psbtx.tx->vin.size(); ++i) { - complete &= (SignPSBTInput(DUMMY_SIGNING_PROVIDER, psbtx, i, &txdata, std::nullopt, nullptr, true) == PSBTError::OK); + PSBTInput& input = psbtx.inputs.at(i); + complete &= (SignPSBTInput(DUMMY_SIGNING_PROVIDER, psbtx, i, &txdata, input.sighash_type, nullptr, true) == PSBTError::OK); } return complete; diff --git a/test/functional/rpc_psbt.py b/test/functional/rpc_psbt.py index f9762412338..335008473e0 100755 --- a/test/functional/rpc_psbt.py +++ b/test/functional/rpc_psbt.py @@ -230,6 +230,57 @@ class PSBTTest(BitcoinTestFramework): wallet.unloadwallet() + def test_sighash_adding(self): + self.log.info("Test adding of sighash type field") + self.nodes[0].createwallet("sighash_adding") + wallet = self.nodes[0].get_wallet_rpc("sighash_adding") + def_wallet = self.nodes[0].get_wallet_rpc(self.default_wallet_name) + + addr = wallet.getnewaddress(address_type="bech32") + outputs = [{addr: 1}] + if self.options.descriptors: + outputs.append({wallet.getnewaddress(address_type="bech32m"): 1}) + descs = wallet.listdescriptors(True)["descriptors"] + else: + descs = [descsum_create(f"wpkh({wallet.dumpprivkey(addr)})")] + def_wallet.send(outputs) + self.generate(self.nodes[0], 6) + utxos = wallet.listunspent() + + # Make a PSBT + psbt = wallet.walletcreatefundedpsbt(utxos, [{def_wallet.getnewaddress(): 0.5}])["psbt"] + + # Process the PSBT with the wallet + wallet_psbt = wallet.walletprocesspsbt(psbt=psbt, sighashtype="ALL|ANYONECANPAY", finalize=False)["psbt"] + + # Separately process the PSBT with descriptors + desc_psbt = self.nodes[0].descriptorprocesspsbt(psbt=psbt, descriptors=descs, sighashtype="ALL|ANYONECANPAY", finalize=False)["psbt"] + + for psbt in [wallet_psbt, desc_psbt]: + # Check that the PSBT has a sighash field on all inputs + dec_psbt = self.nodes[0].decodepsbt(psbt) + for input in dec_psbt["inputs"]: + assert_equal(input["sighash"], "ALL|ANYONECANPAY") + + # Make sure we can still finalize the transaction + fin_res = self.nodes[0].finalizepsbt(psbt) + assert_equal(fin_res["complete"], True) + fin_hex = fin_res["hex"] + + # Change the sighash field to a different value and make sure we can no longer finalize + mod_psbt = PSBT.from_base64(psbt) + mod_psbt.i[0].map[PSBT_IN_SIGHASH_TYPE] = (SIGHASH_ALL).to_bytes(4, byteorder="little") + if self.options.descriptors: + mod_psbt.i[1].map[PSBT_IN_SIGHASH_TYPE] = (SIGHASH_ALL).to_bytes(4, byteorder="little") + psbt = mod_psbt.to_base64() + fin_res = self.nodes[0].finalizepsbt(psbt) + assert_equal(fin_res["complete"], False) + + self.nodes[0].sendrawtransaction(fin_hex) + self.generate(self.nodes[0], 1) + + wallet.unloadwallet() + def assert_change_type(self, psbtx, expected_type): """Assert that the given PSBT has a change output with the given type.""" @@ -1081,6 +1132,7 @@ class PSBTTest(BitcoinTestFramework): assert_raises_rpc_error(-8, "'all' is not a valid sighash parameter.", self.nodes[2].descriptorprocesspsbt, psbt, [descriptor], sighashtype="all") self.test_sighash_mismatch() + self.test_sighash_adding() if __name__ == '__main__': PSBTTest(__file__).main()