diff --git a/bip-DLEQ/reference.py b/bip-DLEQ/reference.py index 231617ac..ac431985 100644 --- a/bip-DLEQ/reference.py +++ b/bip-DLEQ/reference.py @@ -25,11 +25,11 @@ def xor_bytes(lhs: bytes, rhs: bytes) -> bytes: def dleq_challenge( - A: GE, B: GE, C: GE, R1: GE, R2: GE, G: GE = G, m: bytes | None = None + A: GE, B: GE, C: GE, R1: GE, R2: GE, m: bytes | None, G: GE = G, ) -> int: if m is not None: assert len(m) == 32 - m = bytes([]) if m is None else m.to_bytes(32, "big") + m = bytes([]) if m is None else m return int.from_bytes( TaggedHash( DLEQ_TAG_CHALLENGE, @@ -64,10 +64,10 @@ def dleq_generate_proof( return None R1 = k * G R2 = k * B - e = dleq_challenge(A, B, C, R1, R2) + e = dleq_challenge(A, B, C, R1, R2, m) s = (k + e * a) % GE.ORDER proof = e.to_bytes(32, "big") + s.to_bytes(32, "big") - if not dleq_verify_proof(A, B, C, proof): + if not dleq_verify_proof(A, B, C, proof, m=m): return None return proof @@ -87,7 +87,7 @@ def dleq_verify_proof( R2 = s * B + (-e * C) if R2.infinity: return False - if e != dleq_challenge(A, B, C, R1, R2): + if e != dleq_challenge(A, B, C, R1, R2, m): return False return True