diff --git a/nip44/nip44.go b/nip44/nip44.go index 96d94ae..2c3a269 100644 --- a/nip44/nip44.go +++ b/nip44/nip44.go @@ -7,12 +7,14 @@ import ( "crypto/sha256" "encoding/base64" "encoding/binary" + "encoding/hex" "errors" "fmt" "io" "math" - "github.com/nbd-wtf/go-nostr/nip04" + "github.com/btcsuite/btcd/btcec/v2" + "github.com/decred/dcrd/dcrec/secp256k1/v4" "golang.org/x/crypto/chacha20" "golang.org/x/crypto/hkdf" ) @@ -41,7 +43,7 @@ func WithCustomNonce(salt []byte) func(opts *encryptOptions) { } } -func Encrypt(plaintext string, conversationKey []byte, applyOptions ...func(opts *encryptOptions)) (string, error) { +func Encrypt(plaintext string, conversationKey [32]byte, applyOptions ...func(opts *encryptOptions)) (string, error) { opts := encryptOptions{} for _, apply := range applyOptions { apply(&opts) @@ -94,7 +96,7 @@ func Encrypt(plaintext string, conversationKey []byte, applyOptions ...func(opts return base64.StdEncoding.EncodeToString(concat), nil } -func Decrypt(b64ciphertextWrapped string, conversationKey []byte) (string, error) { +func Decrypt(b64ciphertextWrapped string, conversationKey [32]byte) (string, error) { cLen := len(b64ciphertextWrapped) if cLen < 132 || cLen > 87472 { return "", errors.New(fmt.Sprintf("invalid payload length: %d", cLen)) @@ -151,16 +153,22 @@ func Decrypt(b64ciphertextWrapped string, conversationKey []byte) (string, error return string(unpadded), nil } -func GenerateConversationKey(pub string, sk string) ([]byte, error) { +func GenerateConversationKey(pub string, sk string) ([32]byte, error) { + var ck [32]byte + if sk >= "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141" || sk == "0000000000000000000000000000000000000000000000000000000000000000" { - return nil, fmt.Errorf("invalid private key: x coordinate %s is not on the secp256k1 curve", sk) + return ck, fmt.Errorf("invalid private key: x coordinate %s is not on the secp256k1 curve", sk) } - shared, err := nip04.ComputeSharedSecret(pub, sk) + shared, err := computeSharedSecret(pub, sk) if err != nil { - return nil, err + return ck, err } - return hkdf.Extract(sha256.New, shared, []byte("nip44-v2")), nil + + buf := hkdf.Extract(sha256.New, shared[:], []byte("nip44-v2")) + copy(ck[:], buf) + + return ck, nil } func chacha(key []byte, nonce []byte, message []byte) ([]byte, error) { @@ -184,15 +192,12 @@ func sha256Hmac(key []byte, ciphertext []byte, nonce []byte) ([]byte, error) { return h.Sum(nil), nil } -func messageKeys(conversationKey []byte, nonce []byte) ([]byte, []byte, []byte, error) { - if len(conversationKey) != 32 { - return nil, nil, nil, errors.New("conversation key must be 32 bytes") - } +func messageKeys(conversationKey [32]byte, nonce []byte) ([]byte, []byte, []byte, error) { if len(nonce) != 32 { return nil, nil, nil, errors.New("nonce must be 32 bytes") } - r := hkdf.Expand(sha256.New, conversationKey, nonce) + r := hkdf.Expand(sha256.New, conversationKey[:], nonce) enc := make([]byte, 32) if _, err := io.ReadFull(r, enc); err != nil { return nil, nil, nil, err @@ -219,3 +224,29 @@ func calcPadding(sLen int) int { chunk := int(math.Max(32, float64(nextPower/8))) return chunk * int(math.Floor(float64((sLen-1)/chunk))+1) } + +func computeSharedSecret(pub string, sk string) (sharedSecret [32]byte, err error) { + privKeyBytes, err := hex.DecodeString(sk) + if err != nil { + return sharedSecret, fmt.Errorf("error decoding sender private key: %w", err) + } + privKey, _ := btcec.PrivKeyFromBytes(privKeyBytes) + + // adding 02 to signal that this is a compressed public key (33 bytes) + pubKeyBytes, err := hex.DecodeString("02" + pub) + if err != nil { + return sharedSecret, fmt.Errorf("error decoding hex string of receiver public key '%s': %w", "02"+pub, err) + } + pubKey, err := btcec.ParsePubKey(pubKeyBytes) + if err != nil { + return sharedSecret, fmt.Errorf("error parsing receiver public key '%s': %w", "02"+pub, err) + } + + var point, result secp256k1.JacobianPoint + pubKey.AsJacobian(&point) + secp256k1.ScalarMultNonConst(&privKey.Key, &point, &result) + result.ToAffine() + + result.X.PutBytesUnchecked(sharedSecret[:]) + return sharedSecret, nil +} diff --git a/nip44/nip44_test.go b/nip44/nip44_test.go index 89f2fb6..3ee01a5 100644 --- a/nip44/nip44_test.go +++ b/nip44/nip44_test.go @@ -15,14 +15,14 @@ import ( func assertCryptPriv(t *testing.T, sk1 string, sk2 string, conversationKey string, salt string, plaintext string, expected string) { var ( - k1 []byte + k1 [32]byte s []byte actual string decrypted string ok bool err error ) - k1, err = hex.DecodeString(conversationKey) + k1, err = hexDecode32Array(conversationKey) if ok = assert.NoErrorf(t, err, "hex decode failed for conversation key: %v", err); !ok { return } @@ -49,11 +49,11 @@ func assertCryptPriv(t *testing.T, sk1 string, sk2 string, conversationKey strin func assertDecryptFail(t *testing.T, conversationKey string, _ string, ciphertext string, msg string) { var ( - k1 []byte + k1 [32]byte ok bool err error ) - k1, err = hex.DecodeString(conversationKey) + k1, err = hexDecode32Array(conversationKey) if ok = assert.NoErrorf(t, err, "hex decode failed for conversation key: %v", err); !ok { return } @@ -68,12 +68,12 @@ func assertConversationKeyFail(t *testing.T, priv string, pub string, msg string func assertConversationKeyGeneration(t *testing.T, priv string, pub string, conversationKey string) bool { var ( - actualConversationKey []byte - expectedConversationKey []byte + actualConversationKey [32]byte + expectedConversationKey [32]byte ok bool err error ) - expectedConversationKey, err = hex.DecodeString(conversationKey) + expectedConversationKey, err = hexDecode32Array(conversationKey) if ok = assert.NoErrorf(t, err, "hex decode failed for conversation key: %v", err); !ok { return false } @@ -101,7 +101,7 @@ func assertConversationKeyGenerationPub(t *testing.T, sk string, pub string, con func assertMessageKeyGeneration(t *testing.T, conversationKey string, salt string, chachaKey string, chachaSalt string, hmacKey string) bool { var ( - convKey []byte + convKey [32]byte convSalt []byte actualChaChaKey []byte expectedChaChaKey []byte @@ -112,7 +112,7 @@ func assertMessageKeyGeneration(t *testing.T, conversationKey string, salt strin ok bool err error ) - convKey, err = hex.DecodeString(conversationKey) + convKey, err = hexDecode32Array(conversationKey) if ok = assert.NoErrorf(t, err, "hex decode failed for convKey: %v", err); !ok { return false } @@ -150,7 +150,7 @@ func assertMessageKeyGeneration(t *testing.T, conversationKey string, salt strin func assertCryptLong(t *testing.T, conversationKey string, salt string, pattern string, repeat int, plaintextSha256 string, payloadSha256 string) { var ( - convKey []byte + convKey [32]byte convSalt []byte plaintext string actualPlaintextSha256 string @@ -160,7 +160,7 @@ func assertCryptLong(t *testing.T, conversationKey string, salt string, pattern ok bool err error ) - convKey, err = hex.DecodeString(conversationKey) + convKey, err = hexDecode32Array(conversationKey) if ok = assert.NoErrorf(t, err, "hex decode failed for convKey: %v", err); !ok { return } @@ -1164,14 +1164,14 @@ func TestMaxLength(t *testing.T) { func assertCryptPub(t *testing.T, sk1 string, pub2 string, conversationKey string, salt string, plaintext string, expected string) { var ( - k1 []byte + k1 [32]byte s []byte actual string decrypted string ok bool err error ) - k1, err = hex.DecodeString(conversationKey) + k1, err = hexDecode32Array(conversationKey) if ok = assert.NoErrorf(t, err, "hex decode failed for conversation key: %v", err); !ok { return } @@ -1195,3 +1195,8 @@ func assertCryptPub(t *testing.T, sk1 string, pub2 string, conversationKey strin } assert.Equal(t, decrypted, plaintext, "wrong decryption") } + +func hexDecode32Array(hexString string) (res [32]byte, err error) { + _, err = hex.Decode(res[:], []byte(hexString)) + return res, err +} diff --git a/nip46/dynamic-signer.go b/nip46/dynamic-signer.go index 57f26c6..c2e512d 100644 --- a/nip46/dynamic-signer.go +++ b/nip46/dynamic-signer.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "slices" - "strings" "sync" "github.com/mailru/easyjson" @@ -155,7 +154,7 @@ func (p *DynamicSigner) HandleRequest(event *nostr.Event) ( case "get_relays": jrelays, _ := json.Marshal(p.RelaysToAdvertise) result = string(jrelays) - case "nip04_encrypt", "nip44_encrypt": + case "nip44_encrypt": if len(req.Params) != 2 { resultErr = fmt.Errorf("wrong number of arguments to 'nip04_encrypt'") break @@ -171,25 +170,18 @@ func (p *DynamicSigner) HandleRequest(event *nostr.Event) ( } plaintext := req.Params[1] - getKey := nip04.ComputeSharedSecret - encrypt := nip04.Encrypt - if strings.HasPrefix(req.Method, "nip44") { - getKey = nip44.GenerateConversationKey - encrypt = func(message string, key []byte) (string, error) { return nip44.Encrypt(message, key) } - } - - sharedSecret, err := getKey(thirdPartyPubkey, privateKey) + sharedSecret, err := nip44.GenerateConversationKey(thirdPartyPubkey, privateKey) if err != nil { resultErr = fmt.Errorf("failed to compute shared secret: %w", err) break } - ciphertext, err := encrypt(plaintext, sharedSecret) + ciphertext, err := nip44.Encrypt(plaintext, sharedSecret) if err != nil { resultErr = fmt.Errorf("failed to encrypt: %w", err) break } result = ciphertext - case "nip04_decrypt", "nip44_decrypt": + case "nip44_decrypt": if len(req.Params) != 2 { resultErr = fmt.Errorf("wrong number of arguments to 'nip04_decrypt'") break @@ -205,19 +197,66 @@ func (p *DynamicSigner) HandleRequest(event *nostr.Event) ( } ciphertext := req.Params[1] - getKey := nip04.ComputeSharedSecret - decrypt := nip04.Decrypt - if strings.HasPrefix(req.Method, "nip44") { - getKey = nip44.GenerateConversationKey - decrypt = nip44.Decrypt - } - - sharedSecret, err := getKey(thirdPartyPubkey, privateKey) + sharedSecret, err := nip44.GenerateConversationKey(thirdPartyPubkey, privateKey) if err != nil { resultErr = fmt.Errorf("failed to compute shared secret: %w", err) break } - plaintext, err := decrypt(ciphertext, sharedSecret) + plaintext, err := nip44.Decrypt(ciphertext, sharedSecret) + if err != nil { + resultErr = fmt.Errorf("failed to encrypt: %w", err) + break + } + result = plaintext + case "nip04_encrypt": + if len(req.Params) != 2 { + resultErr = fmt.Errorf("wrong number of arguments to 'nip04_encrypt'") + break + } + thirdPartyPubkey := req.Params[0] + if !nostr.IsValidPublicKey(thirdPartyPubkey) { + resultErr = fmt.Errorf("first argument to 'nip04_encrypt' is not a pubkey string") + break + } + if !p.authorizeEncryption(event.PubKey, secret) { + resultErr = fmt.Errorf("refusing to encrypt") + break + } + plaintext := req.Params[1] + + sharedSecret, err := nip04.ComputeSharedSecret(thirdPartyPubkey, privateKey) + if err != nil { + resultErr = fmt.Errorf("failed to compute shared secret: %w", err) + break + } + ciphertext, err := nip04.Encrypt(plaintext, sharedSecret) + if err != nil { + resultErr = fmt.Errorf("failed to encrypt: %w", err) + break + } + result = ciphertext + case "nip04_decrypt": + if len(req.Params) != 2 { + resultErr = fmt.Errorf("wrong number of arguments to 'nip04_decrypt'") + break + } + thirdPartyPubkey := req.Params[0] + if !nostr.IsValidPublicKey(thirdPartyPubkey) { + resultErr = fmt.Errorf("first argument to 'nip04_decrypt' is not a pubkey string") + break + } + if !p.authorizeEncryption(event.PubKey, secret) { + resultErr = fmt.Errorf("refusing to decrypt") + break + } + ciphertext := req.Params[1] + + sharedSecret, err := nip04.ComputeSharedSecret(thirdPartyPubkey, privateKey) + if err != nil { + resultErr = fmt.Errorf("failed to compute shared secret: %w", err) + break + } + plaintext, err := nip04.Decrypt(ciphertext, sharedSecret) if err != nil { resultErr = fmt.Errorf("failed to encrypt: %w", err) break diff --git a/nip46/nip46.go b/nip46/nip46.go index 0b50517..38bb920 100644 --- a/nip46/nip46.go +++ b/nip46/nip46.go @@ -30,8 +30,8 @@ type Signer interface { } type Session struct { - SharedKey []byte // nip04 - ConversationKey []byte // nip44 + SharedKey []byte // nip04 + ConversationKey [32]byte // nip44 } type RelayReadWrite struct { diff --git a/nip46/static-key-signer.go b/nip46/static-key-signer.go index 148220d..73e9193 100644 --- a/nip46/static-key-signer.go +++ b/nip46/static-key-signer.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "slices" - "strings" "sync" "github.com/mailru/easyjson" @@ -141,7 +140,7 @@ func (p *StaticKeySigner) HandleRequest(event *nostr.Event) ( jrelays, _ := json.Marshal(p.RelaysToAdvertise) result = string(jrelays) harmless = true - case "nip04_encrypt", "nip44_encrypt": + case "nip44_encrypt": if len(req.Params) != 2 { resultErr = fmt.Errorf("wrong number of arguments to 'nip04_encrypt'") break @@ -153,25 +152,18 @@ func (p *StaticKeySigner) HandleRequest(event *nostr.Event) ( } plaintext := req.Params[1] - getKey := nip04.ComputeSharedSecret - encrypt := nip04.Encrypt - if strings.HasPrefix(req.Method, "nip44") { - getKey = nip44.GenerateConversationKey - encrypt = func(message string, key []byte) (string, error) { return nip44.Encrypt(message, key) } - } - - sharedSecret, err := getKey(thirdPartyPubkey, p.secretKey) + sharedSecret, err := nip44.GenerateConversationKey(thirdPartyPubkey, p.secretKey) if err != nil { resultErr = fmt.Errorf("failed to compute shared secret: %w", err) break } - ciphertext, err := encrypt(plaintext, sharedSecret) + ciphertext, err := nip44.Encrypt(plaintext, sharedSecret) if err != nil { resultErr = fmt.Errorf("failed to encrypt: %w", err) break } result = ciphertext - case "nip04_decrypt", "nip44_decrypt": + case "nip44_decrypt": if len(req.Params) != 2 { resultErr = fmt.Errorf("wrong number of arguments to 'nip04_decrypt'") break @@ -183,19 +175,58 @@ func (p *StaticKeySigner) HandleRequest(event *nostr.Event) ( } ciphertext := req.Params[1] - getKey := nip04.ComputeSharedSecret - decrypt := nip04.Decrypt - if strings.HasPrefix(req.Method, "nip44") { - getKey = nip44.GenerateConversationKey - decrypt = nip44.Decrypt - } - - sharedSecret, err := getKey(thirdPartyPubkey, p.secretKey) + sharedSecret, err := nip44.GenerateConversationKey(thirdPartyPubkey, p.secretKey) if err != nil { resultErr = fmt.Errorf("failed to compute shared secret: %w", err) break } - plaintext, err := decrypt(ciphertext, sharedSecret) + plaintext, err := nip44.Decrypt(ciphertext, sharedSecret) + if err != nil { + resultErr = fmt.Errorf("failed to encrypt: %w", err) + break + } + result = plaintext + case "nip04_encrypt": + if len(req.Params) != 2 { + resultErr = fmt.Errorf("wrong number of arguments to 'nip04_encrypt'") + break + } + thirdPartyPubkey := req.Params[0] + if !nostr.IsValidPublicKey(thirdPartyPubkey) { + resultErr = fmt.Errorf("first argument to 'nip04_encrypt' is not a pubkey string") + break + } + plaintext := req.Params[1] + + sharedSecret, err := nip04.ComputeSharedSecret(thirdPartyPubkey, p.secretKey) + if err != nil { + resultErr = fmt.Errorf("failed to compute shared secret: %w", err) + break + } + ciphertext, err := nip04.Encrypt(plaintext, sharedSecret) + if err != nil { + resultErr = fmt.Errorf("failed to encrypt: %w", err) + break + } + result = ciphertext + case "nip04_decrypt": + if len(req.Params) != 2 { + resultErr = fmt.Errorf("wrong number of arguments to 'nip04_decrypt'") + break + } + thirdPartyPubkey := req.Params[0] + if !nostr.IsValidPublicKey(thirdPartyPubkey) { + resultErr = fmt.Errorf("first argument to 'nip04_decrypt' is not a pubkey string") + break + } + ciphertext := req.Params[1] + + sharedSecret, err := nip04.ComputeSharedSecret(thirdPartyPubkey, p.secretKey) + if err != nil { + resultErr = fmt.Errorf("failed to compute shared secret: %w", err) + break + } + plaintext, err := nip04.Decrypt(ciphertext, sharedSecret) if err != nil { resultErr = fmt.Errorf("failed to encrypt: %w", err) break