satellite/nodeselection: SelectBySubnet should use placement filters for all nodes

Current node selection logic (in case of using SelectBySubnet):

 1. selects one subnet randomly
 2. selects one node randomly from the subnet
 3. applies the placement NodeFilters to the node and ignore it, if doesn't match

This logic is wrong:

 1. Imagine that we have a subnet with two DE and one GB nodes.
 2. We would like to select DE nodes
 2. In case of GB node is selected (randomly) in step2, step3 will ignore the subnet, even if there are good (DE) nodes in there.

Change-Id: I7673f52c89b46e0cc7b20a9b74137dc689d6c17e
This commit is contained in:
Márton Elek 2023-08-01 13:01:47 +02:00
parent 03c52f184e
commit 0b02a48a10
No known key found for this signature in database
7 changed files with 288 additions and 28 deletions

View File

@ -0,0 +1,66 @@
// Copyright (C) 2023 Storj Labs, Inc.
// See LICENSE for copying information.
package main
import (
"bytes"
"fmt"
"math"
mathrand "math/rand"
"os"
)
// main implements a simple (and slow) prime number generator.
func main() {
dest := bytes.Buffer{}
_, err := dest.WriteString(`// Copyright (C) 2023 Storj Labs, Inc.
// See LICENSE for copying information.
package nodeselection
//go:generate go run ./gen
var primes = []uint64{
`)
if err != nil {
panic(err)
}
min := uint64(1 << 32)
requiredPrimes := 32
for {
n := mathrand.Uint64()
if n < min {
continue
}
prime := true
squareRoot := uint64(math.Floor(math.Sqrt(float64(n))))
for i := uint64(2); i < squareRoot; i++ {
if n%i == 0 {
prime = false
break
}
}
if prime {
fmt.Println("found a prime", n)
requiredPrimes--
_, err = dest.WriteString(fmt.Sprintf(" %d,\n", n))
if err != nil {
panic(err)
}
}
if requiredPrimes == 0 {
break
}
}
_, err = dest.WriteString("}")
if err != nil {
panic(err)
}
err = os.WriteFile("primes.go", dest.Bytes(), 0644)
if err != nil {
panic(err)
}
}

View File

@ -0,0 +1,40 @@
// Copyright (C) 2023 Storj Labs, Inc.
// See LICENSE for copying information.
package nodeselection
//go:generate go run ./gen
var primes = []uint64{
4644169889937985027,
6165318155211055777,
3742593631875035779,
15678494965655331763,
5575516698076218241,
2264694452711313617,
3262352908419267653,
5074258193245947331,
3312977600413714507,
5171218356166066181,
18019739453606759147,
530860110108392567,
17812504517332934837,
12307788838370083211,
16850377343872429559,
17793324760692636731,
15906408129443552839,
9998380786003640893,
3005127980230146739,
5330537884366068391,
11741948568425617691,
8672642363526608261,
11163957437174369417,
8050745157603811909,
1538425910175398813,
2302692389243415683,
14057792358509291159,
7458209465324486017,
3542192325647561567,
11143516174227182297,
16664534039506187251,
15387573628527498169,
}

View File

@ -0,0 +1,42 @@
// Copyright (C) 2023 Storj Labs, Inc.
// See LICENSE for copying information.
package nodeselection
import mathrand "math/rand"
// RandomOrder as an iterator of a pseudo-random permutation set.
type RandomOrder struct {
count uint64
at uint64
prime uint64
len uint64
}
// NewRandomOrder creates new iterator, returns number between [0,n) in pseudo-random order.
func NewRandomOrder(n int) RandomOrder {
if n == 0 {
return RandomOrder{
count: 0,
}
}
return RandomOrder{
count: uint64(n),
at: uint64(mathrand.Intn(n)),
prime: primes[mathrand.Intn(len(primes))],
len: uint64(n),
}
}
// Next generates the next number.
func (r *RandomOrder) Next() bool {
if r.count == 0 {
return false
}
r.at = (r.at + r.prime) % r.len
r.count--
return true
}
// At returns the current number in the permutations.
func (r *RandomOrder) At() uint64 { return r.at }

View File

@ -0,0 +1,43 @@
// Copyright (C) 2023 Storj Labs, Inc.
// See LICENSE for copying information.
package nodeselection
import (
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
func TestRand(t *testing.T) {
// test if we get a full permutation
t.Run("generate real permutations", func(t *testing.T) {
var numbers []uint64
c := NewRandomOrder(20)
for c.Next() {
numbers = append(numbers, c.At())
}
require.Len(t, numbers, 20)
slices.Sort(numbers)
for i := 0; i < len(numbers); i++ {
require.Equal(t, uint64(i), numbers[i])
}
})
t.Run("next always returns with false at the end", func(t *testing.T) {
c := NewRandomOrder(3)
require.True(t, c.Next())
require.True(t, c.Next())
require.True(t, c.Next())
require.False(t, c.Next())
require.False(t, c.Next())
})
t.Run("z ero size is accepted", func(t *testing.T) {
c := NewRandomOrder(0)
require.False(t, c.Next())
})
}

View File

@ -77,15 +77,17 @@ func (subnets SelectBySubnet) Select(n int, filter NodeFilter) []*SelectedNode {
}
selected := []*SelectedNode{}
for _, idx := range mathrand.Perm(len(subnets)) {
subnet := subnets[idx]
node := subnet.Nodes[mathrand.Intn(len(subnet.Nodes))]
r := NewRandomOrder(len(subnets))
for r.Next() {
subnet := subnets[r.At()]
if !filter.MatchInclude(node) {
continue
rs := NewRandomOrder(len(subnet.Nodes))
for rs.Next() {
if filter.MatchInclude(subnet.Nodes[rs.At()]) {
selected = append(selected, subnet.Nodes[rs.At()].Clone())
break
}
}
selected = append(selected, node.Clone())
if len(selected) >= n {
break
}

View File

@ -4,12 +4,15 @@
package nodeselection_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"storj.io/common/identity/testidentity"
"storj.io/common/storj"
"storj.io/common/storj/location"
"storj.io/common/testcontext"
"storj.io/common/testrand"
"storj.io/storj/satellite/nodeselection"
@ -252,3 +255,29 @@ func TestSelectFiltered(t *testing.T) {
assert.Len(t, selector.Select(3, nodeselection.NodeFilters{}.WithAutoExcludeSubnets()), 2)
assert.Len(t, selector.Select(3, nodeselection.NodeFilters{}.WithExcludedIDs([]storj.NodeID{thirdID}).WithAutoExcludeSubnets()), 1)
}
func TestSelectFilteredMulti(t *testing.T) {
// four subnets with 3 nodes in each. Only one per subnet is located in Germany.
// Algorithm should pick the German one from each subnet, and 4 nodes should be possible to be picked.
ctx := testcontext.New(t)
defer ctx.Cleanup()
var nodes []*nodeselection.SelectedNode
for i := 0; i < 12; i++ {
nodes = append(nodes, &nodeselection.SelectedNode{
ID: testidentity.MustPregeneratedIdentity(i, storj.LatestIDVersion()).ID,
LastNet: fmt.Sprintf("68.0.%d", i/3),
LastIPPort: fmt.Sprintf("68.0.%d.%d:1000", i/3, i),
CountryCode: location.Germany + location.CountryCode(i%3),
})
}
selector := nodeselection.SelectBySubnetFromNodes(nodes)
for i := 0; i < 100; i++ {
assert.Len(t, selector.Select(4, nodeselection.NodeFilters{}.WithCountryFilter(location.NewSet(location.Germany))), 4)
}
}

View File

@ -88,8 +88,8 @@ func TestRefresh(t *testing.T) {
func addNodesToNodesTable(ctx context.Context, t *testing.T, db overlay.DB, count, makeReputable int) (ids []storj.NodeID) {
for i := 0; i < count; i++ {
subnet := strconv.Itoa(i) + ".1.2"
addr := subnet + ".3:8080"
subnet := strconv.Itoa(i/3) + ".1.2"
addr := fmt.Sprintf("%s.%d:8080", subnet, i%3+1)
n := overlay.NodeCheckInInfo{
NodeID: storj.NodeID{byte(i)},
Address: &pb.NodeAddress{
@ -107,6 +107,7 @@ func addNodesToNodesTable(ctx context.Context, t *testing.T, db overlay.DB, coun
Timestamp: time.Time{},
Release: true,
},
CountryCode: location.Germany + location.CountryCode(i%2),
}
err := db.UpdateCheckIn(ctx, n, time.Now().UTC(), nodeSelectionConfig)
require.NoError(t, err)
@ -212,12 +213,15 @@ func TestGetNodes(t *testing.T) {
DistinctIP: true,
MinimumDiskSpace: 100 * memory.MiB,
}
placementRules := overlay.NewPlacementRules()
placementRules.AddPlacementRule(storj.PlacementConstraint(5), nodeselection.NodeFilters{}.WithCountryFilter(location.NewSet(location.Germany)))
cache, err := overlay.NewUploadSelectionCache(zap.NewNop(),
db.OverlayCache(),
lowStaleness,
nodeSelectionConfig,
nodeselection.NodeFilters{},
overlay.NewPlacementRules().CreateFilters,
placementRules.CreateFilters,
)
require.NoError(t, err)
@ -225,15 +229,21 @@ func TestGetNodes(t *testing.T) {
defer cacheCancel()
ctx.Go(func() error { return cache.Run(cacheCtx) })
// add 4 nodes to the database and vet 2
const nodeCount = 4
nodeIds := addNodesToNodesTable(ctx, t, db.OverlayCache(), nodeCount, 2)
require.Len(t, nodeIds, 2)
// add 10 nodes to the database and vet 8
// 4 subnets [A A A B B B C C C D]
// 2 countries [DE X DE x DE x DE x DE x]
// vetted [1 1 1 1 1 1 1 1 0 0]
const nodeCount = 10
nodeIds := addNodesToNodesTable(ctx, t, db.OverlayCache(), nodeCount, 8)
require.Len(t, nodeIds, 8)
t.Run("normal selection", func(t *testing.T) {
t.Run("get 2", func(t *testing.T) {
// confirm cache.GetNodes returns the correct nodes
selectedNodes, err := cache.GetNodes(ctx, overlay.FindStorageNodesRequest{RequestedCount: 2})
require.NoError(t, err)
require.Equal(t, 2, len(selectedNodes))
require.Len(t, selectedNodes, 2)
for _, node := range selectedNodes {
require.NotEqual(t, node.ID, "")
require.NotEqual(t, node.Address.Address, "")
@ -242,6 +252,33 @@ func TestGetNodes(t *testing.T) {
require.NotEqual(t, node.LastNet, "")
}
})
t.Run("too much", func(t *testing.T) {
// we have 5 subnets (1 new, 4 vetted), with two nodes in each
_, err := cache.GetNodes(ctx, overlay.FindStorageNodesRequest{RequestedCount: 6})
require.Error(t, err)
})
})
t.Run("using country filter", func(t *testing.T) {
t.Run("normal", func(t *testing.T) {
selectedNodes, err := cache.GetNodes(ctx, overlay.FindStorageNodesRequest{
RequestedCount: 3,
Placement: 5,
})
require.NoError(t, err)
require.Len(t, selectedNodes, 3)
})
t.Run("too much", func(t *testing.T) {
_, err := cache.GetNodes(ctx, overlay.FindStorageNodesRequest{
RequestedCount: 4,
Placement: 5,
})
require.Error(t, err)
})
})
})
}
func TestGetNodesExcludeCountryCodes(t *testing.T) {
@ -539,9 +576,10 @@ func TestNewNodeFraction(t *testing.T) {
require.NoError(t, err)
// add some nodes to the database, some are reputable and some are new nodes
const nodeCount = 10
repIDs := addNodesToNodesTable(ctx, t, db.OverlayCache(), nodeCount, 4)
require.Len(t, repIDs, 4)
// 3 nodes per net --> we need 4 net (* 3 node) reputable + 1 net (* 3 node) new to select 5 with 0.2 percentage new
const nodeCount = 15
repIDs := addNodesToNodesTable(ctx, t, db.OverlayCache(), nodeCount, 12)
require.Len(t, repIDs, 12)
// confirm nodes are in the cache once
err = cache.Refresh(ctx)
require.NoError(t, err)