storagenode approvedSatelliteIDs (#1116)

* add config fields for satellite restriction on psserver

* add whitelistedSatIDs to psserver Server struct

* check pbwa satellite ID against whitelist

* add whitelist to psserver tests

* reword help message, make approved() a method on server
This commit is contained in:
Cameron 2019-01-23 12:56:12 -05:00 committed by GitHub
parent 3fdb47e31c
commit 95d2d54fc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 23 deletions

View File

@ -18,6 +18,8 @@ var (
// Config contains everything necessary for a server // Config contains everything necessary for a server
type Config struct { type Config struct {
Path string `help:"path to store data in" default:"$CONFDIR/storage"` Path string `help:"path to store data in" default:"$CONFDIR/storage"`
WhitelistedSatelliteIDs string `help:"a comma-separated list of approved satellite node ids" default:""`
SatelliteIDRestriction bool `help:"if true, only allow data from approved satellites" default:"false"`
AllocatedDiskSpace memory.Size `user:"true" help:"total allocated disk space in bytes" default:"1TB"` AllocatedDiskSpace memory.Size `user:"true" help:"total allocated disk space in bytes" default:"1TB"`
AllocatedBandwidth memory.Size `user:"true" help:"total allocated bandwidth in bytes" default:"500GiB"` AllocatedBandwidth memory.Size `user:"true" help:"total allocated bandwidth in bytes" default:"500GiB"`
KBucketRefreshInterval time.Duration `help:"how frequently Kademlia bucket should be refreshed with node stats" default:"1h0m0s"` KBucketRefreshInterval time.Duration `help:"how frequently Kademlia bucket should be refreshed with node stats" default:"1h0m0s"`

View File

@ -81,6 +81,13 @@ func NewStreamReader(s *Server, stream pb.PieceStoreRoutes_StoreServer, bandwidt
return nil, err return nil, err
} }
// if whitelist does not contain PBA satellite ID, reject storage request
if len(s.whitelist) != 0 {
if !s.approved(pbaData.SatelliteId) {
return nil, StoreError.New("Satellite ID not approved")
}
}
// Update bandwidthallocation to be stored // Update bandwidthallocation to be stored
if deserializedData.GetTotal() > sr.currentTotal { if deserializedData.GetTotal() > sr.currentTotal {
sr.bandwidthAllocation = ba sr.bandwidthAllocation = ba

View File

@ -30,6 +30,7 @@ import (
pstore "storj.io/storj/pkg/piecestore" pstore "storj.io/storj/pkg/piecestore"
"storj.io/storj/pkg/piecestore/psserver/psdb" "storj.io/storj/pkg/piecestore/psserver/psdb"
"storj.io/storj/pkg/provider" "storj.io/storj/pkg/provider"
"storj.io/storj/pkg/storj"
) )
var ( var (
@ -64,6 +65,7 @@ type Server struct {
pkey crypto.PrivateKey pkey crypto.PrivateKey
totalAllocated int64 // TODO: use memory.Size totalAllocated int64 // TODO: use memory.Size
totalBwAllocated int64 // TODO: use memory.Size totalBwAllocated int64 // TODO: use memory.Size
whitelist []storj.NodeID
verifier auth.SignedMessageVerifier verifier auth.SignedMessageVerifier
kad *kademlia.Kademlia kad *kademlia.Kademlia
} }
@ -121,6 +123,18 @@ func NewEndpoint(log *zap.Logger, config Config, storage *pstore.Storage, db *ps
log.Warn("Disk space is less than requested. Allocating space", zap.Int64("bytes", allocatedDiskSpace)) log.Warn("Disk space is less than requested. Allocating space", zap.Int64("bytes", allocatedDiskSpace))
} }
// parse the comma separated list of approved satellite IDs into an array of storj.NodeIDs
var whitelist []storj.NodeID
if config.SatelliteIDRestriction {
idStrings := strings.Split(config.WhitelistedSatelliteIDs, ",")
for i, s := range idStrings {
whitelist[i], err = storj.NodeIDFromString(s)
if err != nil {
return nil, err
}
}
}
return &Server{ return &Server{
startTime: time.Now(), startTime: time.Now(),
log: log, log: log,
@ -129,6 +143,7 @@ func NewEndpoint(log *zap.Logger, config Config, storage *pstore.Storage, db *ps
pkey: pkey, pkey: pkey,
totalAllocated: allocatedDiskSpace, totalAllocated: allocatedDiskSpace,
totalBwAllocated: allocatedBandwidth, totalBwAllocated: allocatedBandwidth,
whitelist: whitelist,
verifier: auth.NewSignedMessageVerifier(), verifier: auth.NewSignedMessageVerifier(),
kad: k, kad: k,
}, nil }, nil
@ -297,6 +312,16 @@ func (s *Server) verifyPayerAllocation(pba *pb.PayerBandwidthAllocation_Data, ac
return nil return nil
} }
// approved returns true if a node ID exists in a list of approved node IDs
func (s *Server) approved(id storj.NodeID) bool {
for _, n := range s.whitelist {
if n == id {
return true
}
}
return false
}
func getBeginningOfMonth() time.Time { func getBeginningOfMonth() time.Time {
t := time.Now() t := time.Now()
y, m, _ := t.Date() y, m, _ := t.Date()

View File

@ -51,7 +51,7 @@ func TestPiece(t *testing.T) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
defer ctx.Cleanup() defer ctx.Cleanup()
TS := NewTestServer(t) TS := NewTestServer(t, []storj.NodeID{})
defer TS.Stop() defer TS.Stop()
if err := TS.writeFile("11111111111111111111"); err != nil { if err := TS.writeFile("11111111111111111111"); err != nil {
@ -138,7 +138,7 @@ func TestRetrieve(t *testing.T) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
defer ctx.Cleanup() defer ctx.Cleanup()
TS := NewTestServer(t) TS := NewTestServer(t, []storj.NodeID{})
defer TS.Stop() defer TS.Stop()
if err := TS.writeFile("11111111111111111111"); err != nil { if err := TS.writeFile("11111111111111111111"); err != nil {
@ -300,21 +300,22 @@ func TestStore(t *testing.T) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
defer ctx.Cleanup() defer ctx.Cleanup()
TS := NewTestServer(t) satID := teststorj.NodeIDFromString("satelliteid")
defer TS.Stop()
db := TS.s.DB.DB
tests := []struct { tests := []struct {
id string id string
satelliteID storj.NodeID
whitelist []storj.NodeID
ttl int64 ttl int64
content []byte content []byte
message string message string
totalReceived int64 totalReceived int64
err string err string
}{ }{
{ // should successfully store data { // should successfully store data with no approved satellites
id: "99999999999999999999", id: "99999999999999999999",
satelliteID: satID,
whitelist: []storj.NodeID{},
ttl: 9999999999, ttl: 9999999999,
content: []byte("xyzwq"), content: []byte("xyzwq"),
message: "OK", message: "OK",
@ -323,6 +324,8 @@ func TestStore(t *testing.T) {
}, },
{ // should err with invalid id length { // should err with invalid id length
id: "butts", id: "butts",
satelliteID: satID,
whitelist: []storj.NodeID{satID},
ttl: 9999999999, ttl: 9999999999,
content: []byte("xyzwq"), content: []byte("xyzwq"),
message: "", message: "",
@ -331,6 +334,8 @@ func TestStore(t *testing.T) {
}, },
{ // should err with piece ID not specified { // should err with piece ID not specified
id: "", id: "",
satelliteID: satID,
whitelist: []storj.NodeID{satID},
ttl: 9999999999, ttl: 9999999999,
content: []byte("xyzwq"), content: []byte("xyzwq"),
message: "", message: "",
@ -341,6 +346,10 @@ func TestStore(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run("should return expected PieceStoreSummary values", func(t *testing.T) { t.Run("should return expected PieceStoreSummary values", func(t *testing.T) {
TS := NewTestServer(t, tt.whitelist)
db := TS.s.DB.DB
defer TS.Stop()
assert := assert.New(t) assert := assert.New(t)
stream, err := TS.c.Store(ctx) stream, err := TS.c.Store(ctx)
assert.NoError(err) assert.NoError(err)
@ -350,7 +359,7 @@ func TestStore(t *testing.T) {
assert.NoError(err) assert.NoError(err)
pbad := &pb.PayerBandwidthAllocation_Data{ pbad := &pb.PayerBandwidthAllocation_Data{
SatelliteId: teststorj.NodeIDFromString("satelliteid"), SatelliteId: tt.satelliteID,
UplinkId: teststorj.NodeIDFromString("uplinkid"), UplinkId: teststorj.NodeIDFromString("uplinkid"),
Action: pb.PayerBandwidthAllocation_PUT, Action: pb.PayerBandwidthAllocation_PUT,
} }
@ -428,37 +437,64 @@ func TestPbaValidation(t *testing.T) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
defer ctx.Cleanup() defer ctx.Cleanup()
TS := NewTestServer(t)
defer TS.Stop()
tests := []struct { tests := []struct {
satelliteID storj.NodeID satelliteID storj.NodeID
uplinkID storj.NodeID uplinkID storj.NodeID
whitelist []storj.NodeID
action pb.PayerBandwidthAllocation_Action action pb.PayerBandwidthAllocation_Action
err string err string
}{ }{
{ // unapproved satellite id
satelliteID: teststorj.NodeIDFromString("bad-satellite"),
uplinkID: teststorj.NodeIDFromString("uplinkid"),
whitelist: []storj.NodeID{
teststorj.NodeIDFromString("satelliteid1"),
teststorj.NodeIDFromString("satelliteid2"),
teststorj.NodeIDFromString("satelliteid3"),
},
action: pb.PayerBandwidthAllocation_PUT,
err: "rpc error: code = Unknown desc = store error: Satellite ID not approved",
},
{ // missing satellite id { // missing satellite id
satelliteID: storj.NodeID{}, satelliteID: storj.NodeID{},
uplinkID: teststorj.NodeIDFromString("uplinkid"), uplinkID: teststorj.NodeIDFromString("uplinkid"),
action: pb.PayerBandwidthAllocation_PUT, whitelist: []storj.NodeID{
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing satellite id", teststorj.NodeIDFromString("satelliteid1"),
teststorj.NodeIDFromString("satelliteid2"),
teststorj.NodeIDFromString("satelliteid3"),
},
action: pb.PayerBandwidthAllocation_PUT,
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing satellite id",
}, },
{ // missing uplink id { // missing uplink id
satelliteID: teststorj.NodeIDFromString("satelliteid"), satelliteID: teststorj.NodeIDFromString("satelliteid1"),
uplinkID: storj.NodeID{}, uplinkID: storj.NodeID{},
action: pb.PayerBandwidthAllocation_PUT, whitelist: []storj.NodeID{
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing uplink id", teststorj.NodeIDFromString("satelliteid1"),
teststorj.NodeIDFromString("satelliteid2"),
teststorj.NodeIDFromString("satelliteid3"),
},
action: pb.PayerBandwidthAllocation_PUT,
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: missing uplink id",
}, },
{ // wrong action type { // wrong action type
satelliteID: teststorj.NodeIDFromString("satelliteid"), satelliteID: teststorj.NodeIDFromString("satelliteid1"),
uplinkID: teststorj.NodeIDFromString("uplinkid"), uplinkID: teststorj.NodeIDFromString("uplinkid"),
action: pb.PayerBandwidthAllocation_GET, whitelist: []storj.NodeID{
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: invalid action GET", teststorj.NodeIDFromString("satelliteid1"),
teststorj.NodeIDFromString("satelliteid2"),
teststorj.NodeIDFromString("satelliteid3"),
},
action: pb.PayerBandwidthAllocation_GET,
err: "rpc error: code = Unknown desc = store error: payer bandwidth allocation: invalid action GET",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run("should validate payer bandwidth allocation struct", func(t *testing.T) { t.Run("should validate payer bandwidth allocation struct", func(t *testing.T) {
TS := NewTestServer(t, tt.whitelist)
defer TS.Stop()
assert := assert.New(t) assert := assert.New(t)
stream, err := TS.c.Store(ctx) stream, err := TS.c.Store(ctx)
assert.NoError(err) assert.NoError(err)
@ -500,6 +536,8 @@ func TestPbaValidation(t *testing.T) {
_, err = stream.CloseAndRecv() _, err = stream.CloseAndRecv()
if err != nil { if err != nil {
//assert.NotNil(err) //assert.NotNil(err)
t.Log("Expected err string", tt.err)
t.Log("Actual err.Error:", err.Error())
assert.Equal(tt.err, err.Error()) assert.Equal(tt.err, err.Error())
return return
} }
@ -511,7 +549,7 @@ func TestDelete(t *testing.T) {
ctx := testcontext.New(t) ctx := testcontext.New(t)
defer ctx.Cleanup() defer ctx.Cleanup()
TS := NewTestServer(t) TS := NewTestServer(t, []storj.NodeID{})
defer TS.Stop() defer TS.Stop()
db := TS.s.DB.DB db := TS.s.DB.DB
@ -584,7 +622,7 @@ func TestDelete(t *testing.T) {
} }
} }
func newTestServerStruct(t *testing.T) (*Server, func()) { func newTestServerStruct(t *testing.T, ids []storj.NodeID) (*Server, func()) {
tmp, err := ioutil.TempDir("", "storj-piecestore") tmp, err := ioutil.TempDir("", "storj-piecestore")
if err != nil { if err != nil {
log.Fatalf("failed temp-dir: %v", err) log.Fatalf("failed temp-dir: %v", err)
@ -609,6 +647,7 @@ func newTestServerStruct(t *testing.T) (*Server, func()) {
verifier: verifier, verifier: verifier,
totalAllocated: math.MaxInt64, totalAllocated: math.MaxInt64,
totalBwAllocated: math.MaxInt64, totalBwAllocated: math.MaxInt64,
whitelist: ids,
} }
return server, func() { return server, func() {
if serr := server.Stop(context.TODO()); serr != nil { if serr := server.Stop(context.TODO()); serr != nil {
@ -642,7 +681,7 @@ type TestServer struct {
k crypto.PrivateKey k crypto.PrivateKey
} }
func NewTestServer(t *testing.T) *TestServer { func NewTestServer(t *testing.T, ids []storj.NodeID) *TestServer {
check := func(e error) { check := func(e error) {
if !assert.NoError(t, e) { if !assert.NoError(t, e) {
t.Fail() t.Fail()
@ -663,7 +702,7 @@ func NewTestServer(t *testing.T) *TestServer {
co, err := fiC.DialOption(storj.NodeID{}) co, err := fiC.DialOption(storj.NodeID{})
check(err) check(err)
s, cleanup := newTestServerStruct(t) s, cleanup := newTestServerStruct(t, ids)
grpcs := grpc.NewServer(so) grpcs := grpc.NewServer(so)
k, ok := fiC.Key.(*ecdsa.PrivateKey) k, ok := fiC.Key.(*ecdsa.PrivateKey)