224 lines
5.9 KiB
Go
224 lines
5.9 KiB
Go
|
// Copyright (C) 2019 Storj Labs, Inc.
|
||
|
// See LICENSE for copying information.
|
||
|
|
||
|
package process
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"flag"
|
||
|
"io/ioutil"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"sort"
|
||
|
|
||
|
"github.com/spf13/cast"
|
||
|
"github.com/spf13/cobra"
|
||
|
"github.com/spf13/pflag"
|
||
|
"github.com/zeebo/errs"
|
||
|
yaml "gopkg.in/yaml.v2"
|
||
|
)
|
||
|
|
||
|
// SaveConfigOption is a function that updates the options for SaveConfig.
|
||
|
type SaveConfigOption func(*SaveConfigOptions)
|
||
|
|
||
|
// SaveConfigOptions controls the behavior of SaveConfig.
|
||
|
type SaveConfigOptions struct {
|
||
|
Overrides map[string]interface{}
|
||
|
RemoveDeprecated bool
|
||
|
}
|
||
|
|
||
|
// SaveConfigWithOverrides sets the overrides to the provided map.
|
||
|
func SaveConfigWithOverrides(overrides map[string]interface{}) SaveConfigOption {
|
||
|
return func(opts *SaveConfigOptions) {
|
||
|
opts.Overrides = overrides
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SaveConfigWithOverride adds a single override to SaveConfig.
|
||
|
func SaveConfigWithOverride(name string, value interface{}) SaveConfigOption {
|
||
|
return func(opts *SaveConfigOptions) {
|
||
|
if opts.Overrides == nil {
|
||
|
opts.Overrides = make(map[string]interface{})
|
||
|
}
|
||
|
opts.Overrides[name] = value
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SaveConfigRemovingDeprecated tells SaveConfig to not store deprecated flags.
|
||
|
func SaveConfigRemovingDeprecated() SaveConfigOption {
|
||
|
return func(opts *SaveConfigOptions) {
|
||
|
opts.RemoveDeprecated = true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SaveConfig will save only the user-specific flags with default values to
|
||
|
// outfile with specific values specified in 'overrides' overridden.
|
||
|
func SaveConfig(cmd *cobra.Command, outfile string, opts ...SaveConfigOption) error {
|
||
|
// step 0. apply any options to change the behavior
|
||
|
//
|
||
|
|
||
|
var options SaveConfigOptions
|
||
|
for _, opt := range opts {
|
||
|
opt(&options)
|
||
|
}
|
||
|
|
||
|
// step 1. load all of the configuration settings we are going to save
|
||
|
//
|
||
|
|
||
|
flags := cmd.Flags()
|
||
|
vip, err := Viper(cmd)
|
||
|
if err != nil {
|
||
|
return errs.Wrap(err)
|
||
|
}
|
||
|
if err := vip.MergeConfigMap(options.Overrides); err != nil {
|
||
|
return errs.Wrap(err)
|
||
|
}
|
||
|
settings := vip.AllSettings()
|
||
|
|
||
|
// step 2. construct some data describing what exactly we're saving to the
|
||
|
// config file, and how they're saved.
|
||
|
//
|
||
|
|
||
|
type configValue struct {
|
||
|
value interface{}
|
||
|
comment string
|
||
|
set bool
|
||
|
}
|
||
|
flat := make(map[string]configValue)
|
||
|
flatKeys := make([]string, 0)
|
||
|
|
||
|
// N.B. we have to pre-declare the function so that it can make recursive calls.
|
||
|
var filterAndFlatten func(string, map[string]interface{})
|
||
|
filterAndFlatten = func(base string, settings map[string]interface{}) {
|
||
|
for key, value := range settings {
|
||
|
if value, ok := value.(map[string]interface{}); ok {
|
||
|
filterAndFlatten(base+key+".", value)
|
||
|
continue
|
||
|
}
|
||
|
fullKey := base + key
|
||
|
|
||
|
// since this key can't affect anything from the config file and must be present
|
||
|
// on the command line, remove it so as to not mislead anyone
|
||
|
if fullKey == "defaults" {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// gather information about the flag under consideration
|
||
|
var (
|
||
|
changed bool
|
||
|
setup bool
|
||
|
hidden bool
|
||
|
user bool
|
||
|
deprecated bool
|
||
|
comment string
|
||
|
typ string
|
||
|
|
||
|
_, overrideExists = options.Overrides[fullKey]
|
||
|
)
|
||
|
if f := flags.Lookup(fullKey); f != nil { // first check pflags
|
||
|
changed = f.Changed
|
||
|
setup = readBoolAnnotation(f, "setup")
|
||
|
hidden = readBoolAnnotation(f, "hidden")
|
||
|
user = readBoolAnnotation(f, "user")
|
||
|
deprecated = readBoolAnnotation(f, "deprecated")
|
||
|
comment = f.Usage
|
||
|
typ = f.Value.Type()
|
||
|
} else if f := flag.Lookup(fullKey); f != nil { // then stdlib flags
|
||
|
changed = f.Value.String() != f.DefValue
|
||
|
comment = f.Usage
|
||
|
} else {
|
||
|
// by default we store config values we know nothing about. we
|
||
|
// absue the meaning of "changed" to include this case.
|
||
|
changed = true
|
||
|
}
|
||
|
|
||
|
// in any of these cases, don't store the key in the config file
|
||
|
if setup || hidden || options.RemoveDeprecated && deprecated {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// viper is super cool and doesn't cast floats automatically, so we
|
||
|
// handle that ourselves.
|
||
|
if typ == "float64" {
|
||
|
value = cast.ToFloat64(value)
|
||
|
}
|
||
|
|
||
|
flatKeys = append(flatKeys, fullKey)
|
||
|
flat[fullKey] = configValue{
|
||
|
value: value,
|
||
|
comment: comment,
|
||
|
set: user || changed || overrideExists,
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
filterAndFlatten("", settings)
|
||
|
sort.Strings(flatKeys)
|
||
|
|
||
|
// step 3. write out the configuration file
|
||
|
//
|
||
|
|
||
|
var nl = []byte("\n")
|
||
|
var lines [][]byte
|
||
|
for _, key := range flatKeys {
|
||
|
config := flat[key]
|
||
|
|
||
|
if config.comment != "" {
|
||
|
lines = append(lines, []byte("# "+config.comment))
|
||
|
}
|
||
|
|
||
|
data, err := yaml.Marshal(map[string]interface{}{key: config.value})
|
||
|
if err != nil {
|
||
|
return errs.Wrap(err)
|
||
|
}
|
||
|
dataLines := bytes.Split(bytes.TrimSpace(data), nl)
|
||
|
|
||
|
// if the config value is set, concat in the yaml lines
|
||
|
if config.set {
|
||
|
lines = append(lines, dataLines...)
|
||
|
} else {
|
||
|
// otherwise, add them in but commented out
|
||
|
for _, line := range dataLines {
|
||
|
lines = append(lines, append([]byte("# "), line...))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// add a blank line separator
|
||
|
lines = append(lines, nil)
|
||
|
}
|
||
|
|
||
|
return errs.Wrap(atomicWrite(outfile, 0600, bytes.Join(lines, nl)))
|
||
|
}
|
||
|
|
||
|
// readBoolAnnotation is a helper to see if a boolean annotation is set to true on the flag.
|
||
|
func readBoolAnnotation(flag *pflag.Flag, key string) bool {
|
||
|
annotation := flag.Annotations[key]
|
||
|
return len(annotation) > 0 && annotation[0] == "true"
|
||
|
}
|
||
|
|
||
|
// atomicWrite is a helper to atomically write the data to the outfile.
|
||
|
func atomicWrite(outfile string, mode os.FileMode, data []byte) (err error) {
|
||
|
fh, err := ioutil.TempFile(filepath.Dir(outfile), filepath.Base(outfile))
|
||
|
if err != nil {
|
||
|
return errs.Wrap(err)
|
||
|
}
|
||
|
defer func() {
|
||
|
if err != nil {
|
||
|
err = errs.Combine(err, fh.Close())
|
||
|
err = errs.Combine(err, os.Remove(fh.Name()))
|
||
|
}
|
||
|
}()
|
||
|
if _, err := fh.Write(data); err != nil {
|
||
|
return errs.Wrap(err)
|
||
|
}
|
||
|
if err := fh.Sync(); err != nil {
|
||
|
return errs.Wrap(err)
|
||
|
}
|
||
|
if err := fh.Close(); err != nil {
|
||
|
return errs.Wrap(err)
|
||
|
}
|
||
|
if err := os.Rename(fh.Name(), outfile); err != nil {
|
||
|
return errs.Wrap(err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|