From 6c19aa1b5ecaff900ee99a0a73e0c1083643bddd Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Mon, 15 Jul 2024 18:57:33 -0300 Subject: [PATCH] nip44: refactor so bizarre var declarations are eliminated. --- nip44/nip44.go | 146 ++++++++++++++++++++++--------------------------- 1 file changed, 66 insertions(+), 80 deletions(-) diff --git a/nip44/nip44.go b/nip44/nip44.go index c666d49..3c036c1 100644 --- a/nip44/nip44.go +++ b/nip44/nip44.go @@ -19,7 +19,7 @@ import ( const version byte = 2 -var ( +const ( MinPlaintextSize = 0x0001 // 1b msg => padded to 32b MaxPlaintextSize = 0xffff // 65535 (64kb-1) => padded to 64kb ) @@ -59,91 +59,89 @@ func Encrypt(plaintext string, conversationKey []byte, applyOptions ...func(opts } } - var ( - enc []byte - nonce []byte - auth []byte - padded []byte - ciphertext []byte - hmac_ []byte - concat []byte - err error - ) + enc, nonce, auth, err := messageKeys(conversationKey, salt) + if err != nil { + return "", err + } - if enc, nonce, auth, err = messageKeys(conversationKey, salt); err != nil { + padded, err := pad(plaintext) + if err != nil { return "", err } - if padded, err = pad(plaintext); err != nil { + + ciphertext, err := chacha20_(enc, nonce, []byte(padded)) + if err != nil { return "", err } - if ciphertext, err = chacha20_(enc, nonce, []byte(padded)); err != nil { + + hmac_, err := sha256Hmac(auth, ciphertext, salt) + if err != nil { return "", err } - if hmac_, err = sha256Hmac(auth, ciphertext, salt); err != nil { - return "", err - } - concat = append(concat, []byte{version}...) - concat = append(concat, salt...) - concat = append(concat, ciphertext...) - concat = append(concat, hmac_...) + + concat := make([]byte, 1+len(salt)+len(ciphertext)+len(hmac_)) + concat[0] = version + copy(concat[1:], salt) + copy(concat[1+len(salt):], ciphertext) + copy(concat[1+len(salt)+len(ciphertext):], hmac_) + return base64.StdEncoding.EncodeToString(concat), nil } func Decrypt(ciphertext string, conversationKey []byte) (string, error) { - var ( - decoded []byte - cLen int - dLen int - salt []byte - ciphertext_ []byte - hmac []byte - hmac_ []byte - enc []byte - nonce []byte - auth []byte - padded []byte - unpaddedLen uint16 - unpadded []byte - err error - ) - cLen = len(ciphertext) + cLen := len(ciphertext) if cLen < 132 || cLen > 87472 { return "", errors.New(fmt.Sprintf("invalid payload length: %d", cLen)) } if ciphertext[0:1] == "#" { return "", errors.New("unknown version") } - if decoded, err = base64.StdEncoding.DecodeString(ciphertext); err != nil { + + decoded, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { return "", errors.New("invalid base64") } + if decoded[0] != version { return "", errors.New(fmt.Sprintf("unknown version %d", decoded[0])) } - dLen = len(decoded) + + dLen := len(decoded) if dLen < 99 || dLen > 65603 { return "", errors.New(fmt.Sprintf("invalid data length: %d", dLen)) } - salt, ciphertext_, hmac_ = decoded[1:33], decoded[33:dLen-32], decoded[dLen-32:] - if enc, nonce, auth, err = messageKeys(conversationKey, salt); err != nil { + + salt, ciphertext_, hmac_ := decoded[1:33], decoded[33:dLen-32], decoded[dLen-32:] + enc, nonce, auth, err := messageKeys(conversationKey, salt) + if err != nil { return "", err } - if hmac, err = sha256Hmac(auth, ciphertext_, salt); err != nil { + + hmac, err := sha256Hmac(auth, ciphertext_, salt) + if err != nil { return "", err } + if !bytes.Equal(hmac_, hmac) { return "", errors.New("invalid hmac") } - if padded, err = chacha20_(enc, nonce, ciphertext_); err != nil { + + padded, err := chacha20_(enc, nonce, ciphertext_) + if err != nil { return "", err } - unpaddedLen = binary.BigEndian.Uint16(padded[0:2]) - if unpaddedLen < uint16(MinPlaintextSize) || unpaddedLen > uint16(MaxPlaintextSize) || len(padded) != 2+calcPadding(int(unpaddedLen)) { + + unpaddedLen := binary.BigEndian.Uint16(padded[0:2]) + if unpaddedLen < uint16(MinPlaintextSize) || + unpaddedLen > uint16(MaxPlaintextSize) || len(padded) != 2+calcPadding(int(unpaddedLen)) { return "", errors.New("invalid padding") } - unpadded = padded[2 : unpaddedLen+2] + + unpadded := padded[2 : unpaddedLen+2] if len(unpadded) == 0 || len(unpadded) != int(unpaddedLen) { return "", errors.New("invalid padding") } + return string(unpadded), nil } @@ -160,14 +158,12 @@ func GenerateConversationKey(pub string, sk string) ([]byte, error) { } func chacha20_(key []byte, nonce []byte, message []byte) ([]byte, error) { - var ( - cipher *chacha20.Cipher - dst = make([]byte, len(message)) - err error - ) - if cipher, err = chacha20.NewUnauthenticatedCipher(key, nonce); err != nil { + cipher, err := chacha20.NewUnauthenticatedCipher(key, nonce) + if err != nil { return nil, err } + + dst := make([]byte, len(message)) cipher.XORKeyStream(dst, message) return dst, nil } @@ -183,46 +179,40 @@ func sha256Hmac(key []byte, ciphertext []byte, aad []byte) ([]byte, error) { } func messageKeys(conversationKey []byte, salt []byte) ([]byte, []byte, []byte, error) { - var ( - r io.Reader - enc []byte = make([]byte, 32) - nonce []byte = make([]byte, 12) - auth []byte = make([]byte, 32) - err error - ) if len(conversationKey) != 32 { return nil, nil, nil, errors.New("conversation key must be 32 bytes") } if len(salt) != 32 { return nil, nil, nil, errors.New("salt must be 32 bytes") } - r = hkdf.Expand(sha256.New, conversationKey, salt) - if _, err = io.ReadFull(r, enc); err != nil { + + r := hkdf.Expand(sha256.New, conversationKey, salt) + enc := make([]byte, 32) + if _, err := io.ReadFull(r, enc); err != nil { return nil, nil, nil, err } - if _, err = io.ReadFull(r, nonce); err != nil { + + nonce := make([]byte, 12) + if _, err := io.ReadFull(r, nonce); err != nil { return nil, nil, nil, err } - if _, err = io.ReadFull(r, auth); err != nil { + + auth := make([]byte, 32) + if _, err := io.ReadFull(r, auth); err != nil { return nil, nil, nil, err } + return enc, nonce, auth, nil } func pad(s string) ([]byte, error) { - var ( - sb []byte - sbLen int - padding int - result []byte - ) - sb = []byte(s) - sbLen = len(sb) + sb := []byte(s) + sbLen := len(sb) if sbLen < 1 || sbLen > MaxPlaintextSize { return nil, errors.New("plaintext should be between 1b and 64kB") } - padding = calcPadding(sbLen) - result = make([]byte, 2) + padding := calcPadding(sbLen) + result := make([]byte, 2) binary.BigEndian.PutUint16(result, uint16(sbLen)) result = append(result, sb...) result = append(result, make([]byte, padding-sbLen)...) @@ -230,14 +220,10 @@ func pad(s string) ([]byte, error) { } func calcPadding(sLen int) int { - var ( - nextPower int - chunk int - ) if sLen <= 32 { return 32 } - nextPower = 1 << int(math.Floor(math.Log2(float64(sLen-1)))+1) - chunk = int(math.Max(32, float64(nextPower/8))) + nextPower := 1 << int(math.Floor(math.Log2(float64(sLen-1)))+1) + chunk := int(math.Max(32, float64(nextPower/8))) return chunk * int(math.Floor(float64((sLen-1)/chunk))+1) }