From e8dc15dae49451fc469c8d444fecdaae41b80efd Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Thu, 28 Jul 2022 18:24:13 +0800 Subject: [PATCH] lntemp: add supporting methods for `testDataLossProtection` --- lntemp/harness.go | 55 ++++++++++++++++++ lntemp/harness_assertion.go | 56 ++++++++++++++++++ lntemp/node/harness_node.go | 110 ++++++++++++++++++++++++++++++++++++ 3 files changed, 221 insertions(+) diff --git a/lntemp/harness.go b/lntemp/harness.go index 7638a1674..3c0afdae6 100644 --- a/lntemp/harness.go +++ b/lntemp/harness.go @@ -1105,3 +1105,58 @@ func (h *HarnessTest) mineTillForceCloseResolved(hn *node.HarnessNode) { require.NoErrorf(h, err, "assert force close resolved timeout") } + +// CreatePayReqs is a helper method that will create a slice of payment +// requests for the given node. +func (h *HarnessTest) CreatePayReqs(hn *node.HarnessNode, + paymentAmt btcutil.Amount, numInvoices int) ([]string, + [][]byte, []*lnrpc.Invoice) { + + payReqs := make([]string, numInvoices) + rHashes := make([][]byte, numInvoices) + invoices := make([]*lnrpc.Invoice, numInvoices) + for i := 0; i < numInvoices; i++ { + preimage := h.Random32Bytes() + + invoice := &lnrpc.Invoice{ + Memo: "testing", + RPreimage: preimage, + Value: int64(paymentAmt), + } + resp := hn.RPC.AddInvoice(invoice) + + // Set the payment address in the invoice so the caller can + // properly use it. + invoice.PaymentAddr = resp.PaymentAddr + + payReqs[i] = resp.PaymentRequest + rHashes[i] = resp.RHash + invoices[i] = invoice + } + + return payReqs, rHashes, invoices +} + +// BackupDB creates a backup of the current database. It will stop the node +// first, copy the database files, and restart the node. +func (h *HarnessTest) BackupDB(hn *node.HarnessNode) { + restart := h.SuspendNode(hn) + + err := hn.BackupDB() + require.NoErrorf(h, err, "%s: failed to backup db", hn.Name()) + + err = restart() + require.NoErrorf(h, err, "%s: failed to restart", hn.Name()) +} + +// RestartNodeAndRestoreDB restarts a given node with a callback to restore the +// db. +func (h *HarnessTest) RestartNodeAndRestoreDB(hn *node.HarnessNode) { + cb := func() error { return hn.RestoreDB() } + err := h.manager.restartNode(h.runCtx, hn, cb) + require.NoErrorf(h, err, "failed to restart node %s", hn.Name()) + + // Give the node some time to catch up with the chain before we + // continue with the tests. + h.WaitForBlockchainSync(hn) +} diff --git a/lntemp/harness_assertion.go b/lntemp/harness_assertion.go index 33d6d04c7..c0758a134 100644 --- a/lntemp/harness_assertion.go +++ b/lntemp/harness_assertion.go @@ -905,3 +905,59 @@ func (h *HarnessTest) AssertNodeNumChannels(hn *node.HarnessNode, require.NoError(h, err, "timeout checking node's num of channels") } + +// AssertChannelLocalBalance checks the local balance of the given channel is +// expected. The channel found using the specified channel point is returned. +func (h *HarnessTest) AssertChannelLocalBalance(hn *node.HarnessNode, + cp *lnrpc.ChannelPoint, balance int64) *lnrpc.Channel { + + var result *lnrpc.Channel + + // Get the funding point. + err := wait.NoError(func() error { + // Find the target channel first. + target, err := h.findChannel(hn, cp) + + // Exit early if the channel is not found. + if err != nil { + return fmt.Errorf("check balance failed: %w", err) + } + + result = target + + // Check local balance. + if target.LocalBalance == balance { + return nil + } + + return fmt.Errorf("balance is incorrect, got %v, expected %v", + target.LocalBalance, balance) + }, DefaultTimeout) + + require.NoError(h, err, "timeout while checking for balance") + + return result +} + +// AssertChannelNumUpdates checks the num of updates is expected from the given +// channel. +func (h *HarnessTest) AssertChannelNumUpdates(hn *node.HarnessNode, + num uint64, cp *lnrpc.ChannelPoint) { + + old := int(hn.State.OpenChannel.NumUpdates) + + // Find the target channel first. + target, err := h.findChannel(hn, cp) + require.NoError(h, err, "unable to find channel") + + err = wait.NoError(func() error { + total := int(target.NumUpdates) + if total-old == int(num) { + return nil + } + + return errNumNotMatched(hn.Name(), "channel updates", + int(num), total-old, total, old) + }, DefaultTimeout) + require.NoError(h, err, "timeout while checking for num of updates") +} diff --git a/lntemp/node/harness_node.go b/lntemp/node/harness_node.go index 8f7297f7b..26ef329c5 100644 --- a/lntemp/node/harness_node.go +++ b/lntemp/node/harness_node.go @@ -731,6 +731,81 @@ func (hn *HarnessNode) printErrf(format string, a ...interface{}) { fmt.Sprintf(format, a...)) } +// BackupDB creates a backup of the current database. +func (hn *HarnessNode) BackupDB() error { + if hn.Cfg.backupDbDir != "" { + return fmt.Errorf("backup already created") + } + + if hn.Cfg.postgresDbName != "" { + // Backup database. + backupDBName := hn.Cfg.postgresDbName + "_backup" + err := executePgQuery( + "CREATE DATABASE " + backupDBName + " WITH TEMPLATE " + + hn.Cfg.postgresDbName, + ) + if err != nil { + return err + } + } else { + // Backup files. + tempDir, err := ioutil.TempDir("", "past-state") + if err != nil { + return fmt.Errorf("unable to create temp db folder: %w", + err) + } + + if err := copyAll(tempDir, hn.Cfg.DBDir()); err != nil { + return fmt.Errorf("unable to copy database files: %w", + err) + } + + hn.Cfg.backupDbDir = tempDir + } + + return nil +} + +// RestoreDB restores a database backup. +func (hn *HarnessNode) RestoreDB() error { + if hn.Cfg.postgresDbName != "" { + // Restore database. + backupDBName := hn.Cfg.postgresDbName + "_backup" + err := executePgQuery( + "DROP DATABASE " + hn.Cfg.postgresDbName, + ) + if err != nil { + return err + } + err = executePgQuery( + "ALTER DATABASE " + backupDBName + " RENAME TO " + + hn.Cfg.postgresDbName, + ) + if err != nil { + return err + } + } else { + // Restore files. + if hn.Cfg.backupDbDir == "" { + return fmt.Errorf("no database backup created") + } + + err := copyAll(hn.Cfg.DBDir(), hn.Cfg.backupDbDir) + if err != nil { + return fmt.Errorf("unable to copy database files: %w", + err) + } + + if err := os.RemoveAll(hn.Cfg.backupDbDir); err != nil { + return fmt.Errorf("unable to remove backup dir: %w", + err) + } + hn.Cfg.backupDbDir = "" + } + + return nil +} + func postgresDatabaseDsn(dbName string) string { return fmt.Sprintf(postgresDsn, dbName) } @@ -867,3 +942,38 @@ func addLogFile(hn *HarnessNode) error { return nil } + +// copyAll copies all files and directories from srcDir to dstDir recursively. +// Note that this function does not support links. +func copyAll(dstDir, srcDir string) error { + entries, err := ioutil.ReadDir(srcDir) + if err != nil { + return err + } + + for _, entry := range entries { + srcPath := filepath.Join(srcDir, entry.Name()) + dstPath := filepath.Join(dstDir, entry.Name()) + + info, err := os.Stat(srcPath) + if err != nil { + return err + } + + if info.IsDir() { + err := os.Mkdir(dstPath, info.Mode()) + if err != nil && !os.IsExist(err) { + return err + } + + err = copyAll(dstPath, srcPath) + if err != nil { + return err + } + } else if err := lntest.CopyFile(dstPath, srcPath); err != nil { + return err + } + } + + return nil +}