Add functions for signing and verifying during bandwidth exchange (#246)
* Added initial functions for signing and verifying * whoops * Get client up to speed * Added initial functions for signing and verifying * whoops * Get client up to speed * wip * wip * actual signatures in tests (cherry picked from commit 1464853b737f1d712d64fbf90147f535525c8fd9) * bugfixing * Generate private key in example * Generate signatures for pieceranger tests * Update examples to use TLS * Use private key from identity inside of example * Use crypto.PrivateKey interface * Change err name in defers * Pass tests * Pass identity Key to PSClient * Get tests passing on travis * Resolve linter complaints
This commit is contained in:
parent
75de358740
commit
899e1e68f1
@ -5,6 +5,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@ -17,23 +18,39 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"storj.io/storj/pkg/piecestore/rpc/client"
|
||||
"storj.io/storj/pkg/provider"
|
||||
pb "storj.io/storj/protos/piecestore"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
var argError = errs.Class("argError")
|
||||
|
||||
func main() {
|
||||
|
||||
app := cli.NewApp()
|
||||
|
||||
ca, err := provider.NewCA(ctx, 12, 4)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
identity, err := ca.NewIdentity()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
identOpt, err := identity.DialOption()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Set up connection with rpc server
|
||||
var conn *grpc.ClientConn
|
||||
conn, err := grpc.Dial(":7777", grpc.WithInsecure())
|
||||
conn, err = grpc.Dial(":7777", identOpt)
|
||||
if err != nil {
|
||||
log.Fatalf("did not connect: %s", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
psClient, err := client.NewPSClient(conn, 1024*32)
|
||||
|
||||
psClient, err := client.NewPSClient(conn, 1024*32, identity.Key.(*ecdsa.PrivateKey))
|
||||
if err != nil {
|
||||
log.Fatalf("could not initialize PSClient: %s", err)
|
||||
}
|
||||
@ -127,7 +144,6 @@ func main() {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
rr, err := psClient.Get(ctx, client.PieceID(c.Args().Get(id)), pieceInfo.Size, &pb.PayerBandwidthAllocation{})
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to retrieve file of id: %s\n", c.Args().Get(id))
|
||||
|
@ -130,7 +130,7 @@ func (c Config) NewGateway(ctx context.Context,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ec := ecclient.NewClient(t, c.MaxBufferMem)
|
||||
ec := ecclient.NewClient(identity, t, c.MaxBufferMem)
|
||||
fc, err := infectious.NewFEC(c.MinThreshold, c.MaxThreshold)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -5,6 +5,8 @@ package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -16,6 +18,8 @@ import (
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/gtank/cryptopasta"
|
||||
|
||||
"storj.io/storj/pkg/ranger"
|
||||
pb "storj.io/storj/protos/piecestore"
|
||||
)
|
||||
@ -45,11 +49,12 @@ type PSClient interface {
|
||||
type Client struct {
|
||||
route pb.PieceStoreRoutesClient
|
||||
conn *grpc.ClientConn
|
||||
prikey crypto.PrivateKey
|
||||
bandwidthMsgSize int
|
||||
}
|
||||
|
||||
// NewPSClient initilizes a PSClient
|
||||
func NewPSClient(conn *grpc.ClientConn, bandwidthMsgSize int) (PSClient, error) {
|
||||
func NewPSClient(conn *grpc.ClientConn, bandwidthMsgSize int, prikey crypto.PrivateKey) (PSClient, error) {
|
||||
if bandwidthMsgSize < 0 || bandwidthMsgSize > *maxBandwidthMsgSize {
|
||||
return nil, ClientError.New(fmt.Sprintf("Invalid Bandwidth Message Size: %v", bandwidthMsgSize))
|
||||
}
|
||||
@ -62,11 +67,12 @@ func NewPSClient(conn *grpc.ClientConn, bandwidthMsgSize int) (PSClient, error)
|
||||
conn: conn,
|
||||
route: pb.NewPieceStoreRoutesClient(conn),
|
||||
bandwidthMsgSize: bandwidthMsgSize,
|
||||
prikey: prikey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewCustomRoute creates new Client with custom route interface
|
||||
func NewCustomRoute(route pb.PieceStoreRoutesClient, bandwidthMsgSize int) (*Client, error) {
|
||||
func NewCustomRoute(route pb.PieceStoreRoutesClient, bandwidthMsgSize int, prikey crypto.PrivateKey) (*Client, error) {
|
||||
if bandwidthMsgSize < 0 || bandwidthMsgSize > *maxBandwidthMsgSize {
|
||||
return nil, ClientError.New(fmt.Sprintf("Invalid Bandwidth Message Size: %v", bandwidthMsgSize))
|
||||
}
|
||||
@ -78,6 +84,7 @@ func NewCustomRoute(route pb.PieceStoreRoutesClient, bandwidthMsgSize int) (*Cli
|
||||
return &Client{
|
||||
route: route,
|
||||
bandwidthMsgSize: bandwidthMsgSize,
|
||||
prikey: prikey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -152,7 +159,10 @@ func (client *Client) Delete(ctx context.Context, id PieceID) error {
|
||||
|
||||
// sign a message using the clients private key
|
||||
func (client *Client) sign(msg []byte) (signature []byte, err error) {
|
||||
// use c.pkey to sign msg
|
||||
|
||||
return signature, err
|
||||
if client.prikey == nil {
|
||||
return nil, ClientError.New("Failed to sign msg: Private Key not Set")
|
||||
}
|
||||
|
||||
// use c.pkey to sign msg
|
||||
return cryptopasta.Sign(msg, client.prikey.(*ecdsa.PrivateKey))
|
||||
}
|
||||
|
@ -5,12 +5,14 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
@ -42,6 +44,9 @@ func TestPieceRanger(t *testing.T) {
|
||||
} {
|
||||
errTag := fmt.Sprintf("Test case #%d", i)
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.Nil(t, err)
|
||||
|
||||
route := pb.NewMockPieceStoreRoutesClient(ctrl)
|
||||
|
||||
calls := []*gomock.Call{
|
||||
@ -54,46 +59,29 @@ func TestPieceRanger(t *testing.T) {
|
||||
pid := NewPieceID()
|
||||
|
||||
if tt.offset >= 0 && tt.length > 0 && tt.offset+tt.length <= tt.size {
|
||||
calls = append(calls,
|
||||
stream.EXPECT().Send(
|
||||
&pb.PieceRetrieval{
|
||||
msg1 := &pb.PieceRetrieval{
|
||||
PieceData: &pb.PieceRetrieval_PieceData{
|
||||
Id: pid.String(), Size: tt.length, Offset: tt.offset,
|
||||
},
|
||||
},
|
||||
).Return(nil),
|
||||
stream.EXPECT().Send(
|
||||
&pb.PieceRetrieval{
|
||||
Bandwidthallocation: &pb.RenterBandwidthAllocation{
|
||||
Data: serializeData(&pb.RenterBandwidthAllocation_Data{
|
||||
PayerAllocation: &pb.PayerBandwidthAllocation{},
|
||||
Total: 32 * 1024,
|
||||
}),
|
||||
},
|
||||
},
|
||||
).Return(nil),
|
||||
}
|
||||
|
||||
calls = append(calls,
|
||||
stream.EXPECT().Send(msg1).Return(nil),
|
||||
stream.EXPECT().Send(gomock.Any()).Return(nil),
|
||||
stream.EXPECT().Recv().Return(
|
||||
&pb.PieceRetrievalStream{
|
||||
Size: tt.length,
|
||||
Content: []byte(tt.data)[tt.offset : tt.offset+tt.length],
|
||||
}, nil),
|
||||
stream.EXPECT().Send(
|
||||
&pb.PieceRetrieval{
|
||||
Bandwidthallocation: &pb.RenterBandwidthAllocation{
|
||||
Data: serializeData(&pb.RenterBandwidthAllocation_Data{
|
||||
PayerAllocation: &pb.PayerBandwidthAllocation{},
|
||||
Total: 32 * 1024 * 2,
|
||||
}),
|
||||
},
|
||||
},
|
||||
).Return(nil),
|
||||
stream.EXPECT().Send(gomock.Any()).Return(nil),
|
||||
stream.EXPECT().Recv().Return(&pb.PieceRetrievalStream{}, io.EOF),
|
||||
)
|
||||
}
|
||||
gomock.InOrder(calls...)
|
||||
|
||||
ctx := context.Background()
|
||||
c, err := NewCustomRoute(route, 32*1024)
|
||||
|
||||
c, err := NewCustomRoute(route, 32*1024, priv)
|
||||
assert.NoError(t, err)
|
||||
rr, err := PieceRanger(ctx, c, stream, pid, &pb.PayerBandwidthAllocation{})
|
||||
if assert.NoError(t, err, errTag) {
|
||||
@ -142,45 +130,32 @@ func TestPieceRangerSize(t *testing.T) {
|
||||
|
||||
stream := pb.NewMockPieceStoreRoutes_RetrieveClient(ctrl)
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if tt.offset >= 0 && tt.length > 0 && tt.offset+tt.length <= tt.size {
|
||||
gomock.InOrder(
|
||||
stream.EXPECT().Send(
|
||||
&pb.PieceRetrieval{
|
||||
msg1 := &pb.PieceRetrieval{
|
||||
PieceData: &pb.PieceRetrieval_PieceData{
|
||||
Id: pid.String(), Size: tt.length, Offset: tt.offset,
|
||||
},
|
||||
},
|
||||
).Return(nil),
|
||||
stream.EXPECT().Send(
|
||||
&pb.PieceRetrieval{Bandwidthallocation: &pb.RenterBandwidthAllocation{
|
||||
Data: serializeData(&pb.RenterBandwidthAllocation_Data{
|
||||
PayerAllocation: &pb.PayerBandwidthAllocation{},
|
||||
Total: 32 * 1024,
|
||||
}),
|
||||
},
|
||||
},
|
||||
).Return(nil),
|
||||
}
|
||||
|
||||
gomock.InOrder(
|
||||
stream.EXPECT().Send(msg1).Return(nil),
|
||||
stream.EXPECT().Send(gomock.Any()).Return(nil),
|
||||
stream.EXPECT().Recv().Return(
|
||||
&pb.PieceRetrievalStream{
|
||||
Size: tt.length,
|
||||
Content: []byte(tt.data)[tt.offset : tt.offset+tt.length],
|
||||
}, nil),
|
||||
stream.EXPECT().Send(
|
||||
&pb.PieceRetrieval{
|
||||
Bandwidthallocation: &pb.RenterBandwidthAllocation{
|
||||
Data: serializeData(&pb.RenterBandwidthAllocation_Data{
|
||||
PayerAllocation: &pb.PayerBandwidthAllocation{},
|
||||
Total: 32 * 1024 * 2,
|
||||
}),
|
||||
},
|
||||
},
|
||||
).Return(nil),
|
||||
stream.EXPECT().Send(gomock.Any()).Return(nil),
|
||||
stream.EXPECT().Recv().Return(&pb.PieceRetrievalStream{}, io.EOF),
|
||||
)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
c, err := NewCustomRoute(route, 32*1024)
|
||||
|
||||
c, err := NewCustomRoute(route, 32*1024, priv)
|
||||
assert.NoError(t, err)
|
||||
rr := PieceRangerSize(c, stream, pid, tt.size, &pb.PayerBandwidthAllocation{})
|
||||
assert.Equal(t, tt.size, rr.Size(), errTag)
|
||||
@ -196,9 +171,3 @@ func TestPieceRangerSize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func serializeData(ba *pb.RenterBandwidthAllocation_Data) []byte {
|
||||
data, _ := proto.Marshal(ba)
|
||||
|
||||
return data
|
||||
}
|
||||
|
@ -51,7 +51,7 @@ func NewStreamReader(s *Server, stream pb.PieceStoreRoutes_StoreServer) *StreamR
|
||||
ba := recv.GetBandwidthallocation()
|
||||
|
||||
if ba != nil {
|
||||
if err = s.verifySignature(ba); err != nil {
|
||||
if err = s.verifySignature(stream.Context(), ba); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -87,9 +87,9 @@ func (s *Server) retrieveData(ctx context.Context, stream pb.PieceStoreRoutes_Re
|
||||
|
||||
// Save latest bandwidth allocation even if we fail
|
||||
defer func() {
|
||||
err := s.DB.WriteBandwidthAllocToDB(latestBA)
|
||||
if latestBA != nil && err != nil {
|
||||
log.Println("WriteBandwidthAllocToDB Error:", err)
|
||||
baWriteErr := s.DB.WriteBandwidthAllocToDB(latestBA)
|
||||
if latestBA != nil && baWriteErr != nil {
|
||||
log.Println("WriteBandwidthAllocToDB Error:", baWriteErr)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -108,7 +108,7 @@ func (s *Server) retrieveData(ctx context.Context, stream pb.PieceStoreRoutes_Re
|
||||
return am.Used, am.TotalAllocated, err
|
||||
}
|
||||
|
||||
if err = s.verifySignature(ba); err != nil {
|
||||
if err = s.verifySignature(ctx, ba); err != nil {
|
||||
return am.Used, am.TotalAllocated, err
|
||||
}
|
||||
|
||||
|
@ -4,14 +4,19 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/gtank/cryptopasta"
|
||||
"github.com/zeebo/errs"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/net/context"
|
||||
monkit "gopkg.in/spacemonkeygo/monkit.v2"
|
||||
"gopkg.in/spacemonkeygo/monkit.v2"
|
||||
|
||||
"storj.io/storj/pkg/peertls"
|
||||
"storj.io/storj/pkg/piecestore"
|
||||
"storj.io/storj/pkg/piecestore/rpc/server/psdb"
|
||||
"storj.io/storj/pkg/provider"
|
||||
@ -20,6 +25,9 @@ import (
|
||||
|
||||
var (
|
||||
mon = monkit.Package()
|
||||
|
||||
// ServerError wraps errors returned from Server struct methods
|
||||
ServerError = errs.Class("PSServer error")
|
||||
)
|
||||
|
||||
// Config contains everything necessary for a server
|
||||
@ -31,7 +39,7 @@ type Config struct {
|
||||
func (c Config) Run(ctx context.Context, server *provider.Provider) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
s, err := Initialize(ctx, c)
|
||||
s, err := Initialize(ctx, c, server.Identity().Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -54,10 +62,11 @@ func (c Config) Run(ctx context.Context, server *provider.Provider) (err error)
|
||||
type Server struct {
|
||||
DataDir string
|
||||
DB *psdb.PSDB
|
||||
pkey crypto.PrivateKey
|
||||
}
|
||||
|
||||
// Initialize -- initializes a server struct
|
||||
func Initialize(ctx context.Context, config Config) (*Server, error) {
|
||||
func Initialize(ctx context.Context, config Config, pkey crypto.PrivateKey) (*Server, error) {
|
||||
dbPath := filepath.Join(config.Path, "piecestore.db")
|
||||
dataDir := filepath.Join(config.Path, "piece-store-data")
|
||||
|
||||
@ -66,7 +75,7 @@ func Initialize(ctx context.Context, config Config) (*Server, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Server{DataDir: dataDir, DB: psDB}, nil
|
||||
return &Server{DataDir: dataDir, DB: psDB, pkey: pkey}, nil
|
||||
}
|
||||
|
||||
// Stop the piececstore node
|
||||
@ -124,12 +133,19 @@ func (s *Server) deleteByID(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) verifySignature(ba *pb.RenterBandwidthAllocation) error {
|
||||
// TODO: verify signature
|
||||
func (s *Server) verifySignature(ctx context.Context, ba *pb.RenterBandwidthAllocation) error {
|
||||
pi, err := provider.PeerIdentityFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// data := ba.GetData()
|
||||
// signature := ba.GetSignature()
|
||||
log.Printf("Verified signature\n")
|
||||
k, ok := pi.Leaf.PublicKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
return peertls.ErrUnsupportedKey.New("%T", pi.Leaf.PublicKey)
|
||||
}
|
||||
|
||||
if ok := cryptopasta.Verify(ba.GetData(), ba.GetSignature(), k); !ok {
|
||||
return ServerError.New("Failed to verify Signature")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -5,6 +5,8 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -19,15 +21,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/gtank/cryptopasta"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
pstore "storj.io/storj/pkg/piecestore"
|
||||
"storj.io/storj/pkg/piecestore/rpc/server/psdb"
|
||||
"storj.io/storj/pkg/provider"
|
||||
pb "storj.io/storj/protos/piecestore"
|
||||
)
|
||||
|
||||
@ -48,7 +50,7 @@ func writeFileToDir(name, dir string) error {
|
||||
}
|
||||
|
||||
func TestPiece(t *testing.T) {
|
||||
TS := NewTestServer()
|
||||
TS := NewTestServer(t)
|
||||
defer TS.Stop()
|
||||
|
||||
if err := writeFileToDir("11111111111111111111", TS.s.DataDir); err != nil {
|
||||
@ -118,7 +120,7 @@ func TestPiece(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRetrieve(t *testing.T) {
|
||||
TS := NewTestServer()
|
||||
TS := NewTestServer(t)
|
||||
defer TS.Stop()
|
||||
|
||||
// simulate piece stored with storagenode
|
||||
@ -230,15 +232,21 @@ func TestRetrieve(t *testing.T) {
|
||||
for totalAllocated < tt.respSize {
|
||||
// Send bandwidth bandwidthAllocation
|
||||
totalAllocated += tt.allocSize
|
||||
err = stream.Send(
|
||||
&pb.PieceRetrieval{
|
||||
Bandwidthallocation: &pb.RenterBandwidthAllocation{
|
||||
Signature: []byte{'A', 'B'},
|
||||
|
||||
ba := pb.RenterBandwidthAllocation{
|
||||
Data: serializeData(&pb.RenterBandwidthAllocation_Data{
|
||||
PayerAllocation: &pb.PayerBandwidthAllocation{},
|
||||
Total: totalAllocated,
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
s, err := cryptopasta.Sign(ba.Data, TS.k.(*ecdsa.PrivateKey))
|
||||
assert.NoError(err)
|
||||
ba.Signature = s
|
||||
|
||||
err = stream.Send(
|
||||
&pb.PieceRetrieval{
|
||||
Bandwidthallocation: &ba,
|
||||
},
|
||||
)
|
||||
assert.NoError(err)
|
||||
@ -254,8 +262,8 @@ func TestRetrieve(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
data = fmt.Sprintf("%s%s", data, string(resp.Content))
|
||||
totalRetrieved += resp.Size
|
||||
data = fmt.Sprintf("%s%s", data, string(resp.GetContent()))
|
||||
totalRetrieved += resp.GetSize()
|
||||
}
|
||||
|
||||
assert.NoError(err)
|
||||
@ -269,7 +277,7 @@ func TestRetrieve(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
TS := NewTestServer()
|
||||
TS := NewTestServer(t)
|
||||
defer TS.Stop()
|
||||
|
||||
db := TS.s.DB.DB
|
||||
@ -322,7 +330,6 @@ func TestStore(t *testing.T) {
|
||||
msg := &pb.PieceStore{
|
||||
Piecedata: &pb.PieceStore_PieceData{Content: tt.content},
|
||||
Bandwidthallocation: &pb.RenterBandwidthAllocation{
|
||||
Signature: []byte{'A', 'B'},
|
||||
Data: serializeData(&pb.RenterBandwidthAllocation_Data{
|
||||
PayerAllocation: &pb.PayerBandwidthAllocation{},
|
||||
Total: int64(len(tt.content)),
|
||||
@ -330,9 +337,12 @@ func TestStore(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Write the buffer to the stream we opened earlier
|
||||
err = stream.Send(msg)
|
||||
s, err := cryptopasta.Sign(msg.Bandwidthallocation.Data, TS.k.(*ecdsa.PrivateKey))
|
||||
assert.NoError(err)
|
||||
msg.Bandwidthallocation.Signature = s
|
||||
|
||||
// Write the buffer to the stream we opened earlier
|
||||
stream.Send(msg)
|
||||
|
||||
resp, err := stream.CloseAndRecv()
|
||||
if tt.err != "" {
|
||||
@ -378,7 +388,7 @@ func TestStore(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
TS := NewTestServer()
|
||||
TS := NewTestServer(t)
|
||||
defer TS.Stop()
|
||||
|
||||
db := TS.s.DB.DB
|
||||
@ -463,8 +473,8 @@ func newTestServerStruct() *Server {
|
||||
return &Server{DataDir: tempDir, DB: psDB}
|
||||
}
|
||||
|
||||
func connect(addr string) (pb.PieceStoreRoutesClient, *grpc.ClientConn) {
|
||||
conn, err := grpc.Dial(addr, grpc.WithInsecure())
|
||||
func connect(addr string, o ...grpc.DialOption) (pb.PieceStoreRoutesClient, *grpc.ClientConn) {
|
||||
conn, err := grpc.Dial(addr, o...)
|
||||
if err != nil {
|
||||
log.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
@ -479,15 +489,38 @@ type TestServer struct {
|
||||
grpcs *grpc.Server
|
||||
conn *grpc.ClientConn
|
||||
c pb.PieceStoreRoutesClient
|
||||
k crypto.PrivateKey
|
||||
}
|
||||
|
||||
func NewTestServer() *TestServer {
|
||||
s := newTestServerStruct()
|
||||
grpcs := grpc.NewServer()
|
||||
func NewTestServer(t *testing.T) *TestServer {
|
||||
check := func(e error) {
|
||||
if !assert.NoError(t, e) {
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
|
||||
ts := &TestServer{s: s, grpcs: grpcs}
|
||||
caS, err := provider.NewCA(context.Background(), 12, 4)
|
||||
check(err)
|
||||
fiS, err := caS.NewIdentity()
|
||||
check(err)
|
||||
so, err := fiS.ServerOption()
|
||||
check(err)
|
||||
|
||||
caC, err := provider.NewCA(context.Background(), 12, 4)
|
||||
check(err)
|
||||
fiC, err := caC.NewIdentity()
|
||||
check(err)
|
||||
co, err := fiC.DialOption()
|
||||
check(err)
|
||||
|
||||
s := newTestServerStruct()
|
||||
grpcs := grpc.NewServer(so)
|
||||
|
||||
k, ok := fiC.Key.(*ecdsa.PrivateKey)
|
||||
assert.True(t, ok)
|
||||
ts := &TestServer{s: s, grpcs: grpcs, k: k}
|
||||
addr := ts.start()
|
||||
ts.c, ts.conn = connect(addr)
|
||||
ts.c, ts.conn = connect(addr, co)
|
||||
|
||||
return ts
|
||||
}
|
||||
|
@ -65,8 +65,8 @@ func (s *Server) storeData(ctx context.Context, stream pb.PieceStoreRoutes_Store
|
||||
// Delete data if we error
|
||||
defer func() {
|
||||
if err != nil && err != io.EOF {
|
||||
if err = s.deleteByID(id); err != nil {
|
||||
log.Printf("Failed on deleteByID in Store: %s", err.Error())
|
||||
if deleteErr := s.deleteByID(id); deleteErr != nil {
|
||||
log.Printf("Failed on deleteByID in Store: %s", deleteErr.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
@ -82,9 +82,9 @@ func (s *Server) storeData(ctx context.Context, stream pb.PieceStoreRoutes_Store
|
||||
reader := NewStreamReader(s, stream)
|
||||
|
||||
defer func() {
|
||||
err := s.DB.WriteBandwidthAllocToDB(reader.bandwidthAllocation)
|
||||
if err != nil {
|
||||
log.Printf("WriteBandwidthAllocToDB Error: %s\n", err.Error())
|
||||
baWriteErr := s.DB.WriteBandwidthAllocToDB(reader.bandwidthAllocation)
|
||||
if baWriteErr != nil {
|
||||
log.Printf("WriteBandwidthAllocToDB Error: %s\n", baWriteErr.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
@ -136,8 +137,27 @@ func PeerIdentityFromCerts(leaf, ca *x509.Certificate) (*PeerIdentity, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Status returns the status of the identity cert/key files for the config
|
||||
func (is IdentitySetupConfig) Status() TLSFilesStatus {
|
||||
// PeerIdentityFromContext loads a PeerIdentity from a ctx TLS credentials
|
||||
func PeerIdentityFromContext(ctx context.Context) (*PeerIdentity, error) {
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return nil, Error.New("unable to get grpc peer from contex")
|
||||
}
|
||||
tlsInfo := p.AuthInfo.(credentials.TLSInfo)
|
||||
c := tlsInfo.State.PeerCertificates
|
||||
if len(c) < 2 {
|
||||
return nil, Error.New("invalid certificate chain")
|
||||
}
|
||||
pi, err := PeerIdentityFromCerts(c[0], c[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pi, nil
|
||||
}
|
||||
|
||||
// Stat returns the status of the identity cert/key files for the config
|
||||
func (is IdentitySetupConfig) Stat() TLSFilesStatus {
|
||||
return statTLSFiles(is.CertPath, is.KeyPath)
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
|
||||
"storj.io/storj/pkg/eestream"
|
||||
"storj.io/storj/pkg/piecestore/rpc/client"
|
||||
"storj.io/storj/pkg/provider"
|
||||
"storj.io/storj/pkg/ranger"
|
||||
"storj.io/storj/pkg/transport"
|
||||
"storj.io/storj/pkg/utils"
|
||||
@ -39,6 +40,7 @@ type dialer interface {
|
||||
|
||||
type defaultDialer struct {
|
||||
t transport.Client
|
||||
identity *provider.FullIdentity
|
||||
}
|
||||
|
||||
func (d *defaultDialer) dial(ctx context.Context, node *proto.Node) (ps client.PSClient, err error) {
|
||||
@ -48,7 +50,7 @@ func (d *defaultDialer) dial(ctx context.Context, node *proto.Node) (ps client.P
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client.NewPSClient(c, 0)
|
||||
return client.NewPSClient(c, 0, d.identity.Key)
|
||||
}
|
||||
|
||||
type ecClient struct {
|
||||
@ -57,8 +59,8 @@ type ecClient struct {
|
||||
}
|
||||
|
||||
// NewClient from the given TransportClient and max buffer memory
|
||||
func NewClient(t transport.Client, mbm int) Client {
|
||||
d := defaultDialer{t: t}
|
||||
func NewClient(identity *provider.FullIdentity, t transport.Client, mbm int) Client {
|
||||
d := defaultDialer{identity: identity, t: t}
|
||||
return &ecClient{d: &d, mbm: mbm}
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,8 @@ package ecclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -18,6 +20,7 @@ import (
|
||||
|
||||
"storj.io/storj/pkg/eestream"
|
||||
"storj.io/storj/pkg/piecestore/rpc/client"
|
||||
"storj.io/storj/pkg/provider"
|
||||
"storj.io/storj/pkg/ranger"
|
||||
proto "storj.io/storj/protos/overlay"
|
||||
)
|
||||
@ -59,7 +62,9 @@ func TestNewECClient(t *testing.T) {
|
||||
tc := NewMockClient(ctrl)
|
||||
mbm := 1234
|
||||
|
||||
ec := NewClient(tc, mbm)
|
||||
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
identity := &provider.FullIdentity{Key: privKey}
|
||||
ec := NewClient(identity, tc, mbm)
|
||||
assert.NotNil(t, ec)
|
||||
|
||||
ecc, ok := ec.(*ecClient)
|
||||
@ -78,6 +83,9 @@ func TestDefaultDialer(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
identity := &provider.FullIdentity{Key: privKey}
|
||||
|
||||
for i, tt := range []struct {
|
||||
err error
|
||||
errString string
|
||||
@ -90,7 +98,7 @@ func TestDefaultDialer(t *testing.T) {
|
||||
tc := NewMockClient(ctrl)
|
||||
tc.EXPECT().DialNode(gomock.Any(), node0).Return(nil, tt.err)
|
||||
|
||||
dd := defaultDialer{t: tc}
|
||||
dd := defaultDialer{t: tc, identity: identity}
|
||||
_, err := dd.dial(ctx, node0)
|
||||
|
||||
if tt.errString != "" {
|
||||
|
Loading…
Reference in New Issue
Block a user