diff --git a/satellite/metainfo/endpoint_object.go b/satellite/metainfo/endpoint_object.go index 48e27cd1b..b737be36c 100644 --- a/satellite/metainfo/endpoint_object.go +++ b/satellite/metainfo/endpoint_object.go @@ -12,6 +12,7 @@ import ( "github.com/jtolio/eventkit" "github.com/spacemonkeygo/monkit/v3" "go.uber.org/zap" + "golang.org/x/sync/errgroup" "storj.io/common/context2" "storj.io/common/encryption" @@ -1167,10 +1168,24 @@ func (endpoint *Endpoint) GetObjectIPs(ctx context.Context, req *pb.ObjectGetIPs return nil, endpoint.convertMetabaseErr(err) } - pieceCountByNodeID, err := endpoint.metabase.GetStreamPieceCountByNodeID(ctx, - metabase.GetStreamPieceCountByNodeID{ - StreamID: object.StreamID, - }) + var pieceCountByNodeID map[storj.NodeID]int64 + var placement storj.PlacementConstraint + + // TODO this is short term fix to easily filter out IPs out of bucket/object placement + // this request is not heavily used so it should be fine to add additional request to DB for now. + var group errgroup.Group + group.Go(func() error { + placement, err = endpoint.buckets.GetBucketPlacement(ctx, req.Bucket, keyInfo.ProjectID) + return err + }) + group.Go(func() (err error) { + pieceCountByNodeID, err = endpoint.metabase.GetStreamPieceCountByNodeID(ctx, + metabase.GetStreamPieceCountByNodeID{ + StreamID: object.StreamID, + }) + return err + }) + err = group.Wait() if err != nil { return nil, endpoint.convertMetabaseErr(err) } @@ -1180,7 +1195,7 @@ func (endpoint *Endpoint) GetObjectIPs(ctx context.Context, req *pb.ObjectGetIPs nodeIDs = append(nodeIDs, nodeID) } - nodeIPMap, err := endpoint.overlay.GetNodeIPs(ctx, nodeIDs) + nodeIPMap, err := endpoint.overlay.GetNodeIPsFromPlacement(ctx, nodeIDs, placement) if err != nil { endpoint.log.Error("internal", zap.Error(err)) return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) diff --git a/satellite/metainfo/endpoint_object_test.go b/satellite/metainfo/endpoint_object_test.go index 0626419e4..890a6ef24 100644 --- a/satellite/metainfo/endpoint_object_test.go +++ b/satellite/metainfo/endpoint_object_test.go @@ -18,6 +18,7 @@ import ( "github.com/stretchr/testify/require" "github.com/zeebo/errs" "go.uber.org/zap" + "golang.org/x/exp/maps" "storj.io/common/errs2" "storj.io/common/identity" @@ -992,6 +993,9 @@ func TestEndpoint_Object_With_StorageNodes(t *testing.T) { require.NoError(t, uplnk.CreateBucket(uplinkCtx, sat, bucketName)) require.NoError(t, uplnk.Upload(uplinkCtx, sat, bucketName, "jones", testrand.Bytes(20*memory.KB))) + jonesSegments, err := planet.Satellites[0].Metabase.DB.TestingAllSegments(ctx) + require.NoError(t, err) + project, err := uplnk.OpenProject(ctx, planet.Satellites[0]) require.NoError(t, err) defer ctx.Check(project.Close) @@ -1007,24 +1011,45 @@ func TestEndpoint_Object_With_StorageNodes(t *testing.T) { copyIPs, err := object.GetObjectIPs(ctx, uplink.Config{}, access, bucketName, "jones_copy") require.NoError(t, err) - sort.Slice(ips, func(i, j int) bool { - return bytes.Compare(ips[i], ips[j]) < 0 - }) - sort.Slice(copyIPs, func(i, j int) bool { - return bytes.Compare(copyIPs[i], copyIPs[j]) < 0 - }) - // verify that orignal and copy has the same results - require.Equal(t, ips, copyIPs) + require.ElementsMatch(t, ips, copyIPs) - // verify it's a real IP with valid host and port - for _, ip := range ips { - host, port, err := net.SplitHostPort(string(ip)) - require.NoError(t, err) - netIP := net.ParseIP(host) - require.NotNil(t, netIP) - _, err = strconv.Atoi(port) - require.NoError(t, err) + expectedIPsMap := map[string]struct{}{} + for _, segment := range jonesSegments { + for _, piece := range segment.Pieces { + node, err := planet.Satellites[0].Overlay.Service.Get(ctx, piece.StorageNode) + require.NoError(t, err) + expectedIPsMap[node.LastIPPort] = struct{}{} + } + } + + expectedIPs := [][]byte{} + for _, ip := range maps.Keys(expectedIPsMap) { + expectedIPs = append(expectedIPs, []byte(ip)) + } + require.ElementsMatch(t, expectedIPs, ips) + + // set bucket geofencing + _, err = planet.Satellites[0].DB.Buckets().UpdateBucket(ctx, buckets.Bucket{ + ProjectID: planet.Uplinks[0].Projects[0].ID, + Name: bucketName, + Placement: storj.EU, + }) + require.NoError(t, err) + + // set one node to US to filter it out from IP results + usNode := planet.FindNode(jonesSegments[0].Pieces[0].StorageNode) + require.NoError(t, planet.Satellites[0].Overlay.Service.TestNodeCountryCode(ctx, usNode.ID(), "US")) + require.NoError(t, planet.Satellites[0].API.Overlay.Service.DownloadSelectionCache.Refresh(ctx)) + + geoFencedIPs, err := object.GetObjectIPs(ctx, uplink.Config{}, access, bucketName, "jones") + require.NoError(t, err) + + require.Len(t, geoFencedIPs, len(expectedIPs)-1) + for _, ip := range geoFencedIPs { + if string(ip) == usNode.Addr() { + t.Fatal("this IP should be removed from results because of geofencing") + } } }) diff --git a/satellite/overlay/downloadselection.go b/satellite/overlay/downloadselection.go index 63910c714..e1c96cda5 100644 --- a/satellite/overlay/downloadselection.go +++ b/satellite/overlay/downloadselection.go @@ -75,8 +75,8 @@ func (cache *DownloadSelectionCache) read(ctx context.Context) (_ *DownloadSelec return NewDownloadSelectionCacheState(onlineNodes), nil } -// GetNodeIPs gets the last node ip:port from the cache, refreshing when needed. -func (cache *DownloadSelectionCache) GetNodeIPs(ctx context.Context, nodes []storj.NodeID) (_ map[storj.NodeID]string, err error) { +// GetNodeIPsFromPlacement gets the last node ip:port from the cache, refreshing when needed. Results are filtered out by placement. +func (cache *DownloadSelectionCache) GetNodeIPsFromPlacement(ctx context.Context, nodes []storj.NodeID, placement storj.PlacementConstraint) (_ map[storj.NodeID]string, err error) { defer mon.Task()(&ctx)(&err) state, err := cache.cache.Get(ctx, time.Now()) @@ -84,7 +84,7 @@ func (cache *DownloadSelectionCache) GetNodeIPs(ctx context.Context, nodes []sto return nil, Error.Wrap(err) } - return state.IPs(nodes), nil + return state.IPsFromPlacement(nodes, placement), nil } // GetNodes gets nodes by ID from the cache, and refreshes the cache if it is stale. @@ -140,6 +140,17 @@ func (state *DownloadSelectionCacheState) IPs(nodes []storj.NodeID) map[storj.No return xs } +// IPsFromPlacement returns node ip:port for nodes that are in state. Results are filtered out by placement. +func (state *DownloadSelectionCacheState) IPsFromPlacement(nodes []storj.NodeID, placement storj.PlacementConstraint) map[storj.NodeID]string { + xs := make(map[storj.NodeID]string, len(nodes)) + for _, nodeID := range nodes { + if n, exists := state.byID[nodeID]; exists && placement.AllowedCountry(n.CountryCode) { + xs[nodeID] = n.LastIPPort + } + } + return xs +} + // Nodes returns node ip:port for nodes that are in state. func (state *DownloadSelectionCacheState) Nodes(nodes []storj.NodeID) map[storj.NodeID]*SelectedNode { xs := make(map[storj.NodeID]*SelectedNode, len(nodes)) diff --git a/satellite/overlay/downloadselection_test.go b/satellite/overlay/downloadselection_test.go index 0d20705e9..9de484fe2 100644 --- a/satellite/overlay/downloadselection_test.go +++ b/satellite/overlay/downloadselection_test.go @@ -75,7 +75,7 @@ func TestDownloadSelectionCacheState_GetNodeIPs(t *testing.T) { ids := addNodesToNodesTable(ctx, t, db.OverlayCache(), nodeCount, 0) // confirm nodes are in the cache once - nodeips, err := cache.GetNodeIPs(ctx, ids) + nodeips, err := cache.GetNodeIPsFromPlacement(ctx, ids, storj.EveryCountry) require.NoError(t, err) for _, id := range ids { require.NotEmpty(t, nodeips[id]) diff --git a/satellite/overlay/service.go b/satellite/overlay/service.go index 2446c05ee..cf0565c21 100644 --- a/satellite/overlay/service.go +++ b/satellite/overlay/service.go @@ -416,10 +416,10 @@ func (service *Service) GetOnlineNodesForAuditRepair(ctx context.Context, nodeID return service.db.GetOnlineNodesForAuditRepair(ctx, nodeIDs, service.config.Node.OnlineWindow) } -// GetNodeIPs returns a map of node ip:port for the supplied nodeIDs. -func (service *Service) GetNodeIPs(ctx context.Context, nodeIDs []storj.NodeID) (_ map[storj.NodeID]string, err error) { +// GetNodeIPsFromPlacement returns a map of node ip:port for the supplied nodeIDs. Results are filtered out by placement. +func (service *Service) GetNodeIPsFromPlacement(ctx context.Context, nodeIDs []storj.NodeID, placement storj.PlacementConstraint) (_ map[storj.NodeID]string, err error) { defer mon.Task()(&ctx)(&err) - return service.DownloadSelectionCache.GetNodeIPs(ctx, nodeIDs) + return service.DownloadSelectionCache.GetNodeIPsFromPlacement(ctx, nodeIDs, placement) } // IsOnline checks if a node is 'online' based on the collected statistics.