storj/cmd/storagenode-updater/restart_windows.go

204 lines
4.5 KiB
Go
Raw Normal View History

// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
//go:build windows && service
// +build windows,service
package main
import (
"context"
"math"
"os"
"time"
"github.com/spf13/cobra"
"github.com/zeebo/errs"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
"storj.io/common/sync2"
"storj.io/private/process"
)
var unrecoverableErr = errs.Class("unable to recover binary from backup")
func cmdRestart(cmd *cobra.Command, args []string) (err error) {
ctx, _ := process.Ctx(cmd)
currentVersion, err := binaryVersion(runCfg.BinaryLocation)
if err != nil {
return errs.Wrap(err)
}
newVersionPath := args[0]
var backupPath string
if runCfg.ServiceName == updaterServiceName {
// NB: don't include old version number for updater binary backup
backupPath = prependExtension(runCfg.BinaryLocation, "old")
} else {
backupPath = prependExtension(runCfg.BinaryLocation, "old."+currentVersion.String())
}
// check if new binary exists
if _, err := os.Stat(newVersionPath); err != nil {
return errs.Wrap(err)
}
return restartService(ctx, runCfg.ServiceName, runCfg.BinaryLocation, newVersionPath, backupPath)
}
func restartService(ctx context.Context, service, binaryLocation, newVersionPath, backupPath string) (err error) {
srvc, err := openService(service)
if err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
defer func() {
err = errs.Combine(err, errs.Wrap(srvc.Close()))
}()
status, err := srvc.Query()
if err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
// stop service if it's not stopped
if status.State != svc.Stopped && status.State != svc.StopPending {
if err = serviceControl(ctx, srvc, svc.Stop, svc.Stopped, 10*time.Second); err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
// if it is stopping wait for it to complete
} else if status.State == svc.StopPending {
if err = serviceWaitForState(ctx, srvc, svc.Stopped, 10*time.Second); err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
}
err = func() error {
if err := os.Rename(binaryLocation, backupPath); err != nil {
return errs.Combine(err, srvc.Start())
}
if err := os.Rename(newVersionPath, binaryLocation); err != nil {
if rerr := os.Rename(backupPath, binaryLocation); rerr != nil {
// unrecoverable error
return unrecoverableErr.Wrap(errs.Combine(err, rerr))
}
return errs.Combine(err, srvc.Start())
}
return nil
}()
if err != nil {
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
}
// successfully substituted binaries
err = retry(ctx, 2,
func() error {
return srvc.Start()
},
)
// if fail to start the service, try again with backup
if err != nil {
if rerr := os.Rename(backupPath, binaryLocation); rerr != nil {
// unrecoverable error
return unrecoverableErr.Wrap(errs.Combine(err, rerr))
}
return errs.Combine(err, srvc.Start())
}
return nil
}
func openService(name string) (_ *mgr.Service, err error) {
manager, err := mgr.Connect()
if err != nil {
return nil, errs.Wrap(err)
}
defer func() {
err = errs.Combine(err, errs.Wrap(manager.Disconnect()))
}()
service, err := manager.OpenService(name)
if err != nil {
return nil, errs.Wrap(err)
}
return service, nil
}
func serviceControl(ctx context.Context, service *mgr.Service, cmd svc.Cmd, state svc.State, delay time.Duration) error {
status, err := service.Control(cmd)
if err != nil {
return err
}
timeout := time.Now().Add(delay)
for status.State != state {
if err := ctx.Err(); err != nil {
return err
}
if timeout.Before(time.Now()) {
return errs.New("timeout")
}
status, err = service.Query()
if err != nil {
return err
}
}
return nil
}
func serviceWaitForState(ctx context.Context, service *mgr.Service, state svc.State, delay time.Duration) error {
status, err := service.Query()
if err != nil {
return err
}
timeout := time.Now().Add(delay)
for status.State != state {
if err := ctx.Err(); err != nil {
return err
}
if timeout.Before(time.Now()) {
return errs.New("timeout")
}
status, err = service.Query()
if err != nil {
return err
}
}
return nil
}
func retry(ctx context.Context, count int, cb func() error) error {
var err error
if err = cb(); err == nil {
return nil
}
for i := 1; i < count; i++ {
delay := time.Duration(math.Pow10(i))
if !sync2.Sleep(ctx, delay*time.Second) {
return ctx.Err()
}
if err = cb(); err == nil {
return nil
}
}
return err
}