diff --git a/brontide/conn.go b/brontide/conn.go index eb7735486..e2f339eb2 100644 --- a/brontide/conn.go +++ b/brontide/conn.go @@ -47,9 +47,11 @@ func Dial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress) (*Conn, error // Initiate the handshake by sending the first act to the receiver. actOne, err := b.noise.GenActOne() if err != nil { + b.conn.Close() return nil, err } if _, err := conn.Write(actOne[:]); err != nil { + b.conn.Close() return nil, err } @@ -59,9 +61,11 @@ func Dial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress) (*Conn, error // secrecy. var actTwo [ActTwoSize]byte if _, err := io.ReadFull(conn, actTwo[:]); err != nil { + b.conn.Close() return nil, err } if err := b.noise.RecvActTwo(actTwo); err != nil { + b.conn.Close() return nil, err } @@ -69,9 +73,11 @@ func Dial(localPriv *btcec.PrivateKey, netAddr *lnwire.NetAddress) (*Conn, error // key and execute the final ECDH operation. actThree, err := b.noise.GenActThree() if err != nil { + b.conn.Close() return nil, err } if _, err := conn.Write(actThree[:]); err != nil { + b.conn.Close() return nil, err } diff --git a/brontide/listener.go b/brontide/listener.go index a8c0e03d0..478c791b2 100644 --- a/brontide/listener.go +++ b/brontide/listener.go @@ -64,9 +64,11 @@ func (l *Listener) Accept() (net.Conn, error) { // this portion will fail with a non-nil error. var actOne [ActOneSize]byte if _, err := io.ReadFull(conn, actOne[:]); err != nil { + brontideConn.conn.Close() return nil, err } if err := brontideConn.noise.RecvActOne(actOne); err != nil { + brontideConn.conn.Close() return nil, err } @@ -74,9 +76,11 @@ func (l *Listener) Accept() (net.Conn, error) { // key for the session along with an authenticating tag. actTwo, err := brontideConn.noise.GenActTwo() if err != nil { + brontideConn.conn.Close() return nil, err } if _, err := conn.Write(actTwo[:]); err != nil { + brontideConn.conn.Close() return nil, err } @@ -85,9 +89,11 @@ func (l *Listener) Accept() (net.Conn, error) { // sides have mutually authenticated each other. var actThree [ActThreeSize]byte if _, err := io.ReadFull(conn, actThree[:]); err != nil { + brontideConn.conn.Close() return nil, err } if err := brontideConn.noise.RecvActThree(actThree); err != nil { + brontideConn.conn.Close() return nil, err } diff --git a/brontide/noise.go b/brontide/noise.go index b78ba63ec..2ea986db5 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/binary" "errors" + "fmt" "io" "math" @@ -330,30 +331,34 @@ func NewBrontideMachine(initiator bool, localPub *btcec.PrivateKey, return &BrontideMachine{handshakeState: handshake} } -// TODO(roasbeef): add version bytes, paramterize in constructor above - const ( + // HandshakeVersion is the expected version of the brontide handshake. + // Any messages that carry a different version will cause the handshake + // to abort immediately. + HandshakeVersion = byte(0) + // ActOneSize is the size of the packet sent from initiator to - // responder in ActOne. The packet consists of an ephemeral key in - // compressed format, and a 16-byte poly1305 tag. + // responder in ActOne. The packet consists of a handshake version, an + // ephemeral key in compressed format, and a 16-byte poly1305 tag. // - // 33 + 16 - ActOneSize = 49 + // 1 + 33 + 16 + ActOneSize = 50 // ActTwoSize is the size the packet sent from responder to initiator - // in ActTwo. The packet consists of an ephemeral key in compressed - // format and a 16-byte poly1305 tag. + // in ActTwo. The packet consists of a handshake version, an ephemeral + // key in compressed format and a 16-byte poly1305 tag. // - // 33 + 16 - ActTwoSize = 49 + // 1 + 33 + 16 + ActTwoSize = 50 // ActThreeSize is the size of the packet sent from initiator to - // responder in ActThree. The packet consists of the initiators static - // key encrypted with strong forward secrecy and a 16-byte poly1035 + // responder in ActThree. The packet consists of a handshake version, + // the initiators static key encrypted with strong forward secrecy and + // a 16-byte poly1035 // tag. // - // 33 + 16 + 16 - ActThreeSize = 65 + // 1 + 33 + 16 + 16 + ActThreeSize = 66 ) // GenActOne generates the initial packet (act one) to be sent from initiator @@ -384,8 +389,9 @@ func (b *BrontideMachine) GenActOne() ([ActOneSize]byte, error) { authPayload := b.EncryptAndHash([]byte{}) - copy(actOne[:33], ephemeral) - copy(actOne[33:], authPayload) + actOne[0] = HandshakeVersion + copy(actOne[1:34], ephemeral) + copy(actOne[34:], authPayload) return actOne, nil } @@ -401,8 +407,15 @@ func (b *BrontideMachine) RecvActOne(actOne [ActOneSize]byte) error { p [16]byte ) - copy(e[:], actOne[:33]) - copy(p[:], actOne[33:]) + // If the handshake version is unknown, then the handshake fails + // immediately. + if actOne[0] != HandshakeVersion { + return fmt.Errorf("Invalid handshake version: %v, only %v is "+ + "valid", actOne[0], HandshakeVersion) + } + + copy(e[:], actOne[1:34]) + copy(p[:], actOne[34:]) // e b.remoteEphemeral, err = btcec.ParsePubKey(e[:], btcec.S256()) @@ -451,8 +464,9 @@ func (b *BrontideMachine) GenActTwo() ([ActTwoSize]byte, error) { authPayload := b.EncryptAndHash([]byte{}) - copy(actTwo[:33], ephemeral) - copy(actTwo[33:], authPayload) + actTwo[0] = HandshakeVersion + copy(actTwo[1:34], ephemeral) + copy(actTwo[34:], authPayload) return actTwo, nil } @@ -467,8 +481,15 @@ func (b *BrontideMachine) RecvActTwo(actTwo [ActTwoSize]byte) error { p [16]byte ) - copy(e[:], actTwo[:33]) - copy(p[:], actTwo[33:]) + // If the handshake version is unknown, then the handshake fails + // immediately. + if actTwo[0] != HandshakeVersion { + return fmt.Errorf("Invalid handshake version: %v, only %v is "+ + "valid", actTwo[0], HandshakeVersion) + } + + copy(e[:], actTwo[1:34]) + copy(p[:], actTwo[34:]) // e b.remoteEphemeral, err = btcec.ParsePubKey(e[:], btcec.S256()) @@ -506,8 +527,9 @@ func (b *BrontideMachine) GenActThree() ([ActThreeSize]byte, error) { authPayload := b.EncryptAndHash([]byte{}) - copy(actThree[:49], ciphertext) - copy(actThree[49:], authPayload) + actThree[0] = HandshakeVersion + copy(actThree[1:50], ciphertext) + copy(actThree[50:], authPayload) // With the final ECDH operation complete, derive the session sending // and receiving keys. @@ -527,8 +549,15 @@ func (b *BrontideMachine) RecvActThree(actThree [ActThreeSize]byte) error { p [16]byte ) - copy(s[:], actThree[:33+16]) - copy(p[:], actThree[33+16:]) + // If the handshake version is unknown, then the handshake fails + // immediately. + if actThree[0] != HandshakeVersion { + return fmt.Errorf("Invalid handshake version: %v, only %v is "+ + "valid", actThree[0], HandshakeVersion) + } + + copy(s[:], actThree[1:33+16+1]) + copy(p[:], actThree[33+16+1:]) // s remotePub, err := b.DecryptAndHash(s[:]) diff --git a/brontide/noise_test.go b/brontide/noise_test.go index 6d2771793..3c4e58414 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -46,16 +46,17 @@ func establishTestConnection() (net.Conn, net.Conn, error) { connChan := make(chan net.Conn) go func() { conn, err := Dial(remotePriv, netAddr) + errChan <- err connChan <- conn }() localConn, listenErr := listener.Accept() if listenErr != nil { - return nil, nil, err + return nil, nil, listenErr } - if dialErr := <-errChan; err != nil { + if dialErr := <-errChan; dialErr != nil { return nil, nil, dialErr } remoteConn := <-connChan