#!/usr/bin/env python3 """Reference implementation of DLEQ BIP for secp256k1 with unit tests.""" from hashlib import sha256 import random from secp256k1 import G, GE import sys import unittest DLEQ_TAG_AUX = "BIP0374/aux" DLEQ_TAG_NONCE = "BIP0374/nonce" DLEQ_TAG_CHALLENGE = "BIP0374/challenge" def TaggedHash(tag: str, data: bytes) -> bytes: ss = sha256(tag.encode()).digest() ss += ss ss += data return sha256(ss).digest() def xor_bytes(lhs: bytes, rhs: bytes) -> bytes: assert len(lhs) == len(rhs) return bytes([lhs[i] ^ rhs[i] for i in range(len(lhs))]) def dleq_challenge( A: GE, B: GE, C: GE, R1: GE, R2: GE, m: bytes | None, G: GE, ) -> int: if m is not None: assert len(m) == 32 m = bytes([]) if m is None else m return int.from_bytes( TaggedHash( DLEQ_TAG_CHALLENGE, A.to_bytes_compressed() + B.to_bytes_compressed() + C.to_bytes_compressed() + G.to_bytes_compressed() + R1.to_bytes_compressed() + R2.to_bytes_compressed() + m, ), "big", ) def dleq_generate_proof( a: int, B: GE, r: bytes, G: GE = G, m: bytes | None = None ) -> bytes | None: assert len(r) == 32 if not (0 < a < GE.ORDER): return None if B.infinity: return None if m is not None: assert len(m) == 32 A = a * G C = a * B t = xor_bytes(a.to_bytes(32, "big"), TaggedHash(DLEQ_TAG_AUX, r)) m_prime = bytes([]) if m is None else m rand = TaggedHash( DLEQ_TAG_NONCE, t + A.to_bytes_compressed() + C.to_bytes_compressed() + m_prime ) k = int.from_bytes(rand, "big") % GE.ORDER if k == 0: return None R1 = k * G R2 = k * B e = dleq_challenge(A, B, C, R1, R2, m, G) s = (k + e * a) % GE.ORDER proof = e.to_bytes(32, "big") + s.to_bytes(32, "big") if not dleq_verify_proof(A, B, C, proof, G=G, m=m): return None return proof def dleq_verify_proof( A: GE, B: GE, C: GE, proof: bytes, G: GE = G, m: bytes | None = None ) -> bool: if A.infinity or B.infinity or C.infinity or G.infinity: return False assert len(proof) == 64 e = int.from_bytes(proof[:32], "big") s = int.from_bytes(proof[32:], "big") if s >= GE.ORDER: return False # TODO: implement subtraction operator (__sub__) for GE class to simplify these terms R1 = s * G + (-e * A) if R1.infinity: return False R2 = s * B + (-e * C) if R2.infinity: return False if e != dleq_challenge(A, B, C, R1, R2, m, G): return False return True class DLEQTests(unittest.TestCase): def test_dleq(self): seed = random.randrange(sys.maxsize) random.seed(seed) print(f"PRNG seed is: {seed}") for _ in range(10): # generate random keypairs for both parties a = random.randrange(1, GE.ORDER) A = a * G b = random.randrange(1, GE.ORDER) B = b * G # create shared secret C = a * B # create dleq proof rand_aux = random.randbytes(32) proof = dleq_generate_proof(a, B, rand_aux) self.assertTrue(proof is not None) # verify dleq proof success = dleq_verify_proof(A, B, C, proof) self.assertTrue(success) # flip a random bit in the dleq proof and check that verification fails for _ in range(5): proof_damaged = list(proof) proof_damaged[random.randrange(len(proof))] ^= 1 << ( random.randrange(8) ) success = dleq_verify_proof(A, B, C, bytes(proof_damaged)) self.assertFalse(success) # create the same dleq proof with a message message = random.randbytes(32) proof = dleq_generate_proof(a, B, rand_aux, m=message) self.assertTrue(proof is not None) # verify dleq proof with a message success = dleq_verify_proof(A, B, C, proof, m=message) self.assertTrue(success) # flip a random bit in the dleq proof and check that verification fails for _ in range(5): proof_damaged = list(proof) proof_damaged[random.randrange(len(proof))] ^= 1 << ( random.randrange(8) ) success = dleq_verify_proof(A, B, C, bytes(proof_damaged)) self.assertFalse(success)