0d2d59f884
Change-Id: Idfc93948e59a181321d79b365e638d63e256a16f
204 lines
4.5 KiB
Go
204 lines
4.5 KiB
Go
// Copyright (C) 2020 Storj Labs, Inc.
|
|
// See LICENSE for copying information.
|
|
|
|
//go:build windows && service
|
|
// +build windows,service
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"math"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/spf13/cobra"
|
|
"github.com/zeebo/errs"
|
|
"golang.org/x/sys/windows/svc"
|
|
"golang.org/x/sys/windows/svc/mgr"
|
|
|
|
"storj.io/common/sync2"
|
|
"storj.io/private/process"
|
|
)
|
|
|
|
var unrecoverableErr = errs.Class("unable to recover binary from backup")
|
|
|
|
func cmdRestart(cmd *cobra.Command, args []string) (err error) {
|
|
ctx, _ := process.Ctx(cmd)
|
|
|
|
currentVersion, err := binaryVersion(runCfg.BinaryLocation)
|
|
if err != nil {
|
|
return errs.Wrap(err)
|
|
}
|
|
|
|
newVersionPath := args[0]
|
|
|
|
var backupPath string
|
|
if runCfg.ServiceName == updaterServiceName {
|
|
// NB: don't include old version number for updater binary backup
|
|
backupPath = prependExtension(runCfg.BinaryLocation, "old")
|
|
} else {
|
|
backupPath = prependExtension(runCfg.BinaryLocation, "old."+currentVersion.String())
|
|
}
|
|
|
|
// check if new binary exists
|
|
if _, err := os.Stat(newVersionPath); err != nil {
|
|
return errs.Wrap(err)
|
|
}
|
|
|
|
return restartService(ctx, runCfg.ServiceName, runCfg.BinaryLocation, newVersionPath, backupPath)
|
|
}
|
|
|
|
func restartService(ctx context.Context, service, binaryLocation, newVersionPath, backupPath string) (err error) {
|
|
srvc, err := openService(service)
|
|
if err != nil {
|
|
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
|
|
}
|
|
defer func() {
|
|
err = errs.Combine(err, errs.Wrap(srvc.Close()))
|
|
}()
|
|
|
|
status, err := srvc.Query()
|
|
if err != nil {
|
|
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
|
|
}
|
|
|
|
// stop service if it's not stopped
|
|
if status.State != svc.Stopped && status.State != svc.StopPending {
|
|
if err = serviceControl(ctx, srvc, svc.Stop, svc.Stopped, 10*time.Second); err != nil {
|
|
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
|
|
}
|
|
// if it is stopping wait for it to complete
|
|
} else if status.State == svc.StopPending {
|
|
if err = serviceWaitForState(ctx, srvc, svc.Stopped, 10*time.Second); err != nil {
|
|
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
|
|
}
|
|
}
|
|
|
|
err = func() error {
|
|
if err := os.Rename(binaryLocation, backupPath); err != nil {
|
|
return errs.Combine(err, srvc.Start())
|
|
}
|
|
|
|
if err := os.Rename(newVersionPath, binaryLocation); err != nil {
|
|
if rerr := os.Rename(backupPath, binaryLocation); rerr != nil {
|
|
// unrecoverable error
|
|
return unrecoverableErr.Wrap(errs.Combine(err, rerr))
|
|
}
|
|
|
|
return errs.Combine(err, srvc.Start())
|
|
}
|
|
|
|
return nil
|
|
}()
|
|
if err != nil {
|
|
return errs.Combine(errs.Wrap(err), os.Remove(newVersionPath))
|
|
}
|
|
|
|
// successfully substituted binaries
|
|
err = retry(ctx, 2,
|
|
func() error {
|
|
return srvc.Start()
|
|
},
|
|
)
|
|
// if fail to start the service, try again with backup
|
|
if err != nil {
|
|
if rerr := os.Rename(backupPath, binaryLocation); rerr != nil {
|
|
// unrecoverable error
|
|
return unrecoverableErr.Wrap(errs.Combine(err, rerr))
|
|
}
|
|
|
|
return errs.Combine(err, srvc.Start())
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func openService(name string) (_ *mgr.Service, err error) {
|
|
manager, err := mgr.Connect()
|
|
if err != nil {
|
|
return nil, errs.Wrap(err)
|
|
}
|
|
defer func() {
|
|
err = errs.Combine(err, errs.Wrap(manager.Disconnect()))
|
|
}()
|
|
|
|
service, err := manager.OpenService(name)
|
|
if err != nil {
|
|
return nil, errs.Wrap(err)
|
|
}
|
|
|
|
return service, nil
|
|
}
|
|
|
|
func serviceControl(ctx context.Context, service *mgr.Service, cmd svc.Cmd, state svc.State, delay time.Duration) error {
|
|
status, err := service.Control(cmd)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
timeout := time.Now().Add(delay)
|
|
|
|
for status.State != state {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
if timeout.Before(time.Now()) {
|
|
return errs.New("timeout")
|
|
}
|
|
|
|
status, err = service.Query()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func serviceWaitForState(ctx context.Context, service *mgr.Service, state svc.State, delay time.Duration) error {
|
|
status, err := service.Query()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
timeout := time.Now().Add(delay)
|
|
|
|
for status.State != state {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
if timeout.Before(time.Now()) {
|
|
return errs.New("timeout")
|
|
}
|
|
|
|
status, err = service.Query()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func retry(ctx context.Context, count int, cb func() error) error {
|
|
var err error
|
|
|
|
if err = cb(); err == nil {
|
|
return nil
|
|
}
|
|
|
|
for i := 1; i < count; i++ {
|
|
delay := time.Duration(math.Pow10(i))
|
|
|
|
if !sync2.Sleep(ctx, delay*time.Second) {
|
|
return ctx.Err()
|
|
}
|
|
if err = cb(); err == nil {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|