storagenode-updater: add recovery for windows service restart

Reimplement windows service restart part using svc, add recovery
for failed service startup. Added restart-service cmd, to execute
self restart in a separate process.

Addressed issues:
https://storjlabs.atlassian.net/browse/SG-49
https://storjlabs.atlassian.net/browse/SG-136
https://storjlabs.atlassian.net/browse/SG-137

Change-Id: Ic51d9a99e8c1c10800c6c60ff4e218321c674fea
This commit is contained in:
Yaroslav Vorobiov 2020-04-01 14:59:34 +03:00 committed by Ivan Fraixedes
parent bb28851964
commit 516b8cf2be
12 changed files with 733 additions and 412 deletions

View File

@ -0,0 +1,114 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package main
import (
"archive/zip"
"bufio"
"bytes"
"context"
"io"
"io/ioutil"
"net/http"
"os"
"os/exec"
"strings"
"github.com/zeebo/errs"
"go.uber.org/zap"
"storj.io/common/sync2"
"storj.io/private/version"
)
func binaryVersion(location string) (version.SemVer, error) {
out, err := exec.Command(location, "version").CombinedOutput()
if err != nil {
zap.L().Info("Command output.", zap.ByteString("Output", out))
return version.SemVer{}, err
}
scanner := bufio.NewScanner(bytes.NewReader(out))
for scanner.Scan() {
line := scanner.Text()
prefix := "Version: "
if strings.HasPrefix(line, prefix) {
line = line[len(prefix):]
return version.NewSemVer(line)
}
}
return version.SemVer{}, errs.New("unable to determine binary version")
}
func downloadBinary(ctx context.Context, url, target string) error {
f, err := ioutil.TempFile("", createPattern(url))
if err != nil {
return errs.New("cannot create temporary archive: %v", err)
}
defer func() {
err = errs.Combine(err,
f.Close(),
os.Remove(f.Name()),
)
}()
zap.L().Info("Download started.", zap.String("From", url), zap.String("To", f.Name()))
if err = downloadArchive(ctx, f, url); err != nil {
return errs.Wrap(err)
}
if err = unpackBinary(ctx, f.Name(), target); err != nil {
return errs.Wrap(err)
}
zap.L().Info("Download finished.", zap.String("From", url), zap.String("To", f.Name()))
return nil
}
func downloadArchive(ctx context.Context, file io.Writer, url string) (err error) {
resp, err := http.Get(url)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, resp.Body.Close()) }()
if resp.StatusCode != http.StatusOK {
return errs.New("bad status: %s", resp.Status)
}
_, err = sync2.Copy(ctx, file, resp.Body)
return err
}
// unpackBinary unpack zip compressed binary.
func unpackBinary(ctx context.Context, archive, target string) (err error) {
zipReader, err := zip.OpenReader(archive)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, zipReader.Close()) }()
if len(zipReader.File) != 1 {
return errs.New("archive should contain only one file")
}
zipedExec, err := zipReader.File[0].Open()
if err != nil {
return err
}
defer func() { err = errs.Combine(err, zipedExec.Close()) }()
newExec, err := os.OpenFile(target, os.O_CREATE|os.O_EXCL|os.O_WRONLY, os.FileMode(0755))
if err != nil {
return err
}
defer func() { err = errs.Combine(err, newExec.Close()) }()
_, err = sync2.Copy(ctx, newExec, zipedExec)
if err != nil {
return errs.Combine(err, os.Remove(newExec.Name()))
}
return nil
}

View File

@ -0,0 +1,144 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package main
import (
"log"
"os"
"runtime"
"strings"
"time"
"github.com/spf13/cobra"
"go.uber.org/zap"
"storj.io/common/errs2"
"storj.io/common/fpath"
"storj.io/common/identity"
"storj.io/common/storj"
"storj.io/common/sync2"
"storj.io/private/cfgstruct"
"storj.io/private/process"
_ "storj.io/storj/private/version" // This attaches version information during release builds.
"storj.io/storj/private/version/checker"
)
const (
updaterServiceName = "storagenode-updater"
minCheckInterval = time.Minute
)
var (
// TODO: replace with config value of random bytes in storagenode config.
nodeID storj.NodeID
rootCmd = &cobra.Command{
Use: "storagenode-updater",
Short: "Version updater for storage node",
}
runCmd = &cobra.Command{
Use: "run",
Short: "Run the storagenode-updater for storage node",
Args: cobra.OnlyValidArgs,
RunE: cmdRun,
}
restartCmd = &cobra.Command{
Use: "restart-service <new binary path>",
Short: "Restart service with the new binary",
Args: cobra.ExactArgs(1),
RunE: cmdRestart,
}
runCfg struct {
checker.Config
Identity identity.Config
BinaryLocation string `help:"the storage node executable binary location" default:"storagenode.exe"`
ServiceName string `help:"storage node OS service name" default:"storagenode"`
// deprecated
Log string `help:"deprecated, use --log.output" default:""`
}
confDir string
identityDir string
)
func init() {
defaults := cfgstruct.DefaultsFlag(rootCmd)
defaultConfDir := fpath.ApplicationDir("storj", "storagenode")
defaultIdentityDir := fpath.ApplicationDir("storj", "identity", "storagenode")
cfgstruct.SetupFlag(zap.L(), rootCmd, &confDir, "config-dir", defaultConfDir, "main directory for storagenode configuration")
cfgstruct.SetupFlag(zap.L(), rootCmd, &identityDir, "identity-dir", defaultIdentityDir, "main directory for storagenode identity credentials")
rootCmd.AddCommand(runCmd)
rootCmd.AddCommand(restartCmd)
process.Bind(runCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
process.Bind(restartCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
}
func cmdRun(cmd *cobra.Command, args []string) (err error) {
if runCfg.Log != "" {
if err = openLog(runCfg.Log); err != nil {
zap.L().Error("Error creating new logger.", zap.Error(err))
}
}
if !fileExists(runCfg.BinaryLocation) {
zap.L().Fatal("Unable to find storage node executable binary.")
}
ident, err := runCfg.Identity.Load()
if err != nil {
zap.L().Fatal("Error loading identity.", zap.Error(err))
}
nodeID = ident.ID
if nodeID.IsZero() {
zap.L().Fatal("Empty node ID.")
}
ctx, _ := process.Ctx(cmd)
switch {
case runCfg.CheckInterval <= 0:
err = loopFunc(ctx)
case runCfg.CheckInterval < minCheckInterval:
zap.L().Error("Check interval below minimum. Overriding it minimum.",
zap.Stringer("Check Interval", runCfg.CheckInterval),
zap.Stringer("Minimum Check Interval", minCheckInterval),
)
runCfg.CheckInterval = minCheckInterval
fallthrough
default:
loop := sync2.NewCycle(runCfg.CheckInterval)
err = loop.Run(ctx, loopFunc)
}
if err != nil && !errs2.IsCanceled(err) {
log.Fatal(err)
}
return nil
}
func fileExists(filename string) bool {
info, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return info.Mode().IsRegular()
}
func openLog(logPath string) error {
if runtime.GOOS == "windows" && !strings.HasPrefix(logPath, "winfile:///") {
logPath = "winfile:///" + logPath
}
logger, err := process.NewLoggerWithOutputPaths(logPath)
if err != nil {
return err
}
zap.ReplaceGlobals(logger)
return nil
}

View File

@ -0,0 +1,39 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
// +build unittest !windows
package main
import (
"context"
"os"
"go.uber.org/zap"
"storj.io/storj/private/version/checker"
)
// loopFunc is func that is run by the update cycle.
func loopFunc(ctx context.Context) error {
zap.L().Info("Downloading versions.", zap.String("Server Address", runCfg.ServerAddress))
all, err := checker.New(runCfg.ClientConfig).All(ctx)
if err != nil {
zap.L().Error("Error retrieving version info.", zap.Error(err))
return nil
}
if err := update(ctx, runCfg.ServiceName, runCfg.BinaryLocation, all.Processes.Storagenode); err != nil {
// don't finish loop in case of error just wait for another execution
zap.L().Error("Error updating service.", zap.String("Service", runCfg.ServiceName), zap.Error(err))
}
updaterBinName := os.Args[0]
if err := update(ctx, updaterServiceName, updaterBinName, all.Processes.StoragenodeUpdater); err != nil {
// don't finish loop in case of error just wait for another execution
zap.L().Error("Error updating service.", zap.String("Service", updaterServiceName), zap.Error(err))
}
return nil
}

View File

@ -0,0 +1,94 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
// +build windows,!unittest
package main
import (
"context"
"os"
"os/exec"
"github.com/zeebo/errs"
"go.uber.org/zap"
"storj.io/private/version"
"storj.io/storj/private/version/checker"
)
// loopFunc is func that is run by the update cycle.
func loopFunc(ctx context.Context) error {
zap.L().Info("Downloading versions.", zap.String("Server Address", runCfg.ServerAddress))
all, err := checker.New(runCfg.ClientConfig).All(ctx)
if err != nil {
zap.L().Error("Error retrieving version info.", zap.Error(err))
return nil
}
if err := update(ctx, runCfg.ServiceName, runCfg.BinaryLocation, all.Processes.Storagenode); err != nil {
// don't finish loop in case of error just wait for another execution
zap.L().Error("Error updating service.", zap.String("Service", runCfg.ServiceName), zap.Error(err))
}
updaterBinName := os.Args[0]
if err := updateSelf(ctx, updaterBinName, all.Processes.StoragenodeUpdater); err != nil {
// don't finish loop in case of error just wait for another execution
zap.L().Error("Error updating service.", zap.String("Service", updaterServiceName), zap.Error(err))
}
return nil
}
func updateSelf(ctx context.Context, binaryLocation string, ver version.Process) error {
suggestedVersion, err := ver.Suggested.SemVer()
if err != nil {
return errs.Wrap(err)
}
currentVersion := version.Build.Version
// should update
if currentVersion.Compare(suggestedVersion) >= 0 {
zap.L().Info("Version is up to date.", zap.String("Service", updaterServiceName))
return nil
}
if !version.ShouldUpdate(ver.Rollout, nodeID) {
zap.L().Info("New version available but not rolled out to this nodeID yet", zap.String("Service", updaterServiceName))
return nil
}
newVersionPath := prependExtension(binaryLocation, ver.Suggested.Version)
if err = downloadBinary(ctx, parseDownloadURL(ver.Suggested.URL), newVersionPath); err != nil {
return errs.Wrap(err)
}
downloadedVersion, err := binaryVersion(newVersionPath)
if err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
if suggestedVersion.Compare(downloadedVersion) != 0 {
err := errs.New("invalid version downloaded: wants %s got %s",
suggestedVersion.String(),
downloadedVersion.String(),
)
return errs.Combine(err, os.Remove(newVersionPath))
}
zap.L().Info("Restarting service.", zap.String("Service", updaterServiceName))
return restartSelf(binaryLocation, newVersionPath)
}
func restartSelf(bin, newbin string) error {
args := []string{
"restart-service",
"--binary-location", bin,
"--service-name", updaterServiceName,
newbin,
}
return exec.Command(bin, args...).Start()
}

View File

@ -1,363 +1,11 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// +build !windows
package main
import (
"archive/zip"
"bufio"
"bytes"
"context"
"errors"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
"github.com/zeebo/errs"
"go.uber.org/zap"
"storj.io/common/errs2"
"storj.io/common/fpath"
"storj.io/common/identity"
"storj.io/common/storj"
"storj.io/common/sync2"
"storj.io/private/cfgstruct"
"storj.io/private/process"
"storj.io/private/version"
_ "storj.io/storj/private/version" // This attaches version information during release builds.
"storj.io/storj/private/version/checker"
)
const (
updaterServiceName = "storagenode-updater"
minCheckInterval = time.Minute
)
var (
cancel context.CancelFunc
// TODO: replace with config value of random bytes in storagenode config.
nodeID storj.NodeID
rootCmd = &cobra.Command{
Use: "storagenode-updater",
Short: "Version updater for storage node",
}
runCmd = &cobra.Command{
Use: "run",
Short: "Run the storagenode-updater for storage node",
Args: cobra.OnlyValidArgs,
RunE: cmdRun,
}
runCfg struct {
// TODO: check interval default has changed from 6 hours to 15 min.
checker.Config
Identity identity.Config
BinaryLocation string `help:"the storage node executable binary location" default:"storagenode.exe"`
ServiceName string `help:"storage node OS service name" default:"storagenode"`
// deprecated
Log string `help:"deprecated, use --log.output" default:""`
}
confDir string
identityDir string
)
func init() {
// TODO: this will probably generate warnings for mismatched config fields.
defaultConfDir := fpath.ApplicationDir("storj", "storagenode")
defaultIdentityDir := fpath.ApplicationDir("storj", "identity", "storagenode")
cfgstruct.SetupFlag(zap.L(), rootCmd, &confDir, "config-dir", defaultConfDir, "main directory for storagenode configuration")
cfgstruct.SetupFlag(zap.L(), rootCmd, &identityDir, "identity-dir", defaultIdentityDir, "main directory for storagenode identity credentials")
defaults := cfgstruct.DefaultsFlag(rootCmd)
rootCmd.AddCommand(runCmd)
process.Bind(runCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
}
func cmdRun(cmd *cobra.Command, args []string) (err error) {
err = openLog()
if err != nil {
zap.L().Error("Error creating new logger.", zap.Error(err))
}
if !fileExists(runCfg.BinaryLocation) {
zap.L().Fatal("Unable to find storage node executable binary.")
}
ident, err := runCfg.Identity.Load()
if err != nil {
zap.L().Fatal("Error loading identity.", zap.Error(err))
}
nodeID = ident.ID
if nodeID.IsZero() {
zap.L().Fatal("Empty node ID.")
}
var ctx context.Context
ctx, cancel = process.Ctx(cmd)
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-c
signal.Stop(c)
cancel()
}()
loopFunc := func(ctx context.Context) (err error) {
if err := update(ctx, runCfg.BinaryLocation, runCfg.ServiceName); err != nil {
// don't finish loop in case of error just wait for another execution
zap.L().Error("Error updating service.", zap.String("Service", runCfg.ServiceName), zap.Error(err))
}
updaterBinName := os.Args[0]
if err := update(ctx, updaterBinName, updaterServiceName); err != nil {
// don't finish loop in case of error just wait for another execution
zap.L().Error("Error updating service.", zap.String("Service", updaterServiceName), zap.Error(err))
}
return nil
}
switch {
case runCfg.CheckInterval <= 0:
err = loopFunc(ctx)
case runCfg.CheckInterval < minCheckInterval:
zap.L().Error("Check interval below minimum. Overriding it minimum.",
zap.Stringer("Check Interval", runCfg.CheckInterval),
zap.Stringer("Minimum Check Interval", minCheckInterval),
)
runCfg.CheckInterval = minCheckInterval
fallthrough
default:
loop := sync2.NewCycle(runCfg.CheckInterval)
err = loop.Run(ctx, loopFunc)
}
if err != nil && !errs2.IsCanceled(err) {
log.Fatal(err)
}
return nil
}
func update(ctx context.Context, binPath, serviceName string) (err error) {
if nodeID.IsZero() {
zap.L().Fatal("Empty node ID.")
}
var currentVersion version.SemVer
if serviceName == updaterServiceName {
// TODO: find better way to check this binary version
currentVersion = version.Build.Version
} else {
currentVersion, err = binaryVersion(binPath)
if err != nil {
return errs.Wrap(err)
}
}
client := checker.New(runCfg.ClientConfig)
zap.L().Info("Downloading versions.", zap.String("Server Address", runCfg.ServerAddress))
processVersion, err := client.Process(ctx, serviceName)
if err != nil {
return errs.Wrap(err)
}
// TODO: consolidate semver.Version and version.SemVer
suggestedVersion, err := processVersion.Suggested.SemVer()
if err != nil {
return errs.Wrap(err)
}
if currentVersion.Compare(suggestedVersion) >= 0 {
zap.L().Info("Version is up to date.", zap.String("Service", serviceName))
return nil
}
if !version.ShouldUpdate(processVersion.Rollout, nodeID) {
zap.L().Info("New version available but not rolled out to this nodeID yet", zap.String("Service", serviceName))
return nil
}
tempArchive, err := ioutil.TempFile("", serviceName)
if err != nil {
return errs.New("cannot create temporary archive: %v", err)
}
defer func() {
err = errs.Combine(err,
tempArchive.Close(),
os.Remove(tempArchive.Name()),
)
}()
downloadURL := parseDownloadURL(processVersion.Suggested.URL)
zap.L().Info("Download started.", zap.String("From", downloadURL), zap.String("To", tempArchive.Name()))
err = downloadArchive(ctx, tempArchive, downloadURL)
if err != nil {
return errs.Wrap(err)
}
zap.L().Info("Download finished.", zap.String("From", downloadURL), zap.String("To", tempArchive.Name()))
newVersionPath := prependExtension(binPath, suggestedVersion.String())
err = unpackBinary(ctx, tempArchive.Name(), newVersionPath)
if err != nil {
return errs.Wrap(err)
}
// TODO add here recovery even before starting service (if version command cannot be executed)
downloadedVersion, err := binaryVersion(newVersionPath)
if err != nil {
return errs.Wrap(err)
}
if suggestedVersion.Compare(downloadedVersion) != 0 {
return errs.New("invalid version downloaded: wants %s got %s", suggestedVersion.String(), downloadedVersion.String())
}
// backup original binary
var backupPath string
if serviceName == updaterServiceName {
// NB: don't include old version number for updater binary backup
backupPath = prependExtension(binPath, "old")
} else {
backupPath = prependExtension(binPath, "old."+currentVersion.String())
}
if err := os.Rename(binPath, backupPath); err != nil {
return errs.Wrap(err)
}
// rename new binary to replace original
if err := os.Rename(newVersionPath, binPath); err != nil {
return errs.Wrap(err)
}
zap.L().Info("Restarting service.", zap.String("Service", serviceName))
err = restartService(serviceName)
if err != nil {
// TODO: should we try to recover from this?
return errs.New("Unable to restart service: %v", err)
}
zap.L().Info("Service restarted successfully.", zap.String("Service", serviceName))
// TODO remove old binary ??
return nil
}
func prependExtension(path, ext string) string {
originalExt := filepath.Ext(path)
dir, base := filepath.Split(path)
base = base[:len(base)-len(originalExt)]
return filepath.Join(dir, base+"."+ext+originalExt)
}
func parseDownloadURL(template string) string {
url := strings.Replace(template, "{os}", runtime.GOOS, 1)
url = strings.Replace(url, "{arch}", runtime.GOARCH, 1)
return url
}
func binaryVersion(location string) (version.SemVer, error) {
out, err := exec.Command(location, "version").CombinedOutput()
if err != nil {
zap.L().Info("Command output.", zap.ByteString("Output", out))
return version.SemVer{}, err
}
scanner := bufio.NewScanner(bytes.NewReader(out))
for scanner.Scan() {
line := scanner.Text()
prefix := "Version: "
if strings.HasPrefix(line, prefix) {
line = line[len(prefix):]
return version.NewSemVer(line)
}
}
return version.SemVer{}, errs.New("unable to determine binary version")
}
func downloadArchive(ctx context.Context, file io.Writer, url string) (err error) {
resp, err := http.Get(url)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, resp.Body.Close()) }()
if resp.StatusCode != http.StatusOK {
return errs.New("bad status: %s", resp.Status)
}
_, err = sync2.Copy(ctx, file, resp.Body)
return err
}
func unpackBinary(ctx context.Context, archive, target string) (err error) {
// TODO support different compression types e.g. tar.gz
zipReader, err := zip.OpenReader(archive)
if err != nil {
return err
}
defer func() { err = errs.Combine(err, zipReader.Close()) }()
if len(zipReader.File) != 1 {
return errors.New("archive should contain only binary file")
}
zipedExec, err := zipReader.File[0].Open()
if err != nil {
return err
}
defer func() { err = errs.Combine(err, zipedExec.Close()) }()
newExec, err := os.OpenFile(target, os.O_CREATE|os.O_EXCL|os.O_WRONLY, os.FileMode(0755))
if err != nil {
return err
}
defer func() { err = errs.Combine(err, newExec.Close()) }()
_, err = sync2.Copy(ctx, newExec, zipedExec)
if err != nil {
return errs.Combine(err, os.Remove(newExec.Name()))
}
return nil
}
func fileExists(filename string) bool {
info, err := os.Stat(filename)
if os.IsNotExist(err) {
return false
}
return info.Mode().IsRegular()
}
func openLog() error {
if runCfg.Log != "" {
logPath := runCfg.Log
if runtime.GOOS == "windows" && !strings.HasPrefix(logPath, "winfile:///") {
logPath = "winfile:///" + logPath
}
logger, err := process.NewLoggerWithOutputPaths(logPath)
if err != nil {
return err
}
zap.ReplaceGlobals(logger)
}
return nil
}
import "storj.io/private/process"
func main() {
process.Exec(rootCmd)

View File

@ -7,7 +7,7 @@
//
// sc.exe create storagenode-updater binpath= "C:\Users\MyUser\storagenode-updater.exe run ..."
// +build windows,!unittest
// +build windows
package main
@ -22,33 +22,25 @@ import (
"storj.io/private/process"
)
func init() {
// Check if session is interactive
interactive, err := svc.IsAnInteractiveSession()
func isRunCmd() bool {
return len(os.Args) > 1 && os.Args[1] == "run"
}
func main() {
isInteractive, err := svc.IsAnInteractiveSession()
if err != nil {
zap.L().Fatal("Failed to determine if session is interactive.", zap.Error(err))
}
if interactive {
if isInteractive || !isRunCmd() {
process.Exec(rootCmd)
return
}
// Check if the 'run' command is invoked
if len(os.Args) < 2 {
return
}
if os.Args[1] != "run" {
return
}
// Initialize the Windows Service handler
err = svc.Run("storagenode-updater", &service{})
if err != nil {
zap.L().Fatal("Service failed.", zap.Error(err))
}
// avoid starting main() when service was stopped
os.Exit(0)
}
type service struct{}
@ -87,5 +79,6 @@ func (m *service) Execute(args []string, r <-chan svc.ChangeRequest, changes cha
zap.L().Info("Unexpected control request.", zap.Uint32("Event Type", c.EventType))
}
}
return
}

View File

@ -0,0 +1,33 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package main
import (
"path"
"path/filepath"
"runtime"
"strings"
)
func prependExtension(path, ext string) string {
originalExt := filepath.Ext(path)
dir, base := filepath.Split(path)
base = base[:len(base)-len(originalExt)]
return filepath.Join(dir, base+"."+ext+originalExt)
}
func parseDownloadURL(template string) string {
url := strings.Replace(template, "{os}", runtime.GOOS, 1)
url = strings.Replace(url, "{arch}", runtime.GOARCH, 1)
return url
}
func createPattern(url string) string {
_, binary := path.Split(url)
if ext := path.Ext(binary); ext != "" {
return binary[:len(binary)-len(ext)] + ".*" + ext
}
return binary + ".*"
}

View File

@ -1,45 +1,30 @@
// Copyright (C) 2019 Storj Labs, Inc.
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
// +build !unittest
// +build unittest !windows
package main
import (
"fmt"
"context"
"os"
"os/exec"
"path/filepath"
"runtime"
"github.com/spf13/cobra"
"github.com/zeebo/errs"
)
func restartService(name string) error {
switch runtime.GOOS {
case "windows":
// TODO: cleanup temp .bat file
restartSvcBatPath := filepath.Join(os.TempDir(), "restartservice.bat")
restartSvcBat, err := os.Create(restartSvcBatPath)
if err != nil {
return err
}
restartStr := fmt.Sprintf("net stop %s && net start %s", name, name)
_, err = restartSvcBat.WriteString(restartStr)
if err != nil {
return err
}
if err := restartSvcBat.Close(); err != nil {
return err
}
out, err := exec.Command(restartSvcBat.Name()).CombinedOutput()
if err != nil {
return errs.New("%s", string(out))
}
default:
return nil
}
func cmdRestart(cmd *cobra.Command, args []string) error {
return nil
}
func restartService(ctx context.Context, service, binaryLocation, newVersionPath, backupPath string) error {
if err := os.Rename(binaryLocation, backupPath); err != nil {
return errs.Wrap(err)
}
if err := os.Rename(newVersionPath, binaryLocation); err != nil {
return errs.Combine(err, os.Rename(backupPath, binaryLocation), os.Remove(newVersionPath))
}
return nil
}

View File

@ -1,8 +0,0 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// +build unittest
package main
func restartService(name string) error { return nil }

View File

@ -0,0 +1,202 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
// +build windows,!unittest
package main
import (
"context"
"math"
"os"
"time"
"github.com/spf13/cobra"
"github.com/zeebo/errs"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
"storj.io/common/sync2"
"storj.io/private/process"
)
var unrecoverableErr = errs.Class("unable to recoverrecover binary from backup")
func cmdRestart(cmd *cobra.Command, args []string) (err error) {
ctx, _ := process.Ctx(cmd)
currentVersion, err := binaryVersion(runCfg.BinaryLocation)
if err != nil {
return errs.Wrap(err)
}
newVersionPath := args[0]
var backupPath string
if runCfg.ServiceName == updaterServiceName {
// NB: don't include old version number for updater binary backup
backupPath = prependExtension(runCfg.BinaryLocation, "old")
} else {
backupPath = prependExtension(runCfg.BinaryLocation, "old."+currentVersion.String())
}
// check if new binary exists
if _, err := os.Stat(newVersionPath); err != nil {
return errs.Wrap(err)
}
return restartService(ctx, runCfg.ServiceName, runCfg.BinaryLocation, newVersionPath, backupPath)
}
func restartService(ctx context.Context, service, binaryLocation, newVersionPath, backupPath string) (err error) {
srvc, err := openService(service)
if err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
defer func() {
err = errs.Combine(err, errs.Wrap(srvc.Close()))
}()
status, err := srvc.Query()
if err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
// stop service if it's not stopped
if status.State != svc.Stopped && status.State != svc.StopPending {
if err = serviceControl(srvc, ctx, svc.Stop, svc.Stopped, 10*time.Second); err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
// if it is stopping wait for it to complete
} else if status.State == svc.StopPending {
if err = serviceWaitForState(srvc, ctx, svc.Stopped, 10*time.Second); err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
}
err = func() error {
if err := os.Rename(binaryLocation, backupPath); err != nil {
return errs.Combine(err, srvc.Start())
}
if err := os.Rename(newVersionPath, binaryLocation); err != nil {
if rerr := os.Rename(backupPath, binaryLocation); rerr != nil {
// unrecoverable error
return unrecoverableErr.Wrap(errs.Combine(err, rerr))
}
return errs.Combine(err, srvc.Start())
}
return nil
}()
if err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
// successfully substituted binaries
err = retry(ctx, 2,
func() error {
return srvc.Start()
},
)
// if fail to start the service, try again with backup
if err != nil {
if rerr := os.Rename(backupPath, binaryLocation); rerr != nil {
// unrecoverable error
return unrecoverableErr.Wrap(errs.Combine(err, rerr))
}
return errs.Combine(err, srvc.Start())
}
return nil
}
func openService(name string) (_ *mgr.Service, err error) {
manager, err := mgr.Connect()
if err != nil {
return nil, errs.Wrap(err)
}
defer func() {
err = errs.Combine(err, errs.Wrap(manager.Disconnect()))
}()
service, err := manager.OpenService(name)
if err != nil {
return nil, errs.Wrap(err)
}
return service, nil
}
func serviceControl(service *mgr.Service, ctx context.Context, cmd svc.Cmd, state svc.State, delay time.Duration) error {
status, err := service.Control(cmd)
if err != nil {
return err
}
timeout := time.Now().Add(delay)
for status.State != state {
if err := ctx.Err(); err != nil {
return err
}
if timeout.Before(time.Now()) {
return errs.New("timeout")
}
status, err = service.Query()
if err != nil {
return err
}
}
return nil
}
func serviceWaitForState(service *mgr.Service, ctx context.Context, state svc.State, delay time.Duration) error {
status, err := service.Query()
if err != nil {
return err
}
timeout := time.Now().Add(delay)
for status.State != state {
if err := ctx.Err(); err != nil {
return err
}
if timeout.Before(time.Now()) {
return errs.New("timeout")
}
status, err = service.Query()
if err != nil {
return err
}
}
return nil
}
func retry(ctx context.Context, count int, cb func() error) error {
var err error
if err = cb(); err == nil {
return nil
}
for i := 1; i < count; i++ {
delay := time.Duration(math.Pow10(i))
if !sync2.Sleep(ctx, delay*time.Second) {
return ctx.Err()
}
if err = cb(); err == nil {
return nil
}
}
return err
}

View File

@ -0,0 +1,77 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package main
import (
"context"
"os"
"github.com/zeebo/errs"
"go.uber.org/zap"
"storj.io/private/version"
)
func update(ctx context.Context, serviceName, binaryLocation string, ver version.Process) error {
suggestedVersion, err := ver.Suggested.SemVer()
if err != nil {
return errs.Wrap(err)
}
var currentVersion version.SemVer
if serviceName == updaterServiceName {
currentVersion = version.Build.Version
} else {
currentVersion, err = binaryVersion(binaryLocation)
if err != nil {
return errs.Wrap(err)
}
}
// should update
if currentVersion.Compare(suggestedVersion) >= 0 {
zap.L().Info("Version is up to date.", zap.String("Service", serviceName))
return nil
}
if !version.ShouldUpdate(ver.Rollout, nodeID) {
zap.L().Info("New version available but not rolled out to this nodeID yet", zap.String("Service", serviceName))
return nil
}
newVersionPath := prependExtension(binaryLocation, ver.Suggested.Version)
if err = downloadBinary(ctx, parseDownloadURL(ver.Suggested.URL), newVersionPath); err != nil {
return errs.Wrap(err)
}
downloadedVersion, err := binaryVersion(newVersionPath)
if err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
if suggestedVersion.Compare(downloadedVersion) != 0 {
err := errs.New("invalid version downloaded: wants %s got %s",
suggestedVersion.String(),
downloadedVersion.String(),
)
return errs.Combine(err, os.Remove(newVersionPath))
}
var backupPath string
if serviceName == updaterServiceName {
// NB: don't include old version number for updater binary backup
backupPath = prependExtension(binaryLocation, "old")
} else {
backupPath = prependExtension(binaryLocation, "old."+currentVersion.String())
}
zap.L().Info("Restarting service.", zap.String("Service", serviceName))
if err = restartService(ctx, serviceName, binaryLocation, newVersionPath, backupPath); err != nil {
return errs.Wrap(err)
}
zap.L().Info("Service restarted successfully.", zap.String("Service", serviceName))
return nil
}