diff --git a/watchtower/wtwire/init.go b/watchtower/wtwire/init.go index ff9018554..79a5fbf8b 100644 --- a/watchtower/wtwire/init.go +++ b/watchtower/wtwire/init.go @@ -1,6 +1,7 @@ package wtwire import ( + "fmt" "io" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -72,3 +73,67 @@ func (msg *Init) MaxPayloadLength(uint32) uint32 { // A compile-time constraint to ensure Init implements the Message interface. var _ Message = (*Init)(nil) + +// CheckRemoteInit performs basic validation of the remote party's Init message. +// This method checks that the remote Init's chain hash matches our advertised +// chain hash and that the remote Init does not contain any required feature +// bits that we don't understand. +func (msg *Init) CheckRemoteInit(remoteInit *Init, + featureNames map[lnwire.FeatureBit]string) error { + + // Check that the remote peer is on the same chain. + if msg.ChainHash != remoteInit.ChainHash { + return NewErrUnknownChainHash(remoteInit.ChainHash) + } + + remoteConnFeatures := lnwire.NewFeatureVector( + remoteInit.ConnFeatures, featureNames, + ) + + // Check that the remote peer doesn't have any required connection + // feature bits that we ourselves are unaware of. + unknownConnFeatures := remoteConnFeatures.UnknownRequiredFeatures() + if len(unknownConnFeatures) > 0 { + return NewErrUnknownRequiredFeatures(unknownConnFeatures...) + } + + return nil +} + +// ErrUnknownChainHash signals that the remote Init has a different chain hash +// from the one we advertised. +type ErrUnknownChainHash struct { + hash chainhash.Hash +} + +// NewErrUnknownChainHash creates an ErrUnknownChainHash using the remote Init's +// chain hash. +func NewErrUnknownChainHash(hash chainhash.Hash) *ErrUnknownChainHash { + return &ErrUnknownChainHash{hash} +} + +// Error returns a human-readable error displaying the unknown chain hash. +func (e *ErrUnknownChainHash) Error() string { + return fmt.Sprintf("remote init has unknown chain hash: %s", e.hash) +} + +// ErrUnknownRequiredFeatures signals that the remote Init has required feature +// bits that were unknown to us. +type ErrUnknownRequiredFeatures struct { + unknownFeatures []lnwire.FeatureBit +} + +// NewErrUnknownRequiredFeatures creates an ErrUnknownRequiredFeatures using the +// remote Init's required features that were unknown to us. +func NewErrUnknownRequiredFeatures( + unknownFeatures ...lnwire.FeatureBit) *ErrUnknownRequiredFeatures { + + return &ErrUnknownRequiredFeatures{unknownFeatures} +} + +// Error returns a human-readable error displaying the unknown required feature +// bits. +func (e *ErrUnknownRequiredFeatures) Error() string { + return fmt.Sprintf("remote init has unknown required features: %v", + e.unknownFeatures) +}