diff --git a/watchtower/lookout/lookout_test.go b/watchtower/lookout/lookout_test.go index 4232791d5..fb70a9618 100644 --- a/watchtower/lookout/lookout_test.go +++ b/watchtower/lookout/lookout_test.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/lookout" "github.com/lightningnetwork/lnd/watchtower/wtdb" + "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" ) @@ -66,7 +67,7 @@ func makeAddrSlice(size int) []byte { } func TestLookoutBreachMatching(t *testing.T) { - db := wtdb.NewMockDB() + db := wtmock.NewTowerDB() // Initialize an mock backend to feed the lookout blocks. backend := lookout.NewMockBackend() diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 68c34e0e1..86811bf02 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -369,7 +369,7 @@ type testHarness struct { clientDB *wtmock.ClientDB clientCfg *wtclient.Config client wtclient.Client - serverDB *wtdb.MockDB + serverDB *wtmock.TowerDB serverCfg *wtserver.Config server *wtserver.Server net *mockNet @@ -406,7 +406,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { } const timeout = 200 * time.Millisecond - serverDB := wtdb.NewMockDB() + serverDB := wtmock.NewTowerDB() serverCfg := &wtserver.Config{ DB: serverDB, diff --git a/watchtower/wtdb/mock.go b/watchtower/wtdb/mock.go deleted file mode 100644 index 902303881..000000000 --- a/watchtower/wtdb/mock.go +++ /dev/null @@ -1,142 +0,0 @@ -// +build dev - -package wtdb - -import ( - "sync" - - "github.com/lightningnetwork/lnd/chainntnfs" -) - -type MockDB struct { - mu sync.Mutex - lastEpoch *chainntnfs.BlockEpoch - sessions map[SessionID]*SessionInfo - blobs map[BreachHint]map[SessionID]*SessionStateUpdate -} - -func NewMockDB() *MockDB { - return &MockDB{ - sessions: make(map[SessionID]*SessionInfo), - blobs: make(map[BreachHint]map[SessionID]*SessionStateUpdate), - } -} - -func (db *MockDB) InsertStateUpdate(update *SessionStateUpdate) (uint16, error) { - db.mu.Lock() - defer db.mu.Unlock() - - info, ok := db.sessions[update.ID] - if !ok { - return 0, ErrSessionNotFound - } - - err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied) - if err != nil { - return info.LastApplied, err - } - - sessionsToUpdates, ok := db.blobs[update.Hint] - if !ok { - sessionsToUpdates = make(map[SessionID]*SessionStateUpdate) - db.blobs[update.Hint] = sessionsToUpdates - } - sessionsToUpdates[update.ID] = update - - return info.LastApplied, nil -} - -func (db *MockDB) GetSessionInfo(id *SessionID) (*SessionInfo, error) { - db.mu.Lock() - defer db.mu.Unlock() - - if info, ok := db.sessions[*id]; ok { - return info, nil - } - - return nil, ErrSessionNotFound -} - -func (db *MockDB) InsertSessionInfo(info *SessionInfo) error { - db.mu.Lock() - defer db.mu.Unlock() - - dbInfo, ok := db.sessions[info.ID] - if ok && dbInfo.LastApplied > 0 { - return ErrSessionAlreadyExists - } - - db.sessions[info.ID] = info - - return nil -} - -func (db *MockDB) DeleteSession(target SessionID) error { - db.mu.Lock() - defer db.mu.Unlock() - - // Fail if the session doesn't exit. - if _, ok := db.sessions[target]; !ok { - return ErrSessionNotFound - } - - // Remove the target session. - delete(db.sessions, target) - - // Remove the state updates for any blobs stored under the target - // session identifier. - for hint, sessionUpdates := range db.blobs { - delete(sessionUpdates, target) - - //If this was the last state update, we can also remove the hint - //that would map to an empty set. - if len(sessionUpdates) == 0 { - delete(db.blobs, hint) - } - } - - return nil -} - -func (db *MockDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) { - db.mu.Lock() - defer db.mu.Unlock() - - return db.lastEpoch, nil -} - -func (db *MockDB) QueryMatches(breachHints []BreachHint) ([]Match, error) { - db.mu.Lock() - defer db.mu.Unlock() - - var matches []Match - for _, hint := range breachHints { - sessionsToUpdates, ok := db.blobs[hint] - if !ok { - continue - } - - for id, update := range sessionsToUpdates { - info, ok := db.sessions[id] - if !ok { - panic("session not found") - } - - match := Match{ - ID: id, - SeqNum: update.SeqNum, - Hint: hint, - EncryptedBlob: update.EncryptedBlob, - SessionInfo: info, - } - matches = append(matches, match) - } - } - - return matches, nil -} - -func (db *MockDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error { - db.lastEpoch = epoch - return nil -} diff --git a/watchtower/wtmock/tower_db.go b/watchtower/wtmock/tower_db.go new file mode 100644 index 000000000..403d61e30 --- /dev/null +++ b/watchtower/wtmock/tower_db.go @@ -0,0 +1,162 @@ +package wtmock + +import ( + "sync" + + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +// TowerDB is a mock, in-memory implementation of a watchtower.DB. +type TowerDB struct { + mu sync.Mutex + lastEpoch *chainntnfs.BlockEpoch + sessions map[wtdb.SessionID]*wtdb.SessionInfo + blobs map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate +} + +// NewTowerDB initializes a fresh mock TowerDB. +func NewTowerDB() *TowerDB { + return &TowerDB{ + sessions: make(map[wtdb.SessionID]*wtdb.SessionInfo), + blobs: make(map[wtdb.BreachHint]map[wtdb.SessionID]*wtdb.SessionStateUpdate), + } +} + +// InsertStateUpdate stores an update sent by the client after validating that +// the update is well-formed in the context of other updates sent for the same +// session. This include verifying that the sequence number is incremented +// properly and the last applied values echoed by the client are sane. +func (db *TowerDB) InsertStateUpdate(update *wtdb.SessionStateUpdate) (uint16, error) { + db.mu.Lock() + defer db.mu.Unlock() + + info, ok := db.sessions[update.ID] + if !ok { + return 0, wtdb.ErrSessionNotFound + } + + err := info.AcceptUpdateSequence(update.SeqNum, update.LastApplied) + if err != nil { + return info.LastApplied, err + } + + sessionsToUpdates, ok := db.blobs[update.Hint] + if !ok { + sessionsToUpdates = make(map[wtdb.SessionID]*wtdb.SessionStateUpdate) + db.blobs[update.Hint] = sessionsToUpdates + } + sessionsToUpdates[update.ID] = update + + return info.LastApplied, nil +} + +// GetSessionInfo retrieves the session for the passed session id. An error is +// returned if the session could not be found. +func (db *TowerDB) GetSessionInfo(id *wtdb.SessionID) (*wtdb.SessionInfo, error) { + db.mu.Lock() + defer db.mu.Unlock() + + if info, ok := db.sessions[*id]; ok { + return info, nil + } + + return nil, wtdb.ErrSessionNotFound +} + +// InsertSessionInfo records a negotiated session in the tower database. An +// error is returned if the session already exists. +func (db *TowerDB) InsertSessionInfo(info *wtdb.SessionInfo) error { + db.mu.Lock() + defer db.mu.Unlock() + + dbInfo, ok := db.sessions[info.ID] + if ok && dbInfo.LastApplied > 0 { + return wtdb.ErrSessionAlreadyExists + } + + db.sessions[info.ID] = info + + return nil +} + +// DeleteSession removes all data associated with a particular session id from +// the tower's database. +func (db *TowerDB) DeleteSession(target wtdb.SessionID) error { + db.mu.Lock() + defer db.mu.Unlock() + + // Fail if the session doesn't exit. + if _, ok := db.sessions[target]; !ok { + return wtdb.ErrSessionNotFound + } + + // Remove the target session. + delete(db.sessions, target) + + // Remove the state updates for any blobs stored under the target + // session identifier. + for hint, sessionUpdates := range db.blobs { + delete(sessionUpdates, target) + + // If this was the last state update, we can also remove the + // hint that would map to an empty set. + if len(sessionUpdates) == 0 { + delete(db.blobs, hint) + } + } + + return nil +} + +// QueryMatches searches against all known state updates for any that match the +// passed breachHints. More than one Match will be returned for a given hint if +// they exist in the database. +func (db *TowerDB) QueryMatches( + breachHints []wtdb.BreachHint) ([]wtdb.Match, error) { + + db.mu.Lock() + defer db.mu.Unlock() + + var matches []wtdb.Match + for _, hint := range breachHints { + sessionsToUpdates, ok := db.blobs[hint] + if !ok { + continue + } + + for id, update := range sessionsToUpdates { + info, ok := db.sessions[id] + if !ok { + panic("session not found") + } + + match := wtdb.Match{ + ID: id, + SeqNum: update.SeqNum, + Hint: hint, + EncryptedBlob: update.EncryptedBlob, + SessionInfo: info, + } + matches = append(matches, match) + } + } + + return matches, nil +} + +// SetLookoutTip stores the provided epoch as the latest lookout tip epoch in +// the tower database. +func (db *TowerDB) SetLookoutTip(epoch *chainntnfs.BlockEpoch) error { + db.lastEpoch = epoch + return nil +} + +// GetLookoutTip retrieves the current lookout tip block epoch from the tower +// database. +func (db *TowerDB) GetLookoutTip() (*chainntnfs.BlockEpoch, error) { + db.mu.Lock() + defer db.mu.Unlock() + + return db.lastEpoch, nil +} diff --git a/watchtower/wtserver/server_test.go b/watchtower/wtserver/server_test.go index 6d99180fd..c3ae2a33e 100644 --- a/watchtower/wtserver/server_test.go +++ b/watchtower/wtserver/server_test.go @@ -1,5 +1,3 @@ -// +build dev - package wtserver_test import ( @@ -53,7 +51,7 @@ func initServer(t *testing.T, db wtserver.DB, t.Helper() if db == nil { - db = wtdb.NewMockDB() + db = wtmock.NewTowerDB() } s, err := wtserver.New(&wtserver.Config{ @@ -687,7 +685,7 @@ func testServerStateUpdates(t *testing.T, test stateUpdateTestCase) { // checking that the proper error is returned when the session doesn't exist and // that a successful deletion does not disrupt other sessions. func TestServerDeleteSession(t *testing.T) { - db := wtdb.NewMockDB() + db := wtmock.NewTowerDB() localPub := randPubKey(t)