diff --git a/internal/version/service.go b/internal/version/service.go index a3a2ef9ab..86b834552 100644 --- a/internal/version/service.go +++ b/internal/version/service.go @@ -128,6 +128,11 @@ func (srv *Service) checkVersion(ctx context.Context) (allowed bool) { return false } +// isAcceptedVersion compares and checks if the passed version is greater/equal than the minimum required version +func isAcceptedVersion(test SemVer, target SemVer) bool { + return test.Major > target.Major || (test.Major == target.Major && (test.Minor > target.Minor || (test.Minor == target.Minor && test.Patch >= target.Patch))) +} + // QueryVersionFromControlServer handles the HTTP request to gather the allowed and latest version information func (srv *Service) queryVersionFromControlServer(ctx context.Context) (ver AllowedVersions, err error) { defer mon.Task()(&ctx)(&err) diff --git a/internal/version/version.go b/internal/version/version.go index bb3746f2b..8d90f2ce1 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -4,6 +4,7 @@ package version import ( + "encoding/hex" "encoding/json" "fmt" "regexp" @@ -17,10 +18,19 @@ import ( "storj.io/storj/pkg/pb" ) +// semVerRegex is the regular expression used to parse a semantic version. +// https://github.com/Masterminds/semver/blob/master/LICENSE.txt +const ( + semVerRegex string = `v?([0-9]+)\.([0-9]+)\.([0-9]+)` + quote = byte('"') +) + var ( mon = monkit.Package() - verError = errs.Class("version error") + // VerError is the error class for version-related errors. + VerError = errs.Class("version error") + // the following fields are set by linker flags. if any of them // are set and fail to parse, the program will fail to start buildTimestamp string // unix seconds since epoch @@ -30,6 +40,8 @@ var ( // Build is a struct containing all relevant build information associated with the binary Build Info + + versionRegex = regexp.MustCompile("^" + semVerRegex + "$") ) // Info is the versioning information for a binary @@ -51,6 +63,7 @@ type SemVer struct { } // AllowedVersions provides the Minimum SemVer per Service +// TODO: I don't think this name is representative of what this struct now holds. type AllowedVersions struct { Satellite SemVer Storagenode SemVer @@ -62,6 +75,7 @@ type AllowedVersions struct { } // Processes describes versions for each binary. +// TODO: this name is inconsistent with the versioncontrol server pkg's analogue, `Versions`. type Processes struct { Satellite Process `json:"satellite"` Storagenode Process `json:"storagenode"` @@ -74,6 +88,7 @@ type Processes struct { type Process struct { Minimum Version `json:"minimum"` Suggested Version `json:"suggested"` + Rollout Rollout `json:"rollout"` } // Version represents version and download URL for binary. @@ -82,34 +97,54 @@ type Version struct { URL string `json:"url"` } -// SemVerRegex is the regular expression used to parse a semantic version. -// https://github.com/Masterminds/semver/blob/master/LICENSE.txt -const SemVerRegex string = `v?([0-9]+)\.([0-9]+)\.([0-9]+)` +// Rollout represents the state of a version rollout. +type Rollout struct { + Seed RolloutBytes `json:"seed"` + Cursor RolloutBytes `json:"cursor"` +} -var versionRegex = regexp.MustCompile("^" + SemVerRegex + "$") +// RolloutBytes implements json un/marshalling using hex de/encoding. +type RolloutBytes [32]byte + +// MarshalJSON hex-encodes RolloutBytes and pre/appends JSON string literal quotes. +func (rb RolloutBytes) MarshalJSON() ([]byte, error) { + hexBytes := make([]byte, hex.EncodedLen(len(rb))) + hex.Encode(hexBytes, rb[:]) + encoded := append([]byte{quote}, hexBytes...) + encoded = append(encoded, quote) + return encoded, nil +} + +// UnmarshalJSON drops the JSON string literal quotes and hex-decodes RolloutBytes . +func (rb *RolloutBytes) UnmarshalJSON(b []byte) error { + if _, err := hex.Decode(rb[:], b[1:len(b)-1]); err != nil { + return VerError.Wrap(err) + } + return nil +} // NewSemVer parses a given version and returns an instance of SemVer or // an error if unable to parse the version. func NewSemVer(v string) (sv SemVer, err error) { m := versionRegex.FindStringSubmatch(v) if m == nil { - return SemVer{}, verError.New("invalid semantic version for build %s", v) + return SemVer{}, VerError.New("invalid semantic version for build %s", v) } // first entry of m is the entire version string sv.Major, err = strconv.ParseInt(m[1], 10, 64) if err != nil { - return SemVer{}, err + return SemVer{}, VerError.Wrap(err) } sv.Minor, err = strconv.ParseInt(m[2], 10, 64) if err != nil { - return SemVer{}, err + return SemVer{}, VerError.Wrap(err) } sv.Patch, err = strconv.ParseInt(m[3], 10, 64) if err != nil { - return SemVer{}, err + return SemVer{}, VerError.Wrap(err) } return sv, nil @@ -146,13 +181,16 @@ func (sem *SemVer) String() (version string) { // New creates Version_Info from a json byte array func New(data []byte) (v Info, err error) { err = json.Unmarshal(data, &v) - return v, err + return v, VerError.Wrap(err) } // Marshal converts the existing Version Info to any json byte array -func (v Info) Marshal() (data []byte, err error) { - data, err = json.Marshal(v) - return +func (v Info) Marshal() ([]byte, error) { + data, err := json.Marshal(v) + if err != nil { + return nil, VerError.Wrap(err) + } + return data, nil } // Proto converts an Info struct to a pb.NodeVersion @@ -167,18 +205,13 @@ func (v Info) Proto() (*pb.NodeVersion, error) { }, nil } -// isAcceptedVersion compares and checks if the passed version is greater/equal than the minimum required version -func isAcceptedVersion(test SemVer, target SemVer) bool { - return test.Major > target.Major || (test.Major == target.Major && (test.Minor > target.Minor || (test.Minor == target.Minor && test.Patch >= target.Patch))) -} - func init() { if buildVersion == "" && buildTimestamp == "" && buildCommitHash == "" && buildRelease == "" { return } timestamp, err := strconv.ParseInt(buildTimestamp, 10, 64) if err != nil { - panic(verError.Wrap(err)) + panic(VerError.Wrap(err)) } Build = Info{ Timestamp: time.Unix(timestamp, 0), diff --git a/internal/version/version_test.go b/internal/version/version_test.go index 757bbb45a..1fe809ee5 100644 --- a/internal/version/version_test.go +++ b/internal/version/version_test.go @@ -4,6 +4,7 @@ package version_test import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" @@ -42,3 +43,26 @@ func TestSemVer_Compare(t *testing.T) { require.True(t, version030.Compare(version002) > 0) require.True(t, version600.Compare(version040) > 0) } + +func TestRollout_MarshalJSON_UnmarshalJSON(t *testing.T) { + var expectedRollout, actualRollout version.Rollout + + for i := 0; i < len(version.RolloutBytes{}); i++ { + expectedRollout.Seed[i] = byte(i) + expectedRollout.Cursor[i] = byte(i * 2) + } + + _, err := json.Marshal(actualRollout.Seed) + require.NoError(t, err) + + emptyJSONRollout, err := json.Marshal(actualRollout) + require.NoError(t, err) + + jsonRollout, err := json.Marshal(expectedRollout) + require.NoError(t, err) + require.NotEqual(t, emptyJSONRollout, jsonRollout) + + err = json.Unmarshal(jsonRollout, &actualRollout) + require.NoError(t, err) + require.Equal(t, expectedRollout, actualRollout) +} diff --git a/versioncontrol/peer.go b/versioncontrol/peer.go index 965d9bded..c0e4f1cf5 100644 --- a/versioncontrol/peer.go +++ b/versioncontrol/peer.go @@ -5,9 +5,12 @@ package versioncontrol import ( "context" + "encoding/hex" "encoding/json" + "math/big" "net" "net/http" + "reflect" "github.com/zeebo/errs" "go.uber.org/zap" @@ -17,7 +20,17 @@ import ( "storj.io/storj/internal/version" ) -// Config is all the configuration parameters for a Version Control Server +// seedLength is the number of bytes in a rollout seed. +const seedLength = 32 + +var ( + // RolloutErr defines the rollout config error class. + RolloutErr = errs.Class("rollout config error") + // EmptySeedErr is used when the rollout contains an empty seed value. + EmptySeedErr = RolloutErr.New("empty seed") +) + +// Config is all the configuration parameters for a Version Control Server. type Config struct { Address string `user:"true" help:"public address to listen on" default:":8080"` Versions ServiceVersions @@ -25,7 +38,7 @@ type Config struct { Binary Versions } -// ServiceVersions provides a list of allowed Versions per Service +// ServiceVersions provides a list of allowed Versions per Service. type ServiceVersions struct { Satellite string `user:"true" help:"Allowed Satellite Versions" default:"v0.0.1"` Storagenode string `user:"true" help:"Allowed Storagenode Versions" default:"v0.0.1"` @@ -34,7 +47,8 @@ type ServiceVersions struct { Identity string `user:"true" help:"Allowed Identity Versions" default:"v0.0.1"` } -// Versions represents versions for all binaries +// Versions represents versions for all binaries. +// TODO: this name is inconsistent with the internal/version pkg's analogue, `Processes`. type Versions struct { Satellite Binary Storagenode Binary @@ -43,18 +57,26 @@ type Versions struct { Identity Binary } -// Binary represents versions for single binary +// Binary represents versions for single binary. +// TODO: This name is inconsistent with the internal/version pkg's analogue, `Process`. type Binary struct { Minimum Version Suggested Version + Rollout Rollout } -// Version single version +// Version single version. type Version struct { Version string `user:"true" help:"peer version" default:"v0.0.1"` URL string `user:"true" help:"URL for specific binary" default:""` } +// Rollout represents the state of a version rollout of a binary to the suggested version. +type Rollout struct { + Seed string `user:"true" help:"random 32 byte, hex-encoded string"` + Cursor int `user:"true" help:"percentage of nodes which should roll-out to the suggested version" default:"0"` +} + // Peer is the representation of a VersionControl Server. type Peer struct { // core dependencies @@ -71,7 +93,7 @@ type Peer struct { response []byte } -// HandleGet contains the request handler for the version control web server +// HandleGet contains the request handler for the version control web server. func (peer *Peer) HandleGet(w http.ResponseWriter, r *http.Request) { // Only handle GET Requests if r.Method != http.MethodGet { @@ -94,6 +116,10 @@ func (peer *Peer) HandleGet(w http.ResponseWriter, r *http.Request) { // New creates a new VersionControl Server. func New(log *zap.Logger, config *Config) (peer *Peer, err error) { + if err := config.Binary.ValidateRollouts(log); err != nil { + return nil, RolloutErr.Wrap(err) + } + peer = &Peer{ Log: log, } @@ -125,14 +151,32 @@ func New(log *zap.Logger, config *Config) (peer *Peer, err error) { } peer.Versions.Processes = version.Processes{} - peer.Versions.Processes.Satellite = configToProcess(config.Binary.Satellite) - peer.Versions.Processes.Storagenode = configToProcess(config.Binary.Storagenode) - peer.Versions.Processes.Uplink = configToProcess(config.Binary.Uplink) - peer.Versions.Processes.Gateway = configToProcess(config.Binary.Gateway) - peer.Versions.Processes.Identity = configToProcess(config.Binary.Identity) + peer.Versions.Processes.Satellite, err = configToProcess(config.Binary.Satellite) + if err != nil { + return nil, RolloutErr.Wrap(err) + } + + peer.Versions.Processes.Storagenode, err = configToProcess(config.Binary.Storagenode) + if err != nil { + return nil, RolloutErr.Wrap(err) + } + + peer.Versions.Processes.Uplink, err = configToProcess(config.Binary.Uplink) + if err != nil { + return nil, RolloutErr.Wrap(err) + } + + peer.Versions.Processes.Gateway, err = configToProcess(config.Binary.Gateway) + if err != nil { + return nil, RolloutErr.Wrap(err) + } + + peer.Versions.Processes.Identity, err = configToProcess(config.Binary.Identity) + if err != nil { + return nil, RolloutErr.Wrap(err) + } peer.response, err = json.Marshal(peer.Versions) - if err != nil { peer.Log.Sugar().Fatalf("Error marshalling version info: %v", err) } @@ -178,8 +222,67 @@ func (peer *Peer) Close() (err error) { // Addr returns the public address. func (peer *Peer) Addr() string { return peer.Server.Listener.Addr().String() } -func configToProcess(binary Binary) version.Process { - return version.Process{ +// ValidateRollouts validates the rollout field of each field in the Versions struct. +func (versions Versions) ValidateRollouts(log *zap.Logger) error { + value := reflect.ValueOf(versions) + fieldCount := value.NumField() + validationErrs := errs.Group{} + for i := 1; i < fieldCount; i++ { + binary, ok := value.Field(i).Interface().(Binary) + if !ok { + log.Warn("non-binary field in versions config struct", zap.String("field name", value.Type().Field(i).Name)) + continue + } + if err := binary.Rollout.Validate(); err != nil { + if err == EmptySeedErr { + log.Warn(err.Error(), zap.String("binary", value.Type().Field(i).Name)) + continue + } + validationErrs.Add(err) + } + } + return validationErrs.Err() +} + +// Validate validates the rollout seed and cursor config values. +func (rollout Rollout) Validate() error { + seedLen := len(rollout.Seed) + if seedLen == 0 { + return EmptySeedErr + } + + if seedLen != hex.EncodedLen(seedLength) { + return RolloutErr.New("invalid seed length: %d", seedLen) + } + + if rollout.Cursor < 0 || rollout.Cursor > 100 { + return RolloutErr.New("invalid cursor percentage: %d", rollout.Cursor) + } + + if _, err := hex.DecodeString(rollout.Seed); err != nil { + return RolloutErr.New("invalid seed: %s", rollout.Seed) + } + return nil +} + +func percentageToCursor(pct int) version.RolloutBytes { + // NB: convert the max value to a number, multiply by the percentage, convert back. + var maxInt, maskInt big.Int + var maxBytes version.RolloutBytes + for i := 0; i < len(maxBytes); i++ { + maxBytes[i] = 255 + } + maxInt.SetBytes(maxBytes[:]) + maskInt.Div(maskInt.Mul(&maxInt, big.NewInt(int64(pct))), big.NewInt(100)) + + var cursor version.RolloutBytes + copy(cursor[:], maskInt.Bytes()) + + return cursor +} + +func configToProcess(binary Binary) (version.Process, error) { + process := version.Process{ Minimum: version.Version{ Version: binary.Minimum.Version, URL: binary.Minimum.URL, @@ -188,5 +291,15 @@ func configToProcess(binary Binary) version.Process { Version: binary.Suggested.Version, URL: binary.Suggested.URL, }, + Rollout: version.Rollout{ + Cursor: percentageToCursor(binary.Rollout.Cursor), + }, } + + seedBytes, err := hex.DecodeString(binary.Rollout.Seed) + if err != nil { + return version.Process{}, err + } + copy(process.Rollout.Seed[:], seedBytes) + return process, nil } diff --git a/versioncontrol/peer_test.go b/versioncontrol/peer_test.go new file mode 100644 index 000000000..a58d95911 --- /dev/null +++ b/versioncontrol/peer_test.go @@ -0,0 +1,223 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package versioncontrol_test + +import ( + "encoding/hex" + "math/rand" + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + + "storj.io/storj/versioncontrol" +) + +var rolloutErrScenarios = []struct { + name string + rollout versioncontrol.Rollout + errContains string +}{ + { + "short seed", + versioncontrol.Rollout{ + // 31 byte seed + Seed: "00000000000000000000000000000000000000000000000000000000000000", + Cursor: 0, + }, + "invalid seed length:", + }, + { + "long seed", + versioncontrol.Rollout{ + // 33 byte seed + Seed: "000000000000000000000000000000000000000000000000000000000000000000", + Cursor: 0, + }, + "invalid seed length:", + }, + { + "invalid seed", + versioncontrol.Rollout{ + // non-hex seed + Seed: "G000000000000000000000000000000000000000000000000000000000000000", + Cursor: 0, + }, + "invalid seed:", + }, + { + "negative cursor", + versioncontrol.Rollout{ + Seed: "0000000000000000000000000000000000000000000000000000000000000000", + Cursor: -1, + }, + "invalid cursor percentage:", + }, + { + "cursor too big", + versioncontrol.Rollout{ + Seed: "0000000000000000000000000000000000000000000000000000000000000000", + Cursor: 101, + }, + "invalid cursor percentage:", + }, +} + +func TestPeer_Run(t *testing.T) { + testVersion := "v0.0.1" + testServiceVersions := versioncontrol.ServiceVersions{ + Gateway: testVersion, + Identity: testVersion, + Satellite: testVersion, + Storagenode: testVersion, + Uplink: testVersion, + } + + t.Run("random rollouts", func(t *testing.T) { + for i := 0; i < 100; i++ { + config := versioncontrol.Config{ + Versions: testServiceVersions, + Binary: validRandVersions(t), + } + + peer, err := versioncontrol.New(zaptest.NewLogger(t), &config) + require.NoError(t, err) + require.NotNil(t, peer) + } + }) + + t.Run("empty rollout seed", func(t *testing.T) { + versionsType := reflect.TypeOf(versioncontrol.Versions{}) + fieldCount := versionsType.NumField() + + // test invalid rollout for each binary + for i := 1; i < fieldCount; i++ { + versions := versioncontrol.Versions{} + versionsValue := reflect.ValueOf(&versions) + field := reflect.Indirect(versionsValue).Field(i) + + binary := versioncontrol.Binary{ + Rollout: versioncontrol.Rollout{ + Seed: "", + Cursor: 0, + }, + } + + field.Set(reflect.ValueOf(binary)) + + config := versioncontrol.Config{ + Versions: testServiceVersions, + Binary: versions, + } + + peer, err := versioncontrol.New(zaptest.NewLogger(t), &config) + require.NoError(t, err) + require.NotNil(t, peer) + } + }) +} + +func TestPeer_Run_error(t *testing.T) { + for _, scenario := range rolloutErrScenarios { + scenario := scenario + t.Run(scenario.name, func(t *testing.T) { + versionsType := reflect.TypeOf(versioncontrol.Versions{}) + fieldCount := versionsType.NumField() + + // test invalid rollout for each binary + for i := 1; i < fieldCount; i++ { + versions := versioncontrol.Versions{} + versionsValue := reflect.ValueOf(&versions) + field := reflect.Indirect(versionsValue).Field(i) + + binary := versioncontrol.Binary{ + Rollout: scenario.rollout, + } + + field.Set(reflect.ValueOf(binary)) + + config := versioncontrol.Config{ + Binary: versions, + } + + peer, err := versioncontrol.New(zaptest.NewLogger(t), &config) + require.Nil(t, peer) + require.Error(t, err) + require.Contains(t, err.Error(), scenario.errContains) + } + }) + } +} + +func TestVersions_ValidateRollouts(t *testing.T) { + versions := validRandVersions(t) + err := versions.ValidateRollouts(zaptest.NewLogger(t)) + require.NoError(t, err) +} + +func TestRollout_Validate(t *testing.T) { + for i := 0; i < 100; i++ { + rollout := versioncontrol.Rollout{ + Seed: randSeedString(t), + Cursor: i, + } + + err := rollout.Validate() + require.NoError(t, err) + } +} + +func TestRollout_Validate_error(t *testing.T) { + for _, scenario := range rolloutErrScenarios { + scenario := scenario + t.Run(scenario.name, func(t *testing.T) { + err := scenario.rollout.Validate() + require.Error(t, err) + require.True(t, versioncontrol.RolloutErr.Has(err)) + require.Contains(t, err.Error(), scenario.errContains) + }) + } +} + +func validRandVersions(t *testing.T) versioncontrol.Versions { + t.Helper() + + return versioncontrol.Versions{ + Satellite: versioncontrol.Binary{ + Rollout: randRollout(t), + }, + Storagenode: versioncontrol.Binary{ + Rollout: randRollout(t), + }, + Uplink: versioncontrol.Binary{ + Rollout: randRollout(t), + }, + Gateway: versioncontrol.Binary{ + Rollout: randRollout(t), + }, + Identity: versioncontrol.Binary{ + Rollout: randRollout(t), + }, + } +} + +func randRollout(t *testing.T) versioncontrol.Rollout { + t.Helper() + + return versioncontrol.Rollout{ + Seed: randSeedString(t), + Cursor: rand.Intn(101), + } +} + +func randSeedString(t *testing.T) string { + t.Helper() + + seed := make([]byte, 32) + _, err := rand.Read(seed) + require.NoError(t, err) + + return hex.EncodeToString(seed) +}