From 8b4387a498ecc44830bf5ae20aac78c8527cb613 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Elek?= Date: Fri, 30 Jun 2023 14:47:52 +0200 Subject: [PATCH] satellite/satellitedb: add tag information to nodes selected for upload/downloads Change-Id: I0fa7daebcf83f7949726e5fffe68e0bdc6fd1d7a --- satellite/satellitedb/dbx/node.dbx | 6 +- satellite/satellitedb/dbx/satellitedb.dbx.go | 98 ++++++++++++++++++++ satellite/satellitedb/overlaycache.go | 44 +++++++++ satellite/satellitedb/overlaycache_test.go | 28 ++++++ 4 files changed, 175 insertions(+), 1 deletion(-) diff --git a/satellite/satellitedb/dbx/node.dbx b/satellite/satellitedb/dbx/node.dbx index 81d4efd9d..e3c761f95 100644 --- a/satellite/satellitedb/dbx/node.dbx +++ b/satellite/satellitedb/dbx/node.dbx @@ -273,4 +273,8 @@ create node_tags ( noreturn, replace ) read all ( select node_tags where node_tags.node_id = ? -) \ No newline at end of file +) + +read all ( + select node_tags +) diff --git a/satellite/satellitedb/dbx/satellitedb.dbx.go b/satellite/satellitedb/dbx/satellitedb.dbx.go index b60124321..f6925615b 100644 --- a/satellite/satellitedb/dbx/satellitedb.dbx.go +++ b/satellite/satellitedb/dbx/satellitedb.dbx.go @@ -15202,6 +15202,49 @@ func (obj *pgxImpl) All_NodeTags_By_NodeId(ctx context.Context, } +func (obj *pgxImpl) All_NodeTags(ctx context.Context) ( + rows []*NodeTags, err error) { + defer mon.Task()(&ctx)(&err) + + var __embed_stmt = __sqlbundle_Literal("SELECT node_tags.node_id, node_tags.name, node_tags.value, node_tags.signed_at, node_tags.signer FROM node_tags") + + var __values []interface{} + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + for { + rows, err = func() (rows []*NodeTags, err error) { + __rows, err := obj.driver.QueryContext(ctx, __stmt, __values...) + if err != nil { + return nil, err + } + defer __rows.Close() + + for __rows.Next() { + node_tags := &NodeTags{} + err = __rows.Scan(&node_tags.NodeId, &node_tags.Name, &node_tags.Value, &node_tags.SignedAt, &node_tags.Signer) + if err != nil { + return nil, err + } + rows = append(rows, node_tags) + } + if err := __rows.Err(); err != nil { + return nil, err + } + return rows, nil + }() + if err != nil { + if obj.shouldRetry(err) { + continue + } + return nil, obj.makeErr(err) + } + return rows, nil + } + +} + func (obj *pgxImpl) Get_StoragenodePaystub_By_NodeId_And_Period(ctx context.Context, storagenode_paystub_node_id StoragenodePaystub_NodeId_Field, storagenode_paystub_period StoragenodePaystub_Period_Field) ( @@ -23294,6 +23337,49 @@ func (obj *pgxcockroachImpl) All_NodeTags_By_NodeId(ctx context.Context, } +func (obj *pgxcockroachImpl) All_NodeTags(ctx context.Context) ( + rows []*NodeTags, err error) { + defer mon.Task()(&ctx)(&err) + + var __embed_stmt = __sqlbundle_Literal("SELECT node_tags.node_id, node_tags.name, node_tags.value, node_tags.signed_at, node_tags.signer FROM node_tags") + + var __values []interface{} + + var __stmt = __sqlbundle_Render(obj.dialect, __embed_stmt) + obj.logStmt(__stmt, __values...) + + for { + rows, err = func() (rows []*NodeTags, err error) { + __rows, err := obj.driver.QueryContext(ctx, __stmt, __values...) + if err != nil { + return nil, err + } + defer __rows.Close() + + for __rows.Next() { + node_tags := &NodeTags{} + err = __rows.Scan(&node_tags.NodeId, &node_tags.Name, &node_tags.Value, &node_tags.SignedAt, &node_tags.Signer) + if err != nil { + return nil, err + } + rows = append(rows, node_tags) + } + if err := __rows.Err(); err != nil { + return nil, err + } + return rows, nil + }() + if err != nil { + if obj.shouldRetry(err) { + continue + } + return nil, obj.makeErr(err) + } + return rows, nil + } + +} + func (obj *pgxcockroachImpl) Get_StoragenodePaystub_By_NodeId_And_Period(ctx context.Context, storagenode_paystub_node_id StoragenodePaystub_NodeId_Field, storagenode_paystub_period StoragenodePaystub_Period_Field) ( @@ -28350,6 +28436,15 @@ func (rx *Rx) All_CoinpaymentsTransaction_By_UserId_OrderBy_Desc_CreatedAt(ctx c return tx.All_CoinpaymentsTransaction_By_UserId_OrderBy_Desc_CreatedAt(ctx, coinpayments_transaction_user_id) } +func (rx *Rx) All_NodeTags(ctx context.Context) ( + rows []*NodeTags, err error) { + var tx *Tx + if tx, err = rx.getTx(ctx); err != nil { + return + } + return tx.All_NodeTags(ctx) +} + func (rx *Rx) All_NodeTags_By_NodeId(ctx context.Context, node_tags_node_id NodeTags_NodeId_Field) ( rows []*NodeTags, err error) { @@ -30374,6 +30469,9 @@ type Methods interface { coinpayments_transaction_user_id CoinpaymentsTransaction_UserId_Field) ( rows []*CoinpaymentsTransaction, err error) + All_NodeTags(ctx context.Context) ( + rows []*NodeTags, err error) + All_NodeTags_By_NodeId(ctx context.Context, node_tags_node_id NodeTags_NodeId_Field) ( rows []*NodeTags, err error) diff --git a/satellite/satellitedb/overlaycache.go b/satellite/satellitedb/overlaycache.go index 9380976be..4459c42c4 100644 --- a/satellite/satellitedb/overlaycache.go +++ b/satellite/satellitedb/overlaycache.go @@ -48,6 +48,12 @@ func (cache *overlaycache) SelectAllStorageNodesUpload(ctx context.Context, sele } return reputable, new, err } + + err = cache.addNodeTags(ctx, append(reputable, new...)) + if err != nil { + return reputable, new, err + } + break } @@ -128,11 +134,17 @@ func (cache *overlaycache) SelectAllStorageNodesDownload(ctx context.Context, on for { nodes, err = cache.selectAllStorageNodesDownload(ctx, onlineWindow, asOf) if err != nil { + if cockroachutil.NeedsRetry(err) { continue } return nodes, err } + + err = cache.addNodeTags(ctx, nodes) + if err != nil { + return nodes, err + } break } @@ -1630,3 +1642,35 @@ func (cache *overlaycache) GetNodeTags(ctx context.Context, id storj.NodeID) (up } return tags, err } + +func (cache *overlaycache) addNodeTags(ctx context.Context, nodes []*uploadselection.SelectedNode) error { + rows, err := cache.db.All_NodeTags(ctx) + if err != nil { + return Error.Wrap(err) + } + + tagsByNode := map[storj.NodeID]uploadselection.NodeTags{} + for _, row := range rows { + nodeID, err := storj.NodeIDFromBytes(row.NodeId) + if err != nil { + return Error.New("Invalid nodeID in the database: %x", row.NodeId) + } + signerID, err := storj.NodeIDFromBytes(row.Signer) + if err != nil { + return Error.New("Invalid nodeID in the database: %x", row.NodeId) + } + tagsByNode[nodeID] = append(tagsByNode[nodeID], uploadselection.NodeTag{ + NodeID: nodeID, + Name: row.Name, + Value: row.Value, + SignedAt: row.SignedAt, + Signer: signerID, + }) + + } + + for _, node := range nodes { + node.Tags = tagsByNode[node.ID] + } + return nil +} diff --git a/satellite/satellitedb/overlaycache_test.go b/satellite/satellitedb/overlaycache_test.go index 71fbd6fbe..73f31b878 100644 --- a/satellite/satellitedb/overlaycache_test.go +++ b/satellite/satellitedb/overlaycache_test.go @@ -8,11 +8,14 @@ import ( "encoding/binary" "math/rand" "net" + "strconv" + "strings" "testing" "time" "github.com/stretchr/testify/require" + "storj.io/common/identity/testidentity" "storj.io/common/pb" "storj.io/common/storj" "storj.io/common/storj/location" @@ -363,6 +366,8 @@ func TestGetNodesNetwork(t *testing.T) { func TestOverlayCache_SelectAllStorageNodesDownloadUpload(t *testing.T) { satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) { + tagSigner := testidentity.MustPregeneratedIdentity(0, storj.LatestIDVersion()) + cache := db.OverlayCache() const netMask = 28 mask := net.CIDRMask(netMask, 32) @@ -385,6 +390,19 @@ func TestOverlayCache_SelectAllStorageNodesDownloadUpload(t *testing.T) { } err := cache.UpdateCheckIn(ctx, infos[n], time.Now().UTC(), overlay.NodeSelectionConfig{}) require.NoError(t, err) + + if n%2 == 0 { + err = cache.UpdateNodeTags(ctx, uploadselection.NodeTags{ + uploadselection.NodeTag{ + NodeID: id, + SignedAt: time.Now(), + Signer: tagSigner.ID, + Name: "even", + Value: []byte{1}, + }, + }) + require.NoError(t, err) + } } checkNodes := func(selectedNodes []*uploadselection.SelectedNode) { @@ -402,6 +420,16 @@ func TestOverlayCache_SelectAllStorageNodesDownloadUpload(t *testing.T) { require.Equal(t, info.CountryCode, selectedNode.CountryCode) require.Equal(t, info.LastIPPort, selectedNode.LastIPPort) require.Equal(t, info.LastNet, selectedNode.LastNet) + segments := strings.Split(selectedNode.Address.Address, ".") + origIndex, err := strconv.Atoi(segments[len(segments)-1]) + require.NoError(t, err) + if origIndex%2 == 0 { + require.Len(t, selectedNode.Tags, 1) + require.Equal(t, "even", selectedNode.Tags[0].Name) + require.Equal(t, []byte{1}, selectedNode.Tags[0].Value) + } else { + require.Len(t, selectedNode.Tags, 0) + } } }