diff --git a/nip44/nip44.go b/nip44/nip44.go index 8aebe70..c666d49 100644 --- a/nip44/nip44.go +++ b/nip44/nip44.go @@ -7,32 +7,59 @@ import ( "crypto/sha256" "encoding/base64" "encoding/binary" - "encoding/hex" "errors" "fmt" "io" "math" - "math/big" - "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/nbd-wtf/go-nostr/nip04" "golang.org/x/crypto/chacha20" "golang.org/x/crypto/hkdf" ) +const version byte = 2 + var ( MinPlaintextSize = 0x0001 // 1b msg => padded to 32b MaxPlaintextSize = 0xffff // 65535 (64kb-1) => padded to 64kb ) -type EncryptOptions struct { - Salt []byte - Version int +type encryptOptions struct { + err error + salt []byte } -func Encrypt(conversationKey []byte, plaintext string, options *EncryptOptions) (string, error) { +func WithCustomSalt(salt []byte) func(opts *encryptOptions) { + return func(opts *encryptOptions) { + if len(salt) != 32 { + opts.err = errors.New("salt must be 32 bytes") + } + opts.salt = salt + } +} + +func Encrypt(plaintext string, conversationKey []byte, applyOptions ...func(opts *encryptOptions)) (string, error) { + opts := encryptOptions{ + salt: nil, + } + + for _, apply := range applyOptions { + apply(&opts) + } + + if opts.err != nil { + return "", opts.err + } + + salt := opts.salt + if salt == nil { + salt := make([]byte, 32) + if _, err := rand.Read(salt); err != nil { + return "", err + } + } + var ( - version int = 2 - salt []byte enc []byte nonce []byte auth []byte @@ -42,22 +69,7 @@ func Encrypt(conversationKey []byte, plaintext string, options *EncryptOptions) concat []byte err error ) - if options.Version != 0 { - version = options.Version - } - if options.Salt != nil { - salt = options.Salt - } else { - if salt, err = randomBytes(32); err != nil { - return "", err - } - } - if version != 2 { - return "", errors.New(fmt.Sprintf("unknown version %d", version)) - } - if len(salt) != 32 { - return "", errors.New("salt must be 32 bytes") - } + if enc, nonce, auth, err = messageKeys(conversationKey, salt); err != nil { return "", err } @@ -70,16 +82,15 @@ func Encrypt(conversationKey []byte, plaintext string, options *EncryptOptions) if hmac_, err = sha256Hmac(auth, ciphertext, salt); err != nil { return "", err } - concat = append(concat, []byte{byte(version)}...) + concat = append(concat, []byte{version}...) concat = append(concat, salt...) concat = append(concat, ciphertext...) concat = append(concat, hmac_...) return base64.StdEncoding.EncodeToString(concat), nil } -func Decrypt(conversationKey []byte, ciphertext string) (string, error) { +func Decrypt(ciphertext string, conversationKey []byte) (string, error) { var ( - version int = 2 decoded []byte cLen int dLen int @@ -105,8 +116,8 @@ func Decrypt(conversationKey []byte, ciphertext string) (string, error) { if decoded, err = base64.StdEncoding.DecodeString(ciphertext); err != nil { return "", errors.New("invalid base64") } - if version = int(decoded[0]); version != 2 { - return "", errors.New(fmt.Sprintf("unknown version %d", version)) + if decoded[0] != version { + return "", errors.New(fmt.Sprintf("unknown version %d", decoded[0])) } dLen = len(decoded) if dLen < 99 || dLen > 65603 { @@ -136,24 +147,15 @@ func Decrypt(conversationKey []byte, ciphertext string) (string, error) { return string(unpadded), nil } -func GenerateConversationKey(sendPrivkey []byte, recvPubkey []byte) ([]byte, error) { - var ( - N = secp256k1.S256().N - sk *secp256k1.PrivateKey - pk *secp256k1.PublicKey - err error - ) - // make sure that private key is on curve before using unsafe secp256k1.PrivKeyFromBytes - // see https://pkg.go.dev/github.com/decred/dcrd/dcrec/secp256k1/v4#PrivKeyFromBytes - skX := new(big.Int).SetBytes(sendPrivkey) - if skX.Cmp(big.NewInt(0)) == 0 || skX.Cmp(N) >= 0 { - return []byte{}, fmt.Errorf("invalid private key: x coordinate %s is not on the secp256k1 curve", hex.EncodeToString(sendPrivkey)) +func GenerateConversationKey(pub string, sk string) ([]byte, error) { + if sk >= "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141" || sk == "0000000000000000000000000000000000000000000000000000000000000000" { + return nil, fmt.Errorf("invalid private key: x coordinate %s is not on the secp256k1 curve", sk) } - sk = secp256k1.PrivKeyFromBytes(sendPrivkey) - if pk, err = secp256k1.ParsePubKey(recvPubkey); err != nil { - return []byte{}, err + + shared, err := nip04.ComputeSharedSecret(pub, sk) + if err != nil { + return nil, err } - shared := secp256k1.GenerateSharedSecret(sk, pk) return hkdf.Extract(sha256.New, shared, []byte("nip44-v2")), nil } @@ -170,14 +172,6 @@ func chacha20_(key []byte, nonce []byte, message []byte) ([]byte, error) { return dst, nil } -func randomBytes(n int) ([]byte, error) { - buf := make([]byte, n) - if _, err := rand.Read(buf); err != nil { - return nil, err - } - return buf, nil -} - func sha256Hmac(key []byte, ciphertext []byte, aad []byte) ([]byte, error) { if len(aad) != 32 { return nil, errors.New("aad data must be 32 bytes") diff --git a/nip44/nip44_test.go b/nip44/nip44_test.go index f7c14d6..5f3576a 100644 --- a/nip44/nip44_test.go +++ b/nip44/nip44_test.go @@ -6,7 +6,7 @@ import ( "hash" "testing" - "github.com/decred/dcrd/dcrec/secp256k1/v4" + "github.com/nbd-wtf/go-nostr" "github.com/stretchr/testify/assert" ) @@ -30,14 +30,14 @@ func assertCryptPriv(t *testing.T, sk1 string, sk2 string, conversationKey strin if ok = assert.NoErrorf(t, err, "hex decode failed for salt: %v", err); !ok { return } - actual, err = Encrypt(k1, plaintext, &EncryptOptions{Salt: s}) + actual, err = Encrypt(plaintext, k1, WithCustomSalt(s)) if ok = assert.NoError(t, err, "encryption failed: %v", err); !ok { return } if ok = assert.Equalf(t, expected, actual, "wrong encryption"); !ok { return } - decrypted, err = Decrypt(k1, expected) + decrypted, err = Decrypt(expected, k1) if ok = assert.NoErrorf(t, err, "decryption failed: %v", err); !ok { return } @@ -64,14 +64,14 @@ func assertCryptPub(t *testing.T, sk1 string, pub2 string, conversationKey strin if ok = assert.NoErrorf(t, err, "hex decode failed for salt: %v", err); !ok { return } - actual, err = Encrypt(k1, plaintext, &EncryptOptions{Salt: s}) + actual, err = Encrypt(plaintext, k1, WithCustomSalt(s)) if ok = assert.NoError(t, err, "encryption failed: %v", err); !ok { return } if ok = assert.Equalf(t, expected, actual, "wrong encryption"); !ok { return } - decrypted, err = Decrypt(k1, expected) + decrypted, err = Decrypt(expected, k1) if ok = assert.NoErrorf(t, err, "decryption failed: %v", err); !ok { return } @@ -88,30 +88,16 @@ func assertDecryptFail(t *testing.T, conversationKey string, plaintext string, c if ok = assert.NoErrorf(t, err, "hex decode failed for conversation key: %v", err); !ok { return } - _, err = Decrypt(k1, ciphertext) + _, err = Decrypt(ciphertext, k1) assert.ErrorContains(t, err, msg) } -func assertConversationKeyFail(t *testing.T, sk1 string, pub2 string, msg string) { - var ( - sk1Decoded []byte - pub2Decoded []byte - ok bool - err error - ) - sk1Decoded, err = hex.DecodeString(sk1) - if ok = assert.NoErrorf(t, err, "hex decode failed for sk1: %v", err); !ok { - return - } - pub2Decoded, err = hex.DecodeString("02" + pub2) - if ok = assert.NoErrorf(t, err, "hex decode failed for pub2: %v", err); !ok { - return - } - _, err = GenerateConversationKey(sk1Decoded, pub2Decoded) +func assertConversationKeyFail(t *testing.T, priv string, pub string, msg string) { + _, err := GenerateConversationKey(pub, priv) assert.ErrorContains(t, err, msg) } -func assertConversationKeyGeneration(t *testing.T, sendPrivkey []byte, recvPubkey []byte, conversationKey string) bool { +func assertConversationKeyGeneration(t *testing.T, priv string, pub string, conversationKey string) bool { var ( actualConversationKey []byte expectedConversationKey []byte @@ -122,7 +108,7 @@ func assertConversationKeyGeneration(t *testing.T, sendPrivkey []byte, recvPubke if ok = assert.NoErrorf(t, err, "hex decode failed for conversation key: %v", err); !ok { return false } - actualConversationKey, err = GenerateConversationKey(sendPrivkey, recvPubkey) + actualConversationKey, err = GenerateConversationKey(pub, priv) if ok = assert.NoErrorf(t, err, "conversation key generation failed: %v", err); !ok { return false } @@ -133,48 +119,15 @@ func assertConversationKeyGeneration(t *testing.T, sendPrivkey []byte, recvPubke } func assertConversationKeyGenerationSec(t *testing.T, sk1 string, sk2 string, conversationKey string) bool { - var ( - sk1Decoded []byte - pub2Decoded []byte - ok bool - err error - ) - sk1Decoded, err = hex.DecodeString(sk1) - if ok = assert.NoErrorf(t, err, "hex decode failed for sk1: %v", err); !ok { + pub2, err := nostr.GetPublicKey(sk2) + if ok := assert.NoErrorf(t, err, "failed to derive pubkey from sk2: %v", err); !ok { return false } - if decoded, err := hex.DecodeString(sk2); err == nil { - pub2Decoded = secp256k1.PrivKeyFromBytes(decoded).PubKey().SerializeCompressed() - } - if ok = assert.NoErrorf(t, err, "hex decode failed for sk2: %v", err); !ok { - return false - } - return assertConversationKeyGeneration(t, sk1Decoded, pub2Decoded, conversationKey) + return assertConversationKeyGeneration(t, sk1, pub2, conversationKey) } -func assertConversationKeyGenerationPub(t *testing.T, sk1 string, pub2 string, conversationKey string) bool { - var ( - sk1Decoded []byte - pub2Decoded []byte - ok bool - err error - ) - sk1Decoded, err = hex.DecodeString(sk1) - if ok = assert.NoErrorf(t, err, "hex decode failed for sk1: %v", err); !ok { - return false - } - if decoded, err := hex.DecodeString("02" + pub2); err == nil { - if recvPubkey, err := secp256k1.ParsePubKey(decoded); err == nil { - pub2Decoded = recvPubkey.SerializeCompressed() - } - if ok = assert.NoErrorf(t, err, "parse pubkey failed: %v", err); !ok { - return false - } - } - if ok = assert.NoErrorf(t, err, "hex decode failed for pub2: %v", err); !ok { - return false - } - return assertConversationKeyGeneration(t, sk1Decoded, pub2Decoded, conversationKey) +func assertConversationKeyGenerationPub(t *testing.T, sk string, pub string, conversationKey string) bool { + return assertConversationKeyGeneration(t, sk, pub, conversationKey) } func assertMessageKeyGeneration(t *testing.T, conversationKey string, salt string, chachaKey string, chachaSalt string, hmacKey string) bool { @@ -256,7 +209,7 @@ func assertCryptLong(t *testing.T, conversationKey string, salt string, pattern if ok = assert.Equalf(t, plaintextSha256, actualPlaintextSha256, "invalid plaintext sha256 hash: %v", err); !ok { return } - actualPayload, err = Encrypt(convKey, plaintext, &EncryptOptions{Salt: convSalt}) + actualPayload, err = Encrypt(plaintext, convKey, WithCustomSalt(convSalt)) if ok = assert.NoErrorf(t, err, "encryption failed: %v", err); !ok { return }