From 2ce6228021dcad1e46b2558e3b58704584113559 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 26 Apr 2019 17:19:58 -0700 Subject: [PATCH 1/9] watchtower/wtmock/tower_db: move mock tower db to wtmock pkg --- watchtower/lookout/lookout_test.go | 3 +- watchtower/wtclient/client_test.go | 4 +- watchtower/wtdb/mock.go | 142 ------------------------- watchtower/wtmock/tower_db.go | 162 +++++++++++++++++++++++++++++ watchtower/wtserver/server_test.go | 6 +- 5 files changed, 168 insertions(+), 149 deletions(-) delete mode 100644 watchtower/wtdb/mock.go create mode 100644 watchtower/wtmock/tower_db.go diff --git a/watchtower/lookout/lookout_test.go b/watchtower/lookout/lookout_test.go index 4232791d5..fb70a9618 100644 --- a/watchtower/lookout/lookout_test.go +++ b/watchtower/lookout/lookout_test.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/lookout" "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) @@ -66,7 +67,7 @@ func makeAddrSlice(size int) []byte { } func TestLookoutBreachMatching(t *testing.T) { - db := wtdb.NewMockDB() + db := wtmock.NewTowerDB() // Initialize an mock backend to feed the lookout blocks. backend := lookout.NewMockBackend() diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 68c34e0e1..86811bf02 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -369,7 +369,7 @@ type testHarness struct { clientDB *wtmock.ClientDB clientCfg *wtclient.Config client wtclient.Client - serverDB *wtdb.MockDB + serverDB *wtmock.TowerDB serverCfg *wtserver.Config server *wtserver.Server net *mockNet @@ -406,7 +406,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { } const timeout = 200 * time.Millisecond - serverDB := wtdb.NewMockDB() + serverDB := wtmock.NewTowerDB() serverCfg := &wtserver.Config{ DB: serverDB, diff --git a/watchtower/wtdb/mock.go b/watchtower/wtdb/mock.go deleted file mode 100644 index 902303881..000000000 --- a/watchtower/wtdb/mock.go +++ /dev/null @@ -1,142 +0,0 @@ -// +build dev - -package wtdb - -import ( - "sync" - - "github.com/lightningnetwork/lnd/chainntnfs" -) - -type MockDB struct { - mu sync.Mutex - lastEpoch *chainntnfs.BlockEpoch - sessions map[SessionID]*SessionInfo - blobs map[BreachHint]map[SessionID]*SessionStateUpdate -} - -func NewMockDB() *MockDB { - return &MockDB{ - sessions: make(map[SessionID]*SessionInfo), - blobs: make(map[BreachHint]map[SessionID]*SessionStateUpdate), - } -} - -func (db *MockDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) { - db.mu.Lock() - defer db.mu.Unlock() - - info, ok := db.sessions[update.ID] - if !ok { - return 0, ErrSessionNotFound - } - - err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied) - if err != nil { - return info.LastApplied, err - } - - sessionsToUpdates, ok := db.blobs[update.Hint] - if !ok { - sessionsToUpdates = make(map[SessionID]*SessionStateUpdate) - db.blobs[update.Hint] = sessionsToUpdates - } - sessionsToUpdates[update.ID] = update - - return info.LastApplied, nil -} - -func (db *MockDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) { - db.mu.Lock() - defer db.mu.Unlock() - - if info, ok := db.sessions[*id]; ok { - return info, nil - } - - return nil, ErrSessionNotFound -} - -func (db *MockDB) InsertSessionInfo(info *SessionInfo) error { - db.mu.Lock() - defer db.mu.Unlock() - - dbInfo, ok := db.sessions[info.ID] - if ok && dbInfo.LastApplied > 0 { - return ErrSessionAlreadyExists - } - - db.sessions[info.ID] = info - - return nil -} - -func (db *MockDB) DeleteSession(target SessionID) error { - db.mu.Lock() - defer db.mu.Unlock() - - // Fail if the session doesn't exit. - if _, ok := db.sessions[target]; !ok { - return ErrSessionNotFound - } - - // Remove the target session. - delete(db.sessions, target) - - // Remove the state updates for any blobs stored under the target - // session identifier. - for hint, sessionUpdates := range db.blobs { - delete(sessionUpdates, target) - - //If this was the last state update, we can also remove the hint - //that would map to an empty set. - if len(sessionUpdates) == 0 { - delete(db.blobs, hint) - } - } - - return nil -} - -func (db *MockDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) { - db.mu.Lock() - defer db.mu.Unlock() - - return db.lastEpoch, nil -} - -func (db *MockDB) QueryMatches(breachHints []BreachHint) ([]Match, error) { - db.mu.Lock() - defer db.mu.Unlock() - - var matches []Match - for _, hint := range breachHints { - sessionsToUpdates, ok := db.blobs[hint] - if !ok { - continue - } - - for id, update := range sessionsToUpdates { - info, ok := db.sessions[id] - if !ok { - panic("session not found") - } - - match := Match{ - ID: id, - SeqNum: update.SeqNum, - Hint: hint, - EncryptedBlob: update.EncryptedBlob, - SessionInfo: info, - } - matches = append(matches, match) - } - } - - return matches, nil -} - -func (db *MockDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error { - db.lastEpoch = epoch - return nil -} diff --git a/watchtower/wtmock/tower_db.go b/watchtower/wtmock/tower_db.go new file mode 100644 index 000000000..403d61e30 --- /dev/null +++ b/watchtower/wtmock/tower_db.go @@ -0,0 +1,162 @@ +package wtmock + +import ( + "sync" + + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +// TowerDB is a mock, in-memory implementation of a watchtower.DB. +type TowerDB struct { + mu sync.Mutex + lastEpoch *chainntnfs.BlockEpoch + sessions map[wtdb.SessionID]*wtdb.SessionInfo + blobs map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate +} + +// NewTowerDB initializes a fresh mock TowerDB. +func NewTowerDB() *TowerDB { + return &TowerDB{ + sessions: make(map[wtdb.SessionID]*wtdb.SessionInfo), + blobs: make(map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate), + } +} + +// InsertStateUpdate stores an update sent by the client after validating that +// the update is well-formed in the context of other updates sent for the same +// session. This include verifying that the sequence number is incremented +// properly and the last applied values echoed by the client are sane. +func (db *TowerDB) InsertStateUpdate(update *wtdb.SessionStateUpdate) (uint16, error) { + db.mu.Lock() + defer db.mu.Unlock() + + info, ok := db.sessions[update.ID] + if !ok { + return 0, wtdb.ErrSessionNotFound + } + + err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied) + if err != nil { + return info.LastApplied, err + } + + sessionsToUpdates, ok := db.blobs[update.Hint] + if !ok { + sessionsToUpdates = make(map[wtdb.SessionID]*wtdb.SessionStateUpdate) + db.blobs[update.Hint] = sessionsToUpdates + } + sessionsToUpdates[update.ID] = update + + return info.LastApplied, nil +} + +// GetSessionInfo retrieves the session for the passed session id. An error is +// returned if the session could not be found. +func (db *TowerDB) GetSessionInfo(id *wtdb.SessionID) (*wtdb.SessionInfo, error) { + db.mu.Lock() + defer db.mu.Unlock() + + if info, ok := db.sessions[*id]; ok { + return info, nil + } + + return nil, wtdb.ErrSessionNotFound +} + +// InsertSessionInfo records a negotiated session in the tower database. An +// error is returned if the session already exists. +func (db *TowerDB) InsertSessionInfo(info *wtdb.SessionInfo) error { + db.mu.Lock() + defer db.mu.Unlock() + + dbInfo, ok := db.sessions[info.ID] + if ok && dbInfo.LastApplied > 0 { + return wtdb.ErrSessionAlreadyExists + } + + db.sessions[info.ID] = info + + return nil +} + +// DeleteSession removes all data associated with a particular session id from +// the tower's database. +func (db *TowerDB) DeleteSession(target wtdb.SessionID) error { + db.mu.Lock() + defer db.mu.Unlock() + + // Fail if the session doesn't exit. + if _, ok := db.sessions[target]; !ok { + return wtdb.ErrSessionNotFound + } + + // Remove the target session. + delete(db.sessions, target) + + // Remove the state updates for any blobs stored under the target + // session identifier. + for hint, sessionUpdates := range db.blobs { + delete(sessionUpdates, target) + + // If this was the last state update, we can also remove the + // hint that would map to an empty set. + if len(sessionUpdates) == 0 { + delete(db.blobs, hint) + } + } + + return nil +} + +// QueryMatches searches against all known state updates for any that match the +// passed breachHints. More than one Match will be returned for a given hint if +// they exist in the database. +func (db *TowerDB) QueryMatches( + breachHints []wtdb.BreachHint) ([]wtdb.Match, error) { + + db.mu.Lock() + defer db.mu.Unlock() + + var matches []wtdb.Match + for _, hint := range breachHints { + sessionsToUpdates, ok := db.blobs[hint] + if !ok { + continue + } + + for id, update := range sessionsToUpdates { + info, ok := db.sessions[id] + if !ok { + panic("session not found") + } + + match := wtdb.Match{ + ID: id, + SeqNum: update.SeqNum, + Hint: hint, + EncryptedBlob: update.EncryptedBlob, + SessionInfo: info, + } + matches = append(matches, match) + } + } + + return matches, nil +} + +// SetLookoutTip stores the provided epoch as the latest lookout tip epoch in +// the tower database. +func (db *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error { + db.lastEpoch = epoch + return nil +} + +// GetLookoutTip retrieves the current lookout tip block epoch from the tower +// database. +func (db *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) { + db.mu.Lock() + defer db.mu.Unlock() + + return db.lastEpoch, nil +} diff --git a/watchtower/wtserver/server_test.go b/watchtower/wtserver/server_test.go index 6d99180fd..c3ae2a33e 100644 --- a/watchtower/wtserver/server_test.go +++ b/watchtower/wtserver/server_test.go @@ -1,5 +1,3 @@ -// +build dev - package wtserver_test import ( @@ -53,7 +51,7 @@ func initServer(t *testing.T, db wtserver.DB, t.Helper() if db == nil { - db = wtdb.NewMockDB() + db = wtmock.NewTowerDB() } s, err := wtserver.New(&wtserver.Config{ @@ -687,7 +685,7 @@ func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) { // checking that the proper error is returned when the session doesn't exist and // that a successful deletion does not disrupt other sessions. func TestServerDeleteSession(t *testing.T) { - db := wtdb.NewMockDB() + db := wtmock.NewTowerDB() localPub := randPubKey(t) From 678ca9feff0f502069e22c126cf9155a393b2711 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 26 Apr 2019 17:20:23 -0700 Subject: [PATCH 2/9] channeldb/codec: add NewUnknownElementType constructor --- channeldb/codec.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/channeldb/codec.go b/channeldb/codec.go index f491a8c15..1da362dd7 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -51,6 +51,12 @@ type UnknownElementType struct { element interface{} } +// NewUnknownElementType creates a new UnknownElementType error from the passed +// method name and element. +func NewUnknownElementType(method string, el interface{}) UnknownElementType { + return UnknownElementType{method: method, element: el} +} + // Error returns the name of the method that encountered the error, as well as // the type that was unsupported. func (e UnknownElementType) Error() string { From dccef4c8bf28f2c92fa5c82e5f677864fad0a35a Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 26 Apr 2019 17:20:36 -0700 Subject: [PATCH 3/9] watchtower/wtdb/codec: import channeldb code for extension --- watchtower/wtdb/codec.go | 143 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 watchtower/wtdb/codec.go diff --git a/watchtower/wtdb/codec.go b/watchtower/wtdb/codec.go new file mode 100644 index 000000000..2fd30196f --- /dev/null +++ b/watchtower/wtdb/codec.go @@ -0,0 +1,143 @@ +package wtdb + +import ( + "io" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" +) + +// UnknownElementType is an alias for channeldb.UnknownElementType. +type UnknownElementType = channeldb.UnknownElementType + +// ReadElement deserializes a single element from the provided io.Reader. +func ReadElement(r io.Reader, element interface{}) error { + err := channeldb.ReadElement(r, element) + switch { + + // Known to channeldb codec. + case err == nil: + return nil + + // Fail if error is not UnknownElementType. + case err != nil: + if _, ok := err.(UnknownElementType); !ok { + return err + } + } + + // Process any wtdb-specific extensions to the codec. + switch e := element.(type) { + + case *SessionID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *BreachHint: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wtpolicy.Policy: + var ( + blobType uint16 + sweepFeeRate uint64 + ) + err := channeldb.ReadElements(r, + &blobType, + &e.MaxUpdates, + &e.RewardBase, + &e.RewardRate, + &sweepFeeRate, + ) + if err != nil { + return err + } + + e.BlobType = blob.Type(blobType) + e.SweepFeeRate = lnwallet.SatPerKWeight(sweepFeeRate) + + // Type is still unknown to wtdb extensions, fail. + default: + return channeldb.NewUnknownElementType( + "ReadElement", element, + ) + } + + return nil +} + +// WriteElement serializes a single element into the provided io.Writer. +func WriteElement(w io.Writer, element interface{}) error { + err := channeldb.WriteElement(w, element) + switch { + + // Known to channeldb codec. + case err == nil: + return nil + + // Fail if error is not UnknownElementType. + case err != nil: + if _, ok := err.(UnknownElementType); !ok { + return err + } + } + + // Process any wtdb-specific extensions to the codec. + switch e := element.(type) { + + case SessionID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case BreachHint: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case wtpolicy.Policy: + return channeldb.WriteElements(w, + uint16(e.BlobType), + e.MaxUpdates, + e.RewardBase, + e.RewardRate, + uint64(e.SweepFeeRate), + ) + + // Type is still unknown to wtdb extensions, fail. + default: + return channeldb.NewUnknownElementType( + "WriteElement", element, + ) + } + + return nil +} + +// WriteElements serializes a variadic list of elements into the given +// io.Writer. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + if err := WriteElement(w, element); err != nil { + return err + } + } + + return nil +} + +// ReadElements deserializes the provided io.Reader into a variadic list of +// target elements. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + if err := ReadElement(r, element); err != nil { + return err + } + } + + return nil +} From 7ba197c6a7f41767e7d2caa913350fcfff506ad8 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 26 Apr 2019 17:20:52 -0700 Subject: [PATCH 4/9] watchtower/wtdb: add encode/decode to session info + updates --- watchtower/wtdb/session_info.go | 23 +++++++++++++++++++++++ watchtower/wtdb/session_state_update.go | 24 ++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/watchtower/wtdb/session_info.go b/watchtower/wtdb/session_info.go index 1c7b7f0ff..f4acf764d 100644 --- a/watchtower/wtdb/session_info.go +++ b/watchtower/wtdb/session_info.go @@ -2,6 +2,7 @@ package wtdb import ( "errors" + "io" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) @@ -59,6 +60,28 @@ type SessionInfo struct { // TODO(conner): store client metrics, DOS score, etc } +// Encode serializes the session info to the given io.Writer. +func (s *SessionInfo) Encode(w io.Writer) error { + return WriteElements(w, + s.ID, + s.Policy, + s.LastApplied, + s.ClientLastApplied, + s.RewardAddress, + ) +} + +// Decode deserializes the session infor from the given io.Reader. +func (s *SessionInfo) Decode(r io.Reader) error { + return ReadElements(r, + &s.ID, + &s.Policy, + &s.LastApplied, + &s.ClientLastApplied, + &s.RewardAddress, + ) +} + // AcceptUpdateSequence validates that a state update's sequence number and last // applied are valid given our past history with the client. These checks ensure // that clients are properly in sync and following the update protocol properly. diff --git a/watchtower/wtdb/session_state_update.go b/watchtower/wtdb/session_state_update.go index 7711b10f2..de75f6aee 100644 --- a/watchtower/wtdb/session_state_update.go +++ b/watchtower/wtdb/session_state_update.go @@ -1,5 +1,7 @@ package wtdb +import "io" + // SessionStateUpdate holds a state update sent by a client along with its // SessionID. type SessionStateUpdate struct { @@ -21,3 +23,25 @@ type SessionStateUpdate struct { // hint is braodcast. EncryptedBlob []byte } + +// Encode serializes the state update into the provided io.Writer. +func (u *SessionStateUpdate) Encode(w io.Writer) error { + return WriteElements(w, + u.ID, + u.SeqNum, + u.LastApplied, + u.Hint, + u.EncryptedBlob, + ) +} + +// Decode deserializes the target state update from the provided io.Reader. +func (u *SessionStateUpdate) Decode(r io.Reader) error { + return ReadElements(r, + &u.ID, + &u.SeqNum, + &u.LastApplied, + &u.Hint, + &u.EncryptedBlob, + ) +} From a36397e21a4a95c19aafc8b9c99d4924bd538b15 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 26 Apr 2019 17:21:05 -0700 Subject: [PATCH 5/9] watchtower/wtdb/codec_test: encode/decode quick checks --- watchtower/wtdb/codec_test.go | 86 +++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 watchtower/wtdb/codec_test.go diff --git a/watchtower/wtdb/codec_test.go b/watchtower/wtdb/codec_test.go new file mode 100644 index 000000000..948ec4ee4 --- /dev/null +++ b/watchtower/wtdb/codec_test.go @@ -0,0 +1,86 @@ +package wtdb_test + +import ( + "bytes" + "io" + "reflect" + "testing" + "testing/quick" + + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +// dbObject is abstract object support encoding and decoding. +type dbObject interface { + Encode(io.Writer) error + Decode(io.Reader) error +} + +// TestCodec serializes and deserializes wtdb objects in order to test that that +// the codec understands all of the required field types. The test also asserts +// that decoding an object into another results in an equivalent object. +func TestCodec(t *testing.T) { + mainScenario := func(obj dbObject) bool { + // Ensure encoding the object succeeds. + var b bytes.Buffer + err := obj.Encode(&b) + if err != nil { + t.Fatalf("unable to encode: %v", err) + return false + } + + var obj2 dbObject + switch obj.(type) { + case *wtdb.SessionInfo: + obj2 = &wtdb.SessionInfo{} + case *wtdb.SessionStateUpdate: + obj2 = &wtdb.SessionStateUpdate{} + default: + t.Fatalf("unknown type: %T", obj) + return false + } + + // Ensure decoding the object succeeds. + err = obj2.Decode(bytes.NewReader(b.Bytes())) + if err != nil { + t.Fatalf("unable to decode: %v", err) + return false + } + + // Assert the original and decoded object match. + if !reflect.DeepEqual(obj, obj2) { + t.Fatalf("encode/decode mismatch, want: %v, "+ + "got: %v", obj, obj2) + return false + } + + return true + } + + tests := []struct { + name string + scenario interface{} + }{ + { + name: "SessionInfo", + scenario: func(obj wtdb.SessionInfo) bool { + return mainScenario(&obj) + }, + }, + { + name: "SessionStateUpdate", + scenario: func(obj wtdb.SessionStateUpdate) bool { + return mainScenario(&obj) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := quick.Check(test.scenario, nil); err != nil { + t.Fatalf("fuzz checks for msg=%s failed: %v", + test.name, err) + } + }) + } +} From c99d1313fe89387e38381a1068217e85b3beede6 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 26 Apr 2019 17:21:20 -0700 Subject: [PATCH 6/9] watchtower/wtdb/log: add WTDB logs --- watchtower/wtdb/log.go | 45 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 watchtower/wtdb/log.go diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go new file mode 100644 index 000000000..b4b30f459 --- /dev/null +++ b/watchtower/wtdb/log.go @@ -0,0 +1,45 @@ +package wtdb + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger("WTDB", nil)) +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} + +// logClosure is used to provide a closure over expensive logging operations so +// don't have to be performed when the logging level doesn't warrant it. +type logClosure func() string + +// String invokes the underlying function and returns the result. +func (c logClosure) String() string { + return c() +} + +// newLogClosure returns a new closure over a function that returns a string +// which itself provides a Stringer interface so that it can be used with the +// logging system. +func newLogClosure(c func() string) logClosure { + return logClosure(c) +} From 3ef2a3673338ace4db3522151e80b427b822fb50 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 26 Apr 2019 17:21:35 -0700 Subject: [PATCH 7/9] watchtower/wtdb/tower_db: add TowerDB and db versioning --- watchtower/wtdb/tower_db.go | 733 ++++++++++++++++++++++++++++++++++++ watchtower/wtdb/version.go | 84 +++++ 2 files changed, 817 insertions(+) create mode 100644 watchtower/wtdb/tower_db.go create mode 100644 watchtower/wtdb/version.go diff --git a/watchtower/wtdb/tower_db.go b/watchtower/wtdb/tower_db.go new file mode 100644 index 000000000..0bcd271c0 --- /dev/null +++ b/watchtower/wtdb/tower_db.go @@ -0,0 +1,733 @@ +package wtdb + +import ( + "bytes" + "encoding/binary" + "errors" + "os" + "path/filepath" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" +) + +const ( + // dbName is the filename of tower database. + dbName = "watchtower.db" + + // dbFilePermission requests read+write access to the db file. + dbFilePermission = 0600 +) + +var ( + // sessionsBkt is a bucket containing all negotiated client sessions. + // session id -> session + sessionsBkt = []byte("sessions-bucket") + + // updatesBkt is a bucket containing all state updates sent by clients. + // The updates are further bucketed by session id to prevent clients + // from overwrite each other. + // hint => session id -> update + updatesBkt = []byte("updates-bucket") + + // updateIndexBkt is a bucket that indexes all state updates by their + // overarching session id. This allows for efficient lookup of updates + // by their session id, which is currently used to aide deletion + // performance. + // session id => hint1 -> []byte{} + // => hint2 -> []byte{} + updateIndexBkt = []byte("update-index-bucket") + + // lookoutTipBkt is a bucket containing the last block epoch processed + // by the lookout subsystem. It has one key, lookoutTipKey. + // lookoutTipKey -> block epoch + lookoutTipBkt = []byte("lookout-tip-bucket") + + // lookoutTipKey is a static key used to retrieve lookout tip's block + // epoch from the lookoutTipBkt. + lookoutTipKey = []byte("lookout-tip") + + // metadataBkt stores all the meta information concerning the state of + // the database. + metadataBkt = []byte("metadata-bucket") + + // dbVersionKey is a static key used to retrieve the database version + // number from the metadataBkt. + dbVersionKey = []byte("version") + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("tower db not initialized") + + // ErrNoDBVersion signals that the database contains no version info. + ErrNoDBVersion = errors.New("tower db has no version") + + // ErrNoSessionHintIndex signals that an active session does not have an + // initialized index for tracking its own state updates. + ErrNoSessionHintIndex = errors.New("session hint index missing") + + byteOrder = binary.BigEndian +) + +// TowerDB is single database providing a persistent storage engine for the +// wtserver and lookout subsystems. +type TowerDB struct { + db *bbolt.DB + dbPath string +} + +// OpenTowerDB opens the tower database given the path to the database's +// directory. If no such database exists, this method will initialize a fresh +// one using the latest version number and bucket structure. If a database +// exists but has a lower version number than the current version, any necessary +// migrations will be applied before returning. Any attempt to open a database +// with a version number higher that the latest version will fail to prevent +// accidental reversion. +func OpenTowerDB(dbPath string) (*TowerDB, error) { + path := filepath.Join(dbPath, dbName) + + // If the database file doesn't exist, this indicates we much initialize + // a fresh database with the latest version. + firstInit := !fileExists(path) + if firstInit { + // Ensure all parent directories are initialized. + err := os.MkdirAll(dbPath, 0700) + if err != nil { + return nil, err + } + } + + bdb, err := bbolt.Open(path, dbFilePermission, nil) + if err != nil { + return nil, err + } + + // If the file existed previously, we'll now check to see that the + // metadata bucket is properly initialized. It could be the case that + // the database was created, but we failed to actually populate any + // metadata. If the metadata bucket does not actually exist, we'll + // set firstInit to true so that we can treat is initialize the bucket. + if !firstInit { + var metadataExists bool + err = bdb.View(func(tx *bbolt.Tx) error { + metadataExists = tx.Bucket(metadataBkt) != nil + return nil + }) + if err != nil { + return nil, err + } + + if !metadataExists { + firstInit = true + } + } + + towerDB := &TowerDB{ + db: bdb, + dbPath: dbPath, + } + + if firstInit { + // If the database has not yet been created, we'll initialize + // the database version with the latest known version. + err = towerDB.db.Update(func(tx *bbolt.Tx) error { + return initDBVersion(tx, getLatestDBVersion(dbVersions)) + }) + if err != nil { + bdb.Close() + return nil, err + } + } else { + // Otherwise, ensure that any migrations are applied to ensure + // the data is in the format expected by the latest version. + err = towerDB.syncVersions(dbVersions) + if err != nil { + bdb.Close() + return nil, err + } + } + + // Now that the database version fully consistent with our latest known + // version, ensure that all top-level buckets known to this version are + // initialized. This allows us to assume their presence throughout all + // operations. If an known top-level bucket is expected to exist but is + // missing, this will trigger a ErrUninitializedDB error. + err = towerDB.db.Update(initTowerDBBuckets) + if err != nil { + bdb.Close() + return nil, err + } + + return towerDB, nil +} + +// fileExists returns true if the file exists, and false otherwise. +func fileExists(path string) bool { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return false + } + } + + return true +} + +// initTowerDBBuckets creates all top-level buckets required to handle database +// operations required by the latest version. +func initTowerDBBuckets(tx *bbolt.Tx) error { + buckets := [][]byte{ + sessionsBkt, + updateIndexBkt, + updatesBkt, + lookoutTipBkt, + } + + for _, bucket := range buckets { + _, err := tx.CreateBucketIfNotExists(bucket) + if err != nil { + return err + } + } + + return nil +} + +// syncVersions ensures the database version is consistent with the highest +// known database version, applying any migrations that have not been made. If +// the highest known version number is lower than the database's version, this +// method will fail to prevent accidental reversions. +func (t *TowerDB) syncVersions(versions []version) error { + curVersion, err := t.Version() + if err != nil { + return err + } + + latestVersion := getLatestDBVersion(versions) + switch { + + // Current version is higher than any known version, fail to prevent + // reversion. + case curVersion > latestVersion: + return channeldb.ErrDBReversion + + // Current version matches highest known version, nothing to do. + case curVersion == latestVersion: + return nil + } + + // Otherwise, apply any migrations in order to bring the database + // version up to the highest known version. + updates := getMigrations(versions, curVersion) + return t.db.Update(func(tx *bbolt.Tx) error { + for _, update := range updates { + if update.migration == nil { + continue + } + + log.Infof("Applying migration #%d", update.number) + + err := update.migration(tx) + if err != nil { + log.Errorf("Unable to apply migration #%d: %v", + err) + return err + } + } + + return putDBVersion(tx, latestVersion) + }) +} + +// Version returns the database's current version number. +func (t *TowerDB) Version() (uint32, error) { + var version uint32 + err := t.db.View(func(tx *bbolt.Tx) error { + var err error + version, err = getDBVersion(tx) + return err + }) + if err != nil { + return 0, err + } + + return version, nil +} + +// Close closes the underlying database. +func (t *TowerDB) Close() error { + return t.db.Close() +} + +// GetSessionInfo retrieves the session for the passed session id. An error is +// returned if the session could not be found. +func (t *TowerDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) { + var session *SessionInfo + err := t.db.View(func(tx *bbolt.Tx) error { + sessions := tx.Bucket(sessionsBkt) + if sessions == nil { + return ErrUninitializedDB + } + + var err error + session, err = getSession(sessions, id[:]) + return err + }) + if err != nil { + return nil, err + } + + return session, nil +} + +// InsertSessionInfo records a negotiated session in the tower database. An +// error is returned if the session already exists. +func (t *TowerDB) InsertSessionInfo(session *SessionInfo) error { + return t.db.Update(func(tx *bbolt.Tx) error { + sessions := tx.Bucket(sessionsBkt) + if sessions == nil { + return ErrUninitializedDB + } + + updateIndex := tx.Bucket(updateIndexBkt) + if updateIndex == nil { + return ErrUninitializedDB + } + + dbSession, err := getSession(sessions, session.ID[:]) + switch { + case err == ErrSessionNotFound: + // proceed. + + case err != nil: + return err + + case dbSession.LastApplied > 0: + return ErrSessionAlreadyExists + } + + err = putSession(sessions, session) + if err != nil { + return err + } + + // Initialize the session-hint index which will be used to track + // all updates added for this session. Upon deletion, we will + // consult the index to determine exactly which updates should + // be deleted without needing to iterate over the entire + // database. + return touchSessionHintBkt(updateIndex, &session.ID) + }) +} + +// InsertStateUpdate stores an update sent by the client after validating that +// the update is well-formed in the context of other updates sent for the same +// session. This include verifying that the sequence number is incremented +// properly and the last applied values echoed by the client are sane. +func (t *TowerDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) { + var lastApplied uint16 + err := t.db.Update(func(tx *bbolt.Tx) error { + sessions := tx.Bucket(sessionsBkt) + if sessions == nil { + return ErrUninitializedDB + } + + updates := tx.Bucket(updatesBkt) + if updates == nil { + return ErrUninitializedDB + } + + updateIndex := tx.Bucket(updateIndexBkt) + if updateIndex == nil { + return ErrUninitializedDB + } + + // Fetch the session corresponding to the update's session id. + // This will be used to validate that the update's sequence + // number and last applied values are sane. + session, err := getSession(sessions, update.ID[:]) + if err != nil { + return err + } + + // Validate the update against the current state of the session. + err = session.AcceptUpdateSequence( + update.SeqNum, update.LastApplied, + ) + if err != nil { + return err + } + + // Validation succeeded, therefore the update is committed and + // the session's last applied value is equal to the update's + // sequence number. + lastApplied = session.LastApplied + + // Store the updated session to persist the updated last applied + // values. + err = putSession(sessions, session) + if err != nil { + return err + } + + // Create or load the hint bucket for this state update's hint + // and write the given update. + hints, err := updates.CreateBucketIfNotExists(update.Hint[:]) + if err != nil { + return err + } + + var b bytes.Buffer + err = update.Encode(&b) + if err != nil { + return err + } + + err = hints.Put(update.ID[:], b.Bytes()) + if err != nil { + return err + } + + // Finally, create an entry in the update index to track this + // hint under its session id. This will allow us to delete the + // entries efficiently if the session is ever removed. + return putHintForSession(updateIndex, &update.ID, update.Hint) + }) + if err != nil { + return 0, err + } + + return lastApplied, nil +} + +// DeleteSession removes all data associated with a particular session id from +// the tower's database. +func (t *TowerDB) DeleteSession(target SessionID) error { + return t.db.Update(func(tx *bbolt.Tx) error { + sessions := tx.Bucket(sessionsBkt) + if sessions == nil { + return ErrUninitializedDB + } + + updates := tx.Bucket(updatesBkt) + if updates == nil { + return ErrUninitializedDB + } + + updateIndex := tx.Bucket(updateIndexBkt) + if updateIndex == nil { + return ErrUninitializedDB + } + + // Fail if the session doesn't exit. + _, err := getSession(sessions, target[:]) + if err != nil { + return err + } + + // Remove the target session. + err = sessions.Delete(target[:]) + if err != nil { + return err + } + + // Next, check the update index for any hints that were added + // under this session. + hints, err := getHintsForSession(updateIndex, &target) + if err != nil { + return err + } + + for _, hint := range hints { + // Remove the state updates for any blobs stored under + // the target session identifier. + updatesForHint := updates.Bucket(hint[:]) + if updatesForHint == nil { + continue + } + + update := updatesForHint.Get(target[:]) + if update == nil { + continue + } + + err := updatesForHint.Delete(target[:]) + if err != nil { + return err + } + + // If this was the last state update, we can also remove + // the hint that would map to an empty set. + err = isBucketEmpty(updatesForHint) + switch { + + // Other updates exist for this hint, keep the bucket. + case err == errBucketNotEmpty: + continue + + // Unexpected error. + case err != nil: + return err + + // No more updates for this hint, prune hint bucket. + default: + err = updates.DeleteBucket(hint[:]) + if err != nil { + return err + } + } + } + + // Finally, remove this session from the update index, which + // also removes any of the indexed hints beneath it. + return removeSessionHintBkt(updateIndex, &target) + }) +} + +// QueryMatches searches against all known state updates for any that match the +// passed breachHints. More than one Match will be returned for a given hint if +// they exist in the database. +func (t *TowerDB) QueryMatches(breachHints []BreachHint) ([]Match, error) { + var matches []Match + err := t.db.View(func(tx *bbolt.Tx) error { + sessions := tx.Bucket(sessionsBkt) + if sessions == nil { + return ErrUninitializedDB + } + + updates := tx.Bucket(updatesBkt) + if updates == nil { + return ErrUninitializedDB + } + + // Iterate through the target breach hints, appending any + // matching updates to the set of matches. + for _, hint := range breachHints { + // If a bucket does not exist for this hint, no matches + // are known. + updatesForHint := updates.Bucket(hint[:]) + if updatesForHint == nil { + continue + } + + // Otherwise, iterate through all (session id, update) + // pairs, creating a Match for each. + err := updatesForHint.ForEach(func(k, v []byte) error { + // Load the session via the session id for this + // update. The session info contains further + // instructions for how to process the state + // update. + session, err := getSession(sessions, k) + switch { + case err == ErrSessionNotFound: + log.Warnf("Missing session=%x for "+ + "matched state update hint=%x", + k, hint) + return nil + + case err != nil: + return err + } + + // Decode the state update containing the + // encrypted blob. + update := &SessionStateUpdate{} + err = update.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + + var id SessionID + copy(id[:], k) + + // Construct the final match using the found + // update and its session info. + match := Match{ + ID: id, + SeqNum: update.SeqNum, + Hint: hint, + EncryptedBlob: update.EncryptedBlob, + SessionInfo: session, + } + + matches = append(matches, match) + + return nil + }) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return nil, err + } + + return matches, nil +} + +// SetLookoutTip stores the provided epoch as the latest lookout tip epoch in +// the tower database. +func (t *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error { + return t.db.Update(func(tx *bbolt.Tx) error { + lookoutTip := tx.Bucket(lookoutTipBkt) + if lookoutTip == nil { + return ErrUninitializedDB + } + + return putLookoutEpoch(lookoutTip, epoch) + }) +} + +// GetLookoutTip retrieves the current lookout tip block epoch from the tower +// database. +func (t *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) { + var epoch *chainntnfs.BlockEpoch + err := t.db.View(func(tx *bbolt.Tx) error { + lookoutTip := tx.Bucket(lookoutTipBkt) + if lookoutTip == nil { + return ErrUninitializedDB + } + + epoch = getLookoutEpoch(lookoutTip) + + return nil + }) + if err != nil { + return nil, err + } + + return epoch, nil +} + +// getSession retrieves the session info from the sessions bucket identified by +// its session id. An error is returned if the session is not found or a +// deserialization error occurs. +func getSession(sessions *bbolt.Bucket, id []byte) (*SessionInfo, error) { + sessionBytes := sessions.Get(id) + if sessionBytes == nil { + return nil, ErrSessionNotFound + } + + var session SessionInfo + err := session.Decode(bytes.NewReader(sessionBytes)) + if err != nil { + return nil, err + } + + return &session, nil +} + +// putSession stores the session info in the sessions bucket identified by its +// session id. An error is returned if a serialization error occurs. +func putSession(sessions *bbolt.Bucket, session *SessionInfo) error { + var b bytes.Buffer + err := session.Encode(&b) + if err != nil { + return err + } + + return sessions.Put(session.ID[:], b.Bytes()) +} + +// touchSessionHintBkt initializes the session-hint bucket for a particular +// session id. This ensures that future calls to getHintsForSession or +// putHintForSession can rely on the bucket already being created, and fail if +// index has not been initialized as this points to improper usage. +func touchSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error { + _, err := updateIndex.CreateBucketIfNotExists(id[:]) + return err +} + +// removeSessionHintBkt prunes the session-hint bucket for the given session id +// and all of the hints contained inside. This should be used to clean up the +// index upon session deletion. +func removeSessionHintBkt(updateIndex *bbolt.Bucket, id *SessionID) error { + return updateIndex.DeleteBucket(id[:]) +} + +// getHintsForSession returns all known hints belonging to the given session id. +// If the index for the session has not been initialized, this method returns +// ErrNoSessionHintIndex. +func getHintsForSession(updateIndex *bbolt.Bucket, + id *SessionID) ([]BreachHint, error) { + + sessionHints := updateIndex.Bucket(id[:]) + if sessionHints == nil { + return nil, ErrNoSessionHintIndex + } + + var hints []BreachHint + err := sessionHints.ForEach(func(k, _ []byte) error { + if len(k) != BreachHintSize { + return nil + } + + var hint BreachHint + copy(hint[:], k) + hints = append(hints, hint) + return nil + }) + if err != nil { + return nil, err + } + + return hints, nil +} + +// putHintForSession inserts a record into the update index for a given +// (session, hint) pair. The hints are coalesced under a bucket for the target +// session id, and used to perform efficient removal of updates. If the index +// for the session has not been initialized, this method returns +// ErrNoSessionHintIndex. +func putHintForSession(updateIndex *bbolt.Bucket, id *SessionID, + hint BreachHint) error { + + sessionHints := updateIndex.Bucket(id[:]) + if sessionHints == nil { + return ErrNoSessionHintIndex + } + + return sessionHints.Put(hint[:], []byte{}) +} + +// putLookoutEpoch stores the given lookout tip block epoch in provided bucket. +func putLookoutEpoch(bkt *bbolt.Bucket, epoch *chainntnfs.BlockEpoch) error { + epochBytes := make([]byte, 36) + copy(epochBytes, epoch.Hash[:]) + byteOrder.PutUint32(epochBytes[32:], uint32(epoch.Height)) + + return bkt.Put(lookoutTipKey, epochBytes) +} + +// getLookoutEpoch retrieves the lookout tip block epoch from the given bucket. +// A nil epoch is returned if no update exists. +func getLookoutEpoch(bkt *bbolt.Bucket) *chainntnfs.BlockEpoch { + epochBytes := bkt.Get(lookoutTipKey) + if len(epochBytes) != 36 { + return nil + } + + var hash chainhash.Hash + copy(hash[:], epochBytes[:32]) + height := byteOrder.Uint32(epochBytes[32:]) + + return &chainntnfs.BlockEpoch{ + Hash: &hash, + Height: int32(height), + } +} + +// errBucketNotEmpty is a helper error returned when testing whether a bucket is +// empty or not. +var errBucketNotEmpty = errors.New("bucket not empty") + +// isBucketEmpty returns errBucketNotEmpty if the bucket is not empty. +func isBucketEmpty(bkt *bbolt.Bucket) error { + return bkt.ForEach(func(_, _ []byte) error { + return errBucketNotEmpty + }) +} diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go new file mode 100644 index 000000000..fd7481aff --- /dev/null +++ b/watchtower/wtdb/version.go @@ -0,0 +1,84 @@ +package wtdb + +import "github.com/coreos/bbolt" + +// migration is a function which takes a prior outdated version of the database +// instances and mutates the key/bucket structure to arrive at a more +// up-to-date version of the database. +type migration func(tx *bbolt.Tx) error + +// version pairs a version number with the migration that would need to be +// applied from the prior version to upgrade. +type version struct { + number uint32 + migration migration +} + +// dbVersions stores all versions and migrations of the database. This list will +// be used when opening the database to determine if any migrations must be +// applied. +var dbVersions = []version{ + { + // Initial version requires no migration. + number: 0, + migration: nil, + }, +} + +// getLatestDBVersion returns the last known database version. +func getLatestDBVersion(versions []version) uint32 { + return versions[len(versions)-1].number +} + +// getMigrations returns a slice of all updates with a greater number that +// curVersion that need to be applied to sync up with the latest version. +func getMigrations(versions []version, curVersion uint32) []version { + var updates []version + for _, v := range versions { + if v.number > curVersion { + updates = append(updates, v) + } + } + + return updates +} + +// getDBVersion retrieves the current database version from the metadata bucket +// using the dbVersionKey. +func getDBVersion(tx *bbolt.Tx) (uint32, error) { + metadata := tx.Bucket(metadataBkt) + if metadata == nil { + return 0, ErrUninitializedDB + } + + versionBytes := metadata.Get(dbVersionKey) + if len(versionBytes) != 4 { + return 0, ErrNoDBVersion + } + + return byteOrder.Uint32(versionBytes), nil +} + +// initDBVersion initializes the top-level metadata bucket and writes the passed +// version number as the current version. +func initDBVersion(tx *bbolt.Tx, version uint32) error { + _, err := tx.CreateBucketIfNotExists(metadataBkt) + if err != nil { + return err + } + + return putDBVersion(tx, version) +} + +// putDBVersion stores the passed database version in the metadata bucket under +// the dbVersionKey. +func putDBVersion(tx *bbolt.Tx, version uint32) error { + metadata := tx.Bucket(metadataBkt) + if metadata == nil { + return ErrUninitializedDB + } + + versionBytes := make([]byte, 4) + byteOrder.PutUint32(versionBytes, version) + return metadata.Put(dbVersionKey, versionBytes) +} From b7cd70f18699a868313ad244ad3ba03d20daed17 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 26 Apr 2019 17:21:48 -0700 Subject: [PATCH 8/9] watchtower/wtdb/tower_db_test: add mock+bolt iface tests --- watchtower/wtdb/tower_db_test.go | 730 +++++++++++++++++++++++++++++++ 1 file changed, 730 insertions(+) create mode 100644 watchtower/wtdb/tower_db_test.go diff --git a/watchtower/wtdb/tower_db_test.go b/watchtower/wtdb/tower_db_test.go new file mode 100644 index 000000000..c9920bcba --- /dev/null +++ b/watchtower/wtdb/tower_db_test.go @@ -0,0 +1,730 @@ +package wtdb_test + +import ( + "encoding/binary" + "io/ioutil" + "os" + "reflect" + "testing" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/watchtower" + "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtmock" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" +) + +// dbInit is a closure used to initialize a watchtower.DB instance and its +// cleanup function. +type dbInit func(*testing.T) (watchtower.DB, func()) + +// towerDBHarness holds the resources required to execute the tower db tests. +type towerDBHarness struct { + t *testing.T + db watchtower.DB +} + +// newTowerDBHarness initializes a fresh test harness for testing watchtower.DB +// implementations. +func newTowerDBHarness(t *testing.T, init dbInit) (*towerDBHarness, func()) { + db, cleanup := init(t) + + h := &towerDBHarness{ + t: t, + db: db, + } + + return h, cleanup +} + +// insertSession attempts to isnert the passed session and asserts that the +// error returned matches expErr. +func (h *towerDBHarness) insertSession(s *wtdb.SessionInfo, expErr error) { + h.t.Helper() + + err := h.db.InsertSessionInfo(s) + if err != expErr { + h.t.Fatalf("expected insert session error: %v, got : %v", + expErr, err) + } +} + +// getSession retrieves the session identified by id, asserting that the call +// returns expErr. If successful, the found session is returned. +func (h *towerDBHarness) getSession(id *wtdb.SessionID, + expErr error) *wtdb.SessionInfo { + + h.t.Helper() + + session, err := h.db.GetSessionInfo(id) + if err != expErr { + h.t.Fatalf("expected get session error: %v, got: %v", + expErr, err) + } + + return session +} + +// insertUpdate attempts to insert the passed state update and asserts that the +// error returned matches expErr. If successful, the session's last applied +// value is returned. +func (h *towerDBHarness) insertUpdate(s *wtdb.SessionStateUpdate, + expErr error) uint16 { + + h.t.Helper() + + lastApplied, err := h.db.InsertStateUpdate(s) + if err != expErr { + h.t.Fatalf("expected insert update error: %v, got: %v", + expErr, err) + } + + return lastApplied +} + +// deleteSession attempts to delete the session identified by id and asserts +// that the error returned from DeleteSession matches the expected error. +func (h *towerDBHarness) deleteSession(id wtdb.SessionID, expErr error) { + h.t.Helper() + + err := h.db.DeleteSession(id) + if err != expErr { + h.t.Fatalf("expected deletion error: %v, got: %v", + expErr, err) + } +} + +// queryMatches queries that database for the passed breach hint, returning all +// matches found. +func (h *towerDBHarness) queryMatches(hint wtdb.BreachHint) []wtdb.Match { + h.t.Helper() + + matches, err := h.db.QueryMatches([]wtdb.BreachHint{hint}) + if err != nil { + h.t.Fatalf("unable to query matches: %v", err) + } + + return matches +} + +// hasUpdate queries the database for the passed breach hint, asserting that +// only one match is present and that the hints indeed match. If successful, the +// match is returned. +func (h *towerDBHarness) hasUpdate(hint wtdb.BreachHint) wtdb.Match { + h.t.Helper() + + matches := h.queryMatches(hint) + if len(matches) != 1 { + h.t.Fatalf("expected 1 match, found: %d", len(matches)) + } + + match := matches[0] + if match.Hint != hint { + h.t.Fatalf("expected hint: %x, got: %x", hint, match.Hint) + } + + return match +} + +// testInsertSession asserts that a session can only be inserted if a session +// with the same session id does not already exist. +func testInsertSession(h *towerDBHarness) { + var id wtdb.SessionID + h.getSession(&id, wtdb.ErrSessionNotFound) + + session := &wtdb.SessionInfo{ + ID: id, + Policy: wtpolicy.Policy{ + MaxUpdates: 100, + }, + RewardAddress: []byte{0x01, 0x02, 0x03}, + } + + h.insertSession(session, nil) + + session2 := h.getSession(&id, nil) + + if !reflect.DeepEqual(session, session2) { + h.t.Fatalf("expected session: %v, got %v", + session, session2) + } + + h.insertSession(session, nil) + + // Insert a state update to fully commit the session parameters. + update := &wtdb.SessionStateUpdate{ + ID: id, + SeqNum: 1, + } + h.insertUpdate(update, nil) + + // Trying to insert a new session under the same ID should fail. + h.insertSession(session, wtdb.ErrSessionAlreadyExists) +} + +// testMultipleMatches asserts that if multiple sessions insert state updates +// with the same breach hint that all will be returned from QueryMatches. +func testMultipleMatches(h *towerDBHarness) { + const numUpdates = 3 + + // Create a new session and send updates with all the same hint. + var hint wtdb.BreachHint + for i := 0; i < numUpdates; i++ { + id := *id(i) + session := &wtdb.SessionInfo{ + ID: id, + Policy: wtpolicy.Policy{ + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + } + h.insertSession(session, nil) + + update := &wtdb.SessionStateUpdate{ + ID: id, + SeqNum: 1, + Hint: hint, // Use same hint to cause multiple matches + } + h.insertUpdate(update, nil) + } + + // Query the db for matches on the chosen hint. + matches := h.queryMatches(hint) + if len(matches) != numUpdates { + h.t.Fatalf("num updates mismatch, want: %d, got: %d", + numUpdates, len(matches)) + } + + // Assert that the hints are what we asked for, and compute the set of + // sessions returned. + sessions := make(map[wtdb.SessionID]struct{}) + for _, match := range matches { + if match.Hint != hint { + h.t.Fatalf("hint mismatch, want: %v, got: %v", + hint, match.Hint) + } + sessions[match.ID] = struct{}{} + } + + // Assert that the sessions returned match the session ids of the + // sessions we initially created. + for i := 0; i < numUpdates; i++ { + if _, ok := sessions[*id(i)]; !ok { + h.t.Fatalf("match for session %v not found", *id(i)) + } + } +} + +// testLookoutTip asserts that the database properly stores and returns the +// lookout tip block epochs. It also asserts that the epoch returned is nil when +// no tip has ever been set. +func testLookoutTip(h *towerDBHarness) { + // Retrieve lookout tip on fresh db. + epoch, err := h.db.GetLookoutTip() + if err != nil { + h.t.Fatalf("unable to fetch lookout tip: %v", err) + } + + // Assert that the epoch is nil. + if epoch != nil { + h.t.Fatalf("lookout tip should not be set, found: %v", epoch) + } + + // Create a closure that inserts an epoch, retrieves it, and asserts + // that the returned epoch matches what was inserted. + setAndCheck := func(i int) { + expEpoch := epochFromInt(1) + err = h.db.SetLookoutTip(expEpoch) + if err != nil { + h.t.Fatalf("unable to set lookout tip: %v", err) + } + + epoch, err = h.db.GetLookoutTip() + if err != nil { + h.t.Fatalf("unable to fetch lookout tip: %v", err) + } + + if !reflect.DeepEqual(epoch, expEpoch) { + h.t.Fatalf("lookout tip mismatch, want: %v, got: %v", + expEpoch, epoch) + } + } + + // Set and assert the lookout tip. + for i := 0; i < 5; i++ { + setAndCheck(i) + } +} + +// testDeleteSession asserts the behavior of a tower database when deleting +// session data. The test asserts that the only proper the target session is +// remmoved, and that only updates for a particular session are pruned. +func testDeleteSession(h *towerDBHarness) { + // First, create a session so that the database is not empty. + id0 := id(0) + session0 := &wtdb.SessionInfo{ + ID: *id0, + Policy: wtpolicy.Policy{ + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + } + h.insertSession(session0, nil) + + // Now, attempt to delete a session which does not exist, that is also + // different from the first one created. + id1 := id(1) + h.deleteSession(*id1, wtdb.ErrSessionNotFound) + + // The first session should still be present. + h.getSession(id0, nil) + + // Now insert a second session under a different id. + session1 := &wtdb.SessionInfo{ + ID: *id1, + Policy: wtpolicy.Policy{ + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + } + h.insertSession(session1, nil) + + // Create and insert updates for both sessions that have the same hint. + var hint wtdb.BreachHint + update0 := &wtdb.SessionStateUpdate{ + ID: *id0, + Hint: hint, + SeqNum: 1, + EncryptedBlob: []byte{}, + } + update1 := &wtdb.SessionStateUpdate{ + ID: *id1, + Hint: hint, + SeqNum: 1, + EncryptedBlob: []byte{}, + } + + // Insert both updates should succeed. + h.insertUpdate(update0, nil) + h.insertUpdate(update1, nil) + + // Remove the new session, which should succeed. + h.deleteSession(*id1, nil) + + // The first session should still be present. + h.getSession(id0, nil) + + // The second session should be removed. + h.getSession(id1, wtdb.ErrSessionNotFound) + + // Assert that only one update is still present. + matches := h.queryMatches(hint) + if len(matches) != 1 { + h.t.Fatalf("expected one update, found: %d", len(matches)) + } + + // Assert that the update belongs to the first session. + if matches[0].ID != *id0 { + h.t.Fatalf("expected match for %v, instead is for: %v", + *id0, matches[0].ID) + } + + // Finally, remove the first session added. + h.deleteSession(*id0, nil) + + // The session should no longer be present. + h.getSession(id0, wtdb.ErrSessionNotFound) + + // No matches should exist for this hint. + matches = h.queryMatches(hint) + if len(matches) != 0 { + h.t.Fatalf("expected zero updates, found: %d", len(matches)) + } +} + +type stateUpdateTest struct { + session *wtdb.SessionInfo + sessionErr error + updates []*wtdb.SessionStateUpdate + updateErrs []error +} + +func runStateUpdateTest(test stateUpdateTest) func(*towerDBHarness) { + return func(h *towerDBHarness) { + // We may need to modify the initial session as we process + // updates to discern the expected state of the session. We'll + // create a copy of the test session if necessary to prevent + // mutations from impacting other tests. + var expSession *wtdb.SessionInfo + + // Create the session if the tests requests one. + if test.session != nil { + // Copy the initial session and insert it into the + // database. + ogSession := *test.session + expErr := test.sessionErr + h.insertSession(&ogSession, expErr) + + if expErr != nil { + return + } + + // Copy the initial state of the accepted session. + expSession = &wtdb.SessionInfo{} + *expSession = *test.session + } + + if len(test.updates) != len(test.updateErrs) { + h.t.Fatalf("malformed test case, num updates " + + "should match num errors") + } + + // Send any updates provided in the test. + for i, update := range test.updates { + expErr := test.updateErrs[i] + h.insertUpdate(update, expErr) + + if expErr != nil { + continue + } + + // Don't perform the following checks and modfications + // if we don't have an expected session to compare + // against. + if expSession == nil { + continue + } + + // Update the session's last applied and client last + // applied. + expSession.LastApplied = update.SeqNum + expSession.ClientLastApplied = update.LastApplied + + match := h.hasUpdate(update.Hint) + if !reflect.DeepEqual(match.SessionInfo, expSession) { + h.t.Fatalf("expected session: %v, got: %v", + expSession, match.SessionInfo) + } + } + } +} + +var stateUpdateNoSession = stateUpdateTest{ + session: nil, + updates: []*wtdb.SessionStateUpdate{ + {ID: *id(0), SeqNum: 1, LastApplied: 0}, + }, + updateErrs: []error{ + wtdb.ErrSessionNotFound, + }, +} + +var stateUpdateExhaustSession = stateUpdateTest{ + session: &wtdb.SessionInfo{ + ID: *id(0), + Policy: wtpolicy.Policy{ + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + }, + updates: []*wtdb.SessionStateUpdate{ + updateFromInt(id(0), 1, 0), + updateFromInt(id(0), 2, 0), + updateFromInt(id(0), 3, 0), + updateFromInt(id(0), 4, 0), + }, + updateErrs: []error{ + nil, nil, nil, wtdb.ErrSessionConsumed, + }, +} + +var stateUpdateSeqNumEqualLastApplied = stateUpdateTest{ + session: &wtdb.SessionInfo{ + ID: *id(0), + Policy: wtpolicy.Policy{ + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + }, + updates: []*wtdb.SessionStateUpdate{ + updateFromInt(id(0), 1, 0), + updateFromInt(id(0), 2, 1), + updateFromInt(id(0), 3, 2), + updateFromInt(id(0), 3, 3), + }, + updateErrs: []error{ + nil, nil, nil, wtdb.ErrSeqNumAlreadyApplied, + }, +} + +var stateUpdateSeqNumLTLastApplied = stateUpdateTest{ + session: &wtdb.SessionInfo{ + ID: *id(0), + Policy: wtpolicy.Policy{ + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + }, + updates: []*wtdb.SessionStateUpdate{ + updateFromInt(id(0), 1, 0), + updateFromInt(id(0), 2, 1), + updateFromInt(id(0), 1, 2), + }, + updateErrs: []error{ + nil, nil, wtdb.ErrSeqNumAlreadyApplied, + }, +} + +var stateUpdateSeqNumZeroInvalid = stateUpdateTest{ + session: &wtdb.SessionInfo{ + ID: *id(0), + Policy: wtpolicy.Policy{ + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + }, + updates: []*wtdb.SessionStateUpdate{ + updateFromInt(id(0), 0, 0), + }, + updateErrs: []error{ + wtdb.ErrSeqNumAlreadyApplied, + }, +} + +var stateUpdateSkipSeqNum = stateUpdateTest{ + session: &wtdb.SessionInfo{ + ID: *id(0), + Policy: wtpolicy.Policy{ + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + }, + updates: []*wtdb.SessionStateUpdate{ + updateFromInt(id(0), 2, 0), + }, + updateErrs: []error{ + wtdb.ErrUpdateOutOfOrder, + }, +} + +var stateUpdateRevertSeqNum = stateUpdateTest{ + session: &wtdb.SessionInfo{ + ID: *id(0), + Policy: wtpolicy.Policy{ + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + }, + updates: []*wtdb.SessionStateUpdate{ + updateFromInt(id(0), 1, 0), + updateFromInt(id(0), 2, 0), + updateFromInt(id(0), 1, 0), + }, + updateErrs: []error{ + nil, nil, wtdb.ErrUpdateOutOfOrder, + }, +} + +var stateUpdateRevertLastApplied = stateUpdateTest{ + session: &wtdb.SessionInfo{ + ID: *id(0), + Policy: wtpolicy.Policy{ + MaxUpdates: 3, + }, + RewardAddress: []byte{}, + }, + updates: []*wtdb.SessionStateUpdate{ + updateFromInt(id(0), 1, 0), + updateFromInt(id(0), 2, 1), + updateFromInt(id(0), 3, 2), + updateFromInt(id(0), 4, 1), + }, + updateErrs: []error{ + nil, nil, nil, wtdb.ErrLastAppliedReversion, + }, +} + +func TestTowerDB(t *testing.T) { + dbs := []struct { + name string + init dbInit + }{ + { + name: "fresh boltdb", + init: func(t *testing.T) (watchtower.DB, func()) { + path, err := ioutil.TempDir("", "towerdb") + if err != nil { + t.Fatalf("unable to make temp dir: %v", + err) + } + + db, err := wtdb.OpenTowerDB(path) + if err != nil { + os.RemoveAll(path) + t.Fatalf("unable to open db: %v", err) + } + + cleanup := func() { + db.Close() + os.RemoveAll(path) + } + + return db, cleanup + }, + }, + { + name: "reopened boltdb", + init: func(t *testing.T) (watchtower.DB, func()) { + path, err := ioutil.TempDir("", "towerdb") + if err != nil { + t.Fatalf("unable to make temp dir: %v", + err) + } + + db, err := wtdb.OpenTowerDB(path) + if err != nil { + os.RemoveAll(path) + t.Fatalf("unable to open db: %v", err) + } + db.Close() + + // Open the db again, ensuring we test a + // different path during open and that all + // buckets remain initialized. + db, err = wtdb.OpenTowerDB(path) + if err != nil { + os.RemoveAll(path) + t.Fatalf("unable to open db: %v", err) + } + + cleanup := func() { + db.Close() + os.RemoveAll(path) + } + + return db, cleanup + }, + }, + { + name: "mock", + init: func(t *testing.T) (watchtower.DB, func()) { + return wtmock.NewTowerDB(), func() {} + }, + }, + } + + tests := []struct { + name string + run func(*towerDBHarness) + }{ + { + name: "create session", + run: testInsertSession, + }, + { + name: "delete session", + run: testDeleteSession, + }, + { + name: "state update no session", + run: runStateUpdateTest(stateUpdateNoSession), + }, + { + name: "state update exhaust session", + run: runStateUpdateTest(stateUpdateExhaustSession), + }, + { + name: "state update seqnum equal last applied", + run: runStateUpdateTest( + stateUpdateSeqNumEqualLastApplied, + ), + }, + { + name: "state update seqnum less than last applied", + run: runStateUpdateTest( + stateUpdateSeqNumLTLastApplied, + ), + }, + { + name: "state update seqnum zero invalid", + run: runStateUpdateTest(stateUpdateSeqNumZeroInvalid), + }, + { + name: "state update skip seqnum", + run: runStateUpdateTest(stateUpdateSkipSeqNum), + }, + { + name: "state update revert seqnum", + run: runStateUpdateTest(stateUpdateRevertSeqNum), + }, + { + name: "state update revert last applied", + run: runStateUpdateTest(stateUpdateRevertLastApplied), + }, + { + name: "multiple breach matches", + run: testMultipleMatches, + }, + { + name: "lookout tip", + run: testLookoutTip, + }, + } + + for _, database := range dbs { + db := database + t.Run(db.name, func(t *testing.T) { + t.Parallel() + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + h, cleanup := newTowerDBHarness( + t, db.init, + ) + defer cleanup() + + test.run(h) + }) + } + }) + } +} + +// id creates a session id from an integer. +func id(i int) *wtdb.SessionID { + var id wtdb.SessionID + binary.BigEndian.PutUint32(id[:4], uint32(i)) + return &id +} + +// updateFromInt creates a unique update for a given (session, seqnum) pair. The +// lastApplied argument can be used to construct updates simulating different +// levels of synchronicity between client and db. +func updateFromInt(id *wtdb.SessionID, i int, + lastApplied uint16) *wtdb.SessionStateUpdate { + + // Ensure the hint is unique. + var hint wtdb.BreachHint + copy(hint[:4], id[:4]) + binary.BigEndian.PutUint16(hint[4:6], uint16(i)) + + return &wtdb.SessionStateUpdate{ + ID: *id, + Hint: hint, + SeqNum: uint16(i), + LastApplied: lastApplied, + EncryptedBlob: []byte{byte(i)}, + } +} + +// epochFromInt creates a block epoch from an integer. +func epochFromInt(i int) *chainntnfs.BlockEpoch { + var hash chainhash.Hash + binary.BigEndian.PutUint32(hash[:4], uint32(i)) + + return &chainntnfs.BlockEpoch{ + Hash: &hash, + Height: int32(i), + } +} From 54c908be1ab7bf764b74c4f514aedd4cc8c0b94e Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Fri, 26 Apr 2019 17:22:02 -0700 Subject: [PATCH 9/9] watchtower/wtserver/create_session: log missing error in create session --- watchtower/wtserver/create_session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/watchtower/wtserver/create_session.go b/watchtower/wtserver/create_session.go index 411742e0a..5636b79d7 100644 --- a/watchtower/wtserver/create_session.go +++ b/watchtower/wtserver/create_session.go @@ -102,7 +102,7 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, // successful, the session will now be ready for use. err = s.cfg.DB.InsertSessionInfo(&info) if err != nil { - log.Errorf("unable to create session for %s", id) + log.Errorf("Unable to create session for %s: %v", id, err) return s.replyCreateSession( peer, id, wtwire.CodeTemporaryFailure, 0, nil, )