Merge pull request #4452 from yyforyongyu/add-connection-timeout

lnrpc+tor: add network connection timeout
This commit is contained in:
Wilmer Paulino 2020-09-16 12:28:29 -07:00 committed by GitHub
commit a5c5304c09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 994 additions and 724 deletions

View File

@ -10,6 +10,7 @@ import (
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tor"
)
// Conn is an implementation of net.Conn which enforces an authenticated key
@ -34,12 +35,12 @@ var _ net.Conn = (*Conn)(nil)
// public key. In the case of a handshake failure, the connection is closed and
// a non-nil error is returned.
func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
dialer func(string, string) (net.Conn, error)) (*Conn, error) {
timeout time.Duration, dialer tor.DialFunc) (*Conn, error) {
ipAddr := netAddr.Address.String()
var conn net.Conn
var err error
conn, err = dialer("tcp", ipAddr)
conn, err = dialer("tcp", ipAddr, timeout)
if err != nil {
return nil, err
}

View File

@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tor"
)
type maybeNetConn struct {
@ -66,7 +67,10 @@ func establishTestConnection() (net.Conn, net.Conn, func(), error) {
// successful.
remoteConnChan := make(chan maybeNetConn, 1)
go func() {
remoteConn, err := Dial(remoteKeyECDH, netAddr, net.Dial)
remoteConn, err := Dial(
remoteKeyECDH, netAddr,
tor.DefaultConnTimeout, net.DialTimeout,
)
remoteConnChan <- maybeNetConn{remoteConn, err}
}()
@ -196,7 +200,10 @@ func TestConcurrentHandshakes(t *testing.T) {
remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv}
go func() {
remoteConn, err := Dial(remoteKeyECDH, netAddr, net.Dial)
remoteConn, err := Dial(
remoteKeyECDH, netAddr,
tor.DefaultConnTimeout, net.DialTimeout,
)
connChan <- maybeNetConn{remoteConn, err}
}()

View File

@ -775,7 +775,10 @@ func initNeutrinoBackend(cfg *Config, chainDir string) (*neutrino.ChainService,
AddPeers: cfg.NeutrinoMode.AddPeers,
ConnectPeers: cfg.NeutrinoMode.ConnectPeers,
Dialer: func(addr net.Addr) (net.Conn, error) {
return cfg.net.Dial(addr.Network(), addr.String())
return cfg.net.Dial(
addr.Network(), addr.String(),
cfg.ConnectionTimeout,
)
},
NameResolver: func(host string) ([]net.IP, error) {
addrs, err := cfg.net.LookupHost(host)

View File

@ -271,7 +271,7 @@ func (s *server) ConnectPeer(nodePub *btcec.PublicKey, addrs []net.Addr) error {
// Attempt to connect to the peer using this full address. If
// we're unable to connect to them, then we'll try the next
// address in place of it.
err := s.ConnectToPeer(netAddr, true)
err := s.ConnectToPeer(netAddr, true, s.cfg.ConnectionTimeout)
// If we're already connected to this peer, then we don't
// consider this an error, so we'll exit here.

View File

@ -496,6 +496,14 @@ var connectCommand = cli.Command{
Category: "Peers",
Usage: "Connect to a remote lnd peer.",
ArgsUsage: "<pubkey>@host",
Description: `
Connect to a peer using its <pubkey> and host.
A custom timeout on the connection is supported. For instance, to timeout
the connection request in 30 seconds, use the following:
lncli connect <pubkey>@host --timeout 30s
`,
Flags: []cli.Flag{
cli.BoolFlag{
Name: "perm",
@ -503,6 +511,13 @@ var connectCommand = cli.Command{
"connect to the target peer.\n" +
" If not, the call will be synchronous.",
},
cli.DurationFlag{
Name: "timeout",
Usage: "The connection timeout value for current request. " +
"Valid uints are {ms, s, m, h}.\n" +
"If not set, the global connection " +
"timeout value (default to 120s) is used.",
},
},
Action: actionDecorator(connectPeer),
}
@ -524,8 +539,9 @@ func connectPeer(ctx *cli.Context) error {
Host: splitAddr[1],
}
req := &lnrpc.ConnectPeerRequest{
Addr: addr,
Perm: ctx.Bool("perm"),
Addr: addr,
Perm: ctx.Bool("perm"),
Timeout: uint64(ctx.Duration("timeout").Seconds()),
}
lnid, err := client.ConnectPeer(ctxb, req)

View File

@ -191,21 +191,22 @@ type Config struct {
// loadConfig function. We need to expose the 'raw' strings so the
// command line library can access them.
// Only the parsed net.Addrs should be used!
RawRPCListeners []string `long:"rpclisten" description:"Add an interface/port/socket to listen for RPC connections"`
RawRESTListeners []string `long:"restlisten" description:"Add an interface/port/socket to listen for REST connections"`
RawListeners []string `long:"listen" description:"Add an interface/port to listen for peer connections"`
RawExternalIPs []string `long:"externalip" description:"Add an ip:port to the list of local addresses we claim to listen on to peers. If a port is not specified, the default (9735) will be used regardless of other parameters"`
ExternalHosts []string `long:"externalhosts" description:"A set of hosts that should be periodically resolved to announce IPs for"`
RPCListeners []net.Addr
RESTListeners []net.Addr
RestCORS []string `long:"restcors" description:"Add an ip:port/hostname to allow cross origin access from. To allow all origins, set as \"*\"."`
Listeners []net.Addr
ExternalIPs []net.Addr
DisableListen bool `long:"nolisten" description:"Disable listening for incoming peer connections"`
DisableRest bool `long:"norest" description:"Disable REST API"`
NAT bool `long:"nat" description:"Toggle NAT traversal support (using either UPnP or NAT-PMP) to automatically advertise your external IP address to the network -- NOTE this does not support devices behind multiple NATs"`
MinBackoff time.Duration `long:"minbackoff" description:"Shortest backoff when reconnecting to persistent peers. Valid time units are {s, m, h}."`
MaxBackoff time.Duration `long:"maxbackoff" description:"Longest backoff when reconnecting to persistent peers. Valid time units are {s, m, h}."`
RawRPCListeners []string `long:"rpclisten" description:"Add an interface/port/socket to listen for RPC connections"`
RawRESTListeners []string `long:"restlisten" description:"Add an interface/port/socket to listen for REST connections"`
RawListeners []string `long:"listen" description:"Add an interface/port to listen for peer connections"`
RawExternalIPs []string `long:"externalip" description:"Add an ip:port to the list of local addresses we claim to listen on to peers. If a port is not specified, the default (9735) will be used regardless of other parameters"`
ExternalHosts []string `long:"externalhosts" description:"A set of hosts that should be periodically resolved to announce IPs for"`
RPCListeners []net.Addr
RESTListeners []net.Addr
RestCORS []string `long:"restcors" description:"Add an ip:port/hostname to allow cross origin access from. To allow all origins, set as \"*\"."`
Listeners []net.Addr
ExternalIPs []net.Addr
DisableListen bool `long:"nolisten" description:"Disable listening for incoming peer connections"`
DisableRest bool `long:"norest" description:"Disable REST API"`
NAT bool `long:"nat" description:"Toggle NAT traversal support (using either UPnP or NAT-PMP) to automatically advertise your external IP address to the network -- NOTE this does not support devices behind multiple NATs"`
MinBackoff time.Duration `long:"minbackoff" description:"Shortest backoff when reconnecting to persistent peers. Valid time units are {s, m, h}."`
MaxBackoff time.Duration `long:"maxbackoff" description:"Longest backoff when reconnecting to persistent peers. Valid time units are {s, m, h}."`
ConnectionTimeout time.Duration `long:"connectiontimeout" description:"The timeout value for network connections. Valid time units are {ms, s, m, h}."`
DebugLevel string `short:"d" long:"debuglevel" description:"Logging level for all subsystems {trace, debug, info, warn, error, critical} -- You may also specify <subsystem>=<level>,<subsystem2>=<level>,... to set the log level for individual subsystems -- Use show to list available subsystems"`
@ -376,6 +377,7 @@ func DefaultConfig() Config {
NoSeedBackup: defaultNoSeedBackup,
MinBackoff: defaultMinBackoff,
MaxBackoff: defaultMaxBackoff,
ConnectionTimeout: tor.DefaultConnTimeout,
SubRPCServers: &subRPCServerConfigs{
SignRPC: &signrpc.Config{},
RouterRPC: routerrpc.DefaultConfig(),

View File

@ -287,6 +287,10 @@ type DNSSeedBootstrapper struct {
// the network seed.
dnsSeeds [][2]string
net tor.Net
// timeout is the maximum amount of time a dial will wait for a connect to
// complete.
timeout time.Duration
}
// A compile time assertion to ensure that DNSSeedBootstrapper meets the
@ -300,8 +304,10 @@ var _ NetworkPeerBootstrapper = (*ChannelGraphBootstrapper)(nil)
// used as a fallback for manual TCP resolution in the case of an error
// receiving the UDP response. The second host should return a single A record
// with the IP address of the authoritative name server.
func NewDNSSeedBootstrapper(seeds [][2]string, net tor.Net) NetworkPeerBootstrapper {
return &DNSSeedBootstrapper{dnsSeeds: seeds, net: net}
func NewDNSSeedBootstrapper(
seeds [][2]string, net tor.Net,
timeout time.Duration) NetworkPeerBootstrapper {
return &DNSSeedBootstrapper{dnsSeeds: seeds, net: net, timeout: timeout}
}
// fallBackSRVLookup attempts to manually query for SRV records we need to
@ -327,7 +333,7 @@ func (d *DNSSeedBootstrapper) fallBackSRVLookup(soaShim string,
// Once we have the IP address, we'll establish a TCP connection using
// port 53.
dnsServer := net.JoinHostPort(addrs[0], "53")
conn, err := d.net.Dial("tcp", dnsServer)
conn, err := d.net.Dial("tcp", dnsServer, d.timeout)
if err != nil {
return nil, err
}
@ -389,7 +395,9 @@ search:
// obtain a random sample of the encoded public keys of nodes.
// We use the lndLookupSRV function for this task.
primarySeed := dnsSeedTuple[0]
_, addrs, err := d.net.LookupSRV("nodes", "tcp", primarySeed)
_, addrs, err := d.net.LookupSRV(
"nodes", "tcp", primarySeed, d.timeout,
)
if err != nil {
log.Tracef("Unable to lookup SRV records via "+
"primary seed (%v): %v", primarySeed, err)

File diff suppressed because it is too large Load Diff

View File

@ -969,6 +969,12 @@ message ConnectPeerRequest {
/* If set, the daemon will attempt to persistently connect to the target
* peer. Otherwise, the call will be synchronous. */
bool perm = 2;
/*
The connection timeout value (in seconds) for this request. It won't affect
other requests.
*/
uint64 timeout = 3;
}
message ConnectPeerResponse {
}

View File

@ -3223,6 +3223,11 @@
"type": "boolean",
"format": "boolean",
"description": "If set, the daemon will attempt to persistently connect to the target\npeer. Otherwise, the call will be synchronous."
},
"timeout": {
"type": "string",
"format": "uint64",
"description": "The connection timeout value (in seconds) for this request. It won't affect\nother requests."
}
}
},

View File

@ -14378,6 +14378,10 @@ var testsCases = []*testCase{
name: "maximum channel size",
test: testMaxChannelSize,
},
{
name: "connection timeout",
test: testNetworkConnectionTimeout,
},
}
// TestLightningNetworkDaemon performs a series of integration tests amongst a

View File

@ -221,3 +221,9 @@
<time> [ERR] RPCS: [/lnrpc.Lightning/BakeMacaroon]: invalid permission entity. supported actions are [read write generate], supported entities are [onchain offchain address message peers info invoices signer macaroon uri]
<time> [ERR] RPCS: [/lnrpc.Lightning/BakeMacaroon]: permission list cannot be empty. specify at least one action/entity pair. supported actions are [read write generate], supported entities are [onchain offchain address message peers info invoices signer macaroon uri]
<time> [ERR] RPCS: [/lnrpc.Lightning/DeleteMacaroonID]: the specified ID cannot be deleted
<time> [ERR] RPCS: [/lnrpc.Lightning/ConnectPeer]: dial tcp <ip>: i/o timeout
<time> [ERR] RPCS: [connectpeer]: error connecting to peer: dial tcp <ip>: i/o timeout
<time> [ERR] SRVR: Unable to connect to <hex>@<ip>: dial tcp <ip>: i/o timeout
<time> [ERR] RPCS: [/lnrpc.Lightning/ConnectPeer]: dial tcp <ip>: i/o timeout
<time> [ERR] RPCS: [connectpeer]: error connecting to peer: dial tcp <ip>: i/o timeout
<time> [ERR] SRVR: Unable to connect to <hex>@<ip>: dial tcp <ip>: i/o timeout

125
lntest/itest/network.go Normal file
View File

@ -0,0 +1,125 @@
// +build rpctest
package itest
import (
"context"
"fmt"
"strings"
"time"
"github.com/lightningnetwork/lnd"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntest"
"github.com/stretchr/testify/require"
)
// testNetworkConnectionTimeout checks that the connectiontimeout is taking
// effect. It creates a node with a small connection timeout value, and connects
// it to a non-routable IP address.
func testNetworkConnectionTimeout(net *lntest.NetworkHarness, t *harnessTest) {
var (
ctxt, _ = context.WithTimeout(
context.Background(), defaultTimeout,
)
// testPub is a random public key for testing only.
testPub = "0332bda7da70fefe4b6ab92f53b3c4f4ee7999" +
"f312284a8e89c8670bb3f67dbee2"
// testHost is a non-routable IP address. It's used to cause a
// connection timeout.
testHost = "10.255.255.255"
)
// First, test the global timeout settings.
// Create Carol with a connection timeout of 1 millisecond.
carol, err := net.NewNode("Carol", []string{"--connectiontimeout=1ms"})
if err != nil {
t.Fatalf("unable to create new node carol: %v", err)
}
defer shutdownAndAssert(net, t, carol)
// Try to connect Carol to a non-routable IP address, which should give
// us a timeout error.
req := &lnrpc.ConnectPeerRequest{
Addr: &lnrpc.LightningAddress{
Pubkey: testPub,
Host: testHost,
},
}
assertTimeoutError(ctxt, t, carol, req)
// Second, test timeout on the connect peer request.
// Create Dave with the default timeout setting.
dave, err := net.NewNode("Dave", nil)
if err != nil {
t.Fatalf("unable to create new node dave: %v", err)
}
defer shutdownAndAssert(net, t, dave)
// Try to connect Dave to a non-routable IP address, using a timeout
// value of 1ms, which should give us a timeout error immediately.
req = &lnrpc.ConnectPeerRequest{
Addr: &lnrpc.LightningAddress{
Pubkey: testPub,
Host: testHost,
},
Timeout: 1,
}
assertTimeoutError(ctxt, t, dave, req)
}
// assertTimeoutError asserts that a connection timeout error is raised. A
// context with a default timeout is used to make the request. If our customized
// connection timeout is less than the default, we won't see the request context
// times out, instead a network connection timeout will be returned.
func assertTimeoutError(ctxt context.Context, t *harnessTest,
node *lntest.HarnessNode, req *lnrpc.ConnectPeerRequest) {
t.t.Helper()
// Create a context with a timeout value.
ctxt, cancel := context.WithTimeout(ctxt, defaultTimeout)
defer cancel()
err := connect(node, ctxt, req)
// a DeadlineExceeded error will appear in the context if the above
// ctxtTimeout value is reached.
require.NoError(t.t, ctxt.Err(), "context time out")
// Check that the network returns a timeout error.
require.Containsf(
t.t, err.Error(), "i/o timeout",
"expected to get a timeout error, instead got: %v", err,
)
}
func connect(node *lntest.HarnessNode, ctxt context.Context,
req *lnrpc.ConnectPeerRequest) error {
syncTimeout := time.After(15 * time.Second)
ticker := time.NewTicker(time.Millisecond * 100)
defer ticker.Stop()
for {
select {
case <-ticker.C:
_, err := node.ConnectPeer(ctxt, req)
// If there's no error, return nil
if err == nil {
return err
}
// If the error is no ErrServerNotActive, return it.
// Otherwise, we will retry until timeout.
if !strings.Contains(err.Error(),
lnd.ErrServerNotActive.Error()) {
return err
}
case <-syncTimeout:
return fmt.Errorf("chain backend did not " +
"finish syncing")
}
}
return nil
}

View File

@ -226,7 +226,9 @@ func initAutoPilot(svr *server, cfg *lncfg.AutoPilot,
"address type %T", addr)
}
err := svr.ConnectToPeer(lnAddr, false)
err := svr.ConnectToPeer(
lnAddr, false, svr.cfg.ConnectionTimeout,
)
if err != nil {
// If we weren't able to connect to the
// peer at this address, then we'll move

View File

@ -1512,8 +1512,25 @@ func (r *rpcServer) ConnectPeer(ctx context.Context,
rpcsLog.Debugf("[connectpeer] requested connection to %x@%s",
peerAddr.IdentityKey.SerializeCompressed(), peerAddr.Address)
if err := r.server.ConnectToPeer(peerAddr, in.Perm); err != nil {
rpcsLog.Errorf("[connectpeer]: error connecting to peer: %v", err)
// By default, we will use the global connection timeout value.
timeout := r.cfg.ConnectionTimeout
// Check if the connection timeout is set. If set, we will use it in our
// request.
if in.Timeout != 0 {
timeout = time.Duration(in.Timeout) * time.Second
rpcsLog.Debugf(
"[connectpeer] connection timeout is set to %v",
timeout,
)
}
if err := r.server.ConnectToPeer(peerAddr,
in.Perm, timeout); err != nil {
rpcsLog.Errorf(
"[connectpeer]: error connecting to peer: %v", err,
)
return nil, err
}
@ -1980,8 +1997,8 @@ out:
// If a final channel open update is being sent, then
// we can break out of our recv loop as we no longer
// need to process any further updates.
switch update := fundingUpdate.Update.(type) {
case *lnrpc.OpenStatusUpdate_ChanOpen:
update, ok := fundingUpdate.Update.(*lnrpc.OpenStatusUpdate_ChanOpen)
if ok {
chanPoint := update.ChanOpen.ChannelPoint
txid, err := GetChanPointFundingTxid(chanPoint)
if err != nil {

View File

@ -158,6 +158,9 @@
; support devices behind multiple NATs.
; nat=true
; The timeout value for network connections in seconds, default to 120 seconds.
; Valid uints are {ms, s, m, h}.
; connectiontimeout=120s
; Debug logging level.
; Valid levels are {trace, debug, info, warn, error, critical}

View File

@ -68,6 +68,7 @@ import (
"github.com/lightningnetwork/lnd/watchtower/wtclient"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
)
const (
@ -325,11 +326,11 @@ func parseAddr(address string, netCfg tor.Net) (net.Addr, error) {
// noiseDial is a factory function which creates a connmgr compliant dialing
// function by returning a closure which includes the server's identity key.
func noiseDial(idKey keychain.SingleKeyECDH,
netCfg tor.Net) func(net.Addr) (net.Conn, error) {
netCfg tor.Net, timeout time.Duration) func(net.Addr) (net.Conn, error) {
return func(a net.Addr) (net.Conn, error) {
lnAddr := a.(*lnwire.NetAddress)
return brontide.Dial(idKey, lnAddr, netCfg.Dial)
return brontide.Dial(idKey, lnAddr, timeout, netCfg.Dial)
}
}
@ -1236,12 +1237,23 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
return nil, err
}
// authDial is the wrapper around the btrontide.Dial for the
// watchtower.
authDial := func(localKey keychain.SingleKeyECDH,
netAddr *lnwire.NetAddress,
dialer tor.DialFunc) (wtserver.Peer, error) {
return brontide.Dial(
localKey, netAddr, cfg.ConnectionTimeout, dialer,
)
}
s.towerClient, err = wtclient.New(&wtclient.Config{
Signer: cc.wallet.Cfg.Signer,
NewAddress: newSweepPkScriptGen(cc.wallet),
SecretKeyRing: s.cc.keyRing,
Dial: cfg.net.Dial,
AuthDial: wtclient.AuthDial,
AuthDial: authDial,
DB: towerClientDB,
Policy: policy,
ChainHash: *s.cfg.ActiveNetParams.GenesisHash,
@ -1332,8 +1344,10 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
OnAccept: s.InboundPeerConnected,
RetryDuration: time.Second * 5,
TargetOutbound: 100,
Dial: noiseDial(s.identityECDH, s.cfg.net),
OnConnection: s.OutboundPeerConnected,
Dial: noiseDial(
s.identityECDH, s.cfg.net, s.cfg.ConnectionTimeout,
),
OnConnection: s.OutboundPeerConnected,
})
if err != nil {
return nil, err
@ -1847,7 +1861,7 @@ func initNetworkBootstrappers(s *server) ([]discovery.NetworkPeerBootstrapper, e
"seeds: %v", dnsSeeds)
dnsBootStrapper := discovery.NewDNSSeedBootstrapper(
dnsSeeds, s.cfg.net,
dnsSeeds, s.cfg.net, s.cfg.ConnectionTimeout,
)
bootStrappers = append(bootStrappers, dnsBootStrapper)
}
@ -1970,7 +1984,10 @@ func (s *server) peerBootstrapper(numTargetPeers uint32,
// TODO(roasbeef): can do AS, subnet,
// country diversity, etc
errChan := make(chan error, 1)
s.connectToPeer(a, errChan)
s.connectToPeer(
a, errChan,
s.cfg.ConnectionTimeout,
)
select {
case err := <-errChan:
if err == nil {
@ -2072,7 +2089,9 @@ func (s *server) initialPeerBootstrap(ignore map[autopilot.NodeID]struct{},
defer wg.Done()
errChan := make(chan error, 1)
go s.connectToPeer(addr, errChan)
go s.connectToPeer(
addr, errChan, s.cfg.ConnectionTimeout,
)
// We'll only allow this connection attempt to
// take up to 3 seconds. This allows us to move
@ -3383,7 +3402,9 @@ type openChanReq struct {
// connection is established, or the initial handshake process fails.
//
// NOTE: This function is safe for concurrent access.
func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
func (s *server) ConnectToPeer(addr *lnwire.NetAddress,
perm bool, timeout time.Duration) error {
targetPub := string(addr.IdentityKey.SerializeCompressed())
// Acquire mutex, but use explicit unlocking instead of defer for
@ -3444,7 +3465,7 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
// the crypto negotiation breaks down, then return an error to the
// caller.
errChan := make(chan error, 1)
s.connectToPeer(addr, errChan)
s.connectToPeer(addr, errChan, timeout)
select {
case err := <-errChan:
@ -3457,8 +3478,12 @@ func (s *server) ConnectToPeer(addr *lnwire.NetAddress, perm bool) error {
// connectToPeer establishes a connection to a remote peer. errChan is used to
// notify the caller if the connection attempt has failed. Otherwise, it will be
// closed.
func (s *server) connectToPeer(addr *lnwire.NetAddress, errChan chan<- error) {
conn, err := brontide.Dial(s.identityECDH, addr, s.cfg.net.Dial)
func (s *server) connectToPeer(addr *lnwire.NetAddress,
errChan chan<- error, timeout time.Duration) {
conn, err := brontide.Dial(
s.identityECDH, addr, timeout, s.cfg.net.Dial,
)
if err != nil {
srvrLog.Errorf("Unable to connect to %v: %v", addr, err)
select {

View File

@ -1,19 +1,31 @@
package tor
import (
"context"
"errors"
"net"
"time"
)
// TODO: this interface and its implementations should ideally be moved
// elsewhere as they are not Tor-specific.
const (
// DefaultConnTimeout is the maximum amount of time a dial will wait for
// a connect to complete.
DefaultConnTimeout time.Duration = time.Second * 120
)
// DialFunc is a type defines the signature of a dialer used by our Net
// interface.
type DialFunc func(net, addr string, timeout time.Duration) (net.Conn, error)
// Net is an interface housing a Dial function and several DNS functions that
// allows us to abstract the implementations of these functions over different
// networks, e.g. clearnet, Tor net, etc.
type Net interface {
// Dial connects to the address on the named network.
Dial(network, address string) (net.Conn, error)
Dial(network, address string, timeout time.Duration) (net.Conn, error)
// LookupHost performs DNS resolution on a given host and returns its
// addresses.
@ -21,7 +33,8 @@ type Net interface {
// LookupSRV tries to resolve an SRV query of the given service,
// protocol, and domain name.
LookupSRV(service, proto, name string) (string, []*net.SRV, error)
LookupSRV(service, proto, name string,
timeout time.Duration) (string, []*net.SRV, error)
// ResolveTCPAddr resolves TCP addresses.
ResolveTCPAddr(network, address string) (*net.TCPAddr, error)
@ -32,8 +45,10 @@ type Net interface {
type ClearNet struct{}
// Dial on the regular network uses net.Dial
func (r *ClearNet) Dial(network, address string) (net.Conn, error) {
return net.Dial(network, address)
func (r *ClearNet) Dial(
network, address string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout(network, address, timeout)
}
// LookupHost for regular network uses the net.LookupHost function
@ -42,8 +57,14 @@ func (r *ClearNet) LookupHost(host string) ([]string, error) {
}
// LookupSRV for regular network uses net.LookupSRV function
func (r *ClearNet) LookupSRV(service, proto, name string) (string, []*net.SRV, error) {
return net.LookupSRV(service, proto, name)
func (r *ClearNet) LookupSRV(service, proto, name string,
timeout time.Duration) (string, []*net.SRV, error) {
// Create a context with a timeout value.
ctxt, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return net.DefaultResolver.LookupSRV(ctxt, service, proto, name)
}
// ResolveTCPAddr for regular network uses net.ResolveTCPAddr function
@ -71,13 +92,15 @@ type ProxyNet struct {
// Dial uses the Tor Dial function in order to establish connections through
// Tor. Since Tor only supports TCP connections, only TCP networks are allowed.
func (p *ProxyNet) Dial(network, address string) (net.Conn, error) {
func (p *ProxyNet) Dial(network, address string,
timeout time.Duration) (net.Conn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
default:
return nil, errors.New("cannot dial non-tcp network via Tor")
}
return Dial(address, p.SOCKS, p.StreamIsolation)
return Dial(address, p.SOCKS, p.StreamIsolation, timeout)
}
// LookupHost uses the Tor LookupHost function in order to resolve hosts over
@ -88,8 +111,13 @@ func (p *ProxyNet) LookupHost(host string) ([]string, error) {
// LookupSRV uses the Tor LookupSRV function in order to resolve SRV DNS queries
// over Tor.
func (p *ProxyNet) LookupSRV(service, proto, name string) (string, []*net.SRV, error) {
return LookupSRV(service, proto, name, p.SOCKS, p.DNS, p.StreamIsolation)
func (p *ProxyNet) LookupSRV(service, proto,
name string, timeout time.Duration) (string, []*net.SRV, error) {
return LookupSRV(
service, proto, name, p.SOCKS, p.DNS,
p.StreamIsolation, timeout,
)
}
// ResolveTCPAddr uses the Tor ResolveTCPAddr function in order to resolve TCP

View File

@ -6,6 +6,7 @@ import (
"fmt"
"net"
"strconv"
"time"
"github.com/btcsuite/btcd/connmgr"
"github.com/miekg/dns"
@ -54,8 +55,10 @@ func (c *proxyConn) RemoteAddr() net.Addr {
// Dial is a wrapper over the non-exported dial function that returns a wrapper
// around net.Conn in order to expose the actual remote address we're dialing,
// rather than the proxy's address.
func Dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) {
conn, err := dial(address, socksAddr, streamIsolation)
func Dial(address, socksAddr string, streamIsolation bool,
timeout time.Duration) (net.Conn, error) {
conn, err := dial(address, socksAddr, streamIsolation, timeout)
if err != nil {
return nil, err
}
@ -75,11 +78,13 @@ func Dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) {
}
// dial establishes a connection to the address via Tor's SOCKS proxy. Only TCP
// is supported over Tor. The final argument determines if we should force
// stream isolation for this new connection. If we do, then this means this new
// connection will use a fresh circuit, rather than possibly re-using an
// existing circuit.
func dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) {
// is supported over Tor. The argument streamIsolation determines if we should
// force stream isolation for this new connection. If we do, then this means
// this new connection will use a fresh circuit, rather than possibly re-using
// an existing circuit.
func dial(address, socksAddr string, streamIsolation bool,
timeout time.Duration) (net.Conn, error) {
// If we were requested to force stream isolation for this connection,
// we'll populate the authentication credentials with random data as
// Tor will create a new circuit for each set of credentials.
@ -97,7 +102,8 @@ func dial(address, socksAddr string, streamIsolation bool) (net.Conn, error) {
}
// Establish the connection through Tor's SOCKS proxy.
dialer, err := proxy.SOCKS5("tcp", socksAddr, auth, proxy.Direct)
proxyDialer := &net.Dialer{Timeout: timeout}
dialer, err := proxy.SOCKS5("tcp", socksAddr, auth, proxyDialer)
if err != nil {
return nil, err
}
@ -121,11 +127,12 @@ func LookupHost(host, socksAddr string) ([]string, error) {
// natively support SRV queries so we must route all SRV queries through the
// proxy by connecting directly to a DNS server and querying it. The DNS server
// must have TCP resolution enabled for the given port.
func LookupSRV(service, proto, name, socksAddr, dnsServer string,
streamIsolation bool) (string, []*net.SRV, error) {
func LookupSRV(service, proto, name, socksAddr,
dnsServer string, streamIsolation bool,
timeout time.Duration) (string, []*net.SRV, error) {
// Connect to the DNS server we'll be using to query SRV records.
conn, err := dial(dnsServer, socksAddr, streamIsolation)
conn, err := dial(dnsServer, socksAddr, streamIsolation, timeout)
if err != nil {
return "", nil, err
}

View File

@ -14,6 +14,7 @@ import (
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
@ -137,7 +138,7 @@ type Config struct {
// Dial connects to an addr using the specified net and returns the
// connection object.
Dial Dial
Dial tor.DialFunc
// AuthDialer establishes a brontide connection over an onion or clear
// network.

View File

@ -16,6 +16,7 @@ import (
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtclient"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
@ -84,7 +85,9 @@ func newMockNet(cb func(wtserver.Peer)) *mockNet {
}
}
func (m *mockNet) Dial(network string, address string) (net.Conn, error) {
func (m *mockNet) Dial(network string, address string,
timeout time.Duration) (net.Conn, error) {
return nil, nil
}
@ -100,8 +103,9 @@ func (m *mockNet) ResolveTCPAddr(network string, address string) (*net.TCPAddr,
panic("not implemented")
}
func (m *mockNet) AuthDial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error) {
func (m *mockNet) AuthDial(local keychain.SingleKeyECDH,
netAddr *lnwire.NetAddress,
dialer tor.DialFunc) (wtserver.Peer, error) {
localPk := local.PubKey()
localAddr := &net.TCPAddr{
@ -433,10 +437,8 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
clientDB := wtmock.NewClientDB()
clientCfg := &wtclient.Config{
Signer: signer,
Dial: func(string, string) (net.Conn, error) {
return nil, nil
},
Signer: signer,
Dial: mockNet.Dial,
DB: clientDB,
AuthDial: mockNet.AuthDial,
SecretKeyRing: wtmock.NewSecretKeyRing(),

View File

@ -4,9 +4,9 @@ import (
"net"
"github.com/btcsuite/btcd/btcec"
"github.com/lightningnetwork/lnd/brontide"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tor"
"github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtserver"
)
@ -95,22 +95,12 @@ type DB interface {
AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error
}
// Dial connects to an addr using the specified net and returns the connection
// object.
type Dial func(net, addr string) (net.Conn, error)
// AuthDialer connects to a remote node using an authenticated transport, such as
// brontide. The dialer argument is used to specify a resolver, which allows
// this method to be used over Tor or clear net connections.
type AuthDialer func(localKey keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error)
// AuthDial is the watchtower client's default method of dialing.
func AuthDial(localKey keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
dialer func(string, string) (net.Conn, error)) (wtserver.Peer, error) {
return brontide.Dial(localKey, netAddr, dialer)
}
type AuthDialer func(localKey keychain.SingleKeyECDH,
netAddr *lnwire.NetAddress,
dialer tor.DialFunc) (wtserver.Peer, error)
// ECDHKeyRing abstracts the ability to derive shared ECDH keys given a
// description of the derivation path of a private key.