diff --git a/src/script/interpreter.cpp b/src/script/interpreter.cpp index 61ea7f4503c..07fda48c493 100644 --- a/src/script/interpreter.cpp +++ b/src/script/interpreter.cpp @@ -1564,11 +1564,57 @@ bool SignatureHashSchnorr(uint256& hash_out, ScriptExecutionData& execdata, cons return true; } +int SigHashCache::CacheIndex(int32_t hash_type) const noexcept +{ + // Note that we do not distinguish between BASE and WITNESS_V0 to determine the cache index, + // because no input can simultaneously use both. + return 3 * !!(hash_type & SIGHASH_ANYONECANPAY) + + 2 * ((hash_type & 0x1f) == SIGHASH_SINGLE) + + 1 * ((hash_type & 0x1f) == SIGHASH_NONE); +} + +bool SigHashCache::Load(int32_t hash_type, const CScript& script_code, HashWriter& writer) const noexcept +{ + auto& entry = m_cache_entries[CacheIndex(hash_type)]; + if (entry.has_value()) { + if (script_code == entry->first) { + writer = HashWriter(entry->second); + return true; + } + } + return false; +} + +void SigHashCache::Store(int32_t hash_type, const CScript& script_code, const HashWriter& writer) noexcept +{ + auto& entry = m_cache_entries[CacheIndex(hash_type)]; + entry.emplace(script_code, writer); +} + template -uint256 SignatureHash(const CScript& scriptCode, const T& txTo, unsigned int nIn, int32_t nHashType, const CAmount& amount, SigVersion sigversion, const PrecomputedTransactionData* cache) +uint256 SignatureHash(const CScript& scriptCode, const T& txTo, unsigned int nIn, int32_t nHashType, const CAmount& amount, SigVersion sigversion, const PrecomputedTransactionData* cache, SigHashCache* sighash_cache) { assert(nIn < txTo.vin.size()); + if (sigversion != SigVersion::WITNESS_V0) { + // Check for invalid use of SIGHASH_SINGLE + if ((nHashType & 0x1f) == SIGHASH_SINGLE) { + if (nIn >= txTo.vout.size()) { + // nOut out of range + return uint256::ONE; + } + } + } + + HashWriter ss{}; + + // Try to compute using cached SHA256 midstate. + if (sighash_cache && sighash_cache->Load(nHashType, scriptCode, ss)) { + // Add sighash type and hash. + ss << nHashType; + return ss.GetHash(); + } + if (sigversion == SigVersion::WITNESS_V0) { uint256 hashPrevouts; uint256 hashSequence; @@ -1583,16 +1629,14 @@ uint256 SignatureHash(const CScript& scriptCode, const T& txTo, unsigned int nIn hashSequence = cacheready ? cache->hashSequence : SHA256Uint256(GetSequencesSHA256(txTo)); } - if ((nHashType & 0x1f) != SIGHASH_SINGLE && (nHashType & 0x1f) != SIGHASH_NONE) { hashOutputs = cacheready ? cache->hashOutputs : SHA256Uint256(GetOutputsSHA256(txTo)); } else if ((nHashType & 0x1f) == SIGHASH_SINGLE && nIn < txTo.vout.size()) { - HashWriter ss{}; - ss << txTo.vout[nIn]; - hashOutputs = ss.GetHash(); + HashWriter inner_ss{}; + inner_ss << txTo.vout[nIn]; + hashOutputs = inner_ss.GetHash(); } - HashWriter ss{}; // Version ss << txTo.version; // Input prevouts/nSequence (none/all, depending on flags) @@ -1609,26 +1653,21 @@ uint256 SignatureHash(const CScript& scriptCode, const T& txTo, unsigned int nIn ss << hashOutputs; // Locktime ss << txTo.nLockTime; - // Sighash type - ss << nHashType; + } else { + // Wrapper to serialize only the necessary parts of the transaction being signed + CTransactionSignatureSerializer txTmp(txTo, scriptCode, nIn, nHashType); - return ss.GetHash(); + // Serialize + ss << txTmp; } - // Check for invalid use of SIGHASH_SINGLE - if ((nHashType & 0x1f) == SIGHASH_SINGLE) { - if (nIn >= txTo.vout.size()) { - // nOut out of range - return uint256::ONE; - } + // If a cache object was provided, store the midstate there. + if (sighash_cache != nullptr) { + sighash_cache->Store(nHashType, scriptCode, ss); } - // Wrapper to serialize only the necessary parts of the transaction being signed - CTransactionSignatureSerializer txTmp(txTo, scriptCode, nIn, nHashType); - - // Serialize and hash - HashWriter ss{}; - ss << txTmp << nHashType; + // Add sighash type and hash. + ss << nHashType; return ss.GetHash(); } @@ -1661,7 +1700,7 @@ bool GenericTransactionSignatureChecker::CheckECDSASignature(const std::vecto // Witness sighashes need the amount. if (sigversion == SigVersion::WITNESS_V0 && amount < 0) return HandleMissingData(m_mdb); - uint256 sighash = SignatureHash(scriptCode, *txTo, nIn, nHashType, amount, sigversion, this->txdata); + uint256 sighash = SignatureHash(scriptCode, *txTo, nIn, nHashType, amount, sigversion, this->txdata, &m_sighash_cache); if (!VerifyECDSASignature(vchSig, pubkey, sighash)) return false; diff --git a/src/script/interpreter.h b/src/script/interpreter.h index e8c5b09045f..2108e3992fc 100644 --- a/src/script/interpreter.h +++ b/src/script/interpreter.h @@ -239,8 +239,27 @@ extern const HashWriter HASHER_TAPSIGHASH; //!< Hasher with tag "TapSighash" pre extern const HashWriter HASHER_TAPLEAF; //!< Hasher with tag "TapLeaf" pre-fed to it. extern const HashWriter HASHER_TAPBRANCH; //!< Hasher with tag "TapBranch" pre-fed to it. +/** Data structure to cache SHA256 midstates for the ECDSA sighash calculations + * (bare, P2SH, P2WPKH, P2WSH). */ +class SigHashCache +{ + /** For each sighash mode (ALL, SINGLE, NONE, ALL|ANYONE, SINGLE|ANYONE, NONE|ANYONE), + * optionally store a scriptCode which the hash is for, plus a midstate for the SHA256 + * computation just before adding the hash_type itself. */ + std::optional> m_cache_entries[6]; + + /** Given a hash_type, find which of the 6 cache entries is to be used. */ + int CacheIndex(int32_t hash_type) const noexcept; + +public: + /** Load into writer the SHA256 midstate if found in this cache. */ + [[nodiscard]] bool Load(int32_t hash_type, const CScript& script_code, HashWriter& writer) const noexcept; + /** Store into this cache object the provided SHA256 midstate. */ + void Store(int32_t hash_type, const CScript& script_code, const HashWriter& writer) noexcept; +}; + template -uint256 SignatureHash(const CScript& scriptCode, const T& txTo, unsigned int nIn, int32_t nHashType, const CAmount& amount, SigVersion sigversion, const PrecomputedTransactionData* cache = nullptr); +uint256 SignatureHash(const CScript& scriptCode, const T& txTo, unsigned int nIn, int32_t nHashType, const CAmount& amount, SigVersion sigversion, const PrecomputedTransactionData* cache = nullptr, SigHashCache* sighash_cache = nullptr); class BaseSignatureChecker { @@ -289,6 +308,7 @@ private: unsigned int nIn; const CAmount amount; const PrecomputedTransactionData* txdata; + mutable SigHashCache m_sighash_cache; protected: virtual bool VerifyECDSASignature(const std::vector& vchSig, const CPubKey& vchPubKey, const uint256& sighash) const; diff --git a/src/test/fuzz/script_interpreter.cpp b/src/test/fuzz/script_interpreter.cpp index 9e3ad02b2e5..2c2ce855d47 100644 --- a/src/test/fuzz/script_interpreter.cpp +++ b/src/test/fuzz/script_interpreter.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -45,3 +46,27 @@ FUZZ_TARGET(script_interpreter) (void)CastToBool(ConsumeRandomLengthByteVector(fuzzed_data_provider)); } } + +/** Differential fuzzing for SignatureHash with and without cache. */ +FUZZ_TARGET(sighash_cache) +{ + FuzzedDataProvider provider(buffer.data(), buffer.size()); + + // Get inputs to the sighash function that won't change across types. + const auto scriptcode{ConsumeScript(provider)}; + const auto tx{ConsumeTransaction(provider, std::nullopt)}; + if (tx.vin.empty()) return; + const auto in_index{provider.ConsumeIntegralInRange(0, tx.vin.size() - 1)}; + const auto amount{ConsumeMoney(provider)}; + const auto sigversion{(SigVersion)provider.ConsumeIntegralInRange(0, 1)}; + + // Check the sighash function will give the same result for 100 fuzzer-generated hash types whether or not a cache is + // provided. The cache is conserved across types to exercise cache hits. + SigHashCache sighash_cache{}; + for (int i{0}; i < 100; ++i) { + const auto hash_type{((i & 2) == 0) ? provider.ConsumeIntegral() : provider.ConsumeIntegral()}; + const auto nocache_res{SignatureHash(scriptcode, tx, in_index, hash_type, amount, sigversion)}; + const auto cache_res{SignatureHash(scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, &sighash_cache)}; + Assert(nocache_res == cache_res); + } +} diff --git a/src/test/sighash_tests.cpp b/src/test/sighash_tests.cpp index d3320878ec0..6e2ec800e74 100644 --- a/src/test/sighash_tests.cpp +++ b/src/test/sighash_tests.cpp @@ -207,4 +207,94 @@ BOOST_AUTO_TEST_CASE(sighash_from_data) BOOST_CHECK_MESSAGE(sh.GetHex() == sigHashHex, strTest); } } + +BOOST_AUTO_TEST_CASE(sighash_caching) +{ + // Get a script, transaction and parameters as inputs to the sighash function. + CScript scriptcode; + RandomScript(scriptcode); + CScript diff_scriptcode{scriptcode}; + diff_scriptcode << OP_1; + CMutableTransaction tx; + RandomTransaction(tx, /*fSingle=*/false); + const auto in_index{static_cast(m_rng.randrange(tx.vin.size()))}; + const auto amount{m_rng.rand()}; + + // Exercise the sighash function under both legacy and segwit v0. + for (const auto sigversion: {SigVersion::BASE, SigVersion::WITNESS_V0}) { + // For each, run it against all the 6 standard hash types and a few additional random ones. + std::vector hash_types{{SIGHASH_ALL, SIGHASH_SINGLE, SIGHASH_NONE, SIGHASH_ALL | SIGHASH_ANYONECANPAY, + SIGHASH_SINGLE | SIGHASH_ANYONECANPAY, SIGHASH_NONE | SIGHASH_ANYONECANPAY, + SIGHASH_ANYONECANPAY, 0, std::numeric_limits::max()}}; + for (int i{0}; i < 10; ++i) { + hash_types.push_back(i % 2 == 0 ? m_rng.rand() : m_rng.rand()); + } + + // Reuse the same cache across script types. This must not cause any issue as the cached value for one hash type must never + // be confused for another (instantiating the cache within the loop instead would prevent testing this). + SigHashCache cache; + for (const auto hash_type: hash_types) { + const bool expect_one{sigversion == SigVersion::BASE && ((hash_type & 0x1f) == SIGHASH_SINGLE) && in_index >= tx.vout.size()}; + + // The result of computing the sighash should be the same with or without cache. + const auto sighash_with_cache{SignatureHash(scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, &cache)}; + const auto sighash_no_cache{SignatureHash(scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, nullptr)}; + BOOST_CHECK_EQUAL(sighash_with_cache, sighash_no_cache); + + // Calling the cached version again should return the same value again. + BOOST_CHECK_EQUAL(sighash_with_cache, SignatureHash(scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, &cache)); + + // While here we might as well also check that the result for legacy is the same as for the old SignatureHash() function. + if (sigversion == SigVersion::BASE) { + BOOST_CHECK_EQUAL(sighash_with_cache, SignatureHashOld(scriptcode, CTransaction(tx), in_index, hash_type)); + } + + // Calling with a different scriptcode (for instance in case a CODESEP is encountered) will not return the cache value but + // overwrite it. The sighash will always be different except in case of legacy SIGHASH_SINGLE bug. + const auto sighash_with_cache2{SignatureHash(diff_scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, &cache)}; + const auto sighash_no_cache2{SignatureHash(diff_scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, nullptr)}; + BOOST_CHECK_EQUAL(sighash_with_cache2, sighash_no_cache2); + if (!expect_one) { + BOOST_CHECK_NE(sighash_with_cache, sighash_with_cache2); + } else { + BOOST_CHECK_EQUAL(sighash_with_cache, sighash_with_cache2); + BOOST_CHECK_EQUAL(sighash_with_cache, uint256::ONE); + } + + // Calling the cached version again should return the same value again. + BOOST_CHECK_EQUAL(sighash_with_cache2, SignatureHash(diff_scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, &cache)); + + // And if we store a different value for this scriptcode and hash type it will return that instead. + { + HashWriter h{}; + h << 42; + cache.Store(hash_type, scriptcode, h); + const auto stored_hash{h.GetHash()}; + BOOST_CHECK(cache.Load(hash_type, scriptcode, h)); + const auto loaded_hash{h.GetHash()}; + BOOST_CHECK_EQUAL(stored_hash, loaded_hash); + } + + // And using this mutated cache with the sighash function will return the new value (except in the legacy SIGHASH_SINGLE bug + // case in which it'll return 1). + if (!expect_one) { + BOOST_CHECK_NE(SignatureHash(scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, &cache), sighash_with_cache); + HashWriter h{}; + BOOST_CHECK(cache.Load(hash_type, scriptcode, h)); + h << hash_type; + const auto new_hash{h.GetHash()}; + BOOST_CHECK_EQUAL(SignatureHash(scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, &cache), new_hash); + } else { + BOOST_CHECK_EQUAL(SignatureHash(scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, &cache), uint256::ONE); + } + + // Wipe the cache and restore the correct cached value for this scriptcode and hash_type before starting the next iteration. + HashWriter dummy{}; + cache.Store(hash_type, diff_scriptcode, dummy); + (void)SignatureHash(scriptcode, tx, in_index, hash_type, amount, sigversion, nullptr, &cache); + BOOST_CHECK(cache.Load(hash_type, scriptcode, dummy) || expect_one); + } + } +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/test/functional/feature_taproot.py b/test/functional/feature_taproot.py index 6109fd9b9d8..b22d5731871 100755 --- a/test/functional/feature_taproot.py +++ b/test/functional/feature_taproot.py @@ -71,6 +71,7 @@ from test_framework.script import ( OP_PUSHDATA1, OP_RETURN, OP_SWAP, + OP_TUCK, OP_VERIFY, SIGHASH_DEFAULT, SIGHASH_ALL, @@ -172,9 +173,9 @@ def get(ctx, name): ctx[name] = expr return expr.value -def getter(name): +def getter(name, **kwargs): """Return a callable that evaluates name in its passed context.""" - return lambda ctx: get(ctx, name) + return lambda ctx: get({**ctx, **kwargs}, name) def override(expr, **kwargs): """Return a callable that evaluates expr in a modified context.""" @@ -218,6 +219,20 @@ def default_controlblock(ctx): """Default expression for "controlblock": combine leafversion, negflag, pubkey_internal, merklebranch.""" return bytes([get(ctx, "leafversion") + get(ctx, "negflag")]) + get(ctx, "pubkey_internal") + get(ctx, "merklebranch") +def default_scriptcode_suffix(ctx): + """Default expression for "scriptcode_suffix", the actually used portion of the scriptcode.""" + scriptcode = get(ctx, "scriptcode") + codesepnum = get(ctx, "codesepnum") + if codesepnum == -1: + return scriptcode + codeseps = 0 + for (opcode, data, sop_idx) in scriptcode.raw_iter(): + if opcode == OP_CODESEPARATOR: + if codeseps == codesepnum: + return CScript(scriptcode[sop_idx+1:]) + codeseps += 1 + assert False + def default_sigmsg(ctx): """Default expression for "sigmsg": depending on mode, compute BIP341, BIP143, or legacy sigmsg.""" tx = get(ctx, "tx") @@ -237,12 +252,12 @@ def default_sigmsg(ctx): return TaprootSignatureMsg(tx, utxos, hashtype, idx, scriptpath=False, annex=annex) elif mode == "witv0": # BIP143 signature hash - scriptcode = get(ctx, "scriptcode") + scriptcode = get(ctx, "scriptcode_suffix") utxos = get(ctx, "utxos") return SegwitV0SignatureMsg(scriptcode, tx, idx, hashtype, utxos[idx].nValue) else: # Pre-segwit signature hash - scriptcode = get(ctx, "scriptcode") + scriptcode = get(ctx, "scriptcode_suffix") return LegacySignatureMsg(scriptcode, tx, idx, hashtype)[0] def default_sighash(ctx): @@ -302,7 +317,12 @@ def default_hashtype_actual(ctx): def default_bytes_hashtype(ctx): """Default expression for "bytes_hashtype": bytes([hashtype_actual]) if not 0, b"" otherwise.""" - return bytes([x for x in [get(ctx, "hashtype_actual")] if x != 0]) + mode = get(ctx, "mode") + hashtype_actual = get(ctx, "hashtype_actual") + if mode != "taproot" or hashtype_actual != 0: + return bytes([hashtype_actual]) + else: + return bytes() def default_sign(ctx): """Default expression for "sign": concatenation of signature and bytes_hashtype.""" @@ -380,6 +400,8 @@ DEFAULT_CONTEXT = { "key_tweaked": default_key_tweaked, # The tweak to use (None for script path spends, the actual tweak for key path spends). "tweak": default_tweak, + # The part of the scriptcode after the last executed OP_CODESEPARATOR. + "scriptcode_suffix": default_scriptcode_suffix, # The sigmsg value (preimage of sighash) "sigmsg": default_sigmsg, # The sighash value (32 bytes) @@ -410,6 +432,8 @@ DEFAULT_CONTEXT = { "annex": None, # The codeseparator position (only when mode=="taproot"). "codeseppos": -1, + # Which OP_CODESEPARATOR is the last executed one in the script (in legacy/P2SH/P2WSH). + "codesepnum": -1, # The redeemscript to add to the scriptSig (if P2SH; None implies not P2SH). "script_p2sh": None, # The script to add to the witness in (if P2WSH; None implies P2WPKH) @@ -1211,6 +1235,70 @@ def spenders_taproot_active(): standard = hashtype in VALID_SIGHASHES_ECDSA and (p2sh or witv0) add_spender(spenders, "compat/nocsa", hashtype=hashtype, p2sh=p2sh, witv0=witv0, standard=standard, script=CScript([OP_IF, OP_11, pubkey1, OP_CHECKSIGADD, OP_12, OP_EQUAL, OP_ELSE, pubkey1, OP_CHECKSIG, OP_ENDIF]), key=eckey1, sigops_weight=4-3*witv0, inputs=[getter("sign"), b''], failure={"inputs": [getter("sign"), b'\x01']}, **ERR_BAD_OPCODE) + # == sighash caching tests == + + # Sighash caching in legacy. + for p2sh in [False, True]: + for witv0 in [False, True]: + eckey1, pubkey1 = generate_keypair(compressed=compressed) + for _ in range(10): + # Construct a script with 20 checksig operations (10 sighash types, each 2 times), + # randomly ordered and interleaved with 4 OP_CODESEPARATORS. + ops = [1, 2, 3, 0x21, 0x42, 0x63, 0x81, 0x83, 0xe1, 0xc2, -1, -1] * 2 + # Make sure no OP_CODESEPARATOR appears last. + while True: + random.shuffle(ops) + if ops[-1] != -1: + break + script = [pubkey1] + inputs = [] + codeseps = -1 + for pos, op in enumerate(ops): + if op == -1: + codeseps += 1 + script.append(OP_CODESEPARATOR) + elif pos + 1 != len(ops): + script += [OP_TUCK, OP_CHECKSIGVERIFY] + inputs.append(getter("sign", codesepnum=codeseps, hashtype=op)) + else: + script += [OP_CHECKSIG] + inputs.append(getter("sign", codesepnum=codeseps, hashtype=op)) + inputs.reverse() + script = CScript(script) + add_spender(spenders, "sighashcache/legacy", p2sh=p2sh, witv0=witv0, standard=False, script=script, inputs=inputs, key=eckey1, sigops_weight=12*8*(4-3*witv0), no_fail=True) + + # Sighash caching in tapscript. + for _ in range(10): + # Construct a script with 700 checksig operations (7 sighash types, each 100 times), + # randomly ordered and interleaved with 100 OP_CODESEPARATORS. + ops = [0, 1, 2, 3, 0x81, 0x82, 0x83, -1] * 100 + # Make sure no OP_CODESEPARATOR appears last. + while True: + random.shuffle(ops) + if ops[-1] != -1: + break + script = [pubs[1]] + inputs = [] + opcount = 1 + codeseppos = -1 + for pos, op in enumerate(ops): + if op == -1: + codeseppos = opcount + opcount += 1 + script.append(OP_CODESEPARATOR) + elif pos + 1 != len(ops): + opcount += 2 + script += [OP_TUCK, OP_CHECKSIGVERIFY] + inputs.append(getter("sign", codeseppos=codeseppos, hashtype=op)) + else: + opcount += 1 + script += [OP_CHECKSIG] + inputs.append(getter("sign", codeseppos=codeseppos, hashtype=op)) + inputs.reverse() + script = CScript(script) + tap = taproot_construct(pubs[0], [("leaf", script)]) + add_spender(spenders, "sighashcache/taproot", tap=tap, leaf="leaf", inputs=inputs, standard=True, key=secs[1], no_fail=True) + return spenders