diff --git a/satellite/audit/containment.go b/satellite/audit/containment.go index 4d667c48b..e4ced4e5f 100644 --- a/satellite/audit/containment.go +++ b/satellite/audit/containment.go @@ -29,4 +29,5 @@ type Containment interface { Get(ctx context.Context, nodeID pb.NodeID) (*ReverificationJob, error) Insert(ctx context.Context, job *PieceLocator) error Delete(ctx context.Context, job *PieceLocator) (wasDeleted, nodeStillContained bool, err error) + GetAllContainedNodes(ctx context.Context) ([]pb.NodeID, error) } diff --git a/satellite/audit/queue.go b/satellite/audit/queue.go index 73ab31bef..67cf81105 100644 --- a/satellite/audit/queue.go +++ b/satellite/audit/queue.go @@ -34,6 +34,7 @@ type ReverifyQueue interface { GetNextJob(ctx context.Context, retryInterval time.Duration) (job *ReverificationJob, err error) Remove(ctx context.Context, piece *PieceLocator) (wasDeleted bool, err error) GetByNodeID(ctx context.Context, nodeID storj.NodeID) (audit *ReverificationJob, err error) + GetAllContainedNodes(ctx context.Context) ([]storj.NodeID, error) } // ByStreamIDAndPosition allows sorting of a slice of segments by stream ID and position. diff --git a/satellite/satellitedb/containment.go b/satellite/satellitedb/containment.go index 4da32ba14..b70942051 100644 --- a/satellite/satellitedb/containment.go +++ b/satellite/satellitedb/containment.go @@ -56,3 +56,9 @@ func (containment *containment) Delete(ctx context.Context, pendingJob *audit.Pi } return isDeleted, nodeStillContained, audit.ContainError.Wrap(err) } + +func (containment *containment) GetAllContainedNodes(ctx context.Context) (nodes []pb.NodeID, err error) { + defer mon.Task()(&ctx)(&err) + + return containment.reverifyQueue.GetAllContainedNodes(ctx) +} diff --git a/satellite/satellitedb/reverifyqueue.go b/satellite/satellitedb/reverifyqueue.go index 8793a23a9..1551aa58d 100644 --- a/satellite/satellitedb/reverifyqueue.go +++ b/satellite/satellitedb/reverifyqueue.go @@ -9,6 +9,8 @@ import ( "errors" "time" + "github.com/zeebo/errs" + "storj.io/common/storj" "storj.io/common/uuid" "storj.io/storj/satellite/audit" @@ -137,6 +139,32 @@ func (rq *reverifyQueue) GetByNodeID(ctx context.Context, nodeID storj.NodeID) ( return convertDBJob(ctx, pending) } +func (rq *reverifyQueue) GetAllContainedNodes(ctx context.Context) (nodes []storj.NodeID, err error) { + defer mon.Task()(&ctx)(&err) + + result, err := rq.db.QueryContext(ctx, `SELECT DISTINCT node_id FROM reverification_audits`) + if err != nil { + return nil, audit.ContainError.Wrap(err) + } + defer func() { + err = errs.Combine(err, audit.ContainError.Wrap(result.Close())) + }() + + for result.Next() { + var nodeIDBytes []byte + if err := result.Scan(&nodeIDBytes); err != nil { + return nil, audit.ContainError.Wrap(err) + } + nodeID, err := storj.NodeIDFromBytes(nodeIDBytes) + if err != nil { + return nil, audit.ContainError.Wrap(err) + } + nodes = append(nodes, nodeID) + } + + return nodes, audit.ContainError.Wrap(result.Err()) +} + func convertDBJob(ctx context.Context, info *dbx.ReverificationAudits) (pendingJob *audit.ReverificationJob, err error) { defer mon.Task()(&ctx)(&err) if info == nil { diff --git a/satellite/satellitedb/reverifyqueue_test.go b/satellite/satellitedb/reverifyqueue_test.go index 6aaeaeaca..8b25376da 100644 --- a/satellite/satellitedb/reverifyqueue_test.go +++ b/satellite/satellitedb/reverifyqueue_test.go @@ -5,11 +5,13 @@ package satellitedb_test import ( "context" + "sort" "testing" "time" "github.com/stretchr/testify/require" + "storj.io/common/storj" "storj.io/common/sync2" "storj.io/common/testcontext" "storj.io/common/testrand" @@ -52,6 +54,8 @@ func TestReverifyQueue(t *testing.T) { err = reverifyQueue.Insert(ctx, locator2) require.NoError(t, err) + checkGetAllContainedNodes(ctx, t, reverifyQueue, locator1.NodeID, locator2.NodeID) + job1, err := reverifyQueue.GetNextJob(ctx, retryInterval) require.NoError(t, err) require.Equal(t, *locator1, job1.Locator) @@ -62,6 +66,8 @@ func TestReverifyQueue(t *testing.T) { require.Equal(t, *locator2, job2.Locator) require.EqualValues(t, 1, job2.ReverifyCount) + checkGetAllContainedNodes(ctx, t, reverifyQueue, locator1.NodeID, locator2.NodeID) + require.Truef(t, job1.InsertedAt.Before(job2.InsertedAt), "job1 [%s] should have an earlier insertion time than job2 [%s]", job1.InsertedAt, job2.InsertedAt) _, err = reverifyQueue.GetNextJob(ctx, retryInterval) @@ -81,15 +87,20 @@ func TestReverifyQueue(t *testing.T) { require.Equal(t, *locator1, job3.Locator) require.EqualValues(t, 2, job3.ReverifyCount) + checkGetAllContainedNodes(ctx, t, reverifyQueue, locator1.NodeID, locator2.NodeID) + wasDeleted, err := reverifyQueue.Remove(ctx, locator1) require.NoError(t, err) require.True(t, wasDeleted) + checkGetAllContainedNodes(ctx, t, reverifyQueue, locator2.NodeID) wasDeleted, err = reverifyQueue.Remove(ctx, locator2) require.NoError(t, err) require.True(t, wasDeleted) + checkGetAllContainedNodes(ctx, t, reverifyQueue) wasDeleted, err = reverifyQueue.Remove(ctx, locator1) require.NoError(t, err) require.False(t, wasDeleted) + checkGetAllContainedNodes(ctx, t, reverifyQueue) _, err = reverifyQueue.GetNextJob(ctx, retryInterval) require.Truef(t, audit.ErrEmptyQueue.Has(err), "expected empty queue error, but got error %+v", err) @@ -150,3 +161,17 @@ func TestReverifyQueueGetByNodeID(t *testing.T) { require.Nil(t, job3) }) } + +// checkGetAllContainedNodes checks that the GetAllContainedNodes method works as expected +// in a particular situation. +func checkGetAllContainedNodes(ctx context.Context, t testing.TB, reverifyQueue audit.ReverifyQueue, expectedIDs ...storj.NodeID) { + containedNodes, err := reverifyQueue.GetAllContainedNodes(ctx) + require.NoError(t, err) + sort.Slice(containedNodes, func(i, j int) bool { + return containedNodes[i].Compare(containedNodes[j]) < 0 + }) + sort.Slice(expectedIDs, func(i, j int) bool { + return expectedIDs[i].Compare(expectedIDs[j]) < 0 + }) + require.Equal(t, expectedIDs, containedNodes) +}