diff --git a/pkg/piecestore/psserver/config.go b/pkg/piecestore/psserver/config.go index e15dedf9a..33aa38907 100644 --- a/pkg/piecestore/psserver/config.go +++ b/pkg/piecestore/psserver/config.go @@ -18,6 +18,8 @@ var ( // Config contains everything necessary for a server type Config struct { Path string `help:"path to store data in" default:"$CONFDIR/storage"` + WhitelistedSatelliteIDs string `help:"a comma-separated list of approved satellite node ids" default:""` + SatelliteIDRestriction bool `help:"if true, only allow data from approved satellites" default:"false"` AllocatedDiskSpace memory.Size `user:"true" help:"total allocated disk space in bytes" default:"1TB"` AllocatedBandwidth memory.Size `user:"true" help:"total allocated bandwidth in bytes" default:"500GiB"` KBucketRefreshInterval time.Duration `help:"how frequently Kademlia bucket should be refreshed with node stats" default:"1h0m0s"` diff --git a/pkg/piecestore/psserver/readerwriter.go b/pkg/piecestore/psserver/readerwriter.go index 5d503e45d..d270d64c3 100644 --- a/pkg/piecestore/psserver/readerwriter.go +++ b/pkg/piecestore/psserver/readerwriter.go @@ -81,6 +81,13 @@ func NewStreamReader(s *Server, stream pb.PieceStoreRoutes_StoreServer, bandwidt return nil, err } + // if whitelist does not contain PBA satellite ID, reject storage request + if len(s.whitelist) != 0 { + if !s.approved(pbaData.SatelliteId) { + return nil, StoreError.New("Satellite ID not approved") + } + } + // Update bandwidthallocation to be stored if deserializedData.GetTotal() > sr.currentTotal { sr.bandwidthAllocation = ba diff --git a/pkg/piecestore/psserver/server.go b/pkg/piecestore/psserver/server.go index 3967d9abe..1ec9569dc 100644 --- a/pkg/piecestore/psserver/server.go +++ b/pkg/piecestore/psserver/server.go @@ -30,6 +30,7 @@ import ( pstore "storj.io/storj/pkg/piecestore" "storj.io/storj/pkg/piecestore/psserver/psdb" "storj.io/storj/pkg/provider" + "storj.io/storj/pkg/storj" ) var ( @@ -64,6 +65,7 @@ type Server struct { pkey crypto.PrivateKey totalAllocated int64 // TODO: use memory.Size totalBwAllocated int64 // TODO: use memory.Size + whitelist []storj.NodeID verifier auth.SignedMessageVerifier kad *kademlia.Kademlia } @@ -121,6 +123,18 @@ func NewEndpoint(log *zap.Logger, config Config, storage *pstore.Storage, db *ps log.Warn("Disk space is less than requested. Allocating space", zap.Int64("bytes", allocatedDiskSpace)) } + // parse the comma separated list of approved satellite IDs into an array of storj.NodeIDs + var whitelist []storj.NodeID + if config.SatelliteIDRestriction { + idStrings := strings.Split(config.WhitelistedSatelliteIDs, ",") + for i, s := range idStrings { + whitelist[i], err = storj.NodeIDFromString(s) + if err != nil { + return nil, err + } + } + } + return &Server{ startTime: time.Now(), log: log, @@ -129,6 +143,7 @@ func NewEndpoint(log *zap.Logger, config Config, storage *pstore.Storage, db *ps pkey: pkey, totalAllocated: allocatedDiskSpace, totalBwAllocated: allocatedBandwidth, + whitelist: whitelist, verifier: auth.NewSignedMessageVerifier(), kad: k, }, nil @@ -297,6 +312,16 @@ func (s *Server) verifyPayerAllocation(pba *pb.PayerBandwidthAllocation_Data, ac return nil } +// approved returns true if a node ID exists in a list of approved node IDs +func (s *Server) approved(id storj.NodeID) bool { + for _, n := range s.whitelist { + if n == id { + return true + } + } + return false +} + func getBeginningOfMonth() time.Time { t := time.Now() y, m, _ := t.Date() diff --git a/pkg/piecestore/psserver/server_test.go b/pkg/piecestore/psserver/server_test.go index b6e423b63..ad2dd1e60 100644 --- a/pkg/piecestore/psserver/server_test.go +++ b/pkg/piecestore/psserver/server_test.go @@ -51,7 +51,7 @@ func TestPiece(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - TS := NewTestServer(t) + TS := NewTestServer(t, []storj.NodeID{}) defer TS.Stop() if err := TS.writeFile("11111111111111111111"); err != nil { @@ -138,7 +138,7 @@ func TestRetrieve(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - TS := NewTestServer(t) + TS := NewTestServer(t, []storj.NodeID{}) defer TS.Stop() if err := TS.writeFile("11111111111111111111"); err != nil { @@ -300,21 +300,22 @@ func TestStore(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - TS := NewTestServer(t) - defer TS.Stop() - - db := TS.s.DB.DB + satID := teststorj.NodeIDFromString("satelliteid") tests := []struct { id string + satelliteID storj.NodeID + whitelist []storj.NodeID ttl int64 content []byte message string totalReceived int64 err string }{ - { // should successfully store data + { // should successfully store data with no approved satellites id: "99999999999999999999", + satelliteID: satID, + whitelist: []storj.NodeID{}, ttl: 9999999999, content: []byte("xyzwq"), message: "OK", @@ -323,6 +324,8 @@ func TestStore(t *testing.T) { }, { // should err with invalid id length id: "butts", + satelliteID: satID, + whitelist: []storj.NodeID{satID}, ttl: 9999999999, content: []byte("xyzwq"), message: "", @@ -331,6 +334,8 @@ func TestStore(t *testing.T) { }, { // should err with piece ID not specified id: "", + satelliteID: satID, + whitelist: []storj.NodeID{satID}, ttl: 9999999999, content: []byte("xyzwq"), message: "", @@ -341,6 +346,10 @@ func TestStore(t *testing.T) { for _, tt := range tests { t.Run("should return expected PieceStoreSummary values", func(t *testing.T) { + TS := NewTestServer(t, tt.whitelist) + db := TS.s.DB.DB + defer TS.Stop() + assert := assert.New(t) stream, err := TS.c.Store(ctx) assert.NoError(err) @@ -350,7 +359,7 @@ func TestStore(t *testing.T) { assert.NoError(err) pbad := &pb.PayerBandwidthAllocation_Data{ - SatelliteId: teststorj.NodeIDFromString("satelliteid"), + SatelliteId: tt.satelliteID, UplinkId: teststorj.NodeIDFromString("uplinkid"), Action: pb.PayerBandwidthAllocation_PUT, } @@ -428,37 +437,64 @@ func TestPbaValidation(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - TS := NewTestServer(t) - defer TS.Stop() - tests := []struct { satelliteID storj.NodeID uplinkID storj.NodeID + whitelist []storj.NodeID action pb.PayerBandwidthAllocation_Action err string }{ + { // unapproved satellite id + satelliteID: teststorj.NodeIDFromString("bad-satellite"), + uplinkID: teststorj.NodeIDFromString("uplinkid"), + whitelist: []storj.NodeID{ + teststorj.NodeIDFromString("satelliteid1"), + teststorj.NodeIDFromString("satelliteid2"), + teststorj.NodeIDFromString("satelliteid3"), + }, + action: pb.PayerBandwidthAllocation_PUT, + err: "rpc error: code = Unknown desc = store error: Satellite ID not approved", + }, { // missing satellite id satelliteID: storj.NodeID{}, uplinkID: teststorj.NodeIDFromString("uplinkid"), - action: pb.PayerBandwidthAllocation_PUT, - err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing satellite id", + whitelist: []storj.NodeID{ + teststorj.NodeIDFromString("satelliteid1"), + teststorj.NodeIDFromString("satelliteid2"), + teststorj.NodeIDFromString("satelliteid3"), + }, + action: pb.PayerBandwidthAllocation_PUT, + err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing satellite id", }, { // missing uplink id - satelliteID: teststorj.NodeIDFromString("satelliteid"), + satelliteID: teststorj.NodeIDFromString("satelliteid1"), uplinkID: storj.NodeID{}, - action: pb.PayerBandwidthAllocation_PUT, - err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing uplink id", + whitelist: []storj.NodeID{ + teststorj.NodeIDFromString("satelliteid1"), + teststorj.NodeIDFromString("satelliteid2"), + teststorj.NodeIDFromString("satelliteid3"), + }, + action: pb.PayerBandwidthAllocation_PUT, + err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing uplink id", }, { // wrong action type - satelliteID: teststorj.NodeIDFromString("satelliteid"), + satelliteID: teststorj.NodeIDFromString("satelliteid1"), uplinkID: teststorj.NodeIDFromString("uplinkid"), - action: pb.PayerBandwidthAllocation_GET, - err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: invalid action GET", + whitelist: []storj.NodeID{ + teststorj.NodeIDFromString("satelliteid1"), + teststorj.NodeIDFromString("satelliteid2"), + teststorj.NodeIDFromString("satelliteid3"), + }, + action: pb.PayerBandwidthAllocation_GET, + err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: invalid action GET", }, } for _, tt := range tests { t.Run("should validate payer bandwidth allocation struct", func(t *testing.T) { + TS := NewTestServer(t, tt.whitelist) + defer TS.Stop() + assert := assert.New(t) stream, err := TS.c.Store(ctx) assert.NoError(err) @@ -500,6 +536,8 @@ func TestPbaValidation(t *testing.T) { _, err = stream.CloseAndRecv() if err != nil { //assert.NotNil(err) + t.Log("Expected err string", tt.err) + t.Log("Actual err.Error:", err.Error()) assert.Equal(tt.err, err.Error()) return } @@ -511,7 +549,7 @@ func TestDelete(t *testing.T) { ctx := testcontext.New(t) defer ctx.Cleanup() - TS := NewTestServer(t) + TS := NewTestServer(t, []storj.NodeID{}) defer TS.Stop() db := TS.s.DB.DB @@ -584,7 +622,7 @@ func TestDelete(t *testing.T) { } } -func newTestServerStruct(t *testing.T) (*Server, func()) { +func newTestServerStruct(t *testing.T, ids []storj.NodeID) (*Server, func()) { tmp, err := ioutil.TempDir("", "storj-piecestore") if err != nil { log.Fatalf("failed temp-dir: %v", err) @@ -609,6 +647,7 @@ func newTestServerStruct(t *testing.T) (*Server, func()) { verifier: verifier, totalAllocated: math.MaxInt64, totalBwAllocated: math.MaxInt64, + whitelist: ids, } return server, func() { if serr := server.Stop(context.TODO()); serr != nil { @@ -642,7 +681,7 @@ type TestServer struct { k crypto.PrivateKey } -func NewTestServer(t *testing.T) *TestServer { +func NewTestServer(t *testing.T, ids []storj.NodeID) *TestServer { check := func(e error) { if !assert.NoError(t, e) { t.Fail() @@ -663,7 +702,7 @@ func NewTestServer(t *testing.T) *TestServer { co, err := fiC.DialOption(storj.NodeID{}) check(err) - s, cleanup := newTestServerStruct(t) + s, cleanup := newTestServerStruct(t, ids) grpcs := grpc.NewServer(so) k, ok := fiC.Key.(*ecdsa.PrivateKey)