
650 lines
22 KiB
Raw Permalink Normal View History

"""Reference implementation for the cryptographic aspects of BIP-324"""
2022-10-07 11:59:24 -07:00
import sys
import random
import hashlib
import hmac
### BIP-340 tagged hash
def TaggedHash(tag, data):
"""Compute BIP-340 tagged hash with specified tag string of data."""
ss = hashlib.sha256(tag.encode('utf-8')).digest()
ss += ss
ss += data
return hashlib.sha256(ss).digest()
### HKDF-SHA256
def hmac_sha256(key, data):
"""Compute HMAC-SHA256 from specified byte arrays key and data."""
return, data, hashlib.sha256).digest()
def hkdf_sha256(length, ikm, salt, info):
"""Derive a key using HKDF-SHA256."""
if len(salt) == 0:
salt = bytes([0] * 32)
prk = hmac_sha256(salt, ikm)
t = b""
okm = b""
for i in range((length + 32 - 1) // 32):
t = hmac_sha256(prk, t + info + bytes([i + 1]))
okm += t
return okm[:length]
### secp256k1 field/group elements
def modinv(a, n):
"""Compute the modular inverse of a modulo n using the extended Euclidean
Algorithm. See
a = a % n
if a == 0:
return 0
if sys.hexversion >= 0x3080000:
# More efficient version available in Python 3.8.
return pow(a, -1, n)
t1, t2 = 0, 1
r1, r2 = n, a
while r2 != 0:
q = r1 // r2
t1, t2 = t2, t1 - q * t2
r1, r2 = r2, r1 - q * r2
if r1 > 1:
return None
if t1 < 0:
t1 += n
return t1
class FE:
"""Objects of this class represent elements of the field GF(2**256 - 2**32 - 977).
They are represented internally in numerator / denominator form, in order to delay inversions.
SIZE = 2**256 - 2**32 - 977
def __init__(self, a=0, b=1):
"""Initialize an FE as a/b; both a and b can be ints or field elements."""
if isinstance(b, FE):
if isinstance(a, FE):
self.num = (a.num * b.den) % FE.SIZE
self.den = (a.den * b.num) % FE.SIZE
self.num = (a * b.den) % FE.SIZE
self.den = b.num
2022-10-07 11:59:24 -07:00
b = b % FE.SIZE
assert b != 0
if isinstance(a, FE):
self.num = a.num
self.den = (a.den * b) % FE.SIZE
self.num = a % FE.SIZE
self.den = b
def __add__(self, a):
"""Compute the sum of two field elements (second may be int)."""
if isinstance(a, FE):
return FE(self.num * a.den + self.den * a.num, self.den * a.den)
return FE(self.num + self.den * a, self.den)
2022-10-07 11:59:24 -07:00
def __radd__(self, a):
"""Compute the sum of an integer and a field element."""
return FE(self.num + self.den * a, self.den)
def __sub__(self, a):
"""Compute the difference of two field elements (second may be int)."""
if isinstance(a, FE):
return FE(self.num * a.den - self.den * a.num, self.den * a.den)
return FE(self.num - self.den * a, self.den)
2022-10-07 11:59:24 -07:00
def __rsub__(self, a):
"""Compute the difference between an integer and a field element."""
return FE(self.den * a - self.num, self.den)
def __mul__(self, a):
"""Compute the product of two field elements (second may be int)."""
if isinstance(a, FE):
return FE(self.num * a.num, self.den * a.den)
return FE(self.num * a, self.den)
2022-10-07 11:59:24 -07:00
def __rmul__(self, a):
"""Compute the product of an integer with a field element."""
return FE(self.num * a, self.den)
def __truediv__(self, a):
"""Compute the ratio of two field elements (second may be int)."""
return FE(self, a)
def __rtruediv__(self, a):
"""Compute the ratio of an integer and a field element."""
return FE(a, self)
def __pow__(self, a):
"""Raise a field element to a (positive) integer power."""
return FE(pow(self.num, a, FE.SIZE), pow(self.den, a, FE.SIZE))
def __neg__(self):
"""Negate a field element."""
return FE(-self.num, self.den)
def __int__(self):
"""Convert a field element to an integer. The result is cached."""
if self.den != 1:
self.num = (self.num * modinv(self.den, FE.SIZE)) % FE.SIZE
self.den = 1
return self.num
def sqrt(self):
"""Compute the square root of a field element.
Due to the fact that our modulus p is of the form p = 3 (mod 4), the
Tonelli-Shanks algorithm (
is simply raising the argument to the power (p + 1) / 4.
To see why: p-1 = 0 (mod 2), so 2 divides the order of the multiplicative group,
and thus only half of the non-zero field elements are squares. An element a is
a (nonzero) square when Euler's criterion, a^((p-1)/2) = 1 (mod p), holds. We're
looking for x such that x^2 = a (mod p). Given a^((p-1)/2) = 1 (mod p), that is
equivalent to x^2 = a^(1 + (p-1)/2) (mod p). As (1 + (p-1)/2) is even, this is
equivalent to x = a^((1 + (p-1)/2)/2) (mod p), or x = a^((p+1)/4) (mod p)."""
2022-10-07 11:59:24 -07:00
v = int(self)
s = pow(v, (FE.SIZE + 1) // 4, FE.SIZE)
if s**2 % FE.SIZE == v:
return FE(s)
return None
def sqrts(self):
"""Compute all square roots of a field element, if any."""
s = self.sqrt()
if s is None:
return []
return [FE(s), -FE(s)]
# The cube roots of 1 (mod p).
CBRT1 = [
def cbrts(self):
"""Compute all cube roots of a field element, if any.
Due to the fact that our modulus p is of the form p = 7 (mod 9), one cube root
can always be computed by raising to the power (p + 2) / 9. The other roots
(if any) can be found by multiplying with the two non-trivial cube roots of 1.
To see why: p-1 = 0 (mod 3), so 3 divides the order of the multiplicative group,
and thus only 1/3 of the non-zero field elements are cubes. An element a is a
(nonzero) cube when a^((p-1)/3) = 1 (mod p). We're looking for x such that
x^3 = a (mod p). Given a^((p-1)/3) = 1 (mod p), that is equivalent to
x^3 = a^(1 + (p-1)/3) (mod p). As (1 + (p-1)/3) is a multiple of 3, this is
equivalent to x = a^((1 + (p-1)/3)/3) (mod p), or x = a^((p+2)/9) (mod p)."""
v = int(self)
c = pow(v, (FE.SIZE + 2) // 9, FE.SIZE)
if pow(c, 3, FE.SIZE) == v:
return [FE(c * f) for f in FE.CBRT1]
return []
2022-10-07 11:59:24 -07:00
def is_square(self):
"""Determine if this field element has a square root."""
# Compute the Jacobi symbol of (self / p). Since our modulus is prime, this
# is the same as the Legendre symbol, which determines quadratic residuosity.
# See for the algorithm.
n, k, t = (self.num * self.den) % FE.SIZE, FE.SIZE, 0
if n == 0:
return True
while n != 0:
while n & 1 == 0:
n >>= 1
r = k & 7
t ^= (r in (3, 5))
2022-10-07 11:59:24 -07:00
n, k = k, n
t ^= (n & k & 3 == 3)
n = n % k
assert k == 1
return not t
def __eq__(self, a):
"""Check whether two field elements are equal (second may be an int)."""
if isinstance(a, FE):
return (self.num * a.den - self.den * a.num) % FE.SIZE == 0
return (self.num - self.den * a) % FE.SIZE == 0
2022-10-07 11:59:24 -07:00
def to_bytes(self):
"""Convert a field element to 32-byte big endian encoding."""
return int(self).to_bytes(32, 'big')
def from_bytes(b):
"""Convert a 32-byte big endian encoding of a field element to an FE."""
v = int.from_bytes(b, 'big')
if v >= FE.SIZE:
return None
return FE(v)
def __str__(self):
"""Convert this field element to a string."""
return f"{int(self):064x}"
def __repr__(self):
"""Get a string representation of this field element."""
return f"FE(0x{int(self):x})"
assert all(pow(c, 3, FE.SIZE) == 1 for c in FE.CBRT1)
2022-10-07 11:59:24 -07:00
class GE:
"""Objects of this class represent points (group elements) on the secp256k1 curve.
The point at infinity is represented as None."""
def __init__(self, x, y):
"""Initialize a group element with specified x and y coordinates (must be on curve)."""
fx = FE(x)
fy = FE(y)
assert fy**2 == fx**3 + 7
self.x = fx
self.y = fy
def double(self):
"""Compute the double of a point."""
l = 3 * self.x**2 / (2 * self.y)
x3 = l**2 - 2 * self.x
y3 = l * (self.x - x3) - self.y
return GE(x3, y3)
def __add__(self, a):
"""Add two points, or a point and infinity, together."""
if a is None:
# Adding point at infinity
return self
if self.x != a.x:
# Adding distinct x coordinates
l = (a.y - self.y) / (a.x - self.x)
x3 = l**2 - self.x - a.x
y3 = l * (self.x - x3) - self.y
return GE(x3, y3)
if self.y == a.y:
2022-10-07 11:59:24 -07:00
# Adding point to itself
return self.double()
# Adding point to its negation
return None
2022-10-07 11:59:24 -07:00
def __radd__(self, a):
"""Add infinity to a point."""
assert a is None
return self
def __mul__(self, a):
"""Multiply a point with an integer (scalar multiplication)."""
r = None
for i in range(a.bit_length() - 1, -1, -1):
if r is not None:
r = r.double()
if (a >> i) & 1:
r += self
return r
def __rmul__(self, a):
"""Multiply an integer with a point (scalar multiplication)."""
return self * a
def lift_x(x):
"""Take an FE, and return the point with that as X coordinate, and square Y."""
y = (FE(x)**3 + 7).sqrt()
if y is None:
return None
return GE(x, y)
def is_valid_x(x):
"""Determine whether the provided field element is a valid X coordinate."""
return (FE(x)**3 + 7).is_square()
def __str__(self):
"""Convert this group element to a string."""
return f"({self.x},{self.y})"
def __repr__(self):
"""Get a string representation for this group element."""
return f"GE(0x{int(self.x)},0x{int(self.y)})"
2022-10-07 11:59:24 -07:00
SECP256K1_G = GE(
### ElligatorSwift
# Precomputed constant square root of -3 (mod p).
2022-10-07 11:59:24 -07:00
MINUS_3_SQRT = FE(-3).sqrt()
def xswiftec(u, t):
"""Decode field elements (u, t) to an X coordinate on the curve."""
if u == 0:
u = FE(1)
if t == 0:
t = FE(1)
if u**3 + t**2 + 7 == 0:
t = 2 * t
X = (u**3 + 7 - t**2) / (2 * t)
Y = (X + t) / (MINUS_3_SQRT * u)
for x in (u + 4 * Y**2, (-X / Y - u) / 2, (X / Y - u) / 2):
if GE.is_valid_x(x):
return x
assert False
def xswiftec_inv(x, u, case):
"""Given x and u, find t such that xswiftec(u, t) = x, or return None.
Case selects which of the up to 8 results to return."""
if case & 2 == 0:
if GE.is_valid_x(-x - u):
return None
v = x
2022-10-07 11:59:24 -07:00
s = -(u**3 + 7) / (u**2 + u*v + v**2)
s = x - u
if s == 0:
return None
r = (-s * (4 * (u**3 + 7) + 3 * s * u**2)).sqrt()
if r is None:
return None
if case & 1 and r == 0:
return None
2022-10-07 11:59:24 -07:00
v = (-u + r / s) / 2
w = s.sqrt()
if w is None:
return None
if case & 5 == 0: return -w * (u * (1 - MINUS_3_SQRT) / 2 + v)
if case & 5 == 1: return w * (u * (1 + MINUS_3_SQRT) / 2 + v)
if case & 5 == 4: return w * (u * (1 - MINUS_3_SQRT) / 2 + v)
if case & 5 == 5: return -w * (u * (1 + MINUS_3_SQRT) / 2 + v)
2022-10-07 11:59:24 -07:00
def xelligatorswift(x):
"""Given a field element X on the curve, find (u, t) that encode them."""
while True:
u = FE(random.randrange(1, GE.ORDER))
case = random.randrange(0, 8)
t = xswiftec_inv(x, u, case)
if t is not None:
return u, t
def ellswift_create():
"""Generate a (privkey, ellswift_pubkey) pair."""
priv = random.randrange(1, GE.ORDER)
u, t = xelligatorswift((priv * SECP256K1_G).x)
return priv.to_bytes(32, 'big'), u.to_bytes() + t.to_bytes()
def ellswift_decode(ellswift):
"""Convert ellswift encoded X coordinate to 32-byte xonly format."""
u = FE(int.from_bytes(ellswift[:32], 'big'))
t = FE(int.from_bytes(ellswift[32:], 'big'))
return xswiftec(u, t).to_bytes()
2022-10-07 11:59:24 -07:00
def ellswift_ecdh_xonly(pubkey_theirs, privkey):
"""Compute X coordinate of shared ECDH point between elswift pubkey and privkey."""
d = int.from_bytes(privkey, 'big')
pub = ellswift_decode(pubkey_theirs)
return (d * GE.lift_x(FE.from_bytes(pub))).x.to_bytes()
2022-10-07 11:59:24 -07:00
### Poly1305
class Poly1305:
"""Class representing a running poly1305 computation."""
MODULUS = 2**130 - 5
def __init__(self, key):
self.r = int.from_bytes(key[:16], 'little') & 0xffffffc0ffffffc0ffffffc0fffffff
self.s = int.from_bytes(key[16:], 'little')
self.acc = 0
def add(self, msg, length=None, pad=False):
"""Add a message of any length. Input so far must be a multiple of 16 bytes."""
length = len(msg) if length is None else length
for i in range((length + 15) // 16):
chunk = msg[i * 16:i * 16 + min(16, length - i * 16)]
val = int.from_bytes(chunk, 'little') + 256**(16 if pad else len(chunk))
self.acc = (self.r * (self.acc + val)) % Poly1305.MODULUS
return self
def tag(self):
"""Compute the poly1305 tag."""
return ((self.acc + self.s) & 0xffffffffffffffffffffffffffffffff).to_bytes(16, 'little')
### ChaCha20
(0, 4, 8, 12), (1, 5, 9, 13), (2, 6, 10, 14), (3, 7, 11, 15),
(0, 5, 10, 15), (1, 6, 11, 12), (2, 7, 8, 13), (3, 4, 9, 14)
CHACHA20_CONSTANTS = (0x61707865, 0x3320646e, 0x79622d32, 0x6b206574)
def rotl32(v, bits):
"""Rotate the 32-bit value v left by bits bits."""
return ((v << bits) & 0xffffffff) | (v >> (32 - bits))
def chacha20_doubleround(s):
"""Apply a ChaCha20 double round to 16-element state array s.
See and
for a, b, c, d in CHACHA20_INDICES:
s[a] = (s[a] + s[b]) & 0xffffffff
s[d] = rotl32(s[d] ^ s[a], 16)
s[c] = (s[c] + s[d]) & 0xffffffff
s[b] = rotl32(s[b] ^ s[c], 12)
s[a] = (s[a] + s[b]) & 0xffffffff
s[d] = rotl32(s[d] ^ s[a], 8)
s[c] = (s[c] + s[d]) & 0xffffffff
s[b] = rotl32(s[b] ^ s[c], 7)
def chacha20_block(key, nonce, cnt):
"""Compute the 64-byte output of the ChaCha20 block function.
Takes as input a 32-byte key, 12-byte nonce, and 32-bit integer counter.
# Initial state.
init = [0 for _ in range(16)]
for i in range(4):
init[i] = CHACHA20_CONSTANTS[i]
for i in range(8):
init[4 + i] = int.from_bytes(key[4 * i:4 * (i+1)], 'little')
init[12] = cnt
for i in range(3):
init[13 + i] = int.from_bytes(nonce[4 * i:4 * (i+1)], 'little')
# Perform 20 rounds.
state = list(init)
2022-10-07 11:59:24 -07:00
for _ in range(10):
# Add initial values back into state.
for i in range(16):
state[i] = (state[i] + init[i]) & 0xffffffff
# Produce byte output
return b''.join(state[i].to_bytes(4, 'little') for i in range(16))
### ChaCha20Poly1305
def aead_chacha20_poly1305_encrypt(key, nonce, aad, plaintext):
"""Encrypt a plaintext using ChaCha20Poly1305."""
ret = bytearray()
msg_len = len(plaintext)
for i in range((msg_len + 63) // 64):
now = min(64, msg_len - 64 * i)
keystream = chacha20_block(key, nonce, i + 1)
for j in range(now):
ret.append(plaintext[j + 64 * i] ^ keystream[j])
poly1305 = Poly1305(chacha20_block(key, nonce, 0)[:32])
poly1305.add(aad, pad=True).add(ret, pad=True)
poly1305.add(len(aad).to_bytes(8, 'little') + msg_len.to_bytes(8, 'little'))
ret += poly1305.tag()
return bytes(ret)
def aead_chacha20_poly1305_decrypt(key, nonce, aad, ciphertext):
"""Decrypt a ChaCha20Poly1305 ciphertext."""
if len(ciphertext) < 16:
return None
msg_len = len(ciphertext) - 16
poly1305 = Poly1305(chacha20_block(key, nonce, 0)[:32])
poly1305.add(aad, pad=True)
poly1305.add(ciphertext, length=msg_len, pad=True)
poly1305.add(len(aad).to_bytes(8, 'little') + msg_len.to_bytes(8, 'little'))
if ciphertext[-16:] != poly1305.tag():
return None
ret = bytearray()
for i in range((msg_len + 63) // 64):
now = min(64, msg_len - 64 * i)
keystream = chacha20_block(key, nonce, i + 1)
for j in range(now):
ret.append(ciphertext[j + 64 * i] ^ keystream[j])
return bytes(ret)
### FSChaCha20{,Poly1305}
REKEY_INTERVAL = 224 # packets
class FSChaCha20Poly1305:
"""Rekeying wrapper AEAD around ChaCha20Poly1305."""
def __init__(self, initial_key):
self.key = initial_key
self.packet_counter = 0
def crypt(self, aad, text, is_decrypt):
"""Encrypt or decrypt the specified (plain/cipher)text."""
2022-10-07 11:59:24 -07:00
nonce = ((self.packet_counter % REKEY_INTERVAL).to_bytes(4, 'little') +
(self.packet_counter // REKEY_INTERVAL).to_bytes(8, 'little'))
if is_decrypt:
ret = aead_chacha20_poly1305_decrypt(self.key, nonce, aad, text)
ret = aead_chacha20_poly1305_encrypt(self.key, nonce, aad, text)
if (self.packet_counter + 1) % REKEY_INTERVAL == 0:
rekey_nonce = b"\xFF\xFF\xFF\xFF" + nonce[4:]
newkey1 = aead_chacha20_poly1305_encrypt(self.key, rekey_nonce, b"", b"\x00" * 32)[:32]
newkey2 = chacha20_block(self.key, rekey_nonce, 1)[:32]
assert newkey1 == newkey2
self.key = newkey1
self.packet_counter += 1
return ret
def encrypt(self, aad, plaintext):
"""Encrypt the specified plaintext with provided AAD."""
2022-10-07 11:59:24 -07:00
return self.crypt(aad, plaintext, False)
def decrypt(self, aad, ciphertext):
"""Decrypt the specified ciphertext with provided AAD."""
return self.crypt(aad, ciphertext, True)
2022-10-07 11:59:24 -07:00
class FSChaCha20:
"""Rekeying wrapper stream cipher around ChaCha20."""
def __init__(self, initial_key):
self.key = initial_key
self.block_counter = 0
self.chunk_counter = 0
self.keystream = b''
def get_keystream_bytes(self, nbytes):
"""Generate nbytes keystream bytes."""
2022-10-07 11:59:24 -07:00
while len(self.keystream) < nbytes:
nonce = ((0).to_bytes(4, 'little') +
(self.chunk_counter // REKEY_INTERVAL).to_bytes(8, 'little'))
self.keystream += chacha20_block(self.key, nonce, self.block_counter)
self.block_counter += 1
ret = self.keystream[:nbytes]
self.keystream = self.keystream[nbytes:]
return ret
def crypt(self, chunk):
"""Encrypt or decypt chunk."""
2022-10-07 11:59:24 -07:00
ks = self.get_keystream_bytes(len(chunk))
ret = bytes([ks[i] ^ chunk[i] for i in range(len(chunk))])
if ((self.chunk_counter + 1) % REKEY_INTERVAL) == 0:
self.key = self.get_keystream_bytes(32)
self.block_counter = 0
self.chunk_counter += 1
return ret
def encrypt(self, chunk):
"""Encrypt chunk."""
return self.crypt(chunk)
def decrypt(self, chunk):
"""Decrypt chunk."""
return self.crypt(chunk)
2022-10-07 11:59:24 -07:00
### Shared secret computation
def v2_ecdh(priv, ellswift_theirs, ellswift_ours, initiating):
"""Compute BIP324 shared secret."""
ecdh_point_x32 = ellswift_ecdh_xonly(ellswift_theirs, priv)
if initiating:
# Initiating, place our public key encoding first.
return TaggedHash("bip324_ellswift_xonly_ecdh",
ellswift_ours + ellswift_theirs + ecdh_point_x32)
# Responding, place their public key encoding first.
return TaggedHash("bip324_ellswift_xonly_ecdh",
ellswift_theirs + ellswift_ours + ecdh_point_x32)
2022-10-07 11:59:24 -07:00
### Key derivation
NETWORK_MAGIC = b'\xf9\xbe\xb4\xd9'
def initialize_v2_transport(ecdh_secret, initiating):
"""Return a peer object with various BIP324 derived keys and ciphers."""
peer = {}
salt = b'bitcoin_v2_shared_secret' + NETWORK_MAGIC
for name, length in (
('initiator_L', 32), ('initiator_P', 32), ('responder_L', 32), ('responder_P', 32),
('garbage_terminators', 32), ('session_id', 32)):
peer[name] = hkdf_sha256(
salt=salt, ikm=ecdh_secret, info=name.encode('utf-8'), length=length)
peer['initiator_garbage_terminator'] = peer['garbage_terminators'][:16]
peer['responder_garbage_terminator'] = peer['garbage_terminators'][16:]
del peer['garbage_terminators']
if initiating:
peer['send_L'] = FSChaCha20(peer['initiator_L'])
peer['send_P'] = FSChaCha20Poly1305(peer['initiator_P'])
peer['send_garbage_terminator'] = peer['initiator_garbage_terminator']
peer['recv_L'] = FSChaCha20(peer['responder_L'])
peer['recv_P'] = FSChaCha20Poly1305(peer['responder_P'])
peer['recv_garbage_terminator'] = peer['responder_garbage_terminator']
peer['send_L'] = FSChaCha20(peer['responder_L'])
peer['send_P'] = FSChaCha20Poly1305(peer['responder_P'])
peer['send_garbage_terminator'] = peer['responder_garbage_terminator']
peer['recv_L'] = FSChaCha20(peer['initiator_L'])
peer['recv_P'] = FSChaCha20Poly1305(peer['initiator_P'])
peer['recv_garbage_terminator'] = peer['initiator_garbage_terminator']
return peer
### Packet encryption
def v2_enc_packet(peer, contents, aad=b'', ignore=False):
"""Encrypt a BIP324 packet."""
assert len(contents) <= 2**24 - 1
header = (ignore << IGNORE_BIT_POS).to_bytes(HEADER_LEN, 'little')
plaintext = header + contents
aead_ciphertext = peer['send_P'].encrypt(aad, plaintext)
enc_plaintext_len = peer['send_L'].encrypt(len(contents).to_bytes(LENGTH_FIELD_LEN, 'little'))
2022-10-07 11:59:24 -07:00
return enc_plaintext_len + aead_ciphertext