From 899e1e68f1423fbbdfdca9b617ad882712b41eba Mon Sep 17 00:00:00 2001 From: Alexander Leitner Date: Mon, 27 Aug 2018 14:35:27 -0400 Subject: [PATCH] Add functions for signing and verifying during bandwidth exchange (#246) * Added initial functions for signing and verifying * whoops * Get client up to speed * Added initial functions for signing and verifying * whoops * Get client up to speed * wip * wip * actual signatures in tests (cherry picked from commit 1464853b737f1d712d64fbf90147f535525c8fd9) * bugfixing * Generate private key in example * Generate signatures for pieceranger tests * Update examples to use TLS * Use private key from identity inside of example * Use crypto.PrivateKey interface * Change err name in defers * Pass tests * Pass identity Key to PSClient * Get tests passing on travis * Resolve linter complaints --- examples/piecestore-client/rpc/client/main.go | 22 ++++- pkg/miniogw/config.go | 2 +- pkg/piecestore/rpc/client/client.go | 18 +++- pkg/piecestore/rpc/client/pieceranger_test.go | 93 +++++++------------ pkg/piecestore/rpc/server/readerwriter.go | 2 +- pkg/piecestore/rpc/server/retrieve.go | 8 +- pkg/piecestore/rpc/server/server.go | 34 +++++-- pkg/piecestore/rpc/server/server_test.go | 83 ++++++++++++----- pkg/piecestore/rpc/server/store.go | 10 +- pkg/provider/identity.go | 24 ++++- pkg/storage/ec/client.go | 10 +- pkg/storage/ec/client_test.go | 12 ++- 12 files changed, 196 insertions(+), 122 deletions(-) diff --git a/examples/piecestore-client/rpc/client/main.go b/examples/piecestore-client/rpc/client/main.go index 9247d1916..fc303e9e1 100644 --- a/examples/piecestore-client/rpc/client/main.go +++ b/examples/piecestore-client/rpc/client/main.go @@ -5,6 +5,7 @@ package main import ( "context" + "crypto/ecdsa" "fmt" "io" "log" @@ -17,23 +18,39 @@ import ( "google.golang.org/grpc" "storj.io/storj/pkg/piecestore/rpc/client" + "storj.io/storj/pkg/provider" pb "storj.io/storj/protos/piecestore" ) +var ctx = context.Background() var argError = errs.Class("argError") func main() { app := cli.NewApp() + ca, err := provider.NewCA(ctx, 12, 4) + if err != nil { + log.Fatal(err) + } + identity, err := ca.NewIdentity() + if err != nil { + log.Fatal(err) + } + identOpt, err := identity.DialOption() + if err != nil { + log.Fatal(err) + } + // Set up connection with rpc server var conn *grpc.ClientConn - conn, err := grpc.Dial(":7777", grpc.WithInsecure()) + conn, err = grpc.Dial(":7777", identOpt) if err != nil { log.Fatalf("did not connect: %s", err) } defer conn.Close() - psClient, err := client.NewPSClient(conn, 1024*32) + + psClient, err := client.NewPSClient(conn, 1024*32, identity.Key.(*ecdsa.PrivateKey)) if err != nil { log.Fatalf("could not initialize PSClient: %s", err) } @@ -127,7 +144,6 @@ func main() { return err } - ctx := context.Background() rr, err := psClient.Get(ctx, client.PieceID(c.Args().Get(id)), pieceInfo.Size, &pb.PayerBandwidthAllocation{}) if err != nil { fmt.Printf("Failed to retrieve file of id: %s\n", c.Args().Get(id)) diff --git a/pkg/miniogw/config.go b/pkg/miniogw/config.go index 6196f41b5..f8bfd19f5 100644 --- a/pkg/miniogw/config.go +++ b/pkg/miniogw/config.go @@ -130,7 +130,7 @@ func (c Config) NewGateway(ctx context.Context, return nil, err } - ec := ecclient.NewClient(t, c.MaxBufferMem) + ec := ecclient.NewClient(identity, t, c.MaxBufferMem) fc, err := infectious.NewFEC(c.MinThreshold, c.MaxThreshold) if err != nil { return nil, err diff --git a/pkg/piecestore/rpc/client/client.go b/pkg/piecestore/rpc/client/client.go index 3c44690e5..8edcf4032 100644 --- a/pkg/piecestore/rpc/client/client.go +++ b/pkg/piecestore/rpc/client/client.go @@ -5,6 +5,8 @@ package client import ( "bufio" + "crypto" + "crypto/ecdsa" "flag" "fmt" "io" @@ -16,6 +18,8 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" + "github.com/gtank/cryptopasta" + "storj.io/storj/pkg/ranger" pb "storj.io/storj/protos/piecestore" ) @@ -45,11 +49,12 @@ type PSClient interface { type Client struct { route pb.PieceStoreRoutesClient conn *grpc.ClientConn + prikey crypto.PrivateKey bandwidthMsgSize int } // NewPSClient initilizes a PSClient -func NewPSClient(conn *grpc.ClientConn, bandwidthMsgSize int) (PSClient, error) { +func NewPSClient(conn *grpc.ClientConn, bandwidthMsgSize int, prikey crypto.PrivateKey) (PSClient, error) { if bandwidthMsgSize < 0 || bandwidthMsgSize > *maxBandwidthMsgSize { return nil, ClientError.New(fmt.Sprintf("Invalid Bandwidth Message Size: %v", bandwidthMsgSize)) } @@ -62,11 +67,12 @@ func NewPSClient(conn *grpc.ClientConn, bandwidthMsgSize int) (PSClient, error) conn: conn, route: pb.NewPieceStoreRoutesClient(conn), bandwidthMsgSize: bandwidthMsgSize, + prikey: prikey, }, nil } // NewCustomRoute creates new Client with custom route interface -func NewCustomRoute(route pb.PieceStoreRoutesClient, bandwidthMsgSize int) (*Client, error) { +func NewCustomRoute(route pb.PieceStoreRoutesClient, bandwidthMsgSize int, prikey crypto.PrivateKey) (*Client, error) { if bandwidthMsgSize < 0 || bandwidthMsgSize > *maxBandwidthMsgSize { return nil, ClientError.New(fmt.Sprintf("Invalid Bandwidth Message Size: %v", bandwidthMsgSize)) } @@ -78,6 +84,7 @@ func NewCustomRoute(route pb.PieceStoreRoutesClient, bandwidthMsgSize int) (*Cli return &Client{ route: route, bandwidthMsgSize: bandwidthMsgSize, + prikey: prikey, }, nil } @@ -152,7 +159,10 @@ func (client *Client) Delete(ctx context.Context, id PieceID) error { // sign a message using the clients private key func (client *Client) sign(msg []byte) (signature []byte, err error) { - // use c.pkey to sign msg + if client.prikey == nil { + return nil, ClientError.New("Failed to sign msg: Private Key not Set") + } - return signature, err + // use c.pkey to sign msg + return cryptopasta.Sign(msg, client.prikey.(*ecdsa.PrivateKey)) } diff --git a/pkg/piecestore/rpc/client/pieceranger_test.go b/pkg/piecestore/rpc/client/pieceranger_test.go index eba1476b6..284486dec 100644 --- a/pkg/piecestore/rpc/client/pieceranger_test.go +++ b/pkg/piecestore/rpc/client/pieceranger_test.go @@ -5,12 +5,14 @@ package client import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "fmt" "io" "io/ioutil" "testing" - "github.com/gogo/protobuf/proto" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -42,6 +44,9 @@ func TestPieceRanger(t *testing.T) { } { errTag := fmt.Sprintf("Test case #%d", i) + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.Nil(t, err) + route := pb.NewMockPieceStoreRoutesClient(ctrl) calls := []*gomock.Call{ @@ -54,46 +59,29 @@ func TestPieceRanger(t *testing.T) { pid := NewPieceID() if tt.offset >= 0 && tt.length > 0 && tt.offset+tt.length <= tt.size { + msg1 := &pb.PieceRetrieval{ + PieceData: &pb.PieceRetrieval_PieceData{ + Id: pid.String(), Size: tt.length, Offset: tt.offset, + }, + } + calls = append(calls, - stream.EXPECT().Send( - &pb.PieceRetrieval{ - PieceData: &pb.PieceRetrieval_PieceData{ - Id: pid.String(), Size: tt.length, Offset: tt.offset, - }, - }, - ).Return(nil), - stream.EXPECT().Send( - &pb.PieceRetrieval{ - Bandwidthallocation: &pb.RenterBandwidthAllocation{ - Data: serializeData(&pb.RenterBandwidthAllocation_Data{ - PayerAllocation: &pb.PayerBandwidthAllocation{}, - Total: 32 * 1024, - }), - }, - }, - ).Return(nil), + stream.EXPECT().Send(msg1).Return(nil), + stream.EXPECT().Send(gomock.Any()).Return(nil), stream.EXPECT().Recv().Return( &pb.PieceRetrievalStream{ Size: tt.length, Content: []byte(tt.data)[tt.offset : tt.offset+tt.length], }, nil), - stream.EXPECT().Send( - &pb.PieceRetrieval{ - Bandwidthallocation: &pb.RenterBandwidthAllocation{ - Data: serializeData(&pb.RenterBandwidthAllocation_Data{ - PayerAllocation: &pb.PayerBandwidthAllocation{}, - Total: 32 * 1024 * 2, - }), - }, - }, - ).Return(nil), + stream.EXPECT().Send(gomock.Any()).Return(nil), stream.EXPECT().Recv().Return(&pb.PieceRetrievalStream{}, io.EOF), ) } gomock.InOrder(calls...) ctx := context.Background() - c, err := NewCustomRoute(route, 32*1024) + + c, err := NewCustomRoute(route, 32*1024, priv) assert.NoError(t, err) rr, err := PieceRanger(ctx, c, stream, pid, &pb.PayerBandwidthAllocation{}) if assert.NoError(t, err, errTag) { @@ -142,45 +130,32 @@ func TestPieceRangerSize(t *testing.T) { stream := pb.NewMockPieceStoreRoutes_RetrieveClient(ctrl) + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.Nil(t, err) + if tt.offset >= 0 && tt.length > 0 && tt.offset+tt.length <= tt.size { + msg1 := &pb.PieceRetrieval{ + PieceData: &pb.PieceRetrieval_PieceData{ + Id: pid.String(), Size: tt.length, Offset: tt.offset, + }, + } + gomock.InOrder( - stream.EXPECT().Send( - &pb.PieceRetrieval{ - PieceData: &pb.PieceRetrieval_PieceData{ - Id: pid.String(), Size: tt.length, Offset: tt.offset, - }, - }, - ).Return(nil), - stream.EXPECT().Send( - &pb.PieceRetrieval{Bandwidthallocation: &pb.RenterBandwidthAllocation{ - Data: serializeData(&pb.RenterBandwidthAllocation_Data{ - PayerAllocation: &pb.PayerBandwidthAllocation{}, - Total: 32 * 1024, - }), - }, - }, - ).Return(nil), + stream.EXPECT().Send(msg1).Return(nil), + stream.EXPECT().Send(gomock.Any()).Return(nil), stream.EXPECT().Recv().Return( &pb.PieceRetrievalStream{ Size: tt.length, Content: []byte(tt.data)[tt.offset : tt.offset+tt.length], }, nil), - stream.EXPECT().Send( - &pb.PieceRetrieval{ - Bandwidthallocation: &pb.RenterBandwidthAllocation{ - Data: serializeData(&pb.RenterBandwidthAllocation_Data{ - PayerAllocation: &pb.PayerBandwidthAllocation{}, - Total: 32 * 1024 * 2, - }), - }, - }, - ).Return(nil), + stream.EXPECT().Send(gomock.Any()).Return(nil), stream.EXPECT().Recv().Return(&pb.PieceRetrievalStream{}, io.EOF), ) } ctx := context.Background() - c, err := NewCustomRoute(route, 32*1024) + + c, err := NewCustomRoute(route, 32*1024, priv) assert.NoError(t, err) rr := PieceRangerSize(c, stream, pid, tt.size, &pb.PayerBandwidthAllocation{}) assert.Equal(t, tt.size, rr.Size(), errTag) @@ -196,9 +171,3 @@ func TestPieceRangerSize(t *testing.T) { } } } - -func serializeData(ba *pb.RenterBandwidthAllocation_Data) []byte { - data, _ := proto.Marshal(ba) - - return data -} diff --git a/pkg/piecestore/rpc/server/readerwriter.go b/pkg/piecestore/rpc/server/readerwriter.go index b7878d6a1..b3e307600 100644 --- a/pkg/piecestore/rpc/server/readerwriter.go +++ b/pkg/piecestore/rpc/server/readerwriter.go @@ -51,7 +51,7 @@ func NewStreamReader(s *Server, stream pb.PieceStoreRoutes_StoreServer) *StreamR ba := recv.GetBandwidthallocation() if ba != nil { - if err = s.verifySignature(ba); err != nil { + if err = s.verifySignature(stream.Context(), ba); err != nil { return nil, err } } diff --git a/pkg/piecestore/rpc/server/retrieve.go b/pkg/piecestore/rpc/server/retrieve.go index c32a4b858..43ac55b62 100644 --- a/pkg/piecestore/rpc/server/retrieve.go +++ b/pkg/piecestore/rpc/server/retrieve.go @@ -87,9 +87,9 @@ func (s *Server) retrieveData(ctx context.Context, stream pb.PieceStoreRoutes_Re // Save latest bandwidth allocation even if we fail defer func() { - err := s.DB.WriteBandwidthAllocToDB(latestBA) - if latestBA != nil && err != nil { - log.Println("WriteBandwidthAllocToDB Error:", err) + baWriteErr := s.DB.WriteBandwidthAllocToDB(latestBA) + if latestBA != nil && baWriteErr != nil { + log.Println("WriteBandwidthAllocToDB Error:", baWriteErr) } }() @@ -108,7 +108,7 @@ func (s *Server) retrieveData(ctx context.Context, stream pb.PieceStoreRoutes_Re return am.Used, am.TotalAllocated, err } - if err = s.verifySignature(ba); err != nil { + if err = s.verifySignature(ctx, ba); err != nil { return am.Used, am.TotalAllocated, err } diff --git a/pkg/piecestore/rpc/server/server.go b/pkg/piecestore/rpc/server/server.go index 883c51685..0932d17e5 100644 --- a/pkg/piecestore/rpc/server/server.go +++ b/pkg/piecestore/rpc/server/server.go @@ -4,14 +4,19 @@ package server import ( + "crypto" + "crypto/ecdsa" "log" "os" "path/filepath" + "github.com/gtank/cryptopasta" + "github.com/zeebo/errs" "go.uber.org/zap" "golang.org/x/net/context" - monkit "gopkg.in/spacemonkeygo/monkit.v2" + "gopkg.in/spacemonkeygo/monkit.v2" + "storj.io/storj/pkg/peertls" "storj.io/storj/pkg/piecestore" "storj.io/storj/pkg/piecestore/rpc/server/psdb" "storj.io/storj/pkg/provider" @@ -20,6 +25,9 @@ import ( var ( mon = monkit.Package() + + // ServerError wraps errors returned from Server struct methods + ServerError = errs.Class("PSServer error") ) // Config contains everything necessary for a server @@ -31,7 +39,7 @@ type Config struct { func (c Config) Run(ctx context.Context, server *provider.Provider) (err error) { defer mon.Task()(&ctx)(&err) - s, err := Initialize(ctx, c) + s, err := Initialize(ctx, c, server.Identity().Key) if err != nil { return err } @@ -54,10 +62,11 @@ func (c Config) Run(ctx context.Context, server *provider.Provider) (err error) type Server struct { DataDir string DB *psdb.PSDB + pkey crypto.PrivateKey } // Initialize -- initializes a server struct -func Initialize(ctx context.Context, config Config) (*Server, error) { +func Initialize(ctx context.Context, config Config, pkey crypto.PrivateKey) (*Server, error) { dbPath := filepath.Join(config.Path, "piecestore.db") dataDir := filepath.Join(config.Path, "piece-store-data") @@ -66,7 +75,7 @@ func Initialize(ctx context.Context, config Config) (*Server, error) { return nil, err } - return &Server{DataDir: dataDir, DB: psDB}, nil + return &Server{DataDir: dataDir, DB: psDB, pkey: pkey}, nil } // Stop the piececstore node @@ -124,12 +133,19 @@ func (s *Server) deleteByID(id string) error { return nil } -func (s *Server) verifySignature(ba *pb.RenterBandwidthAllocation) error { - // TODO: verify signature +func (s *Server) verifySignature(ctx context.Context, ba *pb.RenterBandwidthAllocation) error { + pi, err := provider.PeerIdentityFromContext(ctx) + if err != nil { + return err + } - // data := ba.GetData() - // signature := ba.GetSignature() - log.Printf("Verified signature\n") + k, ok := pi.Leaf.PublicKey.(*ecdsa.PublicKey) + if !ok { + return peertls.ErrUnsupportedKey.New("%T", pi.Leaf.PublicKey) + } + if ok := cryptopasta.Verify(ba.GetData(), ba.GetSignature(), k); !ok { + return ServerError.New("Failed to verify Signature") + } return nil } diff --git a/pkg/piecestore/rpc/server/server_test.go b/pkg/piecestore/rpc/server/server_test.go index 6be5046c4..62588b1e0 100644 --- a/pkg/piecestore/rpc/server/server_test.go +++ b/pkg/piecestore/rpc/server/server_test.go @@ -5,6 +5,8 @@ package server import ( "bytes" + "crypto" + "crypto/ecdsa" "fmt" "io" "io/ioutil" @@ -19,15 +21,15 @@ import ( "time" "github.com/gogo/protobuf/proto" + "github.com/gtank/cryptopasta" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" - "golang.org/x/net/context" - "google.golang.org/grpc" pstore "storj.io/storj/pkg/piecestore" "storj.io/storj/pkg/piecestore/rpc/server/psdb" + "storj.io/storj/pkg/provider" pb "storj.io/storj/protos/piecestore" ) @@ -48,7 +50,7 @@ func writeFileToDir(name, dir string) error { } func TestPiece(t *testing.T) { - TS := NewTestServer() + TS := NewTestServer(t) defer TS.Stop() if err := writeFileToDir("11111111111111111111", TS.s.DataDir); err != nil { @@ -118,7 +120,7 @@ func TestPiece(t *testing.T) { } func TestRetrieve(t *testing.T) { - TS := NewTestServer() + TS := NewTestServer(t) defer TS.Stop() // simulate piece stored with storagenode @@ -230,15 +232,21 @@ func TestRetrieve(t *testing.T) { for totalAllocated < tt.respSize { // Send bandwidth bandwidthAllocation totalAllocated += tt.allocSize + + ba := pb.RenterBandwidthAllocation{ + Data: serializeData(&pb.RenterBandwidthAllocation_Data{ + PayerAllocation: &pb.PayerBandwidthAllocation{}, + Total: totalAllocated, + }), + } + + s, err := cryptopasta.Sign(ba.Data, TS.k.(*ecdsa.PrivateKey)) + assert.NoError(err) + ba.Signature = s + err = stream.Send( &pb.PieceRetrieval{ - Bandwidthallocation: &pb.RenterBandwidthAllocation{ - Signature: []byte{'A', 'B'}, - Data: serializeData(&pb.RenterBandwidthAllocation_Data{ - PayerAllocation: &pb.PayerBandwidthAllocation{}, - Total: totalAllocated, - }), - }, + Bandwidthallocation: &ba, }, ) assert.NoError(err) @@ -254,8 +262,8 @@ func TestRetrieve(t *testing.T) { return } - data = fmt.Sprintf("%s%s", data, string(resp.Content)) - totalRetrieved += resp.Size + data = fmt.Sprintf("%s%s", data, string(resp.GetContent())) + totalRetrieved += resp.GetSize() } assert.NoError(err) @@ -269,7 +277,7 @@ func TestRetrieve(t *testing.T) { } func TestStore(t *testing.T) { - TS := NewTestServer() + TS := NewTestServer(t) defer TS.Stop() db := TS.s.DB.DB @@ -322,7 +330,6 @@ func TestStore(t *testing.T) { msg := &pb.PieceStore{ Piecedata: &pb.PieceStore_PieceData{Content: tt.content}, Bandwidthallocation: &pb.RenterBandwidthAllocation{ - Signature: []byte{'A', 'B'}, Data: serializeData(&pb.RenterBandwidthAllocation_Data{ PayerAllocation: &pb.PayerBandwidthAllocation{}, Total: int64(len(tt.content)), @@ -330,9 +337,12 @@ func TestStore(t *testing.T) { }, } - // Write the buffer to the stream we opened earlier - err = stream.Send(msg) + s, err := cryptopasta.Sign(msg.Bandwidthallocation.Data, TS.k.(*ecdsa.PrivateKey)) assert.NoError(err) + msg.Bandwidthallocation.Signature = s + + // Write the buffer to the stream we opened earlier + stream.Send(msg) resp, err := stream.CloseAndRecv() if tt.err != "" { @@ -378,7 +388,7 @@ func TestStore(t *testing.T) { } func TestDelete(t *testing.T) { - TS := NewTestServer() + TS := NewTestServer(t) defer TS.Stop() db := TS.s.DB.DB @@ -463,8 +473,8 @@ func newTestServerStruct() *Server { return &Server{DataDir: tempDir, DB: psDB} } -func connect(addr string) (pb.PieceStoreRoutesClient, *grpc.ClientConn) { - conn, err := grpc.Dial(addr, grpc.WithInsecure()) +func connect(addr string, o ...grpc.DialOption) (pb.PieceStoreRoutesClient, *grpc.ClientConn) { + conn, err := grpc.Dial(addr, o...) if err != nil { log.Fatalf("did not connect: %v", err) } @@ -479,15 +489,38 @@ type TestServer struct { grpcs *grpc.Server conn *grpc.ClientConn c pb.PieceStoreRoutesClient + k crypto.PrivateKey } -func NewTestServer() *TestServer { - s := newTestServerStruct() - grpcs := grpc.NewServer() +func NewTestServer(t *testing.T) *TestServer { + check := func(e error) { + if !assert.NoError(t, e) { + t.Fail() + } + } - ts := &TestServer{s: s, grpcs: grpcs} + caS, err := provider.NewCA(context.Background(), 12, 4) + check(err) + fiS, err := caS.NewIdentity() + check(err) + so, err := fiS.ServerOption() + check(err) + + caC, err := provider.NewCA(context.Background(), 12, 4) + check(err) + fiC, err := caC.NewIdentity() + check(err) + co, err := fiC.DialOption() + check(err) + + s := newTestServerStruct() + grpcs := grpc.NewServer(so) + + k, ok := fiC.Key.(*ecdsa.PrivateKey) + assert.True(t, ok) + ts := &TestServer{s: s, grpcs: grpcs, k: k} addr := ts.start() - ts.c, ts.conn = connect(addr) + ts.c, ts.conn = connect(addr, co) return ts } diff --git a/pkg/piecestore/rpc/server/store.go b/pkg/piecestore/rpc/server/store.go index b0aa72de1..44bfcbd11 100644 --- a/pkg/piecestore/rpc/server/store.go +++ b/pkg/piecestore/rpc/server/store.go @@ -65,8 +65,8 @@ func (s *Server) storeData(ctx context.Context, stream pb.PieceStoreRoutes_Store // Delete data if we error defer func() { if err != nil && err != io.EOF { - if err = s.deleteByID(id); err != nil { - log.Printf("Failed on deleteByID in Store: %s", err.Error()) + if deleteErr := s.deleteByID(id); deleteErr != nil { + log.Printf("Failed on deleteByID in Store: %s", deleteErr.Error()) } } }() @@ -82,9 +82,9 @@ func (s *Server) storeData(ctx context.Context, stream pb.PieceStoreRoutes_Store reader := NewStreamReader(s, stream) defer func() { - err := s.DB.WriteBandwidthAllocToDB(reader.bandwidthAllocation) - if err != nil { - log.Printf("WriteBandwidthAllocToDB Error: %s\n", err.Error()) + baWriteErr := s.DB.WriteBandwidthAllocToDB(reader.bandwidthAllocation) + if baWriteErr != nil { + log.Printf("WriteBandwidthAllocToDB Error: %s\n", baWriteErr.Error()) } }() diff --git a/pkg/provider/identity.go b/pkg/provider/identity.go index db4fb3a00..c5c0750f9 100644 --- a/pkg/provider/identity.go +++ b/pkg/provider/identity.go @@ -16,6 +16,7 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" "encoding/base64" "fmt" @@ -136,8 +137,27 @@ func PeerIdentityFromCerts(leaf, ca *x509.Certificate) (*PeerIdentity, error) { }, nil } -// Status returns the status of the identity cert/key files for the config -func (is IdentitySetupConfig) Status() TLSFilesStatus { +// PeerIdentityFromContext loads a PeerIdentity from a ctx TLS credentials +func PeerIdentityFromContext(ctx context.Context) (*PeerIdentity, error) { + p, ok := peer.FromContext(ctx) + if !ok { + return nil, Error.New("unable to get grpc peer from contex") + } + tlsInfo := p.AuthInfo.(credentials.TLSInfo) + c := tlsInfo.State.PeerCertificates + if len(c) < 2 { + return nil, Error.New("invalid certificate chain") + } + pi, err := PeerIdentityFromCerts(c[0], c[1]) + if err != nil { + return nil, err + } + + return pi, nil +} + +// Stat returns the status of the identity cert/key files for the config +func (is IdentitySetupConfig) Stat() TLSFilesStatus { return statTLSFiles(is.CertPath, is.KeyPath) } diff --git a/pkg/storage/ec/client.go b/pkg/storage/ec/client.go index a788af0a5..25236fa56 100644 --- a/pkg/storage/ec/client.go +++ b/pkg/storage/ec/client.go @@ -15,6 +15,7 @@ import ( "storj.io/storj/pkg/eestream" "storj.io/storj/pkg/piecestore/rpc/client" + "storj.io/storj/pkg/provider" "storj.io/storj/pkg/ranger" "storj.io/storj/pkg/transport" "storj.io/storj/pkg/utils" @@ -38,7 +39,8 @@ type dialer interface { } type defaultDialer struct { - t transport.Client + t transport.Client + identity *provider.FullIdentity } func (d *defaultDialer) dial(ctx context.Context, node *proto.Node) (ps client.PSClient, err error) { @@ -48,7 +50,7 @@ func (d *defaultDialer) dial(ctx context.Context, node *proto.Node) (ps client.P return nil, err } - return client.NewPSClient(c, 0) + return client.NewPSClient(c, 0, d.identity.Key) } type ecClient struct { @@ -57,8 +59,8 @@ type ecClient struct { } // NewClient from the given TransportClient and max buffer memory -func NewClient(t transport.Client, mbm int) Client { - d := defaultDialer{t: t} +func NewClient(identity *provider.FullIdentity, t transport.Client, mbm int) Client { + d := defaultDialer{identity: identity, t: t} return &ecClient{d: &d, mbm: mbm} } diff --git a/pkg/storage/ec/client_test.go b/pkg/storage/ec/client_test.go index 96048e9f2..646aaf9ac 100644 --- a/pkg/storage/ec/client_test.go +++ b/pkg/storage/ec/client_test.go @@ -5,6 +5,8 @@ package ecclient import ( "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "errors" "fmt" @@ -18,6 +20,7 @@ import ( "storj.io/storj/pkg/eestream" "storj.io/storj/pkg/piecestore/rpc/client" + "storj.io/storj/pkg/provider" "storj.io/storj/pkg/ranger" proto "storj.io/storj/protos/overlay" ) @@ -59,7 +62,9 @@ func TestNewECClient(t *testing.T) { tc := NewMockClient(ctrl) mbm := 1234 - ec := NewClient(tc, mbm) + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + identity := &provider.FullIdentity{Key: privKey} + ec := NewClient(identity, tc, mbm) assert.NotNil(t, ec) ecc, ok := ec.(*ecClient) @@ -78,6 +83,9 @@ func TestDefaultDialer(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() + privKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + identity := &provider.FullIdentity{Key: privKey} + for i, tt := range []struct { err error errString string @@ -90,7 +98,7 @@ func TestDefaultDialer(t *testing.T) { tc := NewMockClient(ctrl) tc.EXPECT().DialNode(gomock.Any(), node0).Return(nil, tt.err) - dd := defaultDialer{t: tc} + dd := defaultDialer{t: tc, identity: identity} _, err := dd.dial(ctx, node0) if tt.errString != "" {