mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-12-08 03:41:42 +01:00
lntemp: add supporting methods for testDataLossProtection
This commit is contained in:
@@ -731,6 +731,81 @@ func (hn *HarnessNode) printErrf(format string, a ...interface{}) {
|
||||
fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
// BackupDB creates a backup of the current database.
|
||||
func (hn *HarnessNode) BackupDB() error {
|
||||
if hn.Cfg.backupDbDir != "" {
|
||||
return fmt.Errorf("backup already created")
|
||||
}
|
||||
|
||||
if hn.Cfg.postgresDbName != "" {
|
||||
// Backup database.
|
||||
backupDBName := hn.Cfg.postgresDbName + "_backup"
|
||||
err := executePgQuery(
|
||||
"CREATE DATABASE " + backupDBName + " WITH TEMPLATE " +
|
||||
hn.Cfg.postgresDbName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Backup files.
|
||||
tempDir, err := ioutil.TempDir("", "past-state")
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create temp db folder: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
if err := copyAll(tempDir, hn.Cfg.DBDir()); err != nil {
|
||||
return fmt.Errorf("unable to copy database files: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
hn.Cfg.backupDbDir = tempDir
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestoreDB restores a database backup.
|
||||
func (hn *HarnessNode) RestoreDB() error {
|
||||
if hn.Cfg.postgresDbName != "" {
|
||||
// Restore database.
|
||||
backupDBName := hn.Cfg.postgresDbName + "_backup"
|
||||
err := executePgQuery(
|
||||
"DROP DATABASE " + hn.Cfg.postgresDbName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = executePgQuery(
|
||||
"ALTER DATABASE " + backupDBName + " RENAME TO " +
|
||||
hn.Cfg.postgresDbName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Restore files.
|
||||
if hn.Cfg.backupDbDir == "" {
|
||||
return fmt.Errorf("no database backup created")
|
||||
}
|
||||
|
||||
err := copyAll(hn.Cfg.DBDir(), hn.Cfg.backupDbDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to copy database files: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(hn.Cfg.backupDbDir); err != nil {
|
||||
return fmt.Errorf("unable to remove backup dir: %w",
|
||||
err)
|
||||
}
|
||||
hn.Cfg.backupDbDir = ""
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func postgresDatabaseDsn(dbName string) string {
|
||||
return fmt.Sprintf(postgresDsn, dbName)
|
||||
}
|
||||
@@ -867,3 +942,38 @@ func addLogFile(hn *HarnessNode) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// copyAll copies all files and directories from srcDir to dstDir recursively.
|
||||
// Note that this function does not support links.
|
||||
func copyAll(dstDir, srcDir string) error {
|
||||
entries, err := ioutil.ReadDir(srcDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
srcPath := filepath.Join(srcDir, entry.Name())
|
||||
dstPath := filepath.Join(dstDir, entry.Name())
|
||||
|
||||
info, err := os.Stat(srcPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
err := os.Mkdir(dstPath, info.Mode())
|
||||
if err != nil && !os.IsExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = copyAll(dstPath, srcPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := lntest.CopyFile(dstPath, srcPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user