diff --git a/satellite/overlay/service.go b/satellite/overlay/service.go index 2ccf8c620..eec6171e3 100644 --- a/satellite/overlay/service.go +++ b/satellite/overlay/service.go @@ -300,10 +300,11 @@ type NodeReputation struct { func (node *SelectedNode) Clone() *SelectedNode { copy := pb.CopyNode(&pb.Node{Id: node.ID, Address: node.Address}) return &SelectedNode{ - ID: copy.Id, - Address: copy.Address, - LastNet: node.LastNet, - LastIPPort: node.LastIPPort, + ID: copy.Id, + Address: copy.Address, + LastNet: node.LastNet, + LastIPPort: node.LastIPPort, + CountryCode: node.CountryCode, } } diff --git a/satellite/satellitedb/overlaycache.go b/satellite/satellitedb/overlaycache.go index 0afe1409f..a5848c3ce 100644 --- a/satellite/satellitedb/overlaycache.go +++ b/satellite/satellitedb/overlaycache.go @@ -57,7 +57,7 @@ func (cache *overlaycache) selectAllStorageNodesUpload(ctx context.Context, sele defer mon.Task()(&ctx)(&err) query := ` - SELECT id, address, last_net, last_ip_port, vetted_at, country_code, noise_proto, noise_public_key, debounce_limit + SELECT id, address, last_net, last_ip_port, vetted_at, country_code, noise_proto, noise_public_key, debounce_limit, country_code FROM nodes ` + cache.db.impl.AsOfSystemInterval(selectionCfg.AsOfSystemTime.Interval()) + ` WHERE disqualified IS NULL @@ -102,7 +102,8 @@ func (cache *overlaycache) selectAllStorageNodesUpload(ctx context.Context, sele var lastIPPort sql.NullString var vettedAt *time.Time var noise noiseScanner - err = rows.Scan(&node.ID, &node.Address.Address, &node.LastNet, &lastIPPort, &vettedAt, &node.CountryCode, &noise.Proto, &noise.PublicKey, &node.Address.DebounceLimit) + err = rows.Scan(&node.ID, &node.Address.Address, &node.LastNet, &lastIPPort, &vettedAt, &node.CountryCode, &noise.Proto, + &noise.PublicKey, &node.Address.DebounceLimit, &node.CountryCode) if err != nil { return nil, nil, err } @@ -141,7 +142,7 @@ func (cache *overlaycache) selectAllStorageNodesDownload(ctx context.Context, on defer mon.Task()(&ctx)(&err) query := ` - SELECT id, address, last_net, last_ip_port, noise_proto, noise_public_key, debounce_limit + SELECT id, address, last_net, last_ip_port, noise_proto, noise_public_key, debounce_limit, country_code FROM nodes ` + cache.db.impl.AsOfSystemInterval(asOfConfig.Interval()) + ` WHERE disqualified IS NULL @@ -165,7 +166,8 @@ func (cache *overlaycache) selectAllStorageNodesDownload(ctx context.Context, on node.Address = &pb.NodeAddress{} var lastIPPort sql.NullString var noise noiseScanner - err = rows.Scan(&node.ID, &node.Address.Address, &node.LastNet, &lastIPPort, &noise.Proto, &noise.PublicKey, &node.Address.DebounceLimit) + err = rows.Scan(&node.ID, &node.Address.Address, &node.LastNet, &lastIPPort, &noise.Proto, + &noise.PublicKey, &node.Address.DebounceLimit, &node.CountryCode) if err != nil { return nil, err } diff --git a/satellite/satellitedb/overlaycache_test.go b/satellite/satellitedb/overlaycache_test.go index 8b5a3e074..e76fa21e2 100644 --- a/satellite/satellitedb/overlaycache_test.go +++ b/satellite/satellitedb/overlaycache_test.go @@ -15,6 +15,7 @@ import ( "storj.io/common/pb" "storj.io/common/storj" + "storj.io/common/storj/location" "storj.io/common/testcontext" "storj.io/common/testrand" "storj.io/private/version" @@ -358,3 +359,62 @@ func TestGetNodesNetwork(t *testing.T) { }) }) } + +func TestOverlayCache_SelectAllStorageNodesDownloadUpload(t *testing.T) { + satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { + cache := db.OverlayCache() + const netMask = 28 + mask := net.CIDRMask(netMask, 32) + + infos := make([]overlay.NodeCheckInInfo, 5) + + for n := range infos { + id := testrand.NodeID() + ip := net.IP{0, 0, 1, byte(n)} + lastNet := ip.Mask(mask).String() + + infos[n] = overlay.NodeCheckInInfo{ + IsUp: true, + Address: &pb.NodeAddress{Address: ip.String()}, + LastNet: lastNet, + LastIPPort: "0.0.0.0:0", + Version: &pb.NodeVersion{Version: "v0.0.0"}, + NodeID: id, + CountryCode: location.Canada, + } + err := cache.UpdateCheckIn(ctx, infos[n], time.Now().UTC(), overlay.NodeSelectionConfig{}) + require.NoError(t, err) + } + + checkNodes := func(selectedNodes []*overlay.SelectedNode) { + selectedNodesMap := map[storj.NodeID]*overlay.SelectedNode{} + for _, node := range selectedNodes { + selectedNodesMap[node.ID] = node + } + + for _, info := range infos { + selectedNode, ok := selectedNodesMap[info.NodeID] + require.True(t, ok) + + require.Equal(t, info.NodeID, selectedNode.ID) + require.Equal(t, info.Address, selectedNode.Address) + require.Equal(t, info.CountryCode, selectedNode.CountryCode) + require.Equal(t, info.LastIPPort, selectedNode.LastIPPort) + require.Equal(t, info.LastNet, selectedNode.LastNet) + } + } + + selectedNodes, err := cache.SelectAllStorageNodesDownload(ctx, time.Minute, overlay.AsOfSystemTimeConfig{}) + require.NoError(t, err) + + checkNodes(selectedNodes) + + reputableNodes, newNodes, err := cache.SelectAllStorageNodesUpload(ctx, overlay.NodeSelectionConfig{ + OnlineWindow: time.Minute, + }) + require.NoError(t, err) + + checkNodes(append(reputableNodes, newNodes...)) + }) + +}