From 098cbc9c679524caf9b66bdfce6a50c771bfed2a Mon Sep 17 00:00:00 2001 From: Jeff Wendling Date: Wed, 18 Sep 2019 22:46:39 -0600 Subject: [PATCH] all: use pkg/rpc instead of pkg/transport all of the packages and tests work with both grpc and drpc. we'll probably need to do some jenkins pipelines to run the tests with drpc as well. most of the changes are really due to a bit of cleanup of the pkg/transport.Client api into an rpc.Dialer in the spirit of a net.Dialer. now that we don't need observers, we can pass around stateless configuration to everything rather than stateful things that issue observations. it also adds a DialAddressID for the case where we don't have a pb.Node, but we do have an address and want to assert some ID. this happened pretty frequently, and now there's no more weird contortions creating custom tls options, etc. a lot of the other changes are being consistent/using the abstractions in the rpc package to do rpc style things like finding peer information, or checking status codes. Change-Id: Ief62875e21d80a21b3c56a5a37f45887679f9412 --- bootstrap/peer.go | 14 +- cmd/identity/main.go | 11 +- cmd/inspector/main.go | 25 +- cmd/storagenode/dashboard.go | 16 +- go.mod | 2 +- go.sum | 4 +- internal/errs2/ignore.go | 6 +- internal/errs2/ignore_test.go | 21 +- internal/errs2/rpc.go | 8 +- internal/testcontext/compile.go | 20 +- internal/testcontext/compile_drpc.go | 8 + internal/testcontext/compile_nodrpc.go | 8 + internal/testplanet/planet_test.go | 4 +- internal/testplanet/uplink.go | 12 +- internal/testplanet/uplink_test.go | 4 +- lib/uplink/project.go | 6 +- lib/uplink/uplink.go | 29 +- lib/uplink/uplink_test.go | 13 +- pkg/auth/grpcauth/apikey_test.go | 4 +- .../authorization/authorizations.go | 4 +- .../authorization/authorizations_test.go | 42 +-- pkg/certificate/certificateclient/client.go | 27 +- pkg/certificate/endpoint.go | 15 +- pkg/certificate/peer.go | 4 +- pkg/certificate/peer_test.go | 21 +- pkg/identity/identity.go | 26 +- pkg/kademlia/endpoint.go | 25 +- pkg/kademlia/kademlia.go | 17 +- pkg/kademlia/kademlia_test.go | 6 +- pkg/kademlia/kademliaclient/kademliaclient.go | 134 ++++---- pkg/miniogw/gateway_test.go | 2 +- pkg/peertls/tlsopts/options_test.go | 29 +- pkg/rpc/common_drpc.go | 14 + pkg/rpc/common_grpc.go | 14 + pkg/rpc/dial.go | 4 +- pkg/rpc/dial_grpc.go | 9 +- .../transport_test.go => rpc/rpc_test.go} | 66 +--- pkg/rpc/rpcpeer/peer.go | 37 +++ pkg/rpc/rpcpeer/peer_drpc.go | 31 ++ pkg/rpc/rpcpeer/peer_grpc.go | 31 ++ pkg/rpc/rpcstatus/status_drpc.go | 63 ++++ pkg/rpc/rpcstatus/status_grpc.go | 50 +++ pkg/server/config.go | 4 +- pkg/server/interceptors.go | 4 +- pkg/server/server.go | 27 +- pkg/transport/common.go | 25 -- pkg/transport/fetchidentity.go | 116 ------- pkg/transport/insecure.go | 32 -- pkg/transport/slowtransport.go | 143 --------- pkg/transport/timeout.go | 56 ---- pkg/transport/transport.go | 181 ----------- satellite/accounting/projectusage_test.go | 18 +- satellite/audit/reverify_test.go | 22 +- satellite/audit/verifier.go | 28 +- satellite/audit/verifier_test.go | 62 ++-- satellite/contact/client.go | 33 +- satellite/contact/contact_test.go | 13 +- satellite/contact/endpoint.go | 42 +-- satellite/contact/service.go | 46 +-- satellite/gc/service.go | 10 +- satellite/metainfo/metainfo.go | 299 +++++++++--------- satellite/metainfo/metainfo_test.go | 11 +- satellite/metainfo/validation.go | 9 +- satellite/nodestats/endpoint.go | 17 +- satellite/orders/endpoint.go | 15 +- satellite/peer.go | 18 +- satellite/repair/repairer/ec.go | 10 +- satellite/repair/repairer/segments.go | 6 +- satellite/vouchers/vouchers_test.go | 6 +- storagenode/contact/chore.go | 33 +- storagenode/contact/contact_test.go | 20 +- storagenode/contact/endpoint.go | 21 +- storagenode/contact/kademlia.go | 9 +- storagenode/nodestats/service.go | 49 +-- storagenode/orders/service.go | 51 ++- storagenode/peer.go | 30 +- storagenode/piecestore/endpoint.go | 59 ++-- storagenode/piecestore/endpoint_test.go | 10 +- storagenode/piecestore/verification.go | 45 ++- storagenode/trust/service.go | 26 +- uplink/ecclient/client.go | 10 +- uplink/ecclient/client_planet_test.go | 2 +- uplink/metainfo/client.go | 35 +- uplink/metainfo/kvmetainfo/buckets_test.go | 2 +- uplink/piecestore/client.go | 13 +- uplink/piecestore/download.go | 20 +- uplink/piecestore/upload.go | 19 +- uplink/piecestore/verification.go | 6 +- uplink/storage/streams/store_test.go | 2 +- 89 files changed, 1083 insertions(+), 1518 deletions(-) create mode 100644 internal/testcontext/compile_drpc.go create mode 100644 internal/testcontext/compile_nodrpc.go create mode 100644 pkg/rpc/common_drpc.go create mode 100644 pkg/rpc/common_grpc.go rename pkg/{transport/transport_test.go => rpc/rpc_test.go} (75%) create mode 100644 pkg/rpc/rpcpeer/peer.go create mode 100644 pkg/rpc/rpcpeer/peer_drpc.go create mode 100644 pkg/rpc/rpcpeer/peer_grpc.go create mode 100644 pkg/rpc/rpcstatus/status_drpc.go create mode 100644 pkg/rpc/rpcstatus/status_grpc.go delete mode 100644 pkg/transport/common.go delete mode 100644 pkg/transport/fetchidentity.go delete mode 100644 pkg/transport/insecure.go delete mode 100644 pkg/transport/slowtransport.go delete mode 100644 pkg/transport/timeout.go delete mode 100644 pkg/transport/transport.go diff --git a/bootstrap/peer.go b/bootstrap/peer.go index 13e6a6779..66716e051 100644 --- a/bootstrap/peer.go +++ b/bootstrap/peer.go @@ -20,9 +20,9 @@ import ( "storj.io/storj/pkg/pb" "storj.io/storj/pkg/peertls/extensions" "storj.io/storj/pkg/peertls/tlsopts" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/server" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/overlay" "storj.io/storj/storage" ) @@ -62,7 +62,7 @@ type Peer struct { Identity *identity.FullIdentity DB DB - Transport transport.Client + Dialer rpc.Dialer Server *server.Server @@ -106,14 +106,14 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revDB extensions.R { // setup listener and server sc := config.Server - options, err := tlsopts.NewOptions(peer.Identity, sc.Config, revDB) + tlsOptions, err := tlsopts.NewOptions(peer.Identity, sc.Config, revDB) if err != nil { return nil, errs.Combine(err, peer.Close()) } - peer.Transport = transport.NewClient(options) + peer.Dialer = rpc.NewDefaultDialer(tlsOptions) - peer.Server, err = server.New(log.Named("server"), options, sc.Address, sc.PrivateAddress, nil) + peer.Server, err = server.New(log.Named("server"), tlsOptions, sc.Address, sc.PrivateAddress, nil) if err != nil { return nil, errs.Combine(err, peer.Close()) } @@ -153,9 +153,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revDB extensions.R return nil, errs.Combine(err, peer.Close()) } - peer.Transport = peer.Transport.WithObservers(peer.Kademlia.RoutingTable) - - peer.Kademlia.Service, err = kademlia.NewService(peer.Log.Named("kademlia"), peer.Transport, peer.Kademlia.RoutingTable, config) + peer.Kademlia.Service, err = kademlia.NewService(peer.Log.Named("kademlia"), peer.Dialer, peer.Kademlia.RoutingTable, config) if err != nil { return nil, errs.Combine(err, peer.Close()) } diff --git a/cmd/identity/main.go b/cmd/identity/main.go index 425358b16..6770965d5 100644 --- a/cmd/identity/main.go +++ b/cmd/identity/main.go @@ -24,7 +24,7 @@ import ( "storj.io/storj/pkg/pkcrypto" "storj.io/storj/pkg/process" "storj.io/storj/pkg/revocation" - "storj.io/storj/pkg/transport" + "storj.io/storj/pkg/rpc" ) const ( @@ -196,17 +196,16 @@ func cmdAuthorize(cmd *cobra.Command, args []string) (err error) { err = errs.Combine(err, revocationDB.Close()) }() - tlsOpts, err := tlsopts.NewOptions(ident, config.Signer.TLS, nil) + tlsOptions, err := tlsopts.NewOptions(ident, config.Signer.TLS, nil) if err != nil { return err } - client, err := certificateclient.New(ctx, transport.NewClient(tlsOpts), config.Signer.Address) + + client, err := certificateclient.New(ctx, rpc.NewDefaultDialer(tlsOptions), config.Signer.Address) if err != nil { return err } - defer func() { - err = errs.Combine(err, client.Close()) - }() + defer func() { err = errs.Combine(err, client.Close()) }() signedChainBytes, err := client.Sign(ctx, authToken) if err != nil { diff --git a/cmd/inspector/main.go b/cmd/inspector/main.go index 9b93aa132..6da57bc52 100644 --- a/cmd/inspector/main.go +++ b/cmd/inspector/main.go @@ -20,8 +20,8 @@ import ( "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/process" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/uplink/eestream" ) @@ -83,10 +83,11 @@ var ( // Inspector gives access to overlay. type Inspector struct { + conn *rpc.Conn identity *identity.FullIdentity - overlayclient pb.OverlayInspectorClient - irrdbclient pb.IrreparableInspectorClient - healthclient pb.HealthInspectorClient + overlayclient rpc.OverlayInspectorClient + irrdbclient rpc.IrreparableInspectorClient + healthclient rpc.HealthInspectorClient } // NewInspector creates a new gRPC inspector client for access to overlay. @@ -101,19 +102,23 @@ func NewInspector(address, path string) (*Inspector, error) { return nil, ErrIdentity.Wrap(err) } - conn, err := transport.DialAddressInsecure(ctx, address) + conn, err := rpc.NewDefaultDialer(nil).DialAddressUnencrypted(ctx, address) if err != nil { return &Inspector{}, ErrInspectorDial.Wrap(err) } return &Inspector{ + conn: conn, identity: id, - overlayclient: pb.NewOverlayInspectorClient(conn), - irrdbclient: pb.NewIrreparableInspectorClient(conn), - healthclient: pb.NewHealthInspectorClient(conn), + overlayclient: conn.OverlayInspectorClient(), + irrdbclient: conn.IrreparableInspectorClient(), + healthclient: conn.HealthInspectorClient(), }, nil } +// Close closes the inspector. +func (i *Inspector) Close() error { return i.conn.Close() } + // ObjectHealth gets information about the health of an object on the network func ObjectHealth(cmd *cobra.Command, args []string) (err error) { ctx := context.Background() @@ -122,6 +127,7 @@ func ObjectHealth(cmd *cobra.Command, args []string) (err error) { if err != nil { return ErrArgs.Wrap(err) } + defer func() { err = errs.Combine(err, i.Close()) }() startAfterSegment := int64(0) // start from first segment endBeforeSegment := int64(0) // No end, so we stop when we've hit limit or arrived at the last segment @@ -201,6 +207,7 @@ func SegmentHealth(cmd *cobra.Command, args []string) (err error) { if err != nil { return ErrArgs.Wrap(err) } + defer func() { err = errs.Combine(err, i.Close()) }() segmentIndex, err := strconv.ParseInt(args[1], 10, 64) if err != nil { @@ -363,6 +370,8 @@ func getSegments(cmd *cobra.Command, args []string) error { if err != nil { return ErrInspectorDial.Wrap(err) } + defer func() { err = errs.Combine(err, i.Close()) }() + var lastSeenSegmentPath = []byte{} // query DB and paginate results diff --git a/cmd/storagenode/dashboard.go b/cmd/storagenode/dashboard.go index 3bc3f425a..345fe72e2 100644 --- a/cmd/storagenode/dashboard.go +++ b/cmd/storagenode/dashboard.go @@ -17,36 +17,30 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/spf13/cobra" "go.uber.org/zap" - "google.golang.org/grpc" "storj.io/storj/internal/memory" "storj.io/storj/internal/version" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/process" - "storj.io/storj/pkg/transport" + "storj.io/storj/pkg/rpc" ) const contactWindow = time.Minute * 10 type dashboardClient struct { - client pb.PieceStoreInspectorClient - conn *grpc.ClientConn + conn *rpc.Conn } func dialDashboardClient(ctx context.Context, address string) (*dashboardClient, error) { - conn, err := transport.DialAddressInsecure(ctx, address) + conn, err := rpc.NewDefaultDialer(nil).DialAddressUnencrypted(ctx, address) if err != nil { return &dashboardClient{}, err } - - return &dashboardClient{ - client: pb.NewPieceStoreInspectorClient(conn), - conn: conn, - }, nil + return &dashboardClient{conn: conn}, nil } func (dash *dashboardClient) dashboard(ctx context.Context) (*pb.DashboardResponse, error) { - return dash.client.Dashboard(ctx, &pb.DashboardRequest{}) + return dash.conn.PieceStoreInspectorClient().Dashboard(ctx, &pb.DashboardRequest{}) } func (dash *dashboardClient) close() error { diff --git a/go.mod b/go.mod index 67416e2fc..4540fd549 100644 --- a/go.mod +++ b/go.mod @@ -126,5 +126,5 @@ require ( gopkg.in/olivere/elastic.v5 v5.0.76 // indirect gopkg.in/spacemonkeygo/monkit.v2 v2.0.0-20190612171030-cf5a9e6f8fd2 gopkg.in/yaml.v2 v2.2.2 - storj.io/drpc v0.0.3 + storj.io/drpc v0.0.4 ) diff --git a/go.sum b/go.sum index a46807398..9d168a3f0 100644 --- a/go.sum +++ b/go.sum @@ -533,5 +533,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -storj.io/drpc v0.0.3 h1:W7y0HeMA9VpiWtl2efNbavD9WF+jzfTyP9Od+fSYd6s= -storj.io/drpc v0.0.3/go.mod h1:/ascUDbzNAv0A3Jj7wUIKFBH2JdJ2uJIBO/b9+2yHgQ= +storj.io/drpc v0.0.4 h1:raJ8r2PKU/KUuoghR6WVdMWpds8uE8GOK/WN9NEMPsk= +storj.io/drpc v0.0.4/go.mod h1:/ascUDbzNAv0A3Jj7wUIKFBH2JdJ2uJIBO/b9+2yHgQ= diff --git a/internal/errs2/ignore.go b/internal/errs2/ignore.go index 6d4ce0ce1..be53f49e5 100644 --- a/internal/errs2/ignore.go +++ b/internal/errs2/ignore.go @@ -9,8 +9,8 @@ import ( "github.com/zeebo/errs" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + + "storj.io/storj/pkg/rpc/rpcstatus" ) // IsCanceled returns true, when the error is a cancellation. @@ -19,7 +19,7 @@ func IsCanceled(err error) bool { return err == context.Canceled || err == grpc.ErrServerStopped || err == http.ErrServerClosed || - status.Code(err) == codes.Canceled + rpcstatus.Code(err) == rpcstatus.Canceled }) } diff --git a/internal/errs2/ignore_test.go b/internal/errs2/ignore_test.go index 68db8fe8e..e685452fb 100644 --- a/internal/errs2/ignore_test.go +++ b/internal/errs2/ignore_test.go @@ -10,10 +10,9 @@ import ( "github.com/stretchr/testify/require" "github.com/zeebo/errs" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "storj.io/storj/internal/errs2" + "storj.io/storj/pkg/rpc/rpcstatus" ) func TestIsCanceled(t *testing.T) { @@ -35,20 +34,20 @@ func TestIsCanceled(t *testing.T) { require.True(t, errs2.IsCanceled(parentErr)) require.True(t, errs2.IsCanceled(childErr)) - // grpc errors - grpcErr := status.Error(codes.Canceled, context.Canceled.Error()) + // rpc errors + rpcErr := rpcstatus.Error(rpcstatus.Canceled, context.Canceled.Error()) - require.NotEqual(t, grpcErr, context.Canceled) - require.True(t, errs2.IsCanceled(grpcErr)) + require.NotEqual(t, rpcErr, context.Canceled) + require.True(t, errs2.IsCanceled(rpcErr)) // nested errors nestedParentErr := nestedErr.Wrap(parentErr) nestedChildErr := nestedErr.Wrap(childErr) - nestedGRPCErr := nestedErr.Wrap(grpcErr) + nestedRPCErr := nestedErr.Wrap(rpcErr) require.NotEqual(t, nestedParentErr, context.Canceled) require.NotEqual(t, nestedChildErr, context.Canceled) - require.NotEqual(t, nestedGRPCErr, context.Canceled) + require.NotEqual(t, nestedRPCErr, context.Canceled) require.True(t, errs2.IsCanceled(nestedParentErr)) require.True(t, errs2.IsCanceled(nestedChildErr)) @@ -57,13 +56,13 @@ func TestIsCanceled(t *testing.T) { // combined errors combinedParentErr := errs.Combine(combinedErr, parentErr) combinedChildErr := errs.Combine(combinedErr, childErr) - combinedGRPCErr := errs.Combine(combinedErr, childErr) + combinedRPCErr := errs.Combine(combinedErr, childErr) require.NotEqual(t, combinedParentErr, context.Canceled) require.NotEqual(t, combinedChildErr, context.Canceled) - require.NotEqual(t, combinedGRPCErr, context.Canceled) + require.NotEqual(t, combinedRPCErr, context.Canceled) require.True(t, errs2.IsCanceled(combinedParentErr)) require.True(t, errs2.IsCanceled(combinedChildErr)) - require.True(t, errs2.IsCanceled(combinedGRPCErr)) + require.True(t, errs2.IsCanceled(combinedRPCErr)) } diff --git a/internal/errs2/rpc.go b/internal/errs2/rpc.go index 060b65c16..759a885f9 100644 --- a/internal/errs2/rpc.go +++ b/internal/errs2/rpc.go @@ -5,13 +5,13 @@ package errs2 import ( "github.com/zeebo/errs" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + + "storj.io/storj/pkg/rpc/rpcstatus" ) // IsRPC checks if err contains an RPC error with the given status code. -func IsRPC(err error, code codes.Code) bool { +func IsRPC(err error, code rpcstatus.StatusCode) bool { return errs.IsFunc(err, func(err error) bool { - return status.Code(err) == code + return rpcstatus.Code(err) == code }) } diff --git a/internal/testcontext/compile.go b/internal/testcontext/compile.go index 3f242ede4..2484070b9 100644 --- a/internal/testcontext/compile.go +++ b/internal/testcontext/compile.go @@ -28,12 +28,16 @@ func (ctx *Context) Compile(pkg string) string { exe := ctx.File("build", path.Base(pkg)+".exe") - var cmd *exec.Cmd + args := []string{"build"} if raceEnabled { - cmd = exec.Command("go", "build", "-race", "-o", exe, pkg) - } else { - cmd = exec.Command("go", "build", "-o", exe, pkg) + args = append(args, "-race") } + if drpcEnabled { + args = append(args, "-tags=drpc") + } + args = append(args, "-o", exe, pkg) + + cmd := exec.Command("go", args...) ctx.test.Log("exec:", cmd.Args) out, err := cmd.CombinedOutput() @@ -53,8 +57,14 @@ func (ctx *Context) CompileShared(t *testing.T, name string, pkg string) Include base := ctx.File("build", name) + args := []string{"build", "-buildmode", "c-shared"} + if drpcEnabled { + args = append(args, "-tags=drpc") + } + args = append(args, "-o", base+".so", pkg) + // not using race detector for c-shared - cmd := exec.Command("go", "build", "-buildmode", "c-shared", "-o", base+".so", pkg) + cmd := exec.Command("go", args...) t.Log("exec:", cmd.Args) out, err := cmd.CombinedOutput() diff --git a/internal/testcontext/compile_drpc.go b/internal/testcontext/compile_drpc.go new file mode 100644 index 000000000..a936fa05d --- /dev/null +++ b/internal/testcontext/compile_drpc.go @@ -0,0 +1,8 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// +build drpc + +package testcontext + +const drpcEnabled = true diff --git a/internal/testcontext/compile_nodrpc.go b/internal/testcontext/compile_nodrpc.go new file mode 100644 index 000000000..18a2d05e6 --- /dev/null +++ b/internal/testcontext/compile_nodrpc.go @@ -0,0 +1,8 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// +build !drpc + +package testcontext + +const drpcEnabled = false diff --git a/internal/testplanet/planet_test.go b/internal/testplanet/planet_test.go index b95826184..bf75013c8 100644 --- a/internal/testplanet/planet_test.go +++ b/internal/testplanet/planet_test.go @@ -42,9 +42,9 @@ func TestBasic(t *testing.T) { satellite := sat.Local().Node for _, sn := range planet.StorageNodes { node := sn.Local() - conn, err := sn.Transport.DialNode(ctx, &satellite) + conn, err := sn.Dialer.DialNode(ctx, &satellite) require.NoError(t, err) - _, err = pb.NewNodeClient(conn).CheckIn(ctx, &pb.CheckInRequest{ + _, err = conn.NodeClient().CheckIn(ctx, &pb.CheckInRequest{ Address: node.GetAddress().GetAddress(), Version: &node.Version, Capacity: &node.Capacity, diff --git a/internal/testplanet/uplink.go b/internal/testplanet/uplink.go index 9fe6e871a..9e1c65b2c 100644 --- a/internal/testplanet/uplink.go +++ b/internal/testplanet/uplink.go @@ -23,8 +23,8 @@ import ( "storj.io/storj/pkg/macaroon" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/peertls/tlsopts" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/console" "storj.io/storj/uplink" "storj.io/storj/uplink/metainfo" @@ -36,7 +36,7 @@ type Uplink struct { Log *zap.Logger Info pb.Node Identity *identity.FullIdentity - Transport transport.Client + Dialer rpc.Dialer StorageNodeCount int APIKey map[storj.NodeID]*macaroon.APIKey @@ -64,7 +64,7 @@ func (planet *Planet) newUplink(name string, storageNodeCount int) (*Uplink, err return nil, err } - tlsOpts, err := tlsopts.NewOptions(identity, tlsopts.Config{ + tlsOptions, err := tlsopts.NewOptions(identity, tlsopts.Config{ PeerIDVersions: strconv.Itoa(int(planet.config.IdentityVersion.Number)), }, nil) if err != nil { @@ -81,7 +81,7 @@ func (planet *Planet) newUplink(name string, storageNodeCount int) (*Uplink, err uplink.Log.Debug("id=" + identity.ID.String()) - uplink.Transport = transport.NewClient(tlsOpts) + uplink.Dialer = rpc.NewDefaultDialer(tlsOptions) uplink.Info = pb.Node{ Id: uplink.Identity.ID, @@ -149,13 +149,13 @@ func (client *Uplink) Shutdown() error { return nil } // DialMetainfo dials destination with apikey and returns metainfo Client func (client *Uplink) DialMetainfo(ctx context.Context, destination Peer, apikey *macaroon.APIKey) (*metainfo.Client, error) { - return metainfo.Dial(ctx, client.Transport, destination.Addr(), apikey) + return metainfo.Dial(ctx, client.Dialer, destination.Addr(), apikey) } // DialPiecestore dials destination storagenode and returns a piecestore client. func (client *Uplink) DialPiecestore(ctx context.Context, destination Peer) (*piecestore.Client, error) { node := destination.Local() - return piecestore.Dial(ctx, client.Transport, &node.Node, client.Log.Named("uplink>piecestore"), piecestore.DefaultConfig) + return piecestore.Dial(ctx, client.Dialer, &node.Node, client.Log.Named("uplink>piecestore"), piecestore.DefaultConfig) } // Upload data to specific satellite diff --git a/internal/testplanet/uplink_test.go b/internal/testplanet/uplink_test.go index b20df38cc..19087c384 100644 --- a/internal/testplanet/uplink_test.go +++ b/internal/testplanet/uplink_test.go @@ -233,10 +233,10 @@ func TestDownloadFromUnresponsiveNode(t *testing.T) { revocationDB, err := revocation.NewDBFromCfg(tlscfg) require.NoError(t, err) - options, err := tlsopts.NewOptions(storageNode.Identity, tlscfg, revocationDB) + tlsOptions, err := tlsopts.NewOptions(storageNode.Identity, tlscfg, revocationDB) require.NoError(t, err) - server, err := server.New(storageNode.Log.Named("mock-server"), options, storageNode.Addr(), storageNode.PrivateAddr(), nil) + server, err := server.New(storageNode.Log.Named("mock-server"), tlsOptions, storageNode.Addr(), storageNode.PrivateAddr(), nil) require.NoError(t, err) pb.RegisterPiecestoreServer(server.GRPC(), &piecestoreMock{}) go func() { diff --git a/lib/uplink/project.go b/lib/uplink/project.go index 26c0314a4..66f3b8e17 100644 --- a/lib/uplink/project.go +++ b/lib/uplink/project.go @@ -11,8 +11,8 @@ import ( "storj.io/storj/internal/memory" "storj.io/storj/pkg/encryption" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/uplink/ecclient" "storj.io/storj/uplink/eestream" "storj.io/storj/uplink/metainfo" @@ -24,7 +24,7 @@ import ( // Project represents a specific project access session. type Project struct { uplinkCfg *Config - tc transport.Client + dialer rpc.Dialer metainfo *metainfo.Client project *kvmetainfo.Project maxInlineSize memory.Size @@ -179,7 +179,7 @@ func (p *Project) OpenBucket(ctx context.Context, bucketName string, access *Enc } encryptionParameters := cfg.EncryptionParameters - ec := ecclient.NewClient(p.uplinkCfg.Volatile.Log.Named("ecclient"), p.tc, p.uplinkCfg.Volatile.MaxMemory.Int()) + ec := ecclient.NewClient(p.uplinkCfg.Volatile.Log.Named("ecclient"), p.dialer, p.uplinkCfg.Volatile.MaxMemory.Int()) fc, err := infectious.NewFEC(int(cfg.Volatile.RedundancyScheme.RequiredShares), int(cfg.Volatile.RedundancyScheme.TotalShares)) if err != nil { return nil, err diff --git a/lib/uplink/uplink.go b/lib/uplink/uplink.go index a10853eb7..f399f47fa 100644 --- a/lib/uplink/uplink.go +++ b/lib/uplink/uplink.go @@ -12,7 +12,7 @@ import ( "storj.io/storj/internal/memory" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/peertls/tlsopts" - "storj.io/storj/pkg/transport" + "storj.io/storj/pkg/rpc" "storj.io/storj/uplink/metainfo" "storj.io/storj/uplink/metainfo/kvmetainfo" ) @@ -107,9 +107,9 @@ func (cfg *Config) setDefaults(ctx context.Context) error { // a specific Satellite and caches connections and resources, allowing one to // create sessions delineated by specific access controls. type Uplink struct { - ident *identity.FullIdentity - tc transport.Client - cfg *Config + ident *identity.FullIdentity + dialer rpc.Dialer + cfg *Config } // NewUplink creates a new Uplink. This is the first step to create an uplink @@ -137,21 +137,20 @@ func NewUplink(ctx context.Context, cfg *Config) (_ *Uplink, err error) { PeerCAWhitelistPath: cfg.Volatile.TLS.PeerCAWhitelistPath, PeerIDVersions: "0", } - tlsOpts, err := tlsopts.NewOptions(ident, tlsConfig, nil) + + tlsOptions, err := tlsopts.NewOptions(ident, tlsConfig, nil) if err != nil { return nil, err } - timeouts := transport.Timeouts{ - Dial: cfg.Volatile.DialTimeout, - Request: cfg.Volatile.RequestTimeout, - } - tc := transport.NewClientWithTimeouts(tlsOpts, timeouts) + dialer := rpc.NewDefaultDialer(tlsOptions) + dialer.DialTimeout = cfg.Volatile.DialTimeout + dialer.RequestTimeout = cfg.Volatile.RequestTimeout return &Uplink{ - ident: ident, - tc: tc, - cfg: cfg, + ident: ident, + dialer: dialer, + cfg: cfg, }, nil } @@ -161,7 +160,7 @@ func NewUplink(ctx context.Context, cfg *Config) (_ *Uplink, err error) { func (u *Uplink) OpenProject(ctx context.Context, satelliteAddr string, apiKey APIKey) (p *Project, err error) { defer mon.Task()(&ctx)(&err) - m, err := metainfo.Dial(ctx, u.tc, satelliteAddr, apiKey.key) + m, err := metainfo.Dial(ctx, u.dialer, satelliteAddr, apiKey.key) if err != nil { return nil, err } @@ -173,7 +172,7 @@ func (u *Uplink) OpenProject(ctx context.Context, satelliteAddr string, apiKey A return &Project{ uplinkCfg: u.cfg, - tc: u.tc, + dialer: u.dialer, metainfo: m, project: project, maxInlineSize: u.cfg.Volatile.MaxInlineSize, diff --git a/lib/uplink/uplink_test.go b/lib/uplink/uplink_test.go index 8368fa5a8..28d5febfc 100644 --- a/lib/uplink/uplink_test.go +++ b/lib/uplink/uplink_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/assert" "storj.io/storj/internal/testcontext" - "storj.io/storj/pkg/transport" ) // TestUplinkConfigDefaults tests that the uplink configuration gets the correct defaults applied @@ -28,10 +27,8 @@ func TestUplinkConfigDefaultTimeouts(t *testing.T) { assert.Equal(t, 20*time.Second, client.cfg.Volatile.RequestTimeout) // Assert the values propagate correctly all the way down to the transport layer. - trans, ok := client.tc.(*transport.Transport) - assert.Equal(t, true, ok) - assert.Equal(t, 20*time.Second, trans.Timeouts().Dial) - assert.Equal(t, 20*time.Second, trans.Timeouts().Request) + assert.Equal(t, 20*time.Second, client.dialer.DialTimeout) + assert.Equal(t, 20*time.Second, client.dialer.RequestTimeout) } // TestUplinkConfigSetTimeouts tests that the uplink configuration settings properly override @@ -60,8 +57,6 @@ func TestUplinkConfigSetTimeouts(t *testing.T) { assert.Equal(t, 120*time.Second, client.cfg.Volatile.RequestTimeout) // Assert the values propagate correctly all the way down to the transport layer. - trans, ok := client.tc.(*transport.Transport) - assert.Equal(t, true, ok) - assert.Equal(t, 120*time.Second, trans.Timeouts().Dial) - assert.Equal(t, 120*time.Second, trans.Timeouts().Request) + assert.Equal(t, 120*time.Second, client.dialer.DialTimeout) + assert.Equal(t, 120*time.Second, client.dialer.RequestTimeout) } diff --git a/pkg/auth/grpcauth/apikey_test.go b/pkg/auth/grpcauth/apikey_test.go index 1e4c52d61..8bc040b8a 100644 --- a/pkg/auth/grpcauth/apikey_test.go +++ b/pkg/auth/grpcauth/apikey_test.go @@ -63,10 +63,10 @@ func TestAPIKey(t *testing.T) { if test.expected == codes.OK { require.NoError(t, err) - require.Equal(t, response.Message, "Hello Me") + require.Equal(t, "Hello Me", response.Message) } else { require.Error(t, err) - require.Equal(t, status.Code(err), test.expected) + require.Equal(t, test.expected, status.Code(err)) } require.NoError(t, conn.Close()) diff --git a/pkg/certificate/authorization/authorizations.go b/pkg/certificate/authorization/authorizations.go index f60147837..ec6fca36d 100644 --- a/pkg/certificate/authorization/authorizations.go +++ b/pkg/certificate/authorization/authorizations.go @@ -16,11 +16,11 @@ import ( "github.com/btcsuite/btcutil/base58" "github.com/zeebo/errs" - "google.golang.org/grpc/peer" "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcpeer" ) const ( @@ -69,7 +69,7 @@ type Token struct { // ClaimOpts hold parameters for claiming an authorization. type ClaimOpts struct { Req *pb.SigningRequest - Peer *peer.Peer + Peer *rpcpeer.Peer ChainBytes [][]byte MinDifficulty uint16 } diff --git a/pkg/certificate/authorization/authorizations_test.go b/pkg/certificate/authorization/authorizations_test.go index 506ea57c4..0e55e33df 100644 --- a/pkg/certificate/authorization/authorizations_test.go +++ b/pkg/certificate/authorization/authorizations_test.go @@ -17,8 +17,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zeebo/errs" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/peer" "storj.io/storj/internal/testcontext" "storj.io/storj/internal/testidentity" @@ -26,8 +24,9 @@ import ( "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/peertls/tlsopts" + "storj.io/storj/pkg/rpc" + "storj.io/storj/pkg/rpc/rpcpeer" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/storage" ) @@ -230,12 +229,10 @@ func TestAuthorizationDB_Claim_Valid(t *testing.T) { IP: net.ParseIP("1.2.3.4"), Port: 5, } - grpcPeer := &peer.Peer{ + peer := &rpcpeer.Peer{ Addr: addr, - AuthInfo: credentials.TLSInfo{ - State: tls.ConnectionState{ - PeerCertificates: []*x509.Certificate{ident.Leaf, ident.CA}, - }, + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ident.Leaf, ident.CA}, }, } @@ -249,7 +246,7 @@ func TestAuthorizationDB_Claim_Valid(t *testing.T) { err = authDB.Claim(ctx, &ClaimOpts{ Req: req, - Peer: grpcPeer, + Peer: peer, ChainBytes: [][]byte{ident.CA.Raw}, MinDifficulty: difficulty, }) @@ -263,7 +260,7 @@ func TestAuthorizationDB_Claim_Valid(t *testing.T) { require.NotNil(t, updatedAuths[0].Claim) claim := updatedAuths[0].Claim - assert.Equal(t, grpcPeer.Addr.String(), claim.Addr) + assert.Equal(t, peer.Addr.String(), claim.Addr) assert.Equal(t, [][]byte{ident.CA.Raw}, claim.SignedChainBytes) assert.Condition(t, func() bool { return now-MaxClaimDelaySeconds < claim.Timestamp && @@ -314,12 +311,10 @@ func TestAuthorizationDB_Claim_Invalid(t *testing.T) { IP: net.ParseIP("1.2.3.4"), Port: 5, } - grpcPeer := &peer.Peer{ + peer := &rpcpeer.Peer{ Addr: addr, - AuthInfo: credentials.TLSInfo{ - State: tls.ConnectionState{ - PeerCertificates: []*x509.Certificate{ident2.Leaf, ident2.CA}, - }, + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ident2.Leaf, ident2.CA}, }, } @@ -332,7 +327,7 @@ func TestAuthorizationDB_Claim_Invalid(t *testing.T) { AuthToken: auths[claimedIndex].Token.String(), Timestamp: time.Now().Unix(), }, - Peer: grpcPeer, + Peer: peer, ChainBytes: [][]byte{ident2.CA.Raw}, MinDifficulty: difficulty2, }) @@ -361,7 +356,7 @@ func TestAuthorizationDB_Claim_Invalid(t *testing.T) { // NB: 1 day ago Timestamp: time.Now().Unix() - 86400, }, - Peer: grpcPeer, + Peer: peer, ChainBytes: [][]byte{ident2.CA.Raw}, MinDifficulty: difficulty2, }) @@ -385,7 +380,7 @@ func TestAuthorizationDB_Claim_Invalid(t *testing.T) { AuthToken: auths[unclaimedIndex].Token.String(), Timestamp: time.Now().Unix(), }, - Peer: grpcPeer, + Peer: peer, ChainBytes: [][]byte{ident2.CA.Raw}, MinDifficulty: difficulty2 + 1, }) @@ -605,10 +600,10 @@ func TestNewClient(t *testing.T) { tlsOptions, err := tlsopts.NewOptions(ident, tlsopts.Config{}, nil) require.NoError(t, err) - clientTransport := transport.NewClient(tlsOptions) + dialer := rpc.NewDefaultDialer(tlsOptions) t.Run("Basic", func(t *testing.T) { - client, err := certificateclient.New(ctx, clientTransport, listener.Addr().String()) + client, err := certificateclient.New(ctx, dialer, listener.Addr().String()) assert.NoError(t, err) assert.NotNil(t, client) @@ -616,16 +611,13 @@ func TestNewClient(t *testing.T) { }) t.Run("ClientFrom", func(t *testing.T) { - conn, err := clientTransport.DialAddress(ctx, listener.Addr().String()) + conn, err := dialer.DialAddressInsecure(ctx, listener.Addr().String()) require.NoError(t, err) require.NotNil(t, conn) defer ctx.Check(conn.Close) - pbClient := pb.NewCertificatesClient(conn) - require.NotNil(t, pbClient) - - client := certificateclient.NewClientFrom(pbClient) + client := certificateclient.NewClientFrom(conn.CertificatesClient()) assert.NoError(t, err) assert.NotNil(t, client) diff --git a/pkg/certificate/certificateclient/client.go b/pkg/certificate/certificateclient/client.go index 20b2d510b..05abc3ed4 100644 --- a/pkg/certificate/certificateclient/client.go +++ b/pkg/certificate/certificateclient/client.go @@ -8,13 +8,12 @@ import ( "time" "github.com/zeebo/errs" - "google.golang.org/grpc" "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/peertls/tlsopts" - "storj.io/storj/pkg/transport" + "storj.io/storj/pkg/rpc" ) var mon = monkit.Package() @@ -25,30 +24,30 @@ type Config struct { TLS tlsopts.Config } -// Client implements pb.CertificateClient +// Client implements rpc.CertificatesClient type Client struct { - conn *grpc.ClientConn - client pb.CertificatesClient + conn *rpc.Conn + client rpc.CertificatesClient } -// New creates a new certificate signing grpc client. -func New(ctx context.Context, tc transport.Client, address string) (_ *Client, err error) { +// New creates a new certificate signing rpc client. +func New(ctx context.Context, dialer rpc.Dialer, address string) (_ *Client, err error) { defer mon.Task()(&ctx, address)(&err) - conn, err := tc.DialAddress(ctx, address) + conn, err := dialer.DialAddressInsecure(ctx, address) if err != nil { return nil, err } return &Client{ conn: conn, - client: pb.NewCertificatesClient(conn), + client: conn.CertificatesClient(), }, nil } // NewClientFrom creates a new certificate signing gRPC client from an existing // grpc cert signing client. -func NewClientFrom(client pb.CertificatesClient) *Client { +func NewClientFrom(client rpc.CertificatesClient) *Client { return &Client{ client: client, } @@ -58,17 +57,15 @@ func NewClientFrom(client pb.CertificatesClient) *Client { func (config Config) Sign(ctx context.Context, ident *identity.FullIdentity, authToken string) (_ [][]byte, err error) { defer mon.Task()(&ctx)(&err) - tlsOpts, err := tlsopts.NewOptions(ident, config.TLS, nil) + tlsOptions, err := tlsopts.NewOptions(ident, config.TLS, nil) if err != nil { return nil, err } - client, err := New(ctx, transport.NewClient(tlsOpts), config.Address) + client, err := New(ctx, rpc.NewDefaultDialer(tlsOptions), config.Address) if err != nil { return nil, err } - defer func() { - err = errs.Combine(err, client.Close()) - }() + defer func() { err = errs.Combine(err, client.Close()) }() return client.Sign(ctx, authToken) } diff --git a/pkg/certificate/endpoint.go b/pkg/certificate/endpoint.go index e017e2379..5809d263f 100644 --- a/pkg/certificate/endpoint.go +++ b/pkg/certificate/endpoint.go @@ -7,13 +7,12 @@ import ( "context" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/peer" - "google.golang.org/grpc/status" "storj.io/storj/pkg/certificate/authorization" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcpeer" + "storj.io/storj/pkg/rpc/rpcstatus" ) // Endpoint implements pb.CertificatesServer. @@ -38,14 +37,14 @@ func NewEndpoint(log *zap.Logger, ca *identity.FullCertificateAuthority, authori // Returns a certificate chain consisting of the remote peer's CA followed by the CA chain. func (endpoint Endpoint) Sign(ctx context.Context, req *pb.SigningRequest) (_ *pb.SigningResponse, err error) { defer mon.Task()(&ctx)(&err) - grpcPeer, ok := peer.FromContext(ctx) - if !ok { + peer, err := rpcpeer.FromContext(ctx) + if err != nil { msg := "error getting peer from context" endpoint.log.Error(msg, zap.Error(err)) return nil, internalErr(msg) } - peerIdent, err := identity.PeerIdentityFromPeer(grpcPeer) + peerIdent, err := identity.PeerIdentityFromPeer(peer) if err != nil { msg := "error getting peer identity" endpoint.log.Error(msg, zap.Error(err)) @@ -63,7 +62,7 @@ func (endpoint Endpoint) Sign(ctx context.Context, req *pb.SigningRequest) (_ *p signedChainBytes = append(signedChainBytes, endpoint.ca.RawRestChain()...) err = endpoint.authorizationDB.Claim(ctx, &authorization.ClaimOpts{ Req: req, - Peer: grpcPeer, + Peer: peer, ChainBytes: signedChainBytes, MinDifficulty: endpoint.minDifficulty, }) @@ -100,5 +99,5 @@ func (endpoint Endpoint) Sign(ctx context.Context, req *pb.SigningRequest) (_ *p } func internalErr(msg string) error { - return status.Error(codes.Internal, Error.New(msg).Error()) + return rpcstatus.Error(rpcstatus.Internal, Error.New(msg).Error()) } diff --git a/pkg/certificate/peer.go b/pkg/certificate/peer.go index 2026faf8e..dd2b2c4e4 100644 --- a/pkg/certificate/peer.go +++ b/pkg/certificate/peer.go @@ -72,12 +72,12 @@ func New(log *zap.Logger, ident *identity.FullIdentity, ca *identity.FullCertifi log.Debug("Starting listener and server") sc := config.Server - options, err := tlsopts.NewOptions(peer.Identity, sc.Config, revocationDB) + tlsOptions, err := tlsopts.NewOptions(peer.Identity, sc.Config, revocationDB) if err != nil { return nil, Error.Wrap(errs.Combine(err, peer.Close())) } - peer.Server, err = server.New(log.Named("server"), options, sc.Address, sc.PrivateAddress, nil) + peer.Server, err = server.New(log.Named("server"), tlsOptions, sc.Address, sc.PrivateAddress, nil) if err != nil { return nil, Error.Wrap(err) } diff --git a/pkg/certificate/peer_test.go b/pkg/certificate/peer_test.go index 0fe7180fc..58debecf7 100644 --- a/pkg/certificate/peer_test.go +++ b/pkg/certificate/peer_test.go @@ -13,8 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/peer" "storj.io/storj/internal/testcontext" "storj.io/storj/internal/testidentity" @@ -25,9 +23,10 @@ import ( "storj.io/storj/pkg/pb" "storj.io/storj/pkg/peertls/tlsopts" "storj.io/storj/pkg/pkcrypto" + "storj.io/storj/pkg/rpc" + "storj.io/storj/pkg/rpc/rpcpeer" "storj.io/storj/pkg/server" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" ) // TODO: test sad path @@ -82,14 +81,14 @@ func TestCertificateSigner_Sign_E2E(t *testing.T) { }) defer ctx.Check(peer.Close) - clientOpts, err := tlsopts.NewOptions(clientIdent, tlsopts.Config{ + tlsOptions, err := tlsopts.NewOptions(clientIdent, tlsopts.Config{ PeerIDVersions: "*", }, nil) require.NoError(t, err) - clientTransport := transport.NewClient(clientOpts) + dialer := rpc.NewDefaultDialer(tlsOptions) - client, err := certificateclient.New(ctx, clientTransport, peer.Server.Addr().String()) + client, err := certificateclient.New(ctx, dialer, peer.Server.Addr().String()) require.NoError(t, err) require.NotNil(t, client) defer ctx.Check(client.Close) @@ -163,15 +162,13 @@ func TestCertificateSigner_Sign(t *testing.T) { IP: net.ParseIP("1.2.3.4"), Port: 5, } - grpcPeer := &peer.Peer{ + peer := &rpcpeer.Peer{ Addr: expectedAddr, - AuthInfo: credentials.TLSInfo{ - State: tls.ConnectionState{ - PeerCertificates: []*x509.Certificate{ident.Leaf, ident.CA}, - }, + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ident.Leaf, ident.CA}, }, } - peerCtx := peer.NewContext(ctx, grpcPeer) + peerCtx := rpcpeer.NewContext(ctx, peer) certSigner := certificate.NewEndpoint(zaptest.NewLogger(t), ca, authDB, 0) req := pb.SigningRequest{ diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go index 465392e8e..547f58862 100644 --- a/pkg/identity/identity.go +++ b/pkg/identity/identity.go @@ -18,12 +18,11 @@ import ( "time" "github.com/zeebo/errs" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/peer" "storj.io/storj/pkg/peertls" "storj.io/storj/pkg/peertls/extensions" "storj.io/storj/pkg/pkcrypto" + "storj.io/storj/pkg/rpc/rpcpeer" "storj.io/storj/pkg/storj" ) @@ -197,17 +196,8 @@ func PeerIdentityFromChain(chain []*x509.Certificate) (*PeerIdentity, error) { } // PeerIdentityFromPeer loads a PeerIdentity from a peer connection. -func PeerIdentityFromPeer(peer *peer.Peer) (*PeerIdentity, error) { - if peer.AuthInfo == nil { - return nil, Error.New("peer AuthInfo is nil") - } - - tlsInfo, ok := peer.AuthInfo.(credentials.TLSInfo) - if !ok { - return nil, Error.New("peer AuthInfo is not credentials.TLSInfo") - } - - chain := tlsInfo.State.PeerCertificates +func PeerIdentityFromPeer(peer *rpcpeer.Peer) (*PeerIdentity, error) { + chain := peer.State.PeerCertificates if len(chain)-1 < peertls.CAIndex { return nil, Error.New("invalid certificate chain") } @@ -215,18 +205,16 @@ func PeerIdentityFromPeer(peer *peer.Peer) (*PeerIdentity, error) { if err != nil { return nil, err } - return pi, nil } // PeerIdentityFromContext loads a PeerIdentity from a ctx TLS credentials. func PeerIdentityFromContext(ctx context.Context) (*PeerIdentity, error) { - p, ok := peer.FromContext(ctx) - if !ok { - return nil, Error.New("unable to get grpc peer from contex") + peer, err := rpcpeer.FromContext(ctx) + if err != nil { + return nil, err } - - return PeerIdentityFromPeer(p) + return PeerIdentityFromPeer(peer) } // NodeIDFromCertPath loads a node ID from a certificate file path. diff --git a/pkg/kademlia/endpoint.go b/pkg/kademlia/endpoint.go index 87d01363b..e663fc0ee 100644 --- a/pkg/kademlia/endpoint.go +++ b/pkg/kademlia/endpoint.go @@ -10,12 +10,11 @@ import ( "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/peer" - "google.golang.org/grpc/status" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcpeer" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/storj" ) @@ -105,16 +104,16 @@ func (endpoint *Endpoint) Ping(ctx context.Context, req *pb.PingRequest) (_ *pb. // NOTE: this code is very similar to that in storagenode/contact.(*Endpoint).PingNode(). // That other will be used going forward, and this will soon be gutted and deprecated. The // code similarity will only exist until the transition away from Kademlia is complete. - p, ok := peer.FromContext(ctx) - if !ok { - return nil, status.Error(codes.Internal, "unable to get grpc peer from context") - } - peerID, err := identity.PeerIdentityFromPeer(p) + peer, err := rpcpeer.FromContext(ctx) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) + } + peerID, err := identity.PeerIdentityFromPeer(peer) + if err != nil { + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } if endpoint.pingStats != nil { - endpoint.pingStats.WasPinged(time.Now(), peerID.ID, p.Addr.String()) + endpoint.pingStats.WasPinged(time.Now(), peerID.ID, peer.Addr.String()) } return &pb.PingResponse{}, nil } @@ -126,17 +125,17 @@ func (endpoint *Endpoint) RequestInfo(ctx context.Context, req *pb.InfoRequest) if self.Type == pb.NodeType_STORAGE { if endpoint.trust == nil { - return nil, status.Error(codes.Internal, "missing trust") + return nil, rpcstatus.Error(rpcstatus.Internal, "missing trust") } peer, err := identity.PeerIdentityFromContext(ctx) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.trust.VerifySatelliteID(ctx, peer.ID) if err != nil { - return nil, status.Errorf(codes.PermissionDenied, "untrusted peer %v", peer.ID) + return nil, rpcstatus.Errorf(rpcstatus.PermissionDenied, "untrusted peer %v", peer.ID) } } diff --git a/pkg/kademlia/kademlia.go b/pkg/kademlia/kademlia.go index f97616e25..fd4aba437 100644 --- a/pkg/kademlia/kademlia.go +++ b/pkg/kademlia/kademlia.go @@ -17,8 +17,8 @@ import ( "storj.io/storj/pkg/identity" "storj.io/storj/pkg/kademlia/kademliaclient" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/overlay" "storj.io/storj/storage" ) @@ -53,7 +53,7 @@ type Kademlia struct { } // NewService returns a newly configured Kademlia instance -func NewService(log *zap.Logger, transport transport.Client, rt *RoutingTable, config Config) (*Kademlia, error) { +func NewService(log *zap.Logger, dialer rpc.Dialer, rt *RoutingTable, config Config) (*Kademlia, error) { k := &Kademlia{ log: log, alpha: config.Alpha, @@ -61,7 +61,7 @@ func NewService(log *zap.Logger, transport transport.Client, rt *RoutingTable, c bootstrapNodes: config.BootstrapNodes(), bootstrapBackoffMax: config.BootstrapBackoffMax, bootstrapBackoffBase: config.BootstrapBackoffBase, - dialer: kademliaclient.NewDialer(log.Named("dialer"), transport), + dialer: kademliaclient.NewDialer(log.Named("dialer"), dialer, rt), refreshThreshold: int64(time.Minute), } @@ -147,17 +147,6 @@ func (k *Kademlia) Bootstrap(ctx context.Context) (err error) { continue } - // FetchPeerIdentityUnverified uses transport.DialAddress, which should be - // enough to have the TransportObservers find out about this node. Unfortunately, - // getting DialAddress to be able to grab the node id seems challenging with gRPC. - // The way FetchPeerIdentityUnverified does is is to do a basic ping request, which - // we have now done. Let's tell all the transport observers now. - // TODO: remove the explicit transport observer notification - k.dialer.AlertSuccess(ctx, &pb.Node{ - Id: ident.ID, - Address: node.Address, - }) - k.routingTable.mutex.Lock() node.Id = ident.ID k.bootstrapNodes[i] = node diff --git a/pkg/kademlia/kademlia_test.go b/pkg/kademlia/kademlia_test.go index 9cae85f76..28faf0e99 100644 --- a/pkg/kademlia/kademlia_test.go +++ b/pkg/kademlia/kademlia_test.go @@ -24,8 +24,8 @@ import ( "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/peertls/tlsopts" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/overlay" "storj.io/storj/storage/teststore" ) @@ -429,15 +429,13 @@ func newKademlia(log *zap.Logger, nodeType pb.NodeType, bootstrapNodes []pb.Node return nil, err } - transportClient := transport.NewClient(tlsOptions, rt) - kadConfig := Config{ BootstrapBackoffMax: 10 * time.Second, BootstrapBackoffBase: 1 * time.Second, Alpha: alpha, } - kad, err := NewService(log, transportClient, rt, kadConfig) + kad, err := NewService(log, rpc.NewDefaultDialer(tlsOptions), rt, kadConfig) if err != nil { return nil, err } diff --git a/pkg/kademlia/kademliaclient/kademliaclient.go b/pkg/kademlia/kademliaclient/kademliaclient.go index 501292552..c67eee965 100644 --- a/pkg/kademlia/kademliaclient/kademliaclient.go +++ b/pkg/kademlia/kademliaclient/kademliaclient.go @@ -8,40 +8,52 @@ import ( "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/peer" monkit "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/internal/sync2" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" ) var mon = monkit.Package() -// Dialer sends requests to kademlia endpoints on storage nodes -type Dialer struct { - log *zap.Logger - transport transport.Client - limit sync2.Semaphore -} - // Conn represents a connection type Conn struct { - conn *grpc.ClientConn - client pb.NodesClient + conn *rpc.Conn + client rpc.NodesClient +} + +// Close closes this connection. +func (conn *Conn) Close() error { + return conn.conn.Close() +} + +// Dialer sends requests to kademlia endpoints on storage nodes +type Dialer struct { + log *zap.Logger + dialer rpc.Dialer + obs Observer + limit sync2.Semaphore +} + +// Observer implements the ConnSuccess and ConnFailure methods +// for Discovery and other services to use +type Observer interface { + ConnSuccess(ctx context.Context, node *pb.Node) + ConnFailure(ctx context.Context, node *pb.Node, err error) } // NewDialer creates a new kademlia dialer. -func NewDialer(log *zap.Logger, transport transport.Client) *Dialer { - dialer := &Dialer{ - log: log, - transport: transport, +func NewDialer(log *zap.Logger, dialer rpc.Dialer, obs Observer) *Dialer { + d := &Dialer{ + log: log, + dialer: dialer, + obs: obs, } - dialer.limit.Init(32) // TODO: limit should not be hardcoded - return dialer + d.limit.Init(32) // TODO: limit should not be hardcoded + return d } // Close closes the pool resources and prevents new connections to be made. @@ -72,9 +84,7 @@ func (dialer *Dialer) Lookup(ctx context.Context, self *pb.Node, ask pb.Node, fi if err != nil { return nil, err } - defer func() { - err = errs.Combine(err, conn.disconnect()) - }() + defer func() { err = errs.Combine(err, conn.Close()) }() resp, err := conn.client.Query(ctx, &req) if err != nil { @@ -96,10 +106,10 @@ func (dialer *Dialer) PingNode(ctx context.Context, target pb.Node) (_ bool, err if err != nil { return false, err } + defer func() { err = errs.Combine(err, conn.Close()) }() _, err = conn.client.Ping(ctx, &pb.PingRequest{}) - - return err == nil, errs.Combine(err, conn.disconnect()) + return err == nil, err } // FetchPeerIdentity connects to a node and returns its peer identity @@ -114,18 +124,13 @@ func (dialer *Dialer) FetchPeerIdentity(ctx context.Context, target pb.Node) (_ if err != nil { return nil, err } - defer func() { - err = errs.Combine(err, conn.disconnect()) - }() + defer func() { err = errs.Combine(err, conn.Close()) }() - p := &peer.Peer{} - _, err = conn.client.Ping(ctx, &pb.PingRequest{}, grpc.Peer(p)) - ident, errFromPeer := identity.PeerIdentityFromPeer(p) - return ident, errs.Combine(err, errFromPeer) + return conn.conn.PeerIdentity() } // FetchPeerIdentityUnverified connects to an address and returns its peer identity (no node ID verification). -func (dialer *Dialer) FetchPeerIdentityUnverified(ctx context.Context, address string, opts ...grpc.CallOption) (_ *identity.PeerIdentity, err error) { +func (dialer *Dialer) FetchPeerIdentityUnverified(ctx context.Context, address string) (_ *identity.PeerIdentity, err error) { defer mon.Task()(&ctx)(&err) if !dialer.limit.Lock() { return nil, context.Canceled @@ -136,14 +141,9 @@ func (dialer *Dialer) FetchPeerIdentityUnverified(ctx context.Context, address s if err != nil { return nil, err } - defer func() { - err = errs.Combine(err, conn.disconnect()) - }() + defer func() { err = errs.Combine(err, conn.Close()) }() - p := &peer.Peer{} - _, err = conn.client.Ping(ctx, &pb.PingRequest{}, grpc.Peer(p)) - ident, errFromPeer := identity.PeerIdentityFromPeer(p) - return ident, errs.Combine(err, errFromPeer) + return conn.conn.PeerIdentity() } // FetchInfo connects to a node and returns its node info. @@ -158,38 +158,58 @@ func (dialer *Dialer) FetchInfo(ctx context.Context, target pb.Node) (_ *pb.Info if err != nil { return nil, err } + defer func() { err = errs.Combine(err, conn.Close()) }() resp, err := conn.client.RequestInfo(ctx, &pb.InfoRequest{}) + if err != nil { + return nil, err + } - return resp, errs.Combine(err, conn.disconnect()) -} - -// AlertSuccess alerts the transport observers of a successful connection -func (dialer *Dialer) AlertSuccess(ctx context.Context, node *pb.Node) { - dialer.transport.AlertSuccess(ctx, node) + return resp, nil } // dialNode dials the specified node. func (dialer *Dialer) dialNode(ctx context.Context, target pb.Node) (_ *Conn, err error) { defer mon.Task()(&ctx)(&err) - grpcconn, err := dialer.transport.DialNode(ctx, &target) + + conn, err := dialer.dialer.DialNode(ctx, &target) + if err != nil { + if dialer.obs != nil { + dialer.obs.ConnFailure(ctx, &target, err) + } + return nil, err + } + if dialer.obs != nil { + dialer.obs.ConnSuccess(ctx, &target) + } + return &Conn{ - conn: grpcconn, - client: pb.NewNodesClient(grpcconn), - }, err + conn: conn, + client: conn.NodesClient(), + }, nil } // dialAddress dials the specified node by address (no node ID verification) func (dialer *Dialer) dialAddress(ctx context.Context, address string) (_ *Conn, err error) { defer mon.Task()(&ctx)(&err) - grpcconn, err := dialer.transport.DialAddress(ctx, address) - return &Conn{ - conn: grpcconn, - client: pb.NewNodesClient(grpcconn), - }, err -} -// disconnect disconnects this connection. -func (conn *Conn) disconnect() error { - return conn.conn.Close() + conn, err := dialer.dialer.DialAddressInsecure(ctx, address) + if err != nil { + // TODO: can't get an id here because we failed to dial + return nil, err + } + if ident, err := conn.PeerIdentity(); err == nil && dialer.obs != nil { + dialer.obs.ConnSuccess(ctx, &pb.Node{ + Id: ident.ID, + Address: &pb.NodeAddress{ + Transport: pb.NodeTransport_TCP_TLS_GRPC, + Address: address, + }, + }) + } + + return &Conn{ + conn: conn, + client: conn.NodesClient(), + }, nil } diff --git a/pkg/miniogw/gateway_test.go b/pkg/miniogw/gateway_test.go index 11c8d0d67..488bec91a 100644 --- a/pkg/miniogw/gateway_test.go +++ b/pkg/miniogw/gateway_test.go @@ -690,7 +690,7 @@ func initEnv(ctx context.Context, t *testing.T, planet *testplanet.Planet) (mini } // TODO(leak): close m metainfo.Client somehow - ec := ecclient.NewClient(planet.Uplinks[0].Log.Named("ecclient"), planet.Uplinks[0].Transport, 0) + ec := ecclient.NewClient(planet.Uplinks[0].Log.Named("ecclient"), planet.Uplinks[0].Dialer, 0) fc, err := infectious.NewFEC(2, 4) if err != nil { return nil, nil, nil, err diff --git a/pkg/peertls/tlsopts/options_test.go b/pkg/peertls/tlsopts/options_test.go index d47b1f871..205c6492f 100644 --- a/pkg/peertls/tlsopts/options_test.go +++ b/pkg/peertls/tlsopts/options_test.go @@ -19,8 +19,8 @@ import ( "storj.io/storj/pkg/peertls/extensions" "storj.io/storj/pkg/peertls/tlsopts" "storj.io/storj/pkg/revocation" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" ) func TestNewOptions(t *testing.T) { @@ -110,12 +110,12 @@ func TestNewOptions(t *testing.T) { revocationDB, err := revocation.NewDBFromCfg(c.config) require.NoError(t, err) - opts, err := tlsopts.NewOptions(fi, c.config, revocationDB) + tlsOptions, err := tlsopts.NewOptions(fi, c.config, revocationDB) assert.NoError(t, err) - assert.True(t, reflect.DeepEqual(fi, opts.Ident)) - assert.Equal(t, c.config, opts.Config) - assert.Len(t, opts.VerificationFuncs.Client(), c.clientVerificationFuncsLen) - assert.Len(t, opts.VerificationFuncs.Server(), c.serverVerificationFuncsLen) + assert.True(t, reflect.DeepEqual(fi, tlsOptions.Ident)) + assert.Equal(t, c.config, tlsOptions.Config) + assert.Len(t, tlsOptions.VerificationFuncs.Client(), c.clientVerificationFuncsLen) + assert.Len(t, tlsOptions.VerificationFuncs.Server(), c.serverVerificationFuncsLen) require.NoError(t, revocationDB.Close()) } @@ -133,17 +133,14 @@ func TestOptions_ServerOption_Peer_CA_Whitelist(t *testing.T) { target := planet.StorageNodes[1].Local() testidentity.CompleteIdentityVersionsTest(t, func(t *testing.T, version storj.IDVersion, ident *identity.FullIdentity) { - opts, err := tlsopts.NewOptions(ident, tlsopts.Config{ + tlsOptions, err := tlsopts.NewOptions(ident, tlsopts.Config{ PeerIDVersions: "*", }, nil) require.NoError(t, err) - dialOption, err := opts.DialOption(target.Id) - require.NoError(t, err) + dialer := rpc.NewDefaultDialer(tlsOptions) - transportClient := transport.NewClient(opts) - - conn, err := transportClient.DialNode(ctx, &target.Node, dialOption) + conn, err := dialer.DialNode(ctx, &target.Node) assert.NotNil(t, conn) assert.NoError(t, err) @@ -153,12 +150,12 @@ func TestOptions_ServerOption_Peer_CA_Whitelist(t *testing.T) { func TestOptions_DialOption_error_on_empty_ID(t *testing.T) { testidentity.CompleteIdentityVersionsTest(t, func(t *testing.T, version storj.IDVersion, ident *identity.FullIdentity) { - opts, err := tlsopts.NewOptions(ident, tlsopts.Config{ + tlsOptions, err := tlsopts.NewOptions(ident, tlsopts.Config{ PeerIDVersions: "*", }, nil) require.NoError(t, err) - dialOption, err := opts.DialOption(storj.NodeID{}) + dialOption, err := tlsOptions.DialOption(storj.NodeID{}) assert.Nil(t, dialOption) assert.Error(t, err) }) @@ -166,12 +163,12 @@ func TestOptions_DialOption_error_on_empty_ID(t *testing.T) { func TestOptions_DialUnverifiedIDOption(t *testing.T) { testidentity.CompleteIdentityVersionsTest(t, func(t *testing.T, version storj.IDVersion, ident *identity.FullIdentity) { - opts, err := tlsopts.NewOptions(ident, tlsopts.Config{ + tlsOptions, err := tlsopts.NewOptions(ident, tlsopts.Config{ PeerIDVersions: "*", }, nil) require.NoError(t, err) - dialOption := opts.DialUnverifiedIDOption() + dialOption := tlsOptions.DialUnverifiedIDOption() assert.NotNil(t, dialOption) }) } diff --git a/pkg/rpc/common_drpc.go b/pkg/rpc/common_drpc.go new file mode 100644 index 000000000..515050b11 --- /dev/null +++ b/pkg/rpc/common_drpc.go @@ -0,0 +1,14 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// +build drpc + +package rpc + +const ( + // IsDRPC is true if drpc is being used. + IsDRPC = true + + // IsGRPC is true if grpc is being used. + IsGRPC = false +) diff --git a/pkg/rpc/common_grpc.go b/pkg/rpc/common_grpc.go new file mode 100644 index 000000000..5ac7f0c36 --- /dev/null +++ b/pkg/rpc/common_grpc.go @@ -0,0 +1,14 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// +build !drpc + +package rpc + +const ( + // IsDRPC is true if drpc is being used. + IsDRPC = false + + // IsGRPC is true if grpc is being used. + IsGRPC = true +) diff --git a/pkg/rpc/dial.go b/pkg/rpc/dial.go index e10edcb20..923285efc 100644 --- a/pkg/rpc/dial.go +++ b/pkg/rpc/dial.go @@ -67,7 +67,9 @@ func (d Dialer) dialContext(ctx context.Context, address string) (net.Conn, erro conn, err := new(net.Dialer).DialContext(ctx, "tcp", address) if err != nil { - return nil, Error.Wrap(err) + // N.B. this error is not wrapped on purpose! grpc code cares about inspecting + // it and it's not smart enough to attempt to do any unwrapping. :( + return nil, err } return &timedConn{ diff --git a/pkg/rpc/dial_grpc.go b/pkg/rpc/dial_grpc.go index 105dbc5cc..fad93963e 100644 --- a/pkg/rpc/dial_grpc.go +++ b/pkg/rpc/dial_grpc.go @@ -61,12 +61,13 @@ type captureStateCreds struct { credentials.TransportCredentials once sync.Once state tls.ConnectionState + ok bool } // Get returns the stored tls connection state. func (c *captureStateCreds) Get() (state tls.ConnectionState, ok bool) { - c.once.Do(func() { ok = true }) - return c.state, ok + c.once.Do(func() {}) + return c.state, c.ok } // ClientHandshake dispatches to the underlying credentials and tries to store the @@ -76,7 +77,7 @@ func (c *captureStateCreds) ClientHandshake(ctx context.Context, authority strin conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, authority, rawConn) if tlsInfo, ok := auth.(credentials.TLSInfo); ok { - c.once.Do(func() { c.state = tlsInfo.State }) + c.once.Do(func() { c.state, c.ok = tlsInfo.State, true }) } return conn, auth, err } @@ -88,7 +89,7 @@ func (c *captureStateCreds) ServerHandshake(rawConn net.Conn) ( conn, auth, err := c.TransportCredentials.ServerHandshake(rawConn) if tlsInfo, ok := auth.(credentials.TLSInfo); ok { - c.once.Do(func() { c.state = tlsInfo.State }) + c.once.Do(func() { c.state, c.ok = tlsInfo.State, true }) } return conn, auth, err } diff --git a/pkg/transport/transport_test.go b/pkg/rpc/rpc_test.go similarity index 75% rename from pkg/transport/transport_test.go rename to pkg/rpc/rpc_test.go index cc2fbd5f1..0c61456e1 100644 --- a/pkg/transport/transport_test.go +++ b/pkg/rpc/rpc_test.go @@ -1,7 +1,7 @@ // Copyright (C) 2019 Storj Labs, Inc. // See LICENSE for copying information. -package transport_test +package rpc_test import ( "context" @@ -18,6 +18,7 @@ import ( "storj.io/storj/internal/testplanet" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/peertls/tlsopts" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" ) @@ -36,26 +37,28 @@ func TestDialNode(t *testing.T) { planet.Start(ctx) - client := planet.StorageNodes[0].Transport - unsignedIdent, err := testidentity.PregeneratedIdentity(0, storj.LatestIDVersion()) require.NoError(t, err) signedIdent, err := testidentity.PregeneratedSignedIdentity(0, storj.LatestIDVersion()) require.NoError(t, err) - opts, err := tlsopts.NewOptions(signedIdent, tlsopts.Config{ + tlsOptions, err := tlsopts.NewOptions(signedIdent, tlsopts.Config{ UsePeerCAWhitelist: true, PeerCAWhitelistPath: whitelistPath, PeerIDVersions: "*", }, nil) require.NoError(t, err) + dialer := rpc.NewDefaultDialer(tlsOptions) + unsignedClientOpts, err := tlsopts.NewOptions(unsignedIdent, tlsopts.Config{ PeerIDVersions: "*", }, nil) require.NoError(t, err) + unsignedDialer := rpc.NewDefaultDialer(unsignedClientOpts) + t.Run("DialNode with invalid targets", func(t *testing.T) { targets := []*pb.Node{ { @@ -88,32 +91,13 @@ func TestDialNode(t *testing.T) { tag := fmt.Sprintf("%+v", target) timedCtx, cancel := context.WithTimeout(ctx, time.Second) - conn, err := client.DialNode(timedCtx, target) + conn, err := dialer.DialNode(timedCtx, target) cancel() assert.Error(t, err, tag) assert.Nil(t, conn, tag) } }) - t.Run("DialNode with valid target", func(t *testing.T) { - target := &pb.Node{ - Id: planet.StorageNodes[1].ID(), - Address: &pb.NodeAddress{ - Transport: pb.NodeTransport_TCP_TLS_GRPC, - Address: planet.StorageNodes[1].Addr(), - }, - } - - timedCtx, cancel := context.WithTimeout(ctx, time.Second) - conn, err := client.DialNode(timedCtx, target) - cancel() - - assert.NoError(t, err) - require.NotNil(t, conn) - - assert.NoError(t, conn.Close()) - }) - t.Run("DialNode with valid signed target", func(t *testing.T) { target := &pb.Node{ Id: planet.StorageNodes[1].ID(), @@ -123,11 +107,8 @@ func TestDialNode(t *testing.T) { }, } - dialOption, err := opts.DialOption(target.Id) - require.NoError(t, err) - timedCtx, cancel := context.WithTimeout(ctx, time.Second) - conn, err := client.DialNode(timedCtx, target, dialOption) + conn, err := dialer.DialNode(timedCtx, target) cancel() assert.NoError(t, err) @@ -146,12 +127,7 @@ func TestDialNode(t *testing.T) { } timedCtx, cancel := context.WithTimeout(ctx, time.Second) - dialOption, err := unsignedClientOpts.DialOption(target.Id) - require.NoError(t, err) - - conn, err := client.DialNode( - timedCtx, target, dialOption, - ) + conn, err := unsignedDialer.DialNode(timedCtx, target) cancel() assert.NotNil(t, conn) @@ -161,10 +137,7 @@ func TestDialNode(t *testing.T) { t.Run("DialAddress with unsigned identity", func(t *testing.T) { timedCtx, cancel := context.WithTimeout(ctx, time.Second) - dialOption := unsignedClientOpts.DialUnverifiedIDOption() - conn, err := client.DialAddress( - timedCtx, planet.StorageNodes[1].Addr(), dialOption, - ) + conn, err := unsignedDialer.DialAddressInsecure(timedCtx, planet.StorageNodes[1].Addr()) cancel() assert.NotNil(t, conn) @@ -174,7 +147,7 @@ func TestDialNode(t *testing.T) { t.Run("DialAddress with valid address", func(t *testing.T) { timedCtx, cancel := context.WithTimeout(ctx, time.Second) - conn, err := client.DialAddress(timedCtx, planet.StorageNodes[1].Addr()) + conn, err := dialer.DialAddressInsecure(timedCtx, planet.StorageNodes[1].Addr()) cancel() assert.NoError(t, err) @@ -207,16 +180,17 @@ func TestDialNode_BadServerCertificate(t *testing.T) { planet.Start(ctx) - client := planet.StorageNodes[0].Transport ident, err := testidentity.PregeneratedSignedIdentity(0, storj.LatestIDVersion()) require.NoError(t, err) - opts, err := tlsopts.NewOptions(ident, tlsopts.Config{ + tlsOptions, err := tlsopts.NewOptions(ident, tlsopts.Config{ UsePeerCAWhitelist: true, PeerCAWhitelistPath: whitelistPath, }, nil) require.NoError(t, err) + dialer := rpc.NewDefaultDialer(tlsOptions) + t.Run("DialNode with bad server certificate", func(t *testing.T) { target := &pb.Node{ Id: planet.StorageNodes[1].ID(), @@ -227,10 +201,7 @@ func TestDialNode_BadServerCertificate(t *testing.T) { } timedCtx, cancel := context.WithTimeout(ctx, time.Second) - dialOption, err := opts.DialOption(target.Id) - require.NoError(t, err) - - conn, err := client.DialNode(timedCtx, target, dialOption) + conn, err := dialer.DialNode(timedCtx, target) cancel() tag := fmt.Sprintf("%+v", target) @@ -241,10 +212,7 @@ func TestDialNode_BadServerCertificate(t *testing.T) { t.Run("DialAddress with bad server certificate", func(t *testing.T) { timedCtx, cancel := context.WithTimeout(ctx, time.Second) - dialOption, err := opts.DialOption(planet.StorageNodes[1].ID()) - require.NoError(t, err) - - conn, err := client.DialAddress(timedCtx, planet.StorageNodes[1].Addr(), dialOption) + conn, err := dialer.DialAddressID(timedCtx, planet.StorageNodes[1].Addr(), planet.StorageNodes[1].ID()) cancel() assert.Nil(t, conn) diff --git a/pkg/rpc/rpcpeer/peer.go b/pkg/rpc/rpcpeer/peer.go new file mode 100644 index 000000000..2902478db --- /dev/null +++ b/pkg/rpc/rpcpeer/peer.go @@ -0,0 +1,37 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +package rpcpeer + +import ( + "context" + "crypto/tls" + "net" + + "github.com/zeebo/errs" +) + +// Error is the class of errors returned by this package. +var Error = errs.Class("rpcpeer") + +// Peer represents an rpc peer. +type Peer struct { + Addr net.Addr + State tls.ConnectionState +} + +// peerKey is used as a unique value for context keys. +type peerKey struct{} + +// NewContext returns a new context with the peer associated as a value. +func NewContext(ctx context.Context, peer *Peer) context.Context { + return context.WithValue(ctx, peerKey{}, peer) +} + +// FromContext returns the peer that was previously associated by NewContext. +func FromContext(ctx context.Context) (*Peer, error) { + if peer, ok := ctx.Value(peerKey{}).(*Peer); ok { + return peer, nil + } + return internalFromContext(ctx) +} diff --git a/pkg/rpc/rpcpeer/peer_drpc.go b/pkg/rpc/rpcpeer/peer_drpc.go new file mode 100644 index 000000000..de72fe844 --- /dev/null +++ b/pkg/rpc/rpcpeer/peer_drpc.go @@ -0,0 +1,31 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// +build drpc + +package rpcpeer + +import ( + "context" + "crypto/tls" + + "storj.io/drpc/drpcctx" +) + +// internalFromContext returns a peer from the context using drpc. +func internalFromContext(ctx context.Context) (*Peer, error) { + tr, ok := drpcctx.Transport(ctx) + if !ok { + return nil, Error.New("unable to get drpc peer from context") + } + + conn, ok := tr.(*tls.Conn) + if !ok { + return nil, Error.New("drpc transport is not a *tls.Conn") + } + + return &Peer{ + Addr: conn.RemoteAddr(), + State: conn.ConnectionState(), + }, nil +} diff --git a/pkg/rpc/rpcpeer/peer_grpc.go b/pkg/rpc/rpcpeer/peer_grpc.go new file mode 100644 index 000000000..58fbdd469 --- /dev/null +++ b/pkg/rpc/rpcpeer/peer_grpc.go @@ -0,0 +1,31 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// +build !drpc + +package rpcpeer + +import ( + "context" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" +) + +// internalFromContext returns a peer from the context using grpc. +func internalFromContext(ctx context.Context) (*Peer, error) { + peer, ok := peer.FromContext(ctx) + if !ok { + return nil, Error.New("unable to get grpc peer from context") + } + + tlsInfo, ok := peer.AuthInfo.(credentials.TLSInfo) + if !ok { + return nil, Error.New("peer AuthInfo is not credentials.TLSInfo") + } + + return &Peer{ + Addr: peer.Addr, + State: tlsInfo.State, + }, nil +} diff --git a/pkg/rpc/rpcstatus/status_drpc.go b/pkg/rpc/rpcstatus/status_drpc.go new file mode 100644 index 000000000..db0f8050b --- /dev/null +++ b/pkg/rpc/rpcstatus/status_drpc.go @@ -0,0 +1,63 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// +build drpc + +package rpcstatus + +import ( + "context" + "errors" + "fmt" + + "storj.io/drpc/drpcerr" +) + +// StatusCode is the type of status codes for drpc. +type StatusCode uint64 + +// These constants are all the rpc error codes. +const ( + Unknown StatusCode = iota + OK + Canceled + InvalidArgument + DeadlineExceeded + NotFound + AlreadyExists + PermissionDenied + ResourceExhausted + FailedPrecondition + Aborted + OutOfRange + Unimplemented + Internal + Unavailable + DataLoss + Unauthenticated +) + +// Code returns the status code associated with the error. +func Code(err error) StatusCode { + // special case: if the error is context canceled or deadline exceeded, the code + // must be those. + switch err { + case context.Canceled: + return Canceled + case context.DeadlineExceeded: + return DeadlineExceeded + default: + return drpcerr.Code(err) + } + +} + +// Error wraps the message with a status code into an error. +func Error(code StatusCode, msg string) error { + return drpcerr.WithCode(errors.New(msg), uint64(code)) +} + +// Errorf : Error :: fmt.Sprintf : fmt.Sprint +func Errorf(code StatusCode, format string, a ...interface{}) error { + return drpcerr.WithCode(fmt.Errorf(format, a...), uint64(code)) +} diff --git a/pkg/rpc/rpcstatus/status_grpc.go b/pkg/rpc/rpcstatus/status_grpc.go new file mode 100644 index 000000000..6b8c80cbb --- /dev/null +++ b/pkg/rpc/rpcstatus/status_grpc.go @@ -0,0 +1,50 @@ +// Copyright (C) 2019 Storj Labs, Inc. +// See LICENSE for copying information. + +// +build !drpc + +package rpcstatus + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// StatusCode is the type of status codes for grpc. +type StatusCode = codes.Code + +// These constants are all the rpc error codes. +const ( + OK = codes.OK + Canceled = codes.Canceled + Unknown = codes.Unknown + InvalidArgument = codes.InvalidArgument + DeadlineExceeded = codes.DeadlineExceeded + NotFound = codes.NotFound + AlreadyExists = codes.AlreadyExists + PermissionDenied = codes.PermissionDenied + ResourceExhausted = codes.ResourceExhausted + FailedPrecondition = codes.FailedPrecondition + Aborted = codes.Aborted + OutOfRange = codes.OutOfRange + Unimplemented = codes.Unimplemented + Internal = codes.Internal + Unavailable = codes.Unavailable + DataLoss = codes.DataLoss + Unauthenticated = codes.Unauthenticated +) + +// Code returns the status code associated with the error. +func Code(err error) StatusCode { + return status.Code(err) +} + +// Error wraps the message with a status code into an error. +func Error(code StatusCode, msg string) error { + return status.Error(code, msg) +} + +// Errorf : Error :: fmt.Sprintf : fmt.Sprint +func Errorf(code StatusCode, format string, a ...interface{}) error { + return status.Errorf(code, format, a...) +} diff --git a/pkg/server/config.go b/pkg/server/config.go index 83e6b4ec8..72de3f6d7 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -32,12 +32,12 @@ func (sc Config) Run(ctx context.Context, log *zap.Logger, identity *identity.Fu return Error.New("revDB cannot be nil in call to Run") } - opts, err := tlsopts.NewOptions(identity, sc.Config, revDB) + tlsOptions, err := tlsopts.NewOptions(identity, sc.Config, revDB) if err != nil { return err } - server, err := New(log.Named("server"), opts, sc.Address, sc.PrivateAddress, interceptor, services...) + server, err := New(log.Named("server"), tlsOptions, sc.Address, sc.PrivateAddress, interceptor, services...) if err != nil { return err } diff --git a/pkg/server/interceptors.go b/pkg/server/interceptors.go index b174eded7..d970afc48 100644 --- a/pkg/server/interceptors.go +++ b/pkg/server/interceptors.go @@ -12,10 +12,10 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/peer" "google.golang.org/grpc/status" "storj.io/storj/pkg/identity" + "storj.io/storj/pkg/rpc/rpcpeer" "storj.io/storj/storage" ) @@ -73,7 +73,7 @@ func prepareRequestLog(ctx context.Context, req, server interface{}, methodName PeerAddress: "", Msg: req, } - if peer, ok := peer.FromContext(ctx); ok { + if peer, err := rpcpeer.FromContext(ctx); err == nil { reqLog.PeerAddress = peer.Addr.String() if peerIdentity, err := identity.PeerIdentityFromPeer(peer); err == nil { reqLog.PeerNodeID = peerIdentity.ID.String() diff --git a/pkg/server/server.go b/pkg/server/server.go index 65411efba..0974c653e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -5,6 +5,7 @@ package server import ( "context" + "crypto/tls" "net" "sync" @@ -41,11 +42,11 @@ type private struct { // Server represents a bundle of services defined by a specific ID. // Examples of servers are the satellite, the storagenode, and the uplink. type Server struct { - log *zap.Logger - public public - private private - next []Service - identity *identity.FullIdentity + log *zap.Logger + public public + private private + next []Service + tlsOptions *tlsopts.Options mu sync.Mutex wg sync.WaitGroup @@ -55,12 +56,12 @@ type Server struct { // New creates a Server out of an Identity, a net.Listener, // a UnaryServerInterceptor, and a set of services. -func New(log *zap.Logger, opts *tlsopts.Options, publicAddr, privateAddr string, interceptor grpc.UnaryServerInterceptor, services ...Service) (*Server, error) { +func New(log *zap.Logger, tlsOptions *tlsopts.Options, publicAddr, privateAddr string, interceptor grpc.UnaryServerInterceptor, services ...Service) (*Server, error) { server := &Server{ - log: log, - next: services, - identity: opts.Ident, - done: make(chan struct{}), + log: log, + next: services, + tlsOptions: tlsOptions, + done: make(chan struct{}), } unaryInterceptor := server.logOnErrorUnaryInterceptor @@ -78,7 +79,7 @@ func New(log *zap.Logger, opts *tlsopts.Options, publicAddr, privateAddr string, grpc: grpc.NewServer( grpc.StreamInterceptor(server.logOnErrorStreamInterceptor), grpc.UnaryInterceptor(unaryInterceptor), - opts.ServerOption(), + tlsOptions.ServerOption(), ), } @@ -96,7 +97,7 @@ func New(log *zap.Logger, opts *tlsopts.Options, publicAddr, privateAddr string, } // Identity returns the server's identity -func (p *Server) Identity() *identity.FullIdentity { return p.identity } +func (p *Server) Identity() *identity.FullIdentity { return p.tlsOptions.Ident } // Addr returns the server's public listener address func (p *Server) Addr() net.Addr { return p.public.listener.Addr() } @@ -167,7 +168,7 @@ func (p *Server) Run(ctx context.Context) (err error) { const drpcHeader = "DRPC!!!1" publicMux := listenmux.New(p.public.listener, len(drpcHeader)) - publicDRPCListener := publicMux.Route(drpcHeader) + publicDRPCListener := tls.NewListener(publicMux.Route(drpcHeader), p.tlsOptions.ServerTLSConfig()) privateMux := listenmux.New(p.private.listener, len(drpcHeader)) privateDRPCListener := privateMux.Route(drpcHeader) diff --git a/pkg/transport/common.go b/pkg/transport/common.go deleted file mode 100644 index 844a3785b..000000000 --- a/pkg/transport/common.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package transport - -import ( - "time" - - "github.com/zeebo/errs" - monkit "gopkg.in/spacemonkeygo/monkit.v2" -) - -var ( - mon = monkit.Package() - //Error is the errs class of standard Transport Client errors - Error = errs.Class("transport error") -) - -const ( - // defaultTransportDialTimeout is the default time to wait for a connection to be established. - defaultTransportDialTimeout = 20 * time.Second - - // defaultTransportRequestTimeout is the default time to wait for a response. - defaultTransportRequestTimeout = 10 * time.Minute -) diff --git a/pkg/transport/fetchidentity.go b/pkg/transport/fetchidentity.go deleted file mode 100644 index 512a2d452..000000000 --- a/pkg/transport/fetchidentity.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package transport - -import ( - "context" - "net" - "sync" - - "github.com/zeebo/errs" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - - "storj.io/storj/pkg/identity" - "storj.io/storj/pkg/pb" - "storj.io/storj/pkg/peertls" -) - -// handshakeCapture implements a credentials.TransportCredentials for capturing handshake information. -type handshakeCapture struct { - credentials.TransportCredentials - - mu sync.Mutex - authInfo credentials.AuthInfo -} - -// ClientHandshake does the authentication handshake specified by the corresponding -// authentication protocol on conn for clients. It returns the authenticated -// connection and the corresponding auth information about the connection. -func (capture *handshakeCapture) ClientHandshake(ctx context.Context, s string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { - conn, auth, err := capture.TransportCredentials.ClientHandshake(ctx, s, conn) - if err == nil { - capture.mu.Lock() - capture.authInfo = auth - capture.mu.Unlock() - } - return conn, auth, err -} - -// ServerHandshake does the authentication handshake for servers. It returns -// the authenticated connection and the corresponding auth information about -// the connection. -func (capture *handshakeCapture) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { - conn, auth, err := capture.TransportCredentials.ServerHandshake(conn) - if err == nil { - capture.mu.Lock() - capture.authInfo = auth - capture.mu.Unlock() - } - return conn, auth, err -} - -// FetchPeerIdentity dials the node and fetches the identity -func (transport *Transport) FetchPeerIdentity(ctx context.Context, node *pb.Node, opts ...grpc.DialOption) (_ *identity.PeerIdentity, err error) { - defer mon.Task()(&ctx, "node: "+node.Id.String()[0:8])(&err) - - if node.Address == nil || node.Address.Address == "" { - return nil, Error.New("no address") - } - tlsConfig := transport.tlsOpts.ClientTLSConfig(node.Id) - - capture := &handshakeCapture{ - TransportCredentials: credentials.NewTLS(tlsConfig), - } - - options := append([]grpc.DialOption{ - grpc.WithTransportCredentials(capture), - grpc.WithBlock(), - grpc.FailOnNonTempDialError(true), - grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) - if err != nil { - return nil, err - } - return &timeoutConn{conn: conn, timeout: transport.timeouts.Request}, nil - }), - }, opts...) - - timedCtx, cancel := context.WithTimeout(ctx, transport.timeouts.Dial) - defer cancel() - - conn, err := grpc.DialContext(timedCtx, node.GetAddress().Address, options...) - if err != nil { - if err == context.Canceled { - return nil, err - } - transport.AlertFail(timedCtx, node, err) - return nil, Error.Wrap(err) - } - defer func() { - err = errs.Combine(err, conn.Close()) - }() - transport.AlertSuccess(timedCtx, node) - - capture.mu.Lock() - authinfo := capture.authInfo - capture.mu.Unlock() - - switch info := authinfo.(type) { - case credentials.TLSInfo: - chain := info.State.PeerCertificates - if len(chain)-1 < peertls.CAIndex { - return nil, Error.New("invalid certificate chain") - } - - pi, err := identity.PeerIdentityFromChain(chain) - if err != nil { - return nil, err - } - - return pi, nil - default: - return nil, Error.New("unknown capture info %T", authinfo) - } -} diff --git a/pkg/transport/insecure.go b/pkg/transport/insecure.go deleted file mode 100644 index 52df43cd0..000000000 --- a/pkg/transport/insecure.go +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package transport - -import ( - "context" - - "google.golang.org/grpc" -) - -// DialAddressInsecure returns an insecure grpc connection without tls to a node. -// -// Use this method for communication with localhost. For example, with the inspector or debugging services. -// Otherwise in most cases DialNode should be used for communicating with nodes since it is secure. -func DialAddressInsecure(ctx context.Context, address string, opts ...grpc.DialOption) (conn *grpc.ClientConn, err error) { - defer mon.Task()(&ctx)(&err) - options := append([]grpc.DialOption{ - grpc.WithInsecure(), - grpc.WithBlock(), - grpc.FailOnNonTempDialError(true), - }, opts...) - - timedCtx, cf := context.WithTimeout(ctx, defaultTransportDialTimeout) - defer cf() - - conn, err = grpc.DialContext(timedCtx, address, options...) - if err == context.Canceled { - return nil, err - } - return conn, Error.Wrap(err) -} diff --git a/pkg/transport/slowtransport.go b/pkg/transport/slowtransport.go deleted file mode 100644 index a7538d2dc..000000000 --- a/pkg/transport/slowtransport.go +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package transport - -import ( - "context" - "net" - "time" - - "google.golang.org/grpc" - - "storj.io/storj/internal/memory" - "storj.io/storj/pkg/identity" - "storj.io/storj/pkg/pb" -) - -// SimulatedNetwork allows creating connections that try to simulated realistic network conditions. -type SimulatedNetwork struct { - DialLatency time.Duration - BytesPerSecond memory.Size -} - -// NewClient wraps an exiting client with the simulated network params. -func (network *SimulatedNetwork) NewClient(client Client) Client { - return &slowTransport{ - client: client, - network: network, - } -} - -// slowTransport is a slow version of transport -type slowTransport struct { - client Client - network *SimulatedNetwork -} - -// DialNode dials a node with latency -func (client *slowTransport) DialNode(ctx context.Context, node *pb.Node, opts ...grpc.DialOption) (_ *grpc.ClientConn, err error) { - defer mon.Task()(&ctx)(&err) - return client.client.DialNode(ctx, node, append(client.network.DialOptions(), opts...)...) -} - -// DialAddress dials an address with latency -func (client *slowTransport) DialAddress(ctx context.Context, address string, opts ...grpc.DialOption) (_ *grpc.ClientConn, err error) { - defer mon.Task()(&ctx)(&err) - return client.client.DialAddress(ctx, address, append(client.network.DialOptions(), opts...)...) -} - -// FetchPeerIdentity dials the node and fetches the identity. -func (client *slowTransport) FetchPeerIdentity(ctx context.Context, node *pb.Node, opts ...grpc.DialOption) (_ *identity.PeerIdentity, err error) { - defer mon.Task()(&ctx)(&err) - return client.client.FetchPeerIdentity(ctx, node, append(client.network.DialOptions(), opts...)...) -} - -// Identity for slowTransport -func (client *slowTransport) Identity() *identity.FullIdentity { - return client.client.Identity() -} - -// WithObservers calls WithObservers for slowTransport -func (client *slowTransport) WithObservers(obs ...Observer) Client { - return &slowTransport{client.client.WithObservers(obs...), client.network} -} - -// AlertSuccess implements the transport.Client interface -func (client *slowTransport) AlertSuccess(ctx context.Context, node *pb.Node) { - defer mon.Task()(&ctx)(nil) - client.client.AlertSuccess(ctx, node) -} - -// AlertFail implements the transport.Client interface -func (client *slowTransport) AlertFail(ctx context.Context, node *pb.Node, err error) { - defer mon.Task()(&ctx)(nil) - client.client.AlertFail(ctx, node, err) -} - -// DialOptions returns options such that it will use simulated network parameters -func (network *SimulatedNetwork) DialOptions() []grpc.DialOption { - return []grpc.DialOption{grpc.WithContextDialer(network.GRPCDialContext)} -} - -// GRPCDialContext implements DialContext that is suitable for `grpc.WithContextDialer` -func (network *SimulatedNetwork) GRPCDialContext(ctx context.Context, address string) (_ net.Conn, err error) { - defer mon.Task()(&ctx)(&err) - timer := time.NewTimer(network.DialLatency) - defer timer.Stop() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-timer.C: - } - - conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", address) - if err != nil { - return conn, err - } - - if network.BytesPerSecond == 0 { - return conn, err - } - - return &simulatedConn{network, conn}, nil -} - -// simulatedConn implements slow reading and writing to the connection -// -// This does not handle read deadline and write deadline properly. -type simulatedConn struct { - network *SimulatedNetwork - net.Conn -} - -// delay sleeps specified amount of time -func (conn *simulatedConn) delay(actualWait time.Duration, bytes int) { - expectedWait := time.Duration(bytes * int(time.Second) / conn.network.BytesPerSecond.Int()) - if actualWait < expectedWait { - time.Sleep(expectedWait - actualWait) - } -} - -// Read reads data from the connection. -func (conn *simulatedConn) Read(b []byte) (n int, err error) { - start := time.Now() - n, err = conn.Conn.Read(b) - if err == context.Canceled { - return n, err - } - conn.delay(time.Since(start), n) - return n, err -} - -// Write writes data to the connection. -func (conn *simulatedConn) Write(b []byte) (n int, err error) { - start := time.Now() - n, err = conn.Conn.Write(b) - if err == context.Canceled { - return n, err - } - conn.delay(time.Since(start), n) - return n, err -} diff --git a/pkg/transport/timeout.go b/pkg/transport/timeout.go deleted file mode 100644 index d19ace8e6..000000000 --- a/pkg/transport/timeout.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package transport - -import ( - "net" - "time" -) - -type timeoutConn struct { - conn net.Conn - timeout time.Duration -} - -func (tc *timeoutConn) Read(b []byte) (n int, err error) { - // deadline needs to be set before each read operation - err = tc.SetReadDeadline(time.Now().Add(tc.timeout)) - if err != nil { - return 0, err - } - return tc.conn.Read(b) -} - -func (tc *timeoutConn) Write(b []byte) (n int, err error) { - // deadline needs to be set before each write operation - err = tc.SetWriteDeadline(time.Now().Add(tc.timeout)) - if err != nil { - return 0, err - } - return tc.conn.Write(b) -} - -func (tc *timeoutConn) Close() error { - return tc.conn.Close() -} - -func (tc *timeoutConn) LocalAddr() net.Addr { - return tc.conn.LocalAddr() -} - -func (tc *timeoutConn) RemoteAddr() net.Addr { - return tc.conn.RemoteAddr() -} - -func (tc *timeoutConn) SetDeadline(t time.Time) error { - return tc.conn.SetDeadline(t) -} - -func (tc *timeoutConn) SetReadDeadline(t time.Time) error { - return tc.conn.SetReadDeadline(t) -} - -func (tc *timeoutConn) SetWriteDeadline(t time.Time) error { - return tc.conn.SetWriteDeadline(t) -} diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go deleted file mode 100644 index 8db327611..000000000 --- a/pkg/transport/transport.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright (C) 2019 Storj Labs, Inc. -// See LICENSE for copying information. - -package transport - -import ( - "context" - "net" - "time" - - "google.golang.org/grpc" - - "storj.io/storj/pkg/identity" - "storj.io/storj/pkg/pb" - "storj.io/storj/pkg/peertls/tlsopts" -) - -// Observer implements the ConnSuccess and ConnFailure methods -// for Discovery and other services to use -type Observer interface { - ConnSuccess(ctx context.Context, node *pb.Node) - ConnFailure(ctx context.Context, node *pb.Node, err error) -} - -// Client defines the interface to an transport client. -type Client interface { - DialNode(ctx context.Context, node *pb.Node, opts ...grpc.DialOption) (*grpc.ClientConn, error) - DialAddress(ctx context.Context, address string, opts ...grpc.DialOption) (*grpc.ClientConn, error) - FetchPeerIdentity(ctx context.Context, node *pb.Node, opts ...grpc.DialOption) (*identity.PeerIdentity, error) - Identity() *identity.FullIdentity - WithObservers(obs ...Observer) Client - AlertSuccess(ctx context.Context, node *pb.Node) - AlertFail(ctx context.Context, node *pb.Node, err error) -} - -// Timeouts contains all of the timeouts configurable for a transport -type Timeouts struct { - Request time.Duration - Dial time.Duration -} - -// Transport interface structure -type Transport struct { - tlsOpts *tlsopts.Options - observers []Observer - timeouts Timeouts -} - -// NewClient returns a transport client with a default timeout for requests -func NewClient(tlsOpts *tlsopts.Options, obs ...Observer) Client { - return NewClientWithTimeouts(tlsOpts, Timeouts{}, obs...) -} - -// NewClientWithTimeouts returns a transport client with a specified timeout for requests -func NewClientWithTimeouts(tlsOpts *tlsopts.Options, timeouts Timeouts, obs ...Observer) Client { - if timeouts.Request == 0 { - timeouts.Request = defaultTransportRequestTimeout - } - if timeouts.Dial == 0 { - timeouts.Dial = defaultTransportDialTimeout - } - - return &Transport{ - tlsOpts: tlsOpts, - timeouts: timeouts, - observers: obs, - } -} - -// DialNode returns a grpc connection with tls to a node. -// -// Use this method for communicating with nodes as it is more secure than -// DialAddress. The connection will be established successfully only if the -// target node has the private key for the requested node ID. -func (transport *Transport) DialNode(ctx context.Context, node *pb.Node, opts ...grpc.DialOption) (conn *grpc.ClientConn, err error) { - defer mon.Task()(&ctx, "node: "+node.Id.String()[0:8])(&err) - - if node.Address == nil || node.Address.Address == "" { - return nil, Error.New("no address") - } - dialOption, err := transport.tlsOpts.DialOption(node.Id) - if err != nil { - return nil, err - } - - options := append([]grpc.DialOption{ - dialOption, - grpc.WithBlock(), - grpc.FailOnNonTempDialError(true), - grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) - if err != nil { - return nil, err - } - return &timeoutConn{conn: conn, timeout: transport.timeouts.Request}, nil - }), - }, opts...) - - timedCtx, cancel := context.WithTimeout(ctx, transport.timeouts.Dial) - defer cancel() - - conn, err = grpc.DialContext(timedCtx, node.GetAddress().Address, options...) - if err != nil { - if err == context.Canceled { - return nil, err - } - transport.AlertFail(timedCtx, node, err) - return nil, Error.Wrap(err) - } - - transport.AlertSuccess(timedCtx, node) - - return conn, nil -} - -// DialAddress returns a grpc connection with tls to an IP address. -// -// Do not use this method unless having a good reason. In most cases DialNode -// should be used for communicating with nodes as it is more secure than -// DialAddress. -func (transport *Transport) DialAddress(ctx context.Context, address string, opts ...grpc.DialOption) (conn *grpc.ClientConn, err error) { - defer mon.Task()(&ctx)(&err) - - options := append([]grpc.DialOption{ - transport.tlsOpts.DialUnverifiedIDOption(), - grpc.WithBlock(), - grpc.FailOnNonTempDialError(true), - grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) - if err != nil { - return nil, err - } - return &timeoutConn{conn: conn, timeout: transport.timeouts.Request}, nil - }), - }, opts...) - - timedCtx, cancel := context.WithTimeout(ctx, transport.timeouts.Dial) - defer cancel() - - // TODO: this should also call alertFail or alertSuccess with the node id. We should be able - // to get gRPC to give us the node id after dialing? - conn, err = grpc.DialContext(timedCtx, address, options...) - if err == context.Canceled { - return nil, err - } - return conn, Error.Wrap(err) -} - -// Identity is a getter for the transport's identity -func (transport *Transport) Identity() *identity.FullIdentity { - return transport.tlsOpts.Ident -} - -// WithObservers returns a new transport including the listed observers. -func (transport *Transport) WithObservers(obs ...Observer) Client { - tr := &Transport{tlsOpts: transport.tlsOpts, timeouts: transport.timeouts} - tr.observers = append(tr.observers, transport.observers...) - tr.observers = append(tr.observers, obs...) - return tr -} - -// AlertFail alerts any subscribed observers of the failure 'err' for 'node' -func (transport *Transport) AlertFail(ctx context.Context, node *pb.Node, err error) { - defer mon.Task()(&ctx)(nil) - for _, o := range transport.observers { - o.ConnFailure(ctx, node, err) - } -} - -// AlertSuccess alerts any subscribed observers of success for 'node' -func (transport *Transport) AlertSuccess(ctx context.Context, node *pb.Node) { - defer mon.Task()(&ctx)(nil) - for _, o := range transport.observers { - o.ConnSuccess(ctx, node) - } -} - -// Timeouts returns the timeout values for dialing and requests. -func (transport *Transport) Timeouts() Timeouts { - return transport.timeouts -} diff --git a/satellite/accounting/projectusage_test.go b/satellite/accounting/projectusage_test.go index 7f8ad16a3..b89719cda 100644 --- a/satellite/accounting/projectusage_test.go +++ b/satellite/accounting/projectusage_test.go @@ -12,11 +12,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "storj.io/storj/internal/errs2" "storj.io/storj/internal/memory" "storj.io/storj/internal/testcontext" "storj.io/storj/internal/testplanet" "storj.io/storj/internal/testrand" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/storj" "storj.io/storj/satellite" "storj.io/storj/satellite/accounting" @@ -29,10 +31,10 @@ func TestProjectUsageStorage(t *testing.T) { name string expectedExceeded bool expectedResource string - expectedErrMsg string + expectedStatus rpcstatus.StatusCode }{ - {name: "doesn't exceed storage or bandwidth project limit", expectedExceeded: false, expectedErrMsg: ""}, - {name: "exceeds storage project limit", expectedExceeded: true, expectedResource: "storage", expectedErrMsg: "segment error: metainfo error: rpc error: code = ResourceExhausted desc = Exceeded Usage Limit; segment error: metainfo error: rpc error: code = ResourceExhausted desc = Exceeded Usage Limit"}, + {name: "doesn't exceed storage or bandwidth project limit", expectedExceeded: false, expectedStatus: 0}, + {name: "exceeds storage project limit", expectedExceeded: true, expectedResource: "storage", expectedStatus: rpcstatus.ResourceExhausted}, } testplanet.Run(t, testplanet.Config{ @@ -69,7 +71,7 @@ func TestProjectUsageStorage(t *testing.T) { // Execute test: check that the uplink gets an error when they have exceeded storage limits and try to upload a file actualErr := planet.Uplinks[0].Upload(ctx, planet.Satellites[0], "testbucket", "test/path", expectedData) if testCase.expectedResource == "storage" { - assert.EqualError(t, actualErr, testCase.expectedErrMsg) + require.True(t, errs2.IsRPC(actualErr, testCase.expectedStatus)) } else { require.NoError(t, actualErr) } @@ -83,10 +85,10 @@ func TestProjectUsageBandwidth(t *testing.T) { name string expectedExceeded bool expectedResource string - expectedErrMsg string + expectedStatus rpcstatus.StatusCode }{ - {name: "doesn't exceed storage or bandwidth project limit", expectedExceeded: false, expectedErrMsg: ""}, - {name: "exceeds bandwidth project limit", expectedExceeded: true, expectedResource: "bandwidth", expectedErrMsg: "segment error: metainfo error: rpc error: code = ResourceExhausted desc = Exceeded Usage Limit"}, + {name: "doesn't exceed storage or bandwidth project limit", expectedExceeded: false, expectedStatus: 0}, + {name: "exceeds bandwidth project limit", expectedExceeded: true, expectedResource: "bandwidth", expectedStatus: rpcstatus.ResourceExhausted}, } for _, tt := range cases { @@ -128,7 +130,7 @@ func TestProjectUsageBandwidth(t *testing.T) { // Execute test: check that the uplink gets an error when they have exceeded bandwidth limits and try to download a file _, actualErr := planet.Uplinks[0].Download(ctx, planet.Satellites[0], bucketName, filePath) if testCase.expectedResource == "bandwidth" { - assert.EqualError(t, actualErr, testCase.expectedErrMsg) + require.True(t, errs2.IsRPC(actualErr, testCase.expectedStatus)) } else { require.NoError(t, actualErr) } diff --git a/satellite/audit/reverify_test.go b/satellite/audit/reverify_test.go index 45d34e40d..f59b28887 100644 --- a/satellite/audit/reverify_test.go +++ b/satellite/audit/reverify_test.go @@ -16,8 +16,8 @@ import ( "storj.io/storj/internal/testrand" "storj.io/storj/pkg/peertls/tlsopts" "storj.io/storj/pkg/pkcrypto" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/audit" "storj.io/storj/uplink" ) @@ -338,20 +338,14 @@ func TestReverifyOfflineDialTimeout(t *testing.T) { randomIndex, err := audit.GetRandomStripe(ctx, pointer) require.NoError(t, err) - network := &transport.SimulatedNetwork{ - DialLatency: 200 * time.Second, - BytesPerSecond: 1 * memory.KiB, - } - - tlsOpts, err := tlsopts.NewOptions(satellite.Identity, tlsopts.Config{}, nil) + tlsOptions, err := tlsopts.NewOptions(satellite.Identity, tlsopts.Config{}, nil) require.NoError(t, err) - newTransport := transport.NewClientWithTimeouts(tlsOpts, transport.Timeouts{ - Dial: 20 * time.Millisecond, - }) - - slowClient := network.NewClient(newTransport) - require.NotNil(t, slowClient) + dialer := rpc.NewDefaultDialer(tlsOptions) + dialer.RequestTimeout = 0 + dialer.DialTimeout = 20 * time.Millisecond + dialer.DialLatency = 200 * time.Second + dialer.TransferRate = 1 * memory.KB // This config value will create a very short timeframe allowed for receiving // data from storage nodes. This will cause context to cancel and start @@ -361,7 +355,7 @@ func TestReverifyOfflineDialTimeout(t *testing.T) { verifier := audit.NewVerifier( satellite.Log.Named("verifier"), satellite.Metainfo.Service, - slowClient, + dialer, satellite.Overlay.Service, satellite.DB.Containment(), satellite.Orders.Service, diff --git a/satellite/audit/verifier.go b/satellite/audit/verifier.go index 52c49bf50..014cfc2d3 100644 --- a/satellite/audit/verifier.go +++ b/satellite/audit/verifier.go @@ -13,7 +13,6 @@ import ( "github.com/vivint/infectious" "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc/codes" "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/internal/errs2" @@ -21,8 +20,9 @@ import ( "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/pkcrypto" + "storj.io/storj/pkg/rpc" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/metainfo" "storj.io/storj/satellite/orders" "storj.io/storj/satellite/overlay" @@ -56,7 +56,7 @@ type Verifier struct { metainfo *metainfo.Service orders *orders.Service auditor *identity.PeerIdentity - transport transport.Client + dialer rpc.Dialer overlay *overlay.Service containment Containment minBytesPerSecond memory.Size @@ -66,13 +66,13 @@ type Verifier struct { } // NewVerifier creates a Verifier -func NewVerifier(log *zap.Logger, metainfo *metainfo.Service, transport transport.Client, overlay *overlay.Service, containment Containment, orders *orders.Service, id *identity.FullIdentity, minBytesPerSecond memory.Size, minDownloadTimeout time.Duration) *Verifier { +func NewVerifier(log *zap.Logger, metainfo *metainfo.Service, dialer rpc.Dialer, overlay *overlay.Service, containment Containment, orders *orders.Service, id *identity.FullIdentity, minBytesPerSecond memory.Size, minDownloadTimeout time.Duration) *Verifier { return &Verifier{ log: log, metainfo: metainfo, orders: orders, auditor: id.PeerIdentity(), - transport: transport, + dialer: dialer, overlay: overlay, containment: containment, minBytesPerSecond: minBytesPerSecond, @@ -137,14 +137,14 @@ func (verifier *Verifier) Verify(ctx context.Context, path storj.Path, skip map[ sharesToAudit[pieceNum] = share continue } - if transport.Error.Has(share.Error) { + if rpc.Error.Has(share.Error) { if errs.Is(share.Error, context.DeadlineExceeded) { // dial timeout offlineNodes = append(offlineNodes, share.NodeID) verifier.log.Debug("Verify: dial timeout (offline)", zap.String("Segment Path", path), zap.Stringer("Node ID", share.NodeID), zap.Error(share.Error)) continue } - if errs2.IsRPC(share.Error, codes.Unknown) { + if errs2.IsRPC(share.Error, rpcstatus.Unknown) { // dial failed -- offline node offlineNodes = append(offlineNodes, share.NodeID) verifier.log.Debug("Verify: dial failed (offline)", zap.String("Segment Path", path), zap.Stringer("Node ID", share.NodeID), zap.Error(share.Error)) @@ -155,14 +155,14 @@ func (verifier *Verifier) Verify(ctx context.Context, path storj.Path, skip map[ verifier.log.Debug("Verify: unknown transport error (contained)", zap.String("Segment Path", path), zap.Stringer("Node ID", share.NodeID), zap.Error(share.Error)) } - if errs2.IsRPC(share.Error, codes.NotFound) { + if errs2.IsRPC(share.Error, rpcstatus.NotFound) { // missing share failedNodes = append(failedNodes, share.NodeID) verifier.log.Debug("Verify: piece not found (audit failed)", zap.String("Segment Path", path), zap.Stringer("Node ID", share.NodeID), zap.Error(share.Error)) continue } - if errs2.IsRPC(share.Error, codes.DeadlineExceeded) { + if errs2.IsRPC(share.Error, rpcstatus.DeadlineExceeded) { // dial successful, but download timed out containedNodes[pieceNum] = share.NodeID verifier.log.Debug("Verify: download timeout (contained)", zap.String("Segment Path", path), zap.Stringer("Node ID", share.NodeID), zap.Error(share.Error)) @@ -401,14 +401,14 @@ func (verifier *Verifier) Reverify(ctx context.Context, path storj.Path) (report // analyze the error from GetShare if err != nil { - if transport.Error.Has(err) { + if rpc.Error.Has(err) { if errs.Is(err, context.DeadlineExceeded) { // dial timeout ch <- result{nodeID: pending.NodeID, status: offline} verifier.log.Debug("Reverify: dial timeout (offline)", zap.String("Segment Path", path), zap.Stringer("Node ID", pending.NodeID), zap.Error(err)) return } - if errs2.IsRPC(err, codes.Unknown) { + if errs2.IsRPC(err, rpcstatus.Unknown) { // dial failed -- offline node verifier.log.Debug("Reverify: dial failed (offline)", zap.String("Segment Path", path), zap.Stringer("Node ID", pending.NodeID), zap.Error(err)) ch <- result{nodeID: pending.NodeID, status: offline} @@ -419,7 +419,7 @@ func (verifier *Verifier) Reverify(ctx context.Context, path storj.Path) (report verifier.log.Debug("Reverify: unknown transport error (contained)", zap.String("Segment Path", path), zap.Stringer("Node ID", pending.NodeID), zap.Error(err)) return } - if errs2.IsRPC(err, codes.NotFound) { + if errs2.IsRPC(err, rpcstatus.NotFound) { // Get the original segment pointer in the metainfo oldPtr, err := verifier.checkIfSegmentAltered(ctx, pending.Path, pendingPointer) if err != nil { @@ -437,7 +437,7 @@ func (verifier *Verifier) Reverify(ctx context.Context, path storj.Path) (report verifier.log.Debug("Reverify: piece not found (audit failed)", zap.String("Segment Path", path), zap.Stringer("Node ID", pending.NodeID), zap.Error(err)) return } - if errs2.IsRPC(err, codes.DeadlineExceeded) { + if errs2.IsRPC(err, rpcstatus.DeadlineExceeded) { // dial successful, but download timed out ch <- result{nodeID: pending.NodeID, status: contained, pendingAudit: pending} verifier.log.Debug("Reverify: download timeout (contained)", zap.String("Segment Path", path), zap.Stringer("Node ID", pending.NodeID), zap.Error(err)) @@ -526,7 +526,7 @@ func (verifier *Verifier) GetShare(ctx context.Context, limit *pb.AddressedOrder log := verifier.log.Named(storageNodeID.String()) target := &pb.Node{Id: storageNodeID, Address: limit.GetStorageNodeAddress()} - ps, err := piecestore.Dial(timedCtx, verifier.transport, target, log, piecestore.DefaultConfig) + ps, err := piecestore.Dial(timedCtx, verifier.dialer, target, log, piecestore.DefaultConfig) if err != nil { return Share{}, Error.Wrap(err) } diff --git a/satellite/audit/verifier_test.go b/satellite/audit/verifier_test.go index b81eea15d..b1f528023 100644 --- a/satellite/audit/verifier_test.go +++ b/satellite/audit/verifier_test.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/require" "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc/codes" "storj.io/storj/internal/errs2" "storj.io/storj/internal/memory" @@ -21,8 +20,9 @@ import ( "storj.io/storj/internal/testplanet" "storj.io/storj/internal/testrand" "storj.io/storj/pkg/peertls/tlsopts" + "storj.io/storj/pkg/rpc" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/audit" "storj.io/storj/storagenode" ) @@ -77,7 +77,7 @@ func TestDownloadSharesHappyPath(t *testing.T) { // TestDownloadSharesOfflineNode checks that the Share.Error field of the // shares returned by the DownloadShares method for offline nodes contain an // error that: -// - has the transport.Error class +// - has the rpc.Error class // - is not a context.DeadlineExceeded error // - is not an RPC error // @@ -128,9 +128,9 @@ func TestDownloadSharesOfflineNode(t *testing.T) { for _, share := range shares { if share.NodeID == stoppedNodeID { - assert.True(t, transport.Error.Has(share.Error), "unexpected error: %+v", share.Error) + assert.True(t, rpc.Error.Has(share.Error), "unexpected error: %+v", share.Error) assert.False(t, errs.Is(share.Error, context.DeadlineExceeded), "unexpected error: %+v", share.Error) - assert.True(t, errs2.IsRPC(share.Error, codes.Unknown), "unexpected error: %+v", share.Error) + assert.True(t, errs2.IsRPC(share.Error, rpcstatus.Unknown), "unexpected error: %+v", share.Error) } else { assert.NoError(t, share.Error) } @@ -187,7 +187,7 @@ func TestDownloadSharesMissingPiece(t *testing.T) { require.NoError(t, err) for _, share := range shares { - assert.True(t, errs2.IsRPC(share.Error, codes.NotFound), "unexpected error: %+v", share.Error) + assert.True(t, errs2.IsRPC(share.Error, rpcstatus.NotFound), "unexpected error: %+v", share.Error) } }) } @@ -195,7 +195,7 @@ func TestDownloadSharesMissingPiece(t *testing.T) { // TestDownloadSharesDialTimeout checks that the Share.Error field of the // shares returned by the DownloadShares method for nodes that time out on // dialing contain an error that: -// - has the transport.Error class +// - has the rpc.Error class // - is a context.DeadlineExceeded error // - is not an RPC error // @@ -232,20 +232,13 @@ func TestDownloadSharesDialTimeout(t *testing.T) { bucketID := []byte(storj.JoinPaths(projects[0].ID.String(), "testbucket")) - network := &transport.SimulatedNetwork{ - DialLatency: 200 * time.Second, - BytesPerSecond: 1 * memory.KiB, - } - - tlsOpts, err := tlsopts.NewOptions(satellite.Identity, tlsopts.Config{}, nil) + tlsOptions, err := tlsopts.NewOptions(satellite.Identity, tlsopts.Config{}, nil) require.NoError(t, err) - newTransport := transport.NewClientWithTimeouts(tlsOpts, transport.Timeouts{ - Dial: 20 * time.Millisecond, - }) - - slowClient := network.NewClient(newTransport) - require.NotNil(t, slowClient) + dialer := rpc.NewDefaultDialer(tlsOptions) + dialer.DialTimeout = 20 * time.Millisecond + dialer.DialLatency = 200 * time.Second + dialer.TransferRate = 1 * memory.KB // This config value will create a very short timeframe allowed for receiving // data from storage nodes. This will cause context to cancel with timeout. @@ -254,7 +247,7 @@ func TestDownloadSharesDialTimeout(t *testing.T) { verifier := audit.NewVerifier( satellite.Log.Named("verifier"), satellite.Metainfo.Service, - slowClient, + dialer, satellite.Overlay.Service, satellite.DB.Containment(), satellite.Orders.Service, @@ -270,7 +263,7 @@ func TestDownloadSharesDialTimeout(t *testing.T) { require.NoError(t, err) for _, share := range shares { - assert.True(t, transport.Error.Has(share.Error), "unexpected error: %+v", share.Error) + assert.True(t, rpc.Error.Has(share.Error), "unexpected error: %+v", share.Error) assert.True(t, errs.Is(share.Error, context.DeadlineExceeded), "unexpected error: %+v", share.Error) } }) @@ -280,7 +273,7 @@ func TestDownloadSharesDialTimeout(t *testing.T) { // shares returned by the DownloadShares method for nodes that are successfully // dialed, but time out during the download of the share contain an error that: // - is an RPC error with code DeadlineExceeded -// - does not have the transport.Error class +// - does not have the rpc.Error class // // If this test fails, this most probably means we made a backward-incompatible // change that affects the audit service. @@ -329,7 +322,7 @@ func TestDownloadSharesDownloadTimeout(t *testing.T) { verifier := audit.NewVerifier( satellite.Log.Named("verifier"), satellite.Metainfo.Service, - satellite.Transport, + satellite.Dialer, satellite.Overlay.Service, satellite.DB.Containment(), satellite.Orders.Service, @@ -350,8 +343,8 @@ func TestDownloadSharesDownloadTimeout(t *testing.T) { require.Len(t, shares, 1) share := shares[0] - assert.True(t, errs2.IsRPC(share.Error, codes.DeadlineExceeded), "unexpected error: %+v", share.Error) - assert.False(t, transport.Error.Has(share.Error), "unexpected error: %+v", share.Error) + assert.True(t, errs2.IsRPC(share.Error, rpcstatus.DeadlineExceeded), "unexpected error: %+v", share.Error) + assert.False(t, rpc.Error.Has(share.Error), "unexpected error: %+v", share.Error) }) } @@ -491,20 +484,13 @@ func TestVerifierDialTimeout(t *testing.T) { pointer, err := satellite.Metainfo.Service.Get(ctx, path) require.NoError(t, err) - network := &transport.SimulatedNetwork{ - DialLatency: 200 * time.Second, - BytesPerSecond: 1 * memory.KiB, - } - - tlsOpts, err := tlsopts.NewOptions(satellite.Identity, tlsopts.Config{}, nil) + tlsOptions, err := tlsopts.NewOptions(satellite.Identity, tlsopts.Config{}, nil) require.NoError(t, err) - newTransport := transport.NewClientWithTimeouts(tlsOpts, transport.Timeouts{ - Dial: 20 * time.Millisecond, - }) - - slowClient := network.NewClient(newTransport) - require.NotNil(t, slowClient) + dialer := rpc.NewDefaultDialer(tlsOptions) + dialer.DialTimeout = 20 * time.Millisecond + dialer.DialLatency = 200 * time.Second + dialer.TransferRate = 1 * memory.KB // This config value will create a very short timeframe allowed for receiving // data from storage nodes. This will cause context to cancel with timeout. @@ -513,7 +499,7 @@ func TestVerifierDialTimeout(t *testing.T) { verifier := audit.NewVerifier( satellite.Log.Named("verifier"), satellite.Metainfo.Service, - slowClient, + dialer, satellite.Overlay.Service, satellite.DB.Containment(), satellite.Orders.Service, diff --git a/satellite/contact/client.go b/satellite/contact/client.go index 1e360f941..dcec3752b 100644 --- a/satellite/contact/client.go +++ b/satellite/contact/client.go @@ -6,46 +6,35 @@ package contact import ( "context" - "google.golang.org/grpc" - "storj.io/storj/pkg/pb" - "storj.io/storj/pkg/peertls/tlsopts" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" ) type client struct { - conn *grpc.ClientConn - client pb.ContactClient + conn *rpc.Conn + client rpc.ContactClient } // newClient dials the target contact endpoint -func newClient(ctx context.Context, transport transport.Client, targetAddress string, peerIDFromContext storj.NodeID) (*client, error) { - opts, err := tlsopts.NewOptions(transport.Identity(), tlsopts.Config{PeerIDVersions: "latest"}, nil) +func newClient(ctx context.Context, dialer rpc.Dialer, address string, id storj.NodeID) (*client, error) { + conn, err := dialer.DialAddressID(ctx, address, id) if err != nil { - return nil, Error.Wrap(err) - } - dialOption, err := opts.DialOption(peerIDFromContext) - if err != nil { - return nil, Error.Wrap(err) - } - conn, err := transport.DialAddress(ctx, targetAddress, dialOption) - if err != nil { - return nil, Error.Wrap(err) + return nil, err } return &client{ conn: conn, - client: pb.NewContactClient(conn), + client: conn.ContactClient(), }, nil } // pingNode pings a node -func (client *client) pingNode(ctx context.Context, req *pb.ContactPingRequest, opt grpc.CallOption) (*pb.ContactPingResponse, error) { - return client.client.PingNode(ctx, req, opt) +func (client *client) pingNode(ctx context.Context, req *pb.ContactPingRequest) (*pb.ContactPingResponse, error) { + return client.client.PingNode(ctx, req) } -// close closes the connection -func (client *client) close() error { +// Close closes the connection +func (client *client) Close() error { return client.conn.Close() } diff --git a/satellite/contact/contact_test.go b/satellite/contact/contact_test.go index 122ba4622..64ad14fb5 100644 --- a/satellite/contact/contact_test.go +++ b/satellite/contact/contact_test.go @@ -10,12 +10,11 @@ import ( "testing" "github.com/stretchr/testify/require" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/peer" "storj.io/storj/internal/testcontext" "storj.io/storj/internal/testplanet" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcpeer" ) func TestSatelliteContactEndpoint(t *testing.T) { @@ -25,18 +24,16 @@ func TestSatelliteContactEndpoint(t *testing.T) { nodeDossier := planet.StorageNodes[0].Local() ident := planet.StorageNodes[0].Identity - grpcPeer := peer.Peer{ + peer := rpcpeer.Peer{ Addr: &net.TCPAddr{ IP: net.ParseIP(nodeDossier.Address.GetAddress()), Port: 5, }, - AuthInfo: credentials.TLSInfo{ - State: tls.ConnectionState{ - PeerCertificates: []*x509.Certificate{ident.Leaf, ident.CA}, - }, + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ident.Leaf, ident.CA}, }, } - peerCtx := peer.NewContext(ctx, &grpcPeer) + peerCtx := rpcpeer.NewContext(ctx, &peer) resp, err := planet.Satellites[0].Contact.Endpoint.CheckIn(peerCtx, &pb.CheckInRequest{ Address: nodeDossier.Address.GetAddress(), Version: &nodeDossier.Version, diff --git a/satellite/contact/endpoint.go b/satellite/contact/endpoint.go index 46899b567..c8617b265 100644 --- a/satellite/contact/endpoint.go +++ b/satellite/contact/endpoint.go @@ -10,13 +10,10 @@ import ( "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/peer" - "google.golang.org/grpc/status" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/storj" "storj.io/storj/satellite/overlay" ) @@ -43,25 +40,25 @@ func NewEndpoint(log *zap.Logger, service *Service) *Endpoint { func (endpoint *Endpoint) CheckIn(ctx context.Context, req *pb.CheckInRequest) (_ *pb.CheckInResponse, err error) { defer mon.Task()(&ctx)(&err) - peerID, err := peerIDFromContext(ctx) + peerID, err := identity.PeerIdentityFromContext(ctx) if err != nil { - return nil, status.Error(codes.Internal, Error.Wrap(err).Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, Error.Wrap(err).Error()) } nodeID := peerID.ID err = endpoint.service.peerIDs.Set(ctx, nodeID, peerID) if err != nil { - return nil, status.Error(codes.Internal, Error.Wrap(err).Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, Error.Wrap(err).Error()) } lastIP, err := overlay.GetNetwork(ctx, req.Address) if err != nil { - return nil, status.Error(codes.Internal, Error.Wrap(err).Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, Error.Wrap(err).Error()) } pingNodeSuccess, pingErrorMessage, err := endpoint.pingBack(ctx, req, nodeID) if err != nil { - return nil, status.Error(codes.Internal, Error.Wrap(err).Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, Error.Wrap(err).Error()) } nodeInfo := overlay.NodeCheckInInfo{ NodeID: peerID.ID, @@ -77,7 +74,7 @@ func (endpoint *Endpoint) CheckIn(ctx context.Context, req *pb.CheckInRequest) ( } err = endpoint.service.overlay.UpdateCheckIn(ctx, nodeInfo) if err != nil { - return nil, status.Error(codes.Internal, Error.Wrap(err).Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, Error.Wrap(err).Error()) } endpoint.log.Debug("checking in", zap.String("node addr", req.Address), zap.Bool("ping node succes", pingNodeSuccess)) @@ -88,11 +85,7 @@ func (endpoint *Endpoint) CheckIn(ctx context.Context, req *pb.CheckInRequest) ( } func (endpoint *Endpoint) pingBack(ctx context.Context, req *pb.CheckInRequest, peerID storj.NodeID) (bool, string, error) { - client, err := newClient(ctx, - endpoint.service.transport, - req.Address, - peerID, - ) + client, err := newClient(ctx, endpoint.service.dialer, req.Address, peerID) if err != nil { // if this is a network error, then return the error otherwise just report internal error _, ok := err.(net.Error) @@ -102,15 +95,12 @@ func (endpoint *Endpoint) pingBack(ctx context.Context, req *pb.CheckInRequest, endpoint.log.Info("pingBack internal error", zap.String("error", err.Error())) return false, "", Error.New("couldn't connect to client at addr: %s due to internal error.", req.Address) } - defer func() { - err = errs.Combine(err, client.close()) - }() + defer func() { err = errs.Combine(err, client.Close()) }() pingNodeSuccess := true var pingErrorMessage string - p := &peer.Peer{} - _, err = client.pingNode(ctx, &pb.ContactPingRequest{}, grpc.Peer(p)) + _, err = client.pingNode(ctx, &pb.ContactPingRequest{}) if err != nil { pingNodeSuccess = false pingErrorMessage = "erroring while trying to pingNode due to internal error" @@ -122,15 +112,3 @@ func (endpoint *Endpoint) pingBack(ctx context.Context, req *pb.CheckInRequest, return pingNodeSuccess, pingErrorMessage, err } - -func peerIDFromContext(ctx context.Context) (*identity.PeerIdentity, error) { - p, ok := peer.FromContext(ctx) - if !ok { - return nil, Error.New("unable to get grpc peer from context") - } - peerIdentity, err := identity.PeerIdentityFromPeer(p) - if err != nil { - return nil, err - } - return peerIdentity, nil -} diff --git a/satellite/contact/service.go b/satellite/contact/service.go index 0d9f81fac..e17bb8908 100644 --- a/satellite/contact/service.go +++ b/satellite/contact/service.go @@ -9,11 +9,10 @@ import ( "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc" "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/pkg/pb" - "storj.io/storj/pkg/transport" + "storj.io/storj/pkg/rpc" "storj.io/storj/satellite/overlay" ) @@ -29,8 +28,8 @@ type Config struct { // Conn represents a connection type Conn struct { - conn *grpc.ClientConn - client pb.NodesClient + conn *rpc.Conn + client rpc.NodesClient } // Service is the contact service between storage nodes and satellites. @@ -44,19 +43,19 @@ type Service struct { mutex sync.Mutex self *overlay.NodeDossier - overlay *overlay.Service - peerIDs overlay.PeerIdentities - transport transport.Client + overlay *overlay.Service + peerIDs overlay.PeerIdentities + dialer rpc.Dialer } // NewService creates a new contact service. -func NewService(log *zap.Logger, self *overlay.NodeDossier, overlay *overlay.Service, peerIDs overlay.PeerIdentities, transport transport.Client) *Service { +func NewService(log *zap.Logger, self *overlay.NodeDossier, overlay *overlay.Service, peerIDs overlay.PeerIdentities, dialer rpc.Dialer) *Service { return &Service{ - log: log, - self: self, - overlay: overlay, - peerIDs: peerIDs, - transport: transport, + log: log, + self: self, + overlay: overlay, + peerIDs: peerIDs, + dialer: dialer, } } @@ -73,24 +72,33 @@ func (service *Service) FetchInfo(ctx context.Context, target pb.Node) (_ *pb.In if err != nil { return nil, err } + defer func() { err = errs.Combine(err, conn.Close()) }() resp, err := conn.client.RequestInfo(ctx, &pb.InfoRequest{}) + if err != nil { + return nil, err + } - return resp, errs.Combine(err, conn.close()) + return resp, nil } // dialNode dials the specified node. func (service *Service) dialNode(ctx context.Context, target pb.Node) (_ *Conn, err error) { defer mon.Task()(&ctx)(&err) - grpcconn, err := service.transport.DialNode(ctx, &target) + + conn, err := service.dialer.DialNode(ctx, &target) + if err != nil { + return nil, err + } + return &Conn{ - conn: grpcconn, - client: pb.NewNodesClient(grpcconn), + conn: conn, + client: conn.NodesClient(), }, err } -// close disconnects this connection. -func (conn *Conn) close() error { +// Close disconnects this connection. +func (conn *Conn) Close() error { return conn.conn.Close() } diff --git a/satellite/gc/service.go b/satellite/gc/service.go index fb7db98c4..a35b22fb7 100644 --- a/satellite/gc/service.go +++ b/satellite/gc/service.go @@ -14,8 +14,8 @@ import ( "storj.io/storj/internal/sync2" "storj.io/storj/pkg/bloomfilter" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/metainfo" "storj.io/storj/satellite/overlay" "storj.io/storj/uplink/piecestore" @@ -45,7 +45,7 @@ type Service struct { config Config Loop sync2.Cycle - transport transport.Client + dialer rpc.Dialer overlay overlay.DB metainfoLoop *metainfo.Loop } @@ -58,13 +58,13 @@ type RetainInfo struct { } // NewService creates a new instance of the gc service -func NewService(log *zap.Logger, config Config, transport transport.Client, overlay overlay.DB, loop *metainfo.Loop) *Service { +func NewService(log *zap.Logger, config Config, dialer rpc.Dialer, overlay overlay.DB, loop *metainfo.Loop) *Service { return &Service{ log: log, config: config, Loop: *sync2.NewCycle(config.Interval), - transport: transport, + dialer: dialer, overlay: overlay, metainfoLoop: loop, } @@ -146,7 +146,7 @@ func (service *Service) sendRetainRequest(ctx context.Context, id storj.NodeID, return Error.Wrap(err) } - client, err := piecestore.Dial(ctx, service.transport, &dossier.Node, log, piecestore.DefaultConfig) + client, err := piecestore.Dial(ctx, service.dialer, &dossier.Node, log, piecestore.DefaultConfig) if err != nil { return Error.Wrap(err) } diff --git a/satellite/metainfo/metainfo.go b/satellite/metainfo/metainfo.go index ec959fb28..b012941db 100644 --- a/satellite/metainfo/metainfo.go +++ b/satellite/metainfo/metainfo.go @@ -14,13 +14,12 @@ import ( "github.com/skyrings/skyring-common/tools/uuid" "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" monkit "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/macaroon" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" "storj.io/storj/satellite/accounting" @@ -119,12 +118,12 @@ func (endpoint *Endpoint) SegmentInfoOld(ctx context.Context, req *pb.SegmentInf Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Bucket) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } pointer, _, err := endpoint.getPointer(ctx, keyInfo.ProjectID, req.Segment, req.Bucket, req.Path) @@ -146,21 +145,21 @@ func (endpoint *Endpoint) CreateSegmentOld(ctx context.Context, req *pb.SegmentW Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Bucket) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } if !req.Expiration.IsZero() && !req.Expiration.After(time.Now()) { - return nil, status.Error(codes.InvalidArgument, "Invalid expiration time") + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, "Invalid expiration time") } err = endpoint.validateRedundancy(ctx, req.Redundancy) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } exceeded, limit, err := endpoint.projectUsage.ExceedsStorageUsage(ctx, keyInfo.ProjectID) @@ -171,7 +170,7 @@ func (endpoint *Endpoint) CreateSegmentOld(ctx context.Context, req *pb.SegmentW endpoint.log.Sugar().Errorf("monthly project limits are %s of storage and bandwidth usage. This limit has been exceeded for storage for projectID %s", limit, keyInfo.ProjectID, ) - return nil, status.Error(codes.ResourceExhausted, "Exceeded Usage Limit") + return nil, rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Usage Limit") } redundancy, err := eestream.NewRedundancyStrategyFromProto(req.GetRedundancy()) @@ -188,7 +187,7 @@ func (endpoint *Endpoint) CreateSegmentOld(ctx context.Context, req *pb.SegmentW } nodes, err := endpoint.overlay.FindStorageNodes(ctx, request) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } bucketID := createBucketID(keyInfo.ProjectID, req.Bucket) @@ -234,38 +233,38 @@ func (endpoint *Endpoint) CommitSegmentOld(ctx context.Context, req *pb.SegmentC Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Bucket) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } err = endpoint.validateCommitSegment(ctx, req) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } err = endpoint.filterValidPieces(ctx, req.Pointer, req.OriginalLimits) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } path, err := CreatePath(ctx, keyInfo.ProjectID, req.Segment, req.Bucket, req.Path) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } exceeded, limit, err := endpoint.projectUsage.ExceedsStorageUsage(ctx, keyInfo.ProjectID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } if exceeded { endpoint.log.Sugar().Errorf("monthly project limits are %s of storage and bandwidth usage. This limit has been exceeded for storage for projectID %s.", limit, keyInfo.ProjectID, ) - return nil, status.Error(codes.ResourceExhausted, "Exceeded Usage Limit") + return nil, rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Usage Limit") } // clear hashes so we don't store them @@ -282,7 +281,7 @@ func (endpoint *Endpoint) CommitSegmentOld(ctx context.Context, req *pb.SegmentC //We cannot have more redundancy than total/min if float64(remoteUsed) > (float64(req.Pointer.SegmentSize)/float64(req.Pointer.Remote.Redundancy.MinReq))*float64(req.Pointer.Remote.Redundancy.Total) { endpoint.log.Sugar().Debugf("data size mismatch, got segment: %d, pieces: %d, RS Min, Total: %d,%d", req.Pointer.SegmentSize, remoteUsed, req.Pointer.Remote.Redundancy.MinReq, req.Pointer.Remote.Redundancy.Total) - return nil, status.Error(codes.InvalidArgument, "mismatched segment size and piece usage") + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, "mismatched segment size and piece usage") } } @@ -294,20 +293,20 @@ func (endpoint *Endpoint) CommitSegmentOld(ctx context.Context, req *pb.SegmentC err = endpoint.metainfo.Put(ctx, path, req.Pointer) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } if req.Pointer.Type == pb.Pointer_INLINE { // TODO or maybe use pointer.SegmentSize ?? err = endpoint.orders.UpdatePutInlineOrder(ctx, keyInfo.ProjectID, req.Bucket, int64(len(req.Pointer.InlineSegment))) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } } pointer, err := endpoint.metainfo.Get(ctx, path) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } if len(req.OriginalLimits) > 0 { @@ -328,12 +327,12 @@ func (endpoint *Endpoint) DownloadSegmentOld(ctx context.Context, req *pb.Segmen Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Bucket) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } bucketID := createBucketID(keyInfo.ProjectID, req.Bucket) @@ -346,7 +345,7 @@ func (endpoint *Endpoint) DownloadSegmentOld(ctx context.Context, req *pb.Segmen endpoint.log.Sugar().Errorf("monthly project limits are %s of storage and bandwidth usage. This limit has been exceeded for bandwidth for projectID %s.", limit, keyInfo.ProjectID, ) - return nil, status.Error(codes.ResourceExhausted, "Exceeded Usage Limit") + return nil, rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Usage Limit") } pointer, _, err := endpoint.getPointer(ctx, keyInfo.ProjectID, req.Segment, req.Bucket, req.Path) @@ -358,13 +357,13 @@ func (endpoint *Endpoint) DownloadSegmentOld(ctx context.Context, req *pb.Segmen // TODO or maybe use pointer.SegmentSize ?? err := endpoint.orders.UpdateGetInlineOrder(ctx, keyInfo.ProjectID, req.Bucket, int64(len(pointer.InlineSegment))) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.SegmentDownloadResponseOld{Pointer: pointer}, nil } else if pointer.Type == pb.Pointer_REMOTE && pointer.Remote != nil { limits, privateKey, err := endpoint.orders.CreateGetOrderLimits(ctx, bucketID, pointer) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.SegmentDownloadResponseOld{Pointer: pointer, AddressedLimits: limits, PrivateKey: privateKey}, nil } @@ -383,46 +382,46 @@ func (endpoint *Endpoint) DeleteSegmentOld(ctx context.Context, req *pb.SegmentD Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Bucket) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } path, err := CreatePath(ctx, keyInfo.ProjectID, req.Segment, req.Bucket, req.Path) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } // TODO refactor to use []byte directly pointer, err := endpoint.metainfo.Get(ctx, path) if err != nil { if storage.ErrKeyNotFound.Has(err) { - return nil, status.Error(codes.NotFound, err.Error()) + return nil, rpcstatus.Error(rpcstatus.NotFound, err.Error()) } - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } err = endpoint.metainfo.Delete(ctx, path) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } if pointer.Type == pb.Pointer_REMOTE && pointer.Remote != nil { for _, piece := range pointer.GetRemote().GetRemotePieces() { _, err := endpoint.containment.Delete(ctx, piece.NodeId) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } } bucketID := createBucketID(keyInfo.ProjectID, req.Bucket) limits, privateKey, err := endpoint.orders.CreateDeleteOrderLimits(ctx, bucketID, pointer) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.SegmentDeleteResponseOld{AddressedLimits: limits, PrivateKey: privateKey}, nil @@ -442,17 +441,17 @@ func (endpoint *Endpoint) ListSegmentsOld(ctx context.Context, req *pb.ListSegme Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } prefix, err := CreatePath(ctx, keyInfo.ProjectID, -1, req.Bucket, req.Prefix) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } items, more, err := endpoint.metainfo.List(ctx, prefix, string(req.StartAfter), string(req.EndBefore), req.Recursive, req.Limit, req.MetaFlags) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } segmentItems := make([]*pb.ListSegmentsResponseOld_Item, len(items)) @@ -619,7 +618,7 @@ func (endpoint *Endpoint) ProjectInfo(ctx context.Context, req *pb.ProjectInfoRe Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } salt := sha256.Sum256(keyInfo.ProjectID[:]) @@ -639,15 +638,15 @@ func (endpoint *Endpoint) GetBucket(ctx context.Context, req *pb.BucketGetReques Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } bucket, err := endpoint.metainfo.GetBucket(ctx, req.GetName(), keyInfo.ProjectID) if err != nil { if storj.ErrBucketNotFound.Has(err) { - return nil, status.Error(codes.NotFound, err.Error()) + return nil, rpcstatus.Error(rpcstatus.NotFound, err.Error()) } - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } convBucket, err := convertBucketToProto(ctx, bucket) @@ -670,19 +669,19 @@ func (endpoint *Endpoint) CreateBucket(ctx context.Context, req *pb.BucketCreate Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Name) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } // TODO set default Redundancy if not set err = endpoint.validateRedundancy(ctx, req.GetDefaultRedundancyScheme()) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } // checks if bucket exists before updates it or makes a new entry @@ -691,12 +690,12 @@ func (endpoint *Endpoint) CreateBucket(ctx context.Context, req *pb.BucketCreate var partnerID uuid.UUID err = partnerID.UnmarshalJSON(req.GetPartnerId()) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } // partnerID not set if partnerID.IsZero() { - return resp, status.Error(codes.AlreadyExists, "Bucket already exists") + return resp, rpcstatus.Error(rpcstatus.AlreadyExists, "Bucket already exists") } //update the bucket @@ -705,7 +704,7 @@ func (endpoint *Endpoint) CreateBucket(ctx context.Context, req *pb.BucketCreate pbBucket, err := convertBucketToProto(ctx, bucket) if err != nil { - return resp, status.Error(codes.Internal, err.Error()) + return resp, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.BucketCreateResponse{ @@ -717,7 +716,7 @@ func (endpoint *Endpoint) CreateBucket(ctx context.Context, req *pb.BucketCreate if storj.ErrBucketNotFound.Has(err) { bucket, err := convertProtoToBucket(req, keyInfo.ProjectID) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } bucket, err = endpoint.metainfo.CreateBucket(ctx, bucket) @@ -747,17 +746,17 @@ func (endpoint *Endpoint) DeleteBucket(ctx context.Context, req *pb.BucketDelete Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Name) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } err = endpoint.metainfo.DeleteBucket(ctx, req.Name, keyInfo.ProjectID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.BucketDeleteResponse{}, nil @@ -772,7 +771,7 @@ func (endpoint *Endpoint) ListBuckets(ctx context.Context, req *pb.BucketListReq } keyInfo, err := endpoint.validateAuth(ctx, req.Header, action) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } allowedBuckets, err := getAllowedBuckets(ctx, req.Header, action) @@ -807,11 +806,11 @@ func (endpoint *Endpoint) ListBuckets(ctx context.Context, req *pb.BucketListReq func getAllowedBuckets(ctx context.Context, header *pb.RequestHeader, action macaroon.Action) (_ macaroon.AllowedBuckets, err error) { key, err := getAPIKey(ctx, header) if err != nil { - return macaroon.AllowedBuckets{}, status.Errorf(codes.InvalidArgument, "Invalid API credentials: %v", err) + return macaroon.AllowedBuckets{}, rpcstatus.Errorf(rpcstatus.InvalidArgument, "Invalid API credentials: %v", err) } allowedBuckets, err := key.GetAllowedBuckets(ctx, action) if err != nil { - return macaroon.AllowedBuckets{}, status.Errorf(codes.Internal, "GetAllowedBuckets: %v", err) + return macaroon.AllowedBuckets{}, rpcstatus.Errorf(rpcstatus.Internal, "GetAllowedBuckets: %v", err) } return allowedBuckets, err } @@ -833,12 +832,12 @@ func (endpoint *Endpoint) setBucketAttribution(ctx context.Context, header *pb.R Time: time.Now(), }) if err != nil { - return status.Error(codes.Unauthenticated, err.Error()) + return rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } partnerID, err := bytesToUUID(parterID) if err != nil { - return status.Errorf(codes.InvalidArgument, "unable to parse partner ID: %v", err) + return rpcstatus.Errorf(rpcstatus.InvalidArgument, "unable to parse partner ID: %v", err) } // check if attribution is set for given bucket @@ -851,22 +850,22 @@ func (endpoint *Endpoint) setBucketAttribution(ctx context.Context, header *pb.R if !attribution.ErrBucketNotAttributed.Has(err) { // try only to set the attribution, when it's missing endpoint.log.Error("error while getting attribution from DB", zap.Error(err)) - return status.Error(codes.Internal, err.Error()) + return rpcstatus.Error(rpcstatus.Internal, err.Error()) } prefix, err := CreatePath(ctx, keyInfo.ProjectID, -1, bucketName, []byte{}) if err != nil { - return status.Error(codes.InvalidArgument, err.Error()) + return rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } items, _, err := endpoint.metainfo.List(ctx, prefix, "", "", true, 1, 0) if err != nil { endpoint.log.Error("error while listing segments", zap.Error(err)) - return status.Error(codes.Internal, err.Error()) + return rpcstatus.Error(rpcstatus.Internal, err.Error()) } if len(items) > 0 { - return status.Errorf(codes.AlreadyExists, "Bucket %q is not empty, PartnerID %q cannot be attributed", bucketName, partnerID) + return rpcstatus.Errorf(rpcstatus.AlreadyExists, "Bucket %q is not empty, PartnerID %q cannot be attributed", bucketName, partnerID) } _, err = endpoint.partnerinfo.Insert(ctx, &attribution.Info{ @@ -876,7 +875,7 @@ func (endpoint *Endpoint) setBucketAttribution(ctx context.Context, header *pb.R }) if err != nil { endpoint.log.Error("error while inserting attribution to DB", zap.Error(err)) - return status.Error(codes.Internal, err.Error()) + return rpcstatus.Error(rpcstatus.Internal, err.Error()) } return nil } @@ -925,7 +924,7 @@ func convertBucketToProto(ctx context.Context, bucket storj.Bucket) (pbBucket *p rs := bucket.DefaultRedundancyScheme partnerID, err := bucket.PartnerID.MarshalJSON() if err != nil { - return pbBucket, status.Error(codes.Internal, "UUID marshal error") + return pbBucket, rpcstatus.Error(rpcstatus.Internal, "UUID marshal error") } return &pb.Bucket{ Name: []byte(bucket.Name), @@ -959,17 +958,17 @@ func (endpoint *Endpoint) BeginObject(ctx context.Context, req *pb.ObjectBeginRe Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Bucket) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } bucket, err := endpoint.metainfo.GetBucket(ctx, req.Bucket, keyInfo.ProjectID) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } // take bucket RS values if not set in request @@ -1010,7 +1009,7 @@ func (endpoint *Endpoint) BeginObject(ctx context.Context, req *pb.ObjectBeginRe ExpirationDate: req.ExpiresAt, }) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.ObjectBeginResponse{ @@ -1030,16 +1029,16 @@ func (endpoint *Endpoint) CommitObject(ctx context.Context, req *pb.ObjectCommit streamID := &pb.SatStreamID{} err = proto.Unmarshal(req.StreamId, streamID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } err = signing.VerifyStreamID(ctx, endpoint.satellite, streamID) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } if streamID.CreationDate.Before(time.Now().Add(-satIDExpiration)) { - return nil, status.Error(codes.InvalidArgument, "stream ID expired") + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, "stream ID expired") } keyInfo, err := endpoint.validateAuth(ctx, req.Header, macaroon.Action{ @@ -1049,7 +1048,7 @@ func (endpoint *Endpoint) CommitObject(ctx context.Context, req *pb.ObjectCommit Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } segmentIndex := int64(0) @@ -1058,7 +1057,7 @@ func (endpoint *Endpoint) CommitObject(ctx context.Context, req *pb.ObjectCommit for { path, err := CreatePath(ctx, keyInfo.ProjectID, segmentIndex, streamID.Bucket, streamID.EncryptedPath) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to create segment path: %v", err) + return nil, rpcstatus.Errorf(rpcstatus.InvalidArgument, "unable to create segment path: %v", err) } pointer, err := endpoint.metainfo.Get(ctx, path) @@ -1066,7 +1065,7 @@ func (endpoint *Endpoint) CommitObject(ctx context.Context, req *pb.ObjectCommit if storage.ErrKeyNotFound.Has(err) { break } - return nil, status.Errorf(codes.Internal, "unable to create get segment: %v", err) + return nil, rpcstatus.Errorf(rpcstatus.Internal, "unable to create get segment: %v", err) } lastSegmentPointer = pointer @@ -1074,7 +1073,7 @@ func (endpoint *Endpoint) CommitObject(ctx context.Context, req *pb.ObjectCommit segmentIndex++ } if lastSegmentPointer == nil { - return nil, status.Errorf(codes.NotFound, "unable to find object: %q/%q", streamID.Bucket, streamID.EncryptedPath) + return nil, rpcstatus.Errorf(rpcstatus.NotFound, "unable to find object: %q/%q", streamID.Bucket, streamID.EncryptedPath) } if lastSegmentPointer.Remote == nil { @@ -1086,17 +1085,17 @@ func (endpoint *Endpoint) CommitObject(ctx context.Context, req *pb.ObjectCommit err = endpoint.metainfo.Delete(ctx, lastSegmentPath) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } lastSegmentPath, err = CreatePath(ctx, keyInfo.ProjectID, -1, streamID.Bucket, streamID.EncryptedPath) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } err = endpoint.metainfo.Put(ctx, lastSegmentPath, lastSegmentPointer) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.ObjectCommitResponse{}, nil @@ -1113,12 +1112,12 @@ func (endpoint *Endpoint) GetObject(ctx context.Context, req *pb.ObjectGetReques Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Bucket) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } pointer, _, err := endpoint.getPointer(ctx, keyInfo.ProjectID, -1, req.Bucket, req.EncryptedPath) @@ -1129,7 +1128,7 @@ func (endpoint *Endpoint) GetObject(ctx context.Context, req *pb.ObjectGetReques streamMeta := &pb.StreamMeta{} err = proto.Unmarshal(pointer.Metadata, streamMeta) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } streamID, err := endpoint.packStreamID(ctx, &pb.SatStreamID{ @@ -1139,7 +1138,7 @@ func (endpoint *Endpoint) GetObject(ctx context.Context, req *pb.ObjectGetReques CreationDate: time.Now(), }) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } object := &pb.Object{ @@ -1174,7 +1173,7 @@ func (endpoint *Endpoint) GetObject(ctx context.Context, req *pb.ObjectGetReques path, err := CreatePath(ctx, keyInfo.ProjectID, index, req.Bucket, req.EncryptedPath) if err != nil { endpoint.log.Error("unable to get pointer path", zap.Error(err)) - return nil, status.Error(codes.Internal, "unable to get object") + return nil, rpcstatus.Error(rpcstatus.Internal, "unable to get object") } pointer, err = endpoint.metainfo.Get(ctx, path) @@ -1184,7 +1183,7 @@ func (endpoint *Endpoint) GetObject(ctx context.Context, req *pb.ObjectGetReques } endpoint.log.Error("unable to get pointer", zap.Error(err)) - return nil, status.Error(codes.Internal, "unable to get object") + return nil, rpcstatus.Error(rpcstatus.Internal, "unable to get object") } if pointer.Remote != nil { object.RedundancyScheme = pointer.Remote.Redundancy @@ -1210,17 +1209,17 @@ func (endpoint *Endpoint) ListObjects(ctx context.Context, req *pb.ObjectListReq Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Bucket) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } prefix, err := CreatePath(ctx, keyInfo.ProjectID, -1, req.Bucket, req.EncryptedPrefix) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } metaflags := meta.All @@ -1228,7 +1227,7 @@ func (endpoint *Endpoint) ListObjects(ctx context.Context, req *pb.ObjectListReq // TODO find out how EncryptedCursor -> startAfter/endAfter segments, more, err := endpoint.metainfo.List(ctx, prefix, string(req.EncryptedCursor), "", req.Recursive, req.Limit, metaflags) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } items := make([]*pb.ObjectListItem, len(segments)) @@ -1260,12 +1259,12 @@ func (endpoint *Endpoint) BeginDeleteObject(ctx context.Context, req *pb.ObjectB Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.validateBucket(ctx, req.Bucket) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } satStreamID := &pb.SatStreamID{ @@ -1277,17 +1276,17 @@ func (endpoint *Endpoint) BeginDeleteObject(ctx context.Context, req *pb.ObjectB satStreamID, err = signing.SignStreamID(ctx, endpoint.satellite, satStreamID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } encodedStreamID, err := proto.Marshal(satStreamID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } streamID, err := storj.StreamIDFromBytes(encodedStreamID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } _, _, err = endpoint.getPointer(ctx, keyInfo.ProjectID, -1, satStreamID.Bucket, satStreamID.EncryptedPath) @@ -1307,16 +1306,16 @@ func (endpoint *Endpoint) FinishDeleteObject(ctx context.Context, req *pb.Object streamID := &pb.SatStreamID{} err = proto.Unmarshal(req.StreamId, streamID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } err = signing.VerifyStreamID(ctx, endpoint.satellite, streamID) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } if streamID.CreationDate.Before(time.Now().Add(-satIDExpiration)) { - return nil, status.Error(codes.InvalidArgument, "stream ID expired") + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, "stream ID expired") } _, err = endpoint.validateAuth(ctx, req.Header, macaroon.Action{ @@ -1326,7 +1325,7 @@ func (endpoint *Endpoint) FinishDeleteObject(ctx context.Context, req *pb.Object Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } // we don't need to do anything for shim implementation @@ -1340,7 +1339,7 @@ func (endpoint *Endpoint) BeginSegment(ctx context.Context, req *pb.SegmentBegin streamID, err := endpoint.unmarshalSatStreamID(ctx, req.StreamId) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } keyInfo, err := endpoint.validateAuth(ctx, req.Header, macaroon.Action{ @@ -1350,13 +1349,13 @@ func (endpoint *Endpoint) BeginSegment(ctx context.Context, req *pb.SegmentBegin Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } // no need to validate streamID fields because it was validated during BeginObject if req.Position.Index < 0 { - return nil, status.Error(codes.InvalidArgument, "segment index must be greater then 0") + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, "segment index must be greater then 0") } exceeded, limit, err := endpoint.projectUsage.ExceedsStorageUsage(ctx, keyInfo.ProjectID) @@ -1367,12 +1366,12 @@ func (endpoint *Endpoint) BeginSegment(ctx context.Context, req *pb.SegmentBegin endpoint.log.Sugar().Errorf("monthly project limits are %s of storage and bandwidth usage. This limit has been exceeded for storage for projectID %s", limit, keyInfo.ProjectID, ) - return nil, status.Error(codes.ResourceExhausted, "Exceeded Usage Limit") + return nil, rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Usage Limit") } redundancy, err := eestream.NewRedundancyStrategyFromProto(streamID.Redundancy) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } maxPieceSize := eestream.CalcPieceSize(req.MaxOrderLimit, redundancy) @@ -1384,13 +1383,13 @@ func (endpoint *Endpoint) BeginSegment(ctx context.Context, req *pb.SegmentBegin } nodes, err := endpoint.overlay.FindStorageNodes(ctx, request) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } bucketID := createBucketID(keyInfo.ProjectID, streamID.Bucket) rootPieceID, addressedLimits, piecePrivateKey, err := endpoint.orders.CreatePutOrderLimits(ctx, bucketID, nodes, streamID.ExpirationDate, maxPieceSize) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } segmentID, err := endpoint.packSegmentID(ctx, &pb.SatSegmentID{ @@ -1414,7 +1413,7 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm segmentID, err := endpoint.unmarshalSatSegmentID(ctx, req.SegmentId) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } streamID := segmentID.StreamId @@ -1426,7 +1425,7 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } pieces := make([]*pb.RemotePiece, len(req.UploadResult)) @@ -1449,7 +1448,7 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm }) if err != nil { endpoint.log.Error("unable to marshal segment metadata", zap.Error(err)) - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } pointer := &pb.Pointer{ @@ -1471,28 +1470,28 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm err = endpoint.validatePointer(ctx, pointer, orderLimits) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } err = endpoint.filterValidPieces(ctx, pointer, orderLimits) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } path, err := CreatePath(ctx, keyInfo.ProjectID, int64(segmentID.Index), streamID.Bucket, streamID.EncryptedPath) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } exceeded, limit, err := endpoint.projectUsage.ExceedsStorageUsage(ctx, keyInfo.ProjectID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } if exceeded { endpoint.log.Sugar().Errorf("monthly project limits are %s of storage and bandwidth usage. This limit has been exceeded for storage for projectID %s.", limit, keyInfo.ProjectID, ) - return nil, status.Error(codes.ResourceExhausted, "Exceeded Usage Limit") + return nil, rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Usage Limit") } // clear hashes so we don't store them @@ -1508,7 +1507,7 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm //We cannot have more redundancy than total/min if float64(remoteUsed) > (float64(pointer.SegmentSize)/float64(pointer.Remote.Redundancy.MinReq))*float64(pointer.Remote.Redundancy.Total) { endpoint.log.Sugar().Debugf("data size mismatch, got segment: %d, pieces: %d, RS Min, Total: %d,%d", pointer.SegmentSize, remoteUsed, pointer.Remote.Redundancy.MinReq, pointer.Remote.Redundancy.Total) - return nil, status.Error(codes.InvalidArgument, "mismatched segment size and piece usage") + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, "mismatched segment size and piece usage") } } @@ -1520,7 +1519,7 @@ func (endpoint *Endpoint) CommitSegment(ctx context.Context, req *pb.SegmentComm err = endpoint.metainfo.Put(ctx, path, pointer) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.SegmentCommitResponse{ @@ -1534,7 +1533,7 @@ func (endpoint *Endpoint) MakeInlineSegment(ctx context.Context, req *pb.Segment streamID, err := endpoint.unmarshalSatStreamID(ctx, req.StreamId) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } keyInfo, err := endpoint.validateAuth(ctx, req.Header, macaroon.Action{ @@ -1544,27 +1543,27 @@ func (endpoint *Endpoint) MakeInlineSegment(ctx context.Context, req *pb.Segment Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } if req.Position.Index < 0 { - return nil, status.Error(codes.InvalidArgument, "segment index must be greater then 0") + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, "segment index must be greater then 0") } path, err := CreatePath(ctx, keyInfo.ProjectID, int64(req.Position.Index), streamID.Bucket, streamID.EncryptedPath) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } exceeded, limit, err := endpoint.projectUsage.ExceedsStorageUsage(ctx, keyInfo.ProjectID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } if exceeded { endpoint.log.Sugar().Errorf("monthly project limits are %s of storage and bandwidth usage. This limit has been exceeded for storage for projectID %s.", limit, keyInfo.ProjectID, ) - return nil, status.Error(codes.ResourceExhausted, "Exceeded Usage Limit") + return nil, rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Usage Limit") } inlineUsed := int64(len(req.EncryptedInlineData)) @@ -1591,12 +1590,12 @@ func (endpoint *Endpoint) MakeInlineSegment(ctx context.Context, req *pb.Segment err = endpoint.metainfo.Put(ctx, path, pointer) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } err = endpoint.orders.UpdatePutInlineOrder(ctx, keyInfo.ProjectID, streamID.Bucket, inlineUsed) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.SegmentMakeInlineResponse{}, nil @@ -1608,7 +1607,7 @@ func (endpoint *Endpoint) BeginDeleteSegment(ctx context.Context, req *pb.Segmen streamID, err := endpoint.unmarshalSatStreamID(ctx, req.StreamId) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } keyInfo, err := endpoint.validateAuth(ctx, req.Header, macaroon.Action{ @@ -1618,7 +1617,7 @@ func (endpoint *Endpoint) BeginDeleteSegment(ctx context.Context, req *pb.Segmen Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } pointer, path, err := endpoint.getPointer(ctx, keyInfo.ProjectID, int64(req.Position.Index), streamID.Bucket, streamID.EncryptedPath) @@ -1632,7 +1631,7 @@ func (endpoint *Endpoint) BeginDeleteSegment(ctx context.Context, req *pb.Segmen bucketID := createBucketID(keyInfo.ProjectID, streamID.Bucket) limits, privateKey, err = endpoint.orders.CreateDeleteOrderLimits(ctx, bucketID, pointer) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } } @@ -1641,13 +1640,13 @@ func (endpoint *Endpoint) BeginDeleteSegment(ctx context.Context, req *pb.Segmen for _, piece := range pointer.GetRemote().GetRemotePieces() { _, err := endpoint.containment.Delete(ctx, piece.NodeId) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } } err = endpoint.metainfo.Delete(ctx, path) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } segmentID, err := endpoint.packSegmentID(ctx, &pb.SatSegmentID{ @@ -1670,7 +1669,7 @@ func (endpoint *Endpoint) FinishDeleteSegment(ctx context.Context, req *pb.Segme segmentID, err := endpoint.unmarshalSatSegmentID(ctx, req.SegmentId) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } streamID := segmentID.StreamId @@ -1682,7 +1681,7 @@ func (endpoint *Endpoint) FinishDeleteSegment(ctx context.Context, req *pb.Segme Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } // at the moment logic is in BeginDeleteSegment @@ -1696,7 +1695,7 @@ func (endpoint *Endpoint) ListSegments(ctx context.Context, req *pb.SegmentListR streamID, err := endpoint.unmarshalSatStreamID(ctx, req.StreamId) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } keyInfo, err := endpoint.validateAuth(ctx, req.Header, macaroon.Action{ @@ -1706,7 +1705,7 @@ func (endpoint *Endpoint) ListSegments(ctx context.Context, req *pb.SegmentListR Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } limit := req.Limit @@ -1716,21 +1715,21 @@ func (endpoint *Endpoint) ListSegments(ctx context.Context, req *pb.SegmentListR path, err := CreatePath(ctx, keyInfo.ProjectID, lastSegment, streamID.Bucket, streamID.EncryptedPath) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } pointer, err := endpoint.metainfo.Get(ctx, path) if err != nil { if storage.ErrKeyNotFound.Has(err) { - return nil, status.Error(codes.NotFound, err.Error()) + return nil, rpcstatus.Error(rpcstatus.NotFound, err.Error()) } - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } streamMeta := &pb.StreamMeta{} err = proto.Unmarshal(pointer.Metadata, streamMeta) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } if streamMeta.NumberOfSegments > 0 { @@ -1790,14 +1789,14 @@ func (endpoint *Endpoint) listSegmentsManually(ctx context.Context, projectID uu for { path, err := CreatePath(ctx, projectID, index, streamID.Bucket, streamID.EncryptedPath) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } _, err = endpoint.metainfo.Get(ctx, path) if err != nil { if storage.ErrKeyNotFound.Has(err) { break } - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } if limit == int32(len(segmentItems)) { more = true @@ -1834,7 +1833,7 @@ func (endpoint *Endpoint) DownloadSegment(ctx context.Context, req *pb.SegmentDo streamID, err := endpoint.unmarshalSatStreamID(ctx, req.StreamId) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } keyInfo, err := endpoint.validateAuth(ctx, req.Header, macaroon.Action{ @@ -1844,7 +1843,7 @@ func (endpoint *Endpoint) DownloadSegment(ctx context.Context, req *pb.SegmentDo Time: time.Now(), }) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } bucketID := createBucketID(keyInfo.ProjectID, streamID.Bucket) @@ -1857,7 +1856,7 @@ func (endpoint *Endpoint) DownloadSegment(ctx context.Context, req *pb.SegmentDo endpoint.log.Sugar().Errorf("monthly project limits are %s of storage and bandwidth usage. This limit has been exceeded for bandwidth for projectID %s.", limit, keyInfo.ProjectID, ) - return nil, status.Error(codes.ResourceExhausted, "Exceeded Usage Limit") + return nil, rpcstatus.Error(rpcstatus.ResourceExhausted, "Exceeded Usage Limit") } pointer, _, err := endpoint.getPointer(ctx, keyInfo.ProjectID, int64(req.CursorPosition.Index), streamID.Bucket, streamID.EncryptedPath) @@ -1875,21 +1874,21 @@ func (endpoint *Endpoint) DownloadSegment(ctx context.Context, req *pb.SegmentDo streamMeta := &pb.StreamMeta{} err = proto.Unmarshal(pointer.Metadata, streamMeta) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } segmentMeta = streamMeta.LastSegmentMeta } else { segmentMeta = &pb.SegmentMeta{} err = proto.Unmarshal(pointer.Metadata, segmentMeta) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } } if segmentMeta != nil { encryptedKeyNonce, err = storj.NonceFromBytes(segmentMeta.KeyNonce) if err != nil { endpoint.log.Error("unable to get encryption key nonce from metadata", zap.Error(err)) - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } encryptedKey = segmentMeta.EncryptedKey @@ -1899,7 +1898,7 @@ func (endpoint *Endpoint) DownloadSegment(ctx context.Context, req *pb.SegmentDo if pointer.Type == pb.Pointer_INLINE { err := endpoint.orders.UpdateGetInlineOrder(ctx, keyInfo.ProjectID, streamID.Bucket, int64(len(pointer.InlineSegment))) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.SegmentDownloadResponse{ SegmentId: segmentID, @@ -1912,7 +1911,7 @@ func (endpoint *Endpoint) DownloadSegment(ctx context.Context, req *pb.SegmentDo } else if pointer.Type == pb.Pointer_REMOTE && pointer.Remote != nil { limits, privateKey, err := endpoint.orders.CreateGetOrderLimits(ctx, bucketID, pointer) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } limits = sortLimits(limits, pointer) @@ -1935,21 +1934,21 @@ func (endpoint *Endpoint) DownloadSegment(ctx context.Context, req *pb.SegmentDo }, nil } - return &pb.SegmentDownloadResponse{}, status.Error(codes.Internal, "invalid type of pointer") + return &pb.SegmentDownloadResponse{}, rpcstatus.Error(rpcstatus.Internal, "invalid type of pointer") } func (endpoint *Endpoint) getPointer(ctx context.Context, projectID uuid.UUID, segmentIndex int64, bucket, encryptedPath []byte) (*pb.Pointer, string, error) { path, err := CreatePath(ctx, projectID, segmentIndex, bucket, encryptedPath) if err != nil { - return nil, "", status.Error(codes.InvalidArgument, err.Error()) + return nil, "", rpcstatus.Error(rpcstatus.InvalidArgument, err.Error()) } pointer, err := endpoint.metainfo.Get(ctx, path) if err != nil { if storage.ErrKeyNotFound.Has(err) { - return nil, "", status.Error(codes.NotFound, err.Error()) + return nil, "", rpcstatus.Error(rpcstatus.NotFound, err.Error()) } - return nil, "", status.Error(codes.Internal, err.Error()) + return nil, "", rpcstatus.Error(rpcstatus.Internal, err.Error()) } return pointer, path, nil } @@ -1977,17 +1976,17 @@ func (endpoint *Endpoint) packStreamID(ctx context.Context, satStreamID *pb.SatS signedStreamID, err := signing.SignStreamID(ctx, endpoint.satellite, satStreamID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } encodedStreamID, err := proto.Marshal(signedStreamID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } streamID, err = storj.StreamIDFromBytes(encodedStreamID) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return streamID, nil } diff --git a/satellite/metainfo/metainfo_test.go b/satellite/metainfo/metainfo_test.go index 8dae83d86..d7c419616 100644 --- a/satellite/metainfo/metainfo_test.go +++ b/satellite/metainfo/metainfo_test.go @@ -13,11 +13,9 @@ import ( "github.com/gogo/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "storj.io/storj/internal/errs2" "storj.io/storj/internal/memory" "storj.io/storj/internal/testcontext" "storj.io/storj/internal/testidentity" @@ -26,6 +24,7 @@ import ( "storj.io/storj/pkg/identity" "storj.io/storj/pkg/macaroon" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" "storj.io/storj/satellite" @@ -194,10 +193,8 @@ func assertUnauthenticated(t *testing.T, err error, allowed bool) { // If it's allowed, we allow any non-unauthenticated error because // some calls error after authentication checks. - if err, ok := status.FromError(errs.Unwrap(err)); ok { - assert.Equal(t, codes.Unauthenticated == err.Code(), !allowed) - } else if !allowed { - assert.Fail(t, "got unexpected error", "%T", err) + if !allowed { + assert.True(t, errs2.IsRPC(err, rpcstatus.Unauthenticated)) } } diff --git a/satellite/metainfo/validation.go b/satellite/metainfo/validation.go index b05ecacc1..4cf25cf53 100644 --- a/satellite/metainfo/validation.go +++ b/satellite/metainfo/validation.go @@ -13,13 +13,12 @@ import ( "github.com/gogo/protobuf/proto" "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "storj.io/storj/pkg/auth" "storj.io/storj/pkg/encryption" "storj.io/storj/pkg/macaroon" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" "storj.io/storj/satellite/console" @@ -141,20 +140,20 @@ func (endpoint *Endpoint) validateAuth(ctx context.Context, header *pb.RequestHe key, err := getAPIKey(ctx, header) if err != nil { endpoint.log.Debug("invalid request", zap.Error(err)) - return nil, status.Error(codes.InvalidArgument, "Invalid API credentials") + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, "Invalid API credentials") } keyInfo, err := endpoint.apiKeys.GetByHead(ctx, key.Head()) if err != nil { endpoint.log.Debug("unauthorized request", zap.Error(err)) - return nil, status.Error(codes.PermissionDenied, "Unauthorized API credentials") + return nil, rpcstatus.Error(rpcstatus.PermissionDenied, "Unauthorized API credentials") } // Revocations are currently handled by just deleting the key. err = key.Check(ctx, keyInfo.Secret, action, nil) if err != nil { endpoint.log.Debug("unauthorized request", zap.Error(err)) - return nil, status.Error(codes.PermissionDenied, "Unauthorized API credentials") + return nil, rpcstatus.Error(rpcstatus.PermissionDenied, "Unauthorized API credentials") } return keyInfo, nil diff --git a/satellite/nodestats/endpoint.go b/satellite/nodestats/endpoint.go index 1cf133a39..1cd4de7a0 100644 --- a/satellite/nodestats/endpoint.go +++ b/satellite/nodestats/endpoint.go @@ -7,12 +7,11 @@ import ( "context" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/satellite/accounting" "storj.io/storj/satellite/overlay" ) @@ -45,15 +44,15 @@ func (e *Endpoint) GetStats(ctx context.Context, req *pb.GetStatsRequest) (_ *pb peer, err := identity.PeerIdentityFromContext(ctx) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } node, err := e.overlay.Get(ctx, peer.ID) if err != nil { if overlay.ErrNodeNotFound.Has(err) { - return nil, status.Error(codes.PermissionDenied, err.Error()) + return nil, rpcstatus.Error(rpcstatus.PermissionDenied, err.Error()) } e.log.Error("overlay.Get failed", zap.Error(err)) - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } uptimeScore := calculateReputationScore( @@ -89,21 +88,21 @@ func (e *Endpoint) DailyStorageUsage(ctx context.Context, req *pb.DailyStorageUs peer, err := identity.PeerIdentityFromContext(ctx) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } node, err := e.overlay.Get(ctx, peer.ID) if err != nil { if overlay.ErrNodeNotFound.Has(err) { - return nil, status.Error(codes.PermissionDenied, err.Error()) + return nil, rpcstatus.Error(rpcstatus.PermissionDenied, err.Error()) } e.log.Error("overlay.Get failed", zap.Error(err)) - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } nodeSpaceUsages, err := e.accounting.QueryStorageNodeUsage(ctx, node.Id, req.GetFrom(), req.GetTo()) if err != nil { e.log.Error("accounting.QueryStorageNodeUsage failed", zap.Error(err)) - return nil, status.Error(codes.Internal, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } return &pb.DailyStorageUsageResponse{ diff --git a/satellite/orders/endpoint.go b/satellite/orders/endpoint.go index 91c17136d..1400523d1 100644 --- a/satellite/orders/endpoint.go +++ b/satellite/orders/endpoint.go @@ -11,12 +11,11 @@ import ( "github.com/skyrings/skyring-common/tools/uuid" "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" ) @@ -144,14 +143,14 @@ func (endpoint *Endpoint) doSettlement(stream settlementStream) (err error) { peer, err := identity.PeerIdentityFromContext(ctx) if err != nil { - return status.Error(codes.Unauthenticated, err.Error()) + return rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } formatError := func(err error) error { if err == io.EOF { return nil } - return status.Error(codes.Unknown, err.Error()) + return rpcstatus.Error(rpcstatus.Unknown, err.Error()) } log := endpoint.log.Named(peer.ID.String()) @@ -175,20 +174,20 @@ func (endpoint *Endpoint) doSettlement(stream settlementStream) (err error) { } if request == nil { - return status.Error(codes.InvalidArgument, "request missing") + return rpcstatus.Error(rpcstatus.InvalidArgument, "request missing") } if request.Limit == nil { - return status.Error(codes.InvalidArgument, "order limit missing") + return rpcstatus.Error(rpcstatus.InvalidArgument, "order limit missing") } if request.Order == nil { - return status.Error(codes.InvalidArgument, "order missing") + return rpcstatus.Error(rpcstatus.InvalidArgument, "order missing") } orderLimit := request.Limit order := request.Order if orderLimit.StorageNodeId != peer.ID { - return status.Error(codes.Unauthenticated, "only specified storage node can settle order") + return rpcstatus.Error(rpcstatus.Unauthenticated, "only specified storage node can settle order") } rejectErr := func() error { diff --git a/satellite/peer.go b/satellite/peer.go index af689745b..57b2001c4 100644 --- a/satellite/peer.go +++ b/satellite/peer.go @@ -23,10 +23,10 @@ import ( "storj.io/storj/pkg/pb" "storj.io/storj/pkg/peertls/extensions" "storj.io/storj/pkg/peertls/tlsopts" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/server" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/accounting" "storj.io/storj/satellite/accounting/live" "storj.io/storj/satellite/accounting/rollup" @@ -145,7 +145,7 @@ type Peer struct { Identity *identity.FullIdentity DB DB - Transport transport.Client + Dialer rpc.Dialer Server *server.Server @@ -262,18 +262,18 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revocationDB exten log.Debug("Starting listener and server") sc := config.Server - options, err := tlsopts.NewOptions(peer.Identity, sc.Config, revocationDB) + tlsOptions, err := tlsopts.NewOptions(peer.Identity, sc.Config, revocationDB) if err != nil { return nil, errs.Combine(err, peer.Close()) } - peer.Transport = transport.NewClient(options) + peer.Dialer = rpc.NewDefaultDialer(tlsOptions) unaryInterceptor := grpcauth.NewAPIKeyInterceptor() if sc.DebugLogTraffic { unaryInterceptor = server.CombineInterceptors(unaryInterceptor, server.UnaryMessageLoggingInterceptor(log)) } - peer.Server, err = server.New(log.Named("server"), options, sc.Address, sc.PrivateAddress, unaryInterceptor) + peer.Server, err = server.New(log.Named("server"), tlsOptions, sc.Address, sc.PrivateAddress, unaryInterceptor) if err != nil { return nil, errs.Combine(err, peer.Close()) } @@ -312,7 +312,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revocationDB exten Type: pb.NodeType_SATELLITE, Version: *pbVersion, } - peer.Contact.Service = contact.NewService(peer.Log.Named("contact:service"), self, peer.Overlay.Service, peer.DB.PeerIdentities(), peer.Transport) + peer.Contact.Service = contact.NewService(peer.Log.Named("contact:service"), self, peer.Overlay.Service, peer.DB.PeerIdentities(), peer.Dialer) peer.Contact.Endpoint = contact.NewEndpoint(peer.Log.Named("contact:endpoint"), peer.Contact.Service) peer.Contact.KEndpoint = contact.NewKademliaEndpoint(peer.Log.Named("contact:nodes_service_endpoint")) pb.RegisterNodeServer(peer.Server.GRPC(), peer.Contact.Endpoint) @@ -426,7 +426,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revocationDB exten peer.Metainfo.Service, peer.Orders.Service, peer.Overlay.Service, - peer.Transport, + peer.Dialer, config.Repairer.Timeout, config.Repairer.MaxExcessRateOptimalThreshold, signing.SigneeFromPeerIdentity(peer.Identity.PeerIdentity()), @@ -452,7 +452,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revocationDB exten peer.Audit.Verifier = audit.NewVerifier(log.Named("audit:verifier"), peer.Metainfo.Service, - peer.Transport, + peer.Dialer, peer.Overlay.Service, peer.DB.Containment(), peer.Orders.Service, @@ -492,7 +492,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revocationDB exten peer.GarbageCollection.Service = gc.NewService( peer.Log.Named("garbage collection"), config.GarbageCollection, - peer.Transport, + peer.Dialer, peer.Overlay.DB, peer.Metainfo.Loop, ) diff --git a/satellite/repair/repairer/ec.go b/satellite/repair/repairer/ec.go index 33c0b8f09..920463f95 100644 --- a/satellite/repair/repairer/ec.go +++ b/satellite/repair/repairer/ec.go @@ -21,9 +21,9 @@ import ( "storj.io/storj/internal/sync2" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/pkcrypto" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/uplink/eestream" "storj.io/storj/uplink/piecestore" ) @@ -34,22 +34,22 @@ var ErrPieceHashVerifyFailed = errs.Class("piece hashes don't match") // ECRepairer allows the repairer to download, verify, and upload pieces from storagenodes. type ECRepairer struct { log *zap.Logger - transport transport.Client + dialer rpc.Dialer satelliteSignee signing.Signee } // NewECRepairer creates a new repairer for interfacing with storagenodes. -func NewECRepairer(log *zap.Logger, tc transport.Client, satelliteSignee signing.Signee) *ECRepairer { +func NewECRepairer(log *zap.Logger, dialer rpc.Dialer, satelliteSignee signing.Signee) *ECRepairer { return &ECRepairer{ log: log, - transport: tc, + dialer: dialer, satelliteSignee: satelliteSignee, } } func (ec *ECRepairer) dialPiecestore(ctx context.Context, n *pb.Node) (*piecestore.Client, error) { logger := ec.log.Named(n.Id.String()) - return piecestore.Dial(ctx, ec.transport, n, logger, piecestore.DefaultConfig) + return piecestore.Dial(ctx, ec.dialer, n, logger, piecestore.DefaultConfig) } // Get downloads pieces from storagenodes using the provided order limits, and decodes those pieces into a segment. diff --git a/satellite/repair/repairer/segments.go b/satellite/repair/repairer/segments.go index 39fde2264..ba3ad6eb1 100644 --- a/satellite/repair/repairer/segments.go +++ b/satellite/repair/repairer/segments.go @@ -12,9 +12,9 @@ import ( "go.uber.org/zap" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/metainfo" "storj.io/storj/satellite/orders" "storj.io/storj/satellite/overlay" @@ -46,7 +46,7 @@ type SegmentRepairer struct { // when negative, 0 is applied. func NewSegmentRepairer( log *zap.Logger, metainfo *metainfo.Service, orders *orders.Service, - overlay *overlay.Service, tc transport.Client, timeout time.Duration, + overlay *overlay.Service, dialer rpc.Dialer, timeout time.Duration, excessOptimalThreshold float64, satelliteSignee signing.Signee, ) *SegmentRepairer { @@ -59,7 +59,7 @@ func NewSegmentRepairer( metainfo: metainfo, orders: orders, overlay: overlay, - ec: NewECRepairer(log.Named("ec repairer"), tc, satelliteSignee), + ec: NewECRepairer(log.Named("ec repairer"), dialer, satelliteSignee), timeout: timeout, multiplierOptimalThreshold: 1 + excessOptimalThreshold, } diff --git a/satellite/vouchers/vouchers_test.go b/satellite/vouchers/vouchers_test.go index be118a8c5..6b185f087 100644 --- a/satellite/vouchers/vouchers_test.go +++ b/satellite/vouchers/vouchers_test.go @@ -19,14 +19,14 @@ func TestVouchers(t *testing.T) { }, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) { satellite := planet.Satellites[0].Local().Node - conn, err := planet.StorageNodes[0].Transport.DialNode(ctx, &satellite) + conn, err := planet.StorageNodes[0].Dialer.DialNode(ctx, &satellite) require.NoError(t, err) defer ctx.Check(conn.Close) - client := pb.NewVouchersClient(conn) + client := conn.VouchersClient() resp, err := client.Request(ctx, &pb.VoucherRequest{}) require.Nil(t, resp) - require.EqualError(t, err, "rpc error: code = Unknown desc = Vouchers endpoint is deprecated. Please upgrade your storage node to the latest version.") + require.Error(t, err, "Vouchers endpoint is deprecated. Please upgrade your storage node to the latest version.") }) } diff --git a/storagenode/contact/chore.go b/storagenode/contact/chore.go index 0d8de6d1c..b2ba8a354 100644 --- a/storagenode/contact/chore.go +++ b/storagenode/contact/chore.go @@ -14,7 +14,7 @@ import ( "storj.io/storj/internal/sync2" "storj.io/storj/pkg/pb" - "storj.io/storj/pkg/transport" + "storj.io/storj/pkg/rpc" "storj.io/storj/storagenode/trust" ) @@ -22,9 +22,9 @@ import ( // // architecture: Chore type Chore struct { - log *zap.Logger - service *Service - transport transport.Client + log *zap.Logger + service *Service + dialer rpc.Dialer trust *trust.Pool @@ -33,11 +33,11 @@ type Chore struct { } // NewChore creates a new contact chore -func NewChore(log *zap.Logger, interval time.Duration, maxSleep time.Duration, trust *trust.Pool, transport transport.Client, service *Service) *Chore { +func NewChore(log *zap.Logger, interval time.Duration, maxSleep time.Duration, trust *trust.Pool, dialer rpc.Dialer, service *Service) *Chore { return &Chore{ - log: log, - service: service, - transport: transport, + log: log, + service: service, + dialer: dialer, trust: trust, @@ -75,22 +75,13 @@ func (chore *Chore) pingSatellites(ctx context.Context) (err error) { continue } group.Go(func() error { - conn, err := chore.transport.DialNode(ctx, &pb.Node{ - Id: satellite, - Address: &pb.NodeAddress{ - Transport: pb.NodeTransport_TCP_TLS_GRPC, - Address: addr, - }, - }) + conn, err := chore.dialer.DialAddressID(ctx, addr, satellite) if err != nil { return err } - defer func() { - if cerr := conn.Close(); cerr != nil { - err = errs.Combine(err, cerr) - } - }() - _, err = pb.NewNodeClient(conn).CheckIn(ctx, &pb.CheckInRequest{ + defer func() { err = errs.Combine(err, conn.Close()) }() + + _, err = conn.NodeClient().CheckIn(ctx, &pb.CheckInRequest{ Address: self.Address.GetAddress(), Version: &self.Version, Capacity: &self.Capacity, diff --git a/storagenode/contact/contact_test.go b/storagenode/contact/contact_test.go index d6bdce10d..02152329e 100644 --- a/storagenode/contact/contact_test.go +++ b/storagenode/contact/contact_test.go @@ -7,12 +7,12 @@ import ( "testing" "github.com/stretchr/testify/require" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "storj.io/storj/internal/errs2" "storj.io/storj/internal/testcontext" "storj.io/storj/internal/testplanet" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/storagenode" ) @@ -23,17 +23,17 @@ func TestStoragenodeContactEndpoint(t *testing.T) { nodeDossier := planet.StorageNodes[0].Local() pingStats := planet.StorageNodes[0].Contact.PingStats - conn, err := planet.Satellites[0].Transport.DialNode(ctx, &nodeDossier.Node) + conn, err := planet.Satellites[0].Dialer.DialNode(ctx, &nodeDossier.Node) require.NoError(t, err) defer ctx.Check(conn.Close) - resp, err := pb.NewContactClient(conn).PingNode(ctx, &pb.ContactPingRequest{}) + resp, err := conn.ContactClient().PingNode(ctx, &pb.ContactPingRequest{}) require.NotNil(t, resp) require.NoError(t, err) firstPing, _, _ := pingStats.WhenLastPinged() - resp, err = pb.NewContactClient(conn).PingNode(ctx, &pb.ContactPingRequest{}) + resp, err = conn.ContactClient().PingNode(ctx, &pb.ContactPingRequest{}) require.NotNil(t, resp) require.NoError(t, err) @@ -85,11 +85,11 @@ func TestRequestInfoEndpointTrustedSatellite(t *testing.T) { nodeDossier := planet.StorageNodes[0].Local() // Satellite Trusted - conn, err := planet.Satellites[0].Transport.DialNode(ctx, &nodeDossier.Node) + conn, err := planet.Satellites[0].Dialer.DialNode(ctx, &nodeDossier.Node) require.NoError(t, err) defer ctx.Check(conn.Close) - resp, err := pb.NewNodesClient(conn).RequestInfo(ctx, &pb.InfoRequest{}) + resp, err := conn.NodesClient().RequestInfo(ctx, &pb.InfoRequest{}) require.NotNil(t, resp) require.NoError(t, err) require.Equal(t, nodeDossier.Type, resp.Type) @@ -111,13 +111,13 @@ func TestRequestInfoEndpointUntrustedSatellite(t *testing.T) { nodeDossier := planet.StorageNodes[0].Local() // Satellite Untrusted - conn, err := planet.Satellites[0].Transport.DialNode(ctx, &nodeDossier.Node) + conn, err := planet.Satellites[0].Dialer.DialNode(ctx, &nodeDossier.Node) require.NoError(t, err) defer ctx.Check(conn.Close) - resp, err := pb.NewNodesClient(conn).RequestInfo(ctx, &pb.InfoRequest{}) + resp, err := conn.NodesClient().RequestInfo(ctx, &pb.InfoRequest{}) require.Nil(t, resp) require.Error(t, err) - require.Equal(t, status.Errorf(codes.PermissionDenied, "untrusted peer %v", planet.Satellites[0].Local().Id), err) + require.True(t, errs2.IsRPC(err, rpcstatus.PermissionDenied)) }) } diff --git a/storagenode/contact/endpoint.go b/storagenode/contact/endpoint.go index 9fedb409a..c9089bac4 100644 --- a/storagenode/contact/endpoint.go +++ b/storagenode/contact/endpoint.go @@ -9,12 +9,11 @@ import ( "time" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/peer" - "google.golang.org/grpc/status" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcpeer" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/storj" ) @@ -45,16 +44,16 @@ func NewEndpoint(log *zap.Logger, pingStats *PingStats) *Endpoint { // PingNode provides an easy way to verify a node is online and accepting requests func (endpoint *Endpoint) PingNode(ctx context.Context, req *pb.ContactPingRequest) (_ *pb.ContactPingResponse, err error) { defer mon.Task()(&ctx)(&err) - p, ok := peer.FromContext(ctx) - if !ok { - return nil, status.Error(codes.Internal, "unable to get grpc peer from context") - } - peerID, err := identity.PeerIdentityFromPeer(p) + peer, err := rpcpeer.FromContext(ctx) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Internal, err.Error()) } - endpoint.log.Debug("pinged", zap.Stringer("by", peerID.ID), zap.Stringer("srcAddr", p.Addr)) - endpoint.pingStats.WasPinged(time.Now(), peerID.ID, p.Addr.String()) + peerID, err := identity.PeerIdentityFromPeer(peer) + if err != nil { + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) + } + endpoint.log.Debug("pinged", zap.Stringer("by", peerID.ID), zap.Stringer("srcAddr", peer.Addr)) + endpoint.pingStats.WasPinged(time.Now(), peerID.ID, peer.Addr.String()) return &pb.ContactPingResponse{}, nil } diff --git a/storagenode/contact/kademlia.go b/storagenode/contact/kademlia.go index 071bd4a56..e2f71cf47 100644 --- a/storagenode/contact/kademlia.go +++ b/storagenode/contact/kademlia.go @@ -7,11 +7,10 @@ import ( "context" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/storj" ) @@ -54,17 +53,17 @@ func (endpoint *KademliaEndpoint) RequestInfo(ctx context.Context, req *pb.InfoR self := endpoint.service.Local() if endpoint.trust == nil { - return nil, status.Error(codes.Internal, "missing trust") + return nil, rpcstatus.Error(rpcstatus.Internal, "missing trust") } peer, err := identity.PeerIdentityFromContext(ctx) if err != nil { - return nil, status.Error(codes.Unauthenticated, err.Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, err.Error()) } err = endpoint.trust.VerifySatelliteID(ctx, peer.ID) if err != nil { - return nil, status.Errorf(codes.PermissionDenied, "untrusted peer %v", peer.ID) + return nil, rpcstatus.Errorf(rpcstatus.PermissionDenied, "untrusted peer %v", peer.ID) } return &pb.InfoResponse{ diff --git a/storagenode/nodestats/service.go b/storagenode/nodestats/service.go index f38e20f77..27ee2856f 100644 --- a/storagenode/nodestats/service.go +++ b/storagenode/nodestats/service.go @@ -9,12 +9,11 @@ import ( "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc" "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/storagenode/reputation" "storj.io/storj/storagenode/storageusage" "storj.io/storj/storagenode/trust" @@ -31,8 +30,8 @@ var ( // // architecture: Client type Client struct { - conn *grpc.ClientConn - pb.NodeStatsClient + conn *rpc.Conn + rpc.NodeStatsClient } // Close closes underlying client connection @@ -40,22 +39,22 @@ func (c *Client) Close() error { return c.conn.Close() } -// Service retrieves info from satellites using GRPC client +// Service retrieves info from satellites using an rpc client // // architecture: Service type Service struct { log *zap.Logger - transport transport.Client - trust *trust.Pool + dialer rpc.Dialer + trust *trust.Pool } // NewService creates new instance of service -func NewService(log *zap.Logger, transport transport.Client, trust *trust.Pool) *Service { +func NewService(log *zap.Logger, dialer rpc.Dialer, trust *trust.Pool) *Service { return &Service{ - log: log, - transport: transport, - trust: trust, + log: log, + dialer: dialer, + trust: trust, } } @@ -67,12 +66,7 @@ func (s *Service) GetReputationStats(ctx context.Context, satelliteID storj.Node if err != nil { return nil, NodeStatsServiceErr.Wrap(err) } - - defer func() { - if cerr := client.Close(); cerr != nil { - err = errs.Combine(err, NodeStatsServiceErr.New("failed to close connection: %v", cerr)) - } - }() + defer func() { err = errs.Combine(err, client.Close()) }() resp, err := client.GetStats(ctx, &pb.GetStatsRequest{}) if err != nil { @@ -111,12 +105,7 @@ func (s *Service) GetDailyStorageUsage(ctx context.Context, satelliteID storj.No if err != nil { return nil, NodeStatsServiceErr.Wrap(err) } - - defer func() { - if cerr := client.Close(); cerr != nil { - err = errs.Combine(err, NodeStatsServiceErr.New("failed to close connection: %v", cerr)) - } - }() + defer func() { err = errs.Combine(err, client.Close()) }() resp, err := client.DailyStorageUsage(ctx, &pb.DailyStorageUsageRequest{From: from, To: to}) if err != nil { @@ -126,7 +115,7 @@ func (s *Service) GetDailyStorageUsage(ctx context.Context, satelliteID storj.No return fromSpaceUsageResponse(resp, satelliteID), nil } -// dial dials GRPC NodeStats client for the satellite by id +// dial dials the NodeStats client for the satellite by id func (s *Service) dial(ctx context.Context, satelliteID storj.NodeID) (_ *Client, err error) { defer mon.Task()(&ctx)(&err) @@ -135,22 +124,14 @@ func (s *Service) dial(ctx context.Context, satelliteID storj.NodeID) (_ *Client return nil, errs.New("unable to find satellite %s: %v", satelliteID, err) } - satellite := pb.Node{ - Id: satelliteID, - Address: &pb.NodeAddress{ - Transport: pb.NodeTransport_TCP_TLS_GRPC, - Address: address, - }, - } - - conn, err := s.transport.DialNode(ctx, &satellite) + conn, err := s.dialer.DialAddressID(ctx, address, satelliteID) if err != nil { return nil, errs.New("unable to connect to the satellite %s: %v", satelliteID, err) } return &Client{ conn: conn, - NodeStatsClient: pb.NewNodeStatsClient(conn), + NodeStatsClient: conn.NodeStatsClient(), }, nil } diff --git a/storagenode/orders/service.go b/storagenode/orders/service.go index 4037b0280..ee9964a42 100644 --- a/storagenode/orders/service.go +++ b/storagenode/orders/service.go @@ -15,8 +15,8 @@ import ( "storj.io/storj/internal/sync2" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/storagenode/trust" ) @@ -97,22 +97,22 @@ type Service struct { log *zap.Logger config Config - transport transport.Client - orders DB - trust *trust.Pool + dialer rpc.Dialer + orders DB + trust *trust.Pool Sender sync2.Cycle Cleanup sync2.Cycle } // NewService creates an order service. -func NewService(log *zap.Logger, transport transport.Client, orders DB, trust *trust.Pool, config Config) *Service { +func NewService(log *zap.Logger, dialer rpc.Dialer, orders DB, trust *trust.Pool, config Config) *Service { return &Service{ - log: log, - transport: transport, - orders: orders, - config: config, - trust: trust, + log: log, + dialer: dialer, + orders: orders, + config: config, + trust: trust, Sender: *sync2.NewCycle(config.SenderInterval), Cleanup: *sync2.NewCycle(config.CleanupInterval), @@ -251,25 +251,14 @@ func (service *Service) settle(ctx context.Context, log *zap.Logger, satelliteID if err != nil { return OrderError.New("unable to get satellite address: %v", err) } - satellite := pb.Node{ - Id: satelliteID, - Address: &pb.NodeAddress{ - Transport: pb.NodeTransport_TCP_TLS_GRPC, - Address: address, - }, - } - conn, err := service.transport.DialNode(ctx, &satellite) + conn, err := service.dialer.DialAddressID(ctx, address, satelliteID) if err != nil { return OrderError.New("unable to connect to the satellite: %v", err) } - defer func() { - if cerr := conn.Close(); cerr != nil { - err = errs.Combine(err, OrderError.New("failed to close connection: %v", cerr)) - } - }() + defer func() { err = errs.Combine(err, conn.Close()) }() - client, err := pb.NewOrdersClient(conn).Settlement(ctx) + stream, err := conn.OrdersClient().Settlement(ctx) if err != nil { return OrderError.New("failed to start settlement: %v", err) } @@ -283,10 +272,10 @@ func (service *Service) settle(ctx context.Context, log *zap.Logger, satelliteID Limit: order.Limit, Order: order.Order, } - err := client.Send(&req) + err := stream.Send(&req) if err != nil { err = OrderError.New("sending settlement agreements returned an error: %v", err) - log.Error("gRPC client when sending new orders settlements", + log.Error("rpc client when sending new orders settlements", zap.Error(err), zap.Any("request", req), ) @@ -295,10 +284,10 @@ func (service *Service) settle(ctx context.Context, log *zap.Logger, satelliteID } } - err := client.CloseSend() + err := stream.CloseSend() if err != nil { err = OrderError.New("CloseSend settlement agreements returned an error: %v", err) - log.Error("gRPC client error when closing sender ", zap.Error(err)) + log.Error("rpc client error when closing sender ", zap.Error(err)) sendErrors.Add(err) } @@ -307,14 +296,14 @@ func (service *Service) settle(ctx context.Context, log *zap.Logger, satelliteID var errList errs.Group for { - response, err := client.Recv() + response, err := stream.Recv() if err != nil { if err == io.EOF { break } err = OrderError.New("failed to receive settlement response: %v", err) - log.Error("gRPC client error when receiveing new order settlements", zap.Error(err)) + log.Error("rpc client error when receiveing new order settlements", zap.Error(err)) errList.Add(err) break } @@ -327,7 +316,7 @@ func (service *Service) settle(ctx context.Context, log *zap.Logger, satelliteID status = StatusRejected default: err := OrderError.New("unexpected settlement status response: %d", response.Status) - log.Error("gRPC client received a unexpected new orders setlement status", + log.Error("rpc client received a unexpected new orders setlement status", zap.Error(err), zap.Any("response", response), ) errList.Add(err) diff --git a/storagenode/peer.go b/storagenode/peer.go index 100a7bfdf..98754b8e3 100644 --- a/storagenode/peer.go +++ b/storagenode/peer.go @@ -19,10 +19,10 @@ import ( "storj.io/storj/pkg/pb" "storj.io/storj/pkg/peertls/extensions" "storj.io/storj/pkg/peertls/tlsopts" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/server" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/satellite/overlay" "storj.io/storj/storage" "storj.io/storj/storagenode/bandwidth" @@ -106,7 +106,7 @@ type Peer struct { Identity *identity.FullIdentity DB DB - Transport transport.Client + Dialer rpc.Dialer Server *server.Server @@ -175,21 +175,21 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revocationDB exten { // setup listener and server sc := config.Server - options, err := tlsopts.NewOptions(peer.Identity, sc.Config, revocationDB) + tlsOptions, err := tlsopts.NewOptions(peer.Identity, sc.Config, revocationDB) if err != nil { return nil, errs.Combine(err, peer.Close()) } - peer.Transport = transport.NewClient(options) + peer.Dialer = rpc.NewDefaultDialer(tlsOptions) - peer.Server, err = server.New(log.Named("server"), options, sc.Address, sc.PrivateAddress, nil) + peer.Server, err = server.New(log.Named("server"), tlsOptions, sc.Address, sc.PrivateAddress, nil) if err != nil { return nil, errs.Combine(err, peer.Close()) } } { // setup trust pool - peer.Storage2.Trust, err = trust.NewPool(peer.Transport, config.Storage.WhitelistedSatellites) + peer.Storage2.Trust, err = trust.NewPool(peer.Dialer, config.Storage.WhitelistedSatellites) if err != nil { return nil, errs.Combine(err, peer.Close()) } @@ -222,7 +222,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revocationDB exten } peer.Contact.PingStats = new(contact.PingStats) peer.Contact.Service = contact.NewService(peer.Log.Named("contact:service"), self) - peer.Contact.Chore = contact.NewChore(peer.Log.Named("contact:chore"), config.Contact.Interval, config.Contact.MaxSleep, peer.Storage2.Trust, peer.Transport, peer.Contact.Service) + peer.Contact.Chore = contact.NewChore(peer.Log.Named("contact:chore"), config.Contact.Interval, config.Contact.MaxSleep, peer.Storage2.Trust, peer.Dialer, peer.Contact.Service) peer.Contact.Endpoint = contact.NewEndpoint(peer.Log.Named("contact:endpoint"), peer.Contact.PingStats) peer.Contact.KEndpoint = contact.NewKademliaEndpoint(peer.Log.Named("contact:nodes_service_endpoint"), peer.Contact.Service, peer.Storage2.Trust) pb.RegisterContactServer(peer.Server.GRPC(), peer.Contact.Endpoint) @@ -284,21 +284,21 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revocationDB exten pb.RegisterPiecestoreServer(peer.Server.GRPC(), peer.Storage2.Endpoint) pb.DRPCRegisterPiecestore(peer.Server.DRPC(), peer.Storage2.Endpoint.DRPC()) + // TODO workaround for custom timeout for order sending request (read/write) sc := config.Server - options, err := tlsopts.NewOptions(peer.Identity, sc.Config, revocationDB) + + tlsOptions, err := tlsopts.NewOptions(peer.Identity, sc.Config, revocationDB) if err != nil { return nil, errs.Combine(err, peer.Close()) } - // TODO workaround for custom timeout for order sending request (read/write) - ordersTransport := transport.NewClientWithTimeouts(options, transport.Timeouts{ - Dial: config.Storage2.Orders.SenderDialTimeout, - Request: config.Storage2.Orders.SenderRequestTimeout, - }) + dialer := rpc.NewDefaultDialer(tlsOptions) + dialer.DialTimeout = config.Storage2.Orders.SenderDialTimeout + dialer.RequestTimeout = config.Storage2.Orders.SenderRequestTimeout peer.Storage2.Orders = orders.NewService( log.Named("orders"), - ordersTransport, + dialer, peer.DB.Orders(), peer.Storage2.Trust, config.Storage2.Orders, @@ -308,7 +308,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB, revocationDB exten { // setup node stats service peer.NodeStats.Service = nodestats.NewService( peer.Log.Named("nodestats:service"), - peer.Transport, + peer.Dialer, peer.Storage2.Trust) peer.NodeStats.Cache = nodestats.NewCache( diff --git a/storagenode/piecestore/endpoint.go b/storagenode/piecestore/endpoint.go index 48a72e281..ad03ae358 100644 --- a/storagenode/piecestore/endpoint.go +++ b/storagenode/piecestore/endpoint.go @@ -13,8 +13,6 @@ import ( "github.com/zeebo/errs" "go.uber.org/zap" "golang.org/x/sync/errgroup" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" monkit "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/internal/errs2" @@ -23,6 +21,7 @@ import ( "storj.io/storj/pkg/bloomfilter" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" "storj.io/storj/storagenode/bandwidth" @@ -125,11 +124,11 @@ func (endpoint *Endpoint) Delete(ctx context.Context, delete *pb.PieceDeleteRequ defer atomic.AddInt32(&endpoint.liveRequests, -1) if delete.Limit.Action != pb.PieceAction_DELETE { - return nil, Error.New("expected delete action got %v", delete.Limit.Action) // TODO: report grpc status unauthorized or bad request + return nil, Error.New("expected delete action got %v", delete.Limit.Action) // TODO: report rpc status unauthorized or bad request } if err := endpoint.verifyOrderLimit(ctx, delete.Limit); err != nil { - // TODO: report grpc status unauthorized or bad request + // TODO: report rpc status unauthorized or bad request return nil, Error.Wrap(err) } @@ -137,7 +136,7 @@ func (endpoint *Endpoint) Delete(ctx context.Context, delete *pb.PieceDeleteRequ // explicitly ignoring error because the errors // TODO: add more debug info endpoint.log.Error("delete failed", zap.Stringer("Piece ID", delete.Limit.PieceId), zap.Error(err)) - // TODO: report internal server internal or missing error using grpc status, + // TODO: report rpc status of internal server error or missing error, // e.g. missing might happen when we get a deletion request after garbage collection has deleted it } else { endpoint.log.Info("deleted", zap.Stringer("Piece ID", delete.Limit.PieceId)) @@ -174,7 +173,7 @@ func (endpoint *Endpoint) doUpload(stream uploadStream) (err error) { if int(liveRequests) > endpoint.config.MaxConcurrentRequests { endpoint.log.Error("upload rejected, too many requests", zap.Int32("live requests", liveRequests)) - return status.Error(codes.Unavailable, "storage node overloaded") + return rpcstatus.Error(rpcstatus.Unavailable, "storage node overloaded") } startTime := time.Now().UTC() @@ -199,7 +198,7 @@ func (endpoint *Endpoint) doUpload(stream uploadStream) (err error) { // TODO: verify that we have have expected amount of storage before continuing if limit.Action != pb.PieceAction_PUT && limit.Action != pb.PieceAction_PUT_REPAIR { - return ErrProtocol.New("expected put or put repair action got %v", limit.Action) // TODO: report grpc status unauthorized or bad request + return ErrProtocol.New("expected put or put repair action got %v", limit.Action) // TODO: report rpc status unauthorized or bad request } if err := endpoint.verifyOrderLimit(ctx, limit); err != nil { @@ -237,7 +236,7 @@ func (endpoint *Endpoint) doUpload(stream uploadStream) (err error) { pieceWriter, err = endpoint.store.Writer(ctx, limit.SatelliteId, limit.PieceId) if err != nil { - return ErrInternal.Wrap(err) // TODO: report grpc status internal server error + return ErrInternal.Wrap(err) // TODO: report rpc status internal server error } defer func() { // cancel error if it hasn't been committed @@ -264,13 +263,13 @@ func (endpoint *Endpoint) doUpload(stream uploadStream) (err error) { if err == io.EOF { return ErrProtocol.New("unexpected EOF") } else if err != nil { - return ErrProtocol.Wrap(err) // TODO: report grpc status bad message + return ErrProtocol.Wrap(err) // TODO: report rpc status bad message } if message == nil { - return ErrProtocol.New("expected a message") // TODO: report grpc status bad message + return ErrProtocol.New("expected a message") // TODO: report rpc status bad message } if message.Order == nil && message.Chunk == nil && message.Done == nil { - return ErrProtocol.New("expected a message") // TODO: report grpc status bad message + return ErrProtocol.New("expected a message") // TODO: report rpc status bad message } if message.Order != nil { @@ -282,13 +281,13 @@ func (endpoint *Endpoint) doUpload(stream uploadStream) (err error) { if message.Chunk != nil { if message.Chunk.Offset != pieceWriter.Size() { - return ErrProtocol.New("chunk out of order") // TODO: report grpc status bad message + return ErrProtocol.New("chunk out of order") // TODO: report rpc status bad message } chunkSize := int64(len(message.Chunk.Data)) if largestOrder.Amount < pieceWriter.Size()+chunkSize { // TODO: should we write currently and give a chance for uplink to remedy the situation? - return ErrProtocol.New("not enough allocated, allocated=%v writing=%v", largestOrder.Amount, pieceWriter.Size()+int64(len(message.Chunk.Data))) // TODO: report grpc status ? + return ErrProtocol.New("not enough allocated, allocated=%v writing=%v", largestOrder.Amount, pieceWriter.Size()+int64(len(message.Chunk.Data))) // TODO: report rpc status ? } availableBandwidth -= chunkSize @@ -301,14 +300,14 @@ func (endpoint *Endpoint) doUpload(stream uploadStream) (err error) { } if _, err := pieceWriter.Write(message.Chunk.Data); err != nil { - return ErrInternal.Wrap(err) // TODO: report grpc status internal server error + return ErrInternal.Wrap(err) // TODO: report rpc status internal server error } } if message.Done != nil { calculatedHash := pieceWriter.Hash() if err := endpoint.VerifyPieceHash(ctx, limit, message.Done, calculatedHash); err != nil { - return err // TODO: report grpc status internal server error + return err // TODO: report rpc status internal server error } if message.Done.PieceSize != pieceWriter.Size() { return ErrProtocol.New("Size of finished piece does not match size declared by uplink! %d != %d", @@ -323,12 +322,12 @@ func (endpoint *Endpoint) doUpload(stream uploadStream) (err error) { OrderLimit: *limit, } if err := pieceWriter.Commit(ctx, info); err != nil { - return ErrInternal.Wrap(err) // TODO: report grpc status internal server error + return ErrInternal.Wrap(err) // TODO: report rpc status internal server error } if !limit.PieceExpiration.IsZero() { err := endpoint.store.SetExpiration(ctx, limit.SatelliteId, limit.PieceId, limit.PieceExpiration) if err != nil { - return ErrInternal.Wrap(err) // TODO: report grpc status internal server error + return ErrInternal.Wrap(err) // TODO: report rpc status internal server error } } } @@ -397,7 +396,7 @@ func (endpoint *Endpoint) doDownload(stream downloadStream) (err error) { endpoint.log.Info("download started", zap.Stringer("Piece ID", limit.PieceId), zap.Stringer("SatelliteID", limit.SatelliteId), zap.Stringer("Action", limit.Action)) if limit.Action != pb.PieceAction_GET && limit.Action != pb.PieceAction_GET_REPAIR && limit.Action != pb.PieceAction_GET_AUDIT { - return ErrProtocol.New("expected get or get repair or audit action got %v", limit.Action) // TODO: report grpc status unauthorized or bad request + return ErrProtocol.New("expected get or get repair or audit action got %v", limit.Action) // TODO: report rpc status unauthorized or bad request } if chunk.ChunkSize > limit.Limit { @@ -405,7 +404,7 @@ func (endpoint *Endpoint) doDownload(stream downloadStream) (err error) { } if err := endpoint.verifyOrderLimit(ctx, limit); err != nil { - return Error.Wrap(err) // TODO: report grpc status unauthorized or bad request + return Error.Wrap(err) // TODO: report rpc status unauthorized or bad request } var pieceReader *pieces.Reader @@ -439,9 +438,9 @@ func (endpoint *Endpoint) doDownload(stream downloadStream) (err error) { pieceReader, err = endpoint.store.Reader(ctx, limit.SatelliteId, limit.PieceId) if err != nil { if os.IsNotExist(err) { - return status.Error(codes.NotFound, err.Error()) + return rpcstatus.Error(rpcstatus.NotFound, err.Error()) } - return status.Error(codes.Internal, err.Error()) + return rpcstatus.Error(rpcstatus.Internal, err.Error()) } defer func() { err := pieceReader.Close() // similarly how transcation Rollback works @@ -462,7 +461,7 @@ func (endpoint *Endpoint) doDownload(stream downloadStream) (err error) { info, err := endpoint.store.GetV0PieceInfoDB().Get(ctx, limit.SatelliteId, limit.PieceId) if err != nil { endpoint.log.Error("error getting piece from v0 pieceinfo db", zap.Error(err)) - return status.Error(codes.Internal, err.Error()) + return rpcstatus.Error(rpcstatus.Internal, err.Error()) } orderLimit = *info.OrderLimit pieceHash = *info.UplinkPieceHash @@ -471,7 +470,7 @@ func (endpoint *Endpoint) doDownload(stream downloadStream) (err error) { header, err := pieceReader.GetPieceHeader() if err != nil { endpoint.log.Error("error getting header from piecereader", zap.Error(err)) - return status.Error(codes.Internal, err.Error()) + return rpcstatus.Error(rpcstatus.Internal, err.Error()) } orderLimit = header.OrderLimit pieceHash = pb.PieceHash{ @@ -486,7 +485,7 @@ func (endpoint *Endpoint) doDownload(stream downloadStream) (err error) { err = stream.Send(&pb.PieceDownloadResponse{Hash: &pieceHash, Limit: &orderLimit}) if err != nil { endpoint.log.Error("error sending hash and order limit", zap.Error(err)) - return status.Error(codes.Internal, err.Error()) + return rpcstatus.Error(rpcstatus.Internal, err.Error()) } } @@ -498,7 +497,7 @@ func (endpoint *Endpoint) doDownload(stream downloadStream) (err error) { availableBandwidth, err := endpoint.monitor.AvailableBandwidth(ctx) if err != nil { endpoint.log.Error("error getting available bandwidth", zap.Error(err)) - return status.Error(codes.Internal, err.Error()) + return rpcstatus.Error(rpcstatus.Internal, err.Error()) } throttle := sync2.NewThrottle() @@ -524,14 +523,14 @@ func (endpoint *Endpoint) doDownload(stream downloadStream) (err error) { _, err = pieceReader.Seek(currentOffset, io.SeekStart) if err != nil { endpoint.log.Error("error seeking on piecereader", zap.Error(err)) - return status.Error(codes.Internal, err.Error()) + return rpcstatus.Error(rpcstatus.Internal, err.Error()) } // ReadFull is required to ensure we are sending the right amount of data. _, err = io.ReadFull(pieceReader, chunkData) if err != nil { endpoint.log.Error("error reading from piecereader", zap.Error(err)) - return status.Error(codes.Internal, err.Error()) + return rpcstatus.Error(rpcstatus.Internal, err.Error()) } err = stream.Send(&pb.PieceDownloadResponse{ @@ -633,17 +632,17 @@ func (endpoint *Endpoint) Retain(ctx context.Context, retainReq *pb.RetainReques peer, err := identity.PeerIdentityFromContext(ctx) if err != nil { - return nil, status.Error(codes.Unauthenticated, Error.Wrap(err).Error()) + return nil, rpcstatus.Error(rpcstatus.Unauthenticated, Error.Wrap(err).Error()) } err = endpoint.trust.VerifySatelliteID(ctx, peer.ID) if err != nil { - return nil, status.Error(codes.PermissionDenied, Error.New("retain called with untrusted ID").Error()) + return nil, rpcstatus.Error(rpcstatus.PermissionDenied, Error.New("retain called with untrusted ID").Error()) } filter, err := bloomfilter.NewFromBytes(retainReq.GetFilter()) if err != nil { - return nil, status.Error(codes.InvalidArgument, Error.Wrap(err).Error()) + return nil, rpcstatus.Error(rpcstatus.InvalidArgument, Error.Wrap(err).Error()) } // the queue function will update the created before time based on the configurable retain buffer diff --git a/storagenode/piecestore/endpoint_test.go b/storagenode/piecestore/endpoint_test.go index 1ce94a701..048f3fb23 100644 --- a/storagenode/piecestore/endpoint_test.go +++ b/storagenode/piecestore/endpoint_test.go @@ -16,7 +16,6 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zaptest" "golang.org/x/sync/errgroup" - "google.golang.org/grpc/codes" "storj.io/storj/internal/errs2" "storj.io/storj/internal/memory" @@ -25,6 +24,7 @@ import ( "storj.io/storj/internal/testrand" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/pkcrypto" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" "storj.io/storj/storagenode" @@ -497,7 +497,7 @@ func TestTooManyRequests(t *testing.T) { config := piecestore.DefaultConfig config.UploadBufferSize = 0 // disable buffering so we can detect write error early - client, err := piecestore.Dial(ctx, uplink.Transport, &storageNode.Node, uplink.Log, config) + client, err := piecestore.Dial(ctx, uplink.Dialer, &storageNode.Node, uplink.Log, config) if err != nil { return err } @@ -531,7 +531,7 @@ func TestTooManyRequests(t *testing.T) { upload, err := client.Upload(ctx, orderLimit, piecePrivateKey) if err != nil { - if errs2.IsRPC(err, codes.Unavailable) { + if errs2.IsRPC(err, rpcstatus.Unavailable) { if atomic.AddInt64(&failedCount, -1) == 0 { close(doneWaiting) } @@ -543,7 +543,7 @@ func TestTooManyRequests(t *testing.T) { _, err = upload.Write(make([]byte, orderLimit.Limit)) if err != nil { - if errs2.IsRPC(err, codes.Unavailable) { + if errs2.IsRPC(err, rpcstatus.Unavailable) { if atomic.AddInt64(&failedCount, -1) == 0 { close(doneWaiting) } @@ -555,7 +555,7 @@ func TestTooManyRequests(t *testing.T) { _, err = upload.Commit(ctx) if err != nil { - if errs2.IsRPC(err, codes.Unavailable) { + if errs2.IsRPC(err, rpcstatus.Unavailable) { if atomic.AddInt64(&failedCount, -1) == 0 { close(doneWaiting) } diff --git a/storagenode/piecestore/verification.go b/storagenode/piecestore/verification.go index 6d5693394..bd3d26d54 100644 --- a/storagenode/piecestore/verification.go +++ b/storagenode/piecestore/verification.go @@ -9,11 +9,10 @@ import ( "time" "github.com/zeebo/errs" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "storj.io/storj/internal/errs2" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/signing" ) @@ -35,35 +34,35 @@ func (endpoint *Endpoint) verifyOrderLimit(ctx context.Context, limit *pb.OrderL now := time.Now() switch { case limit.Limit < 0: - return status.Error(codes.InvalidArgument, "order limit is negative") + return rpcstatus.Error(rpcstatus.InvalidArgument, "order limit is negative") case endpoint.signer.ID() != limit.StorageNodeId: - return status.Errorf(codes.InvalidArgument, "order intended for other storagenode: %v", limit.StorageNodeId) + return rpcstatus.Errorf(rpcstatus.InvalidArgument, "order intended for other storagenode: %v", limit.StorageNodeId) case endpoint.IsExpired(limit.PieceExpiration): - return status.Errorf(codes.InvalidArgument, "piece expired: %v", limit.PieceExpiration) + return rpcstatus.Errorf(rpcstatus.InvalidArgument, "piece expired: %v", limit.PieceExpiration) case endpoint.IsExpired(limit.OrderExpiration): - return status.Errorf(codes.InvalidArgument, "order expired: %v", limit.OrderExpiration) + return rpcstatus.Errorf(rpcstatus.InvalidArgument, "order expired: %v", limit.OrderExpiration) case now.Sub(limit.OrderCreation) > endpoint.config.OrderLimitGracePeriod: - return status.Errorf(codes.InvalidArgument, "order created too long ago: %v", limit.OrderCreation) + return rpcstatus.Errorf(rpcstatus.InvalidArgument, "order created too long ago: %v", limit.OrderCreation) case limit.SatelliteId.IsZero(): - return status.Errorf(codes.InvalidArgument, "missing satellite id") + return rpcstatus.Errorf(rpcstatus.InvalidArgument, "missing satellite id") case limit.UplinkPublicKey.IsZero(): - return status.Errorf(codes.InvalidArgument, "missing uplink public key") + return rpcstatus.Errorf(rpcstatus.InvalidArgument, "missing uplink public key") case len(limit.SatelliteSignature) == 0: - return status.Errorf(codes.InvalidArgument, "missing satellite signature") + return rpcstatus.Errorf(rpcstatus.InvalidArgument, "missing satellite signature") case limit.PieceId.IsZero(): - return status.Errorf(codes.InvalidArgument, "missing piece id") + return rpcstatus.Errorf(rpcstatus.InvalidArgument, "missing piece id") } if err := endpoint.trust.VerifySatelliteID(ctx, limit.SatelliteId); err != nil { - return status.Errorf(codes.PermissionDenied, "untrusted: %+v", err) + return rpcstatus.Errorf(rpcstatus.PermissionDenied, "untrusted: %+v", err) } if err := endpoint.VerifyOrderLimitSignature(ctx, limit); err != nil { if errs2.IsCanceled(err) { - return status.Error(codes.Canceled, "context has been canceled") + return rpcstatus.Error(rpcstatus.Canceled, "context has been canceled") } - return status.Errorf(codes.Unauthenticated, "untrusted: %+v", err) + return rpcstatus.Errorf(rpcstatus.Unauthenticated, "untrusted: %+v", err) } serialExpiration := limit.OrderExpiration @@ -74,7 +73,7 @@ func (endpoint *Endpoint) verifyOrderLimit(ctx context.Context, limit *pb.OrderL } if err := endpoint.usedSerials.Add(ctx, limit.SatelliteId, limit.SerialNumber, serialExpiration); err != nil { - return status.Errorf(codes.Unauthenticated, "serial number is already used: %+v", err) + return rpcstatus.Errorf(rpcstatus.Unauthenticated, "serial number is already used: %+v", err) } return nil @@ -85,14 +84,14 @@ func (endpoint *Endpoint) VerifyOrder(ctx context.Context, limit *pb.OrderLimit, defer mon.Task()(&ctx)(&err) if order.SerialNumber != limit.SerialNumber { - return ErrProtocol.New("order serial number changed during upload") // TODO: report grpc status bad message + return ErrProtocol.New("order serial number changed during upload") // TODO: report rpc status bad message } // TODO: add check for minimum allocation step if order.Amount < largestOrderAmount { - return ErrProtocol.New("order contained smaller amount=%v, previous=%v", order.Amount, largestOrderAmount) // TODO: report grpc status bad message + return ErrProtocol.New("order contained smaller amount=%v, previous=%v", order.Amount, largestOrderAmount) // TODO: report rpc status bad message } if order.Amount > limit.Limit { - return ErrProtocol.New("order exceeded allowed amount=%v, limit=%v", order.Amount, limit.Limit) // TODO: report grpc status bad message + return ErrProtocol.New("order exceeded allowed amount=%v, limit=%v", order.Amount, limit.Limit) // TODO: report rpc status bad message } if err := signing.VerifyUplinkOrderSignature(ctx, limit.UplinkPublicKey, order); err != nil { @@ -110,14 +109,14 @@ func (endpoint *Endpoint) VerifyPieceHash(ctx context.Context, limit *pb.OrderLi return ErrProtocol.New("invalid arguments") } if limit.PieceId != hash.PieceId { - return ErrProtocol.New("piece id changed") // TODO: report grpc status bad message + return ErrProtocol.New("piece id changed") // TODO: report rpc status bad message } if !bytes.Equal(hash.Hash, expectedHash) { - return ErrProtocol.New("hashes don't match") // TODO: report grpc status bad message + return ErrProtocol.New("hashes don't match") // TODO: report rpc status bad message } if err := signing.VerifyUplinkPieceHashSignature(ctx, limit.UplinkPublicKey, hash); err != nil { - return ErrVerifyUntrusted.New("invalid piece hash signature") // TODO: report grpc status bad message + return ErrVerifyUntrusted.New("invalid piece hash signature") // TODO: report rpc status bad message } return nil @@ -132,11 +131,11 @@ func (endpoint *Endpoint) VerifyOrderLimitSignature(ctx context.Context, limit * if errs2.IsCanceled(err) { return err } - return ErrVerifyUntrusted.New("unable to get signee: %v", err) // TODO: report grpc status bad message + return ErrVerifyUntrusted.New("unable to get signee: %v", err) // TODO: report rpc status bad message } if err := signing.VerifyOrderLimitSignature(ctx, signee, limit); err != nil { - return ErrVerifyUntrusted.New("invalid order limit signature: %v", err) // TODO: report grpc status bad message + return ErrVerifyUntrusted.New("invalid order limit signature: %v", err) // TODO: report rpc status bad message } return nil diff --git a/storagenode/trust/service.go b/storagenode/trust/service.go index df7a03c97..e71691875 100644 --- a/storagenode/trust/service.go +++ b/storagenode/trust/service.go @@ -11,10 +11,9 @@ import ( monkit "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/pkg/identity" - "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/signing" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" ) // Error is the default error class @@ -26,8 +25,8 @@ var mon = monkit.Package() // // architecture: Service type Pool struct { - mu sync.RWMutex - transport transport.Client + mu sync.RWMutex + dialer rpc.Dialer trustedSatellites map[storj.NodeID]*satelliteInfoCache } @@ -40,7 +39,7 @@ type satelliteInfoCache struct { } // NewPool creates a new trust pool of the specified list of trusted satellites. -func NewPool(transport transport.Client, trustedSatellites storj.NodeURLs) (*Pool, error) { +func NewPool(dialer rpc.Dialer, trustedSatellites storj.NodeURLs) (*Pool, error) { // TODO: preload all satellite peer identities // parse the comma separated list of approved satellite IDs into an array of storj.NodeIDs @@ -51,7 +50,7 @@ func NewPool(transport transport.Client, trustedSatellites storj.NodeURLs) (*Poo } return &Pool{ - transport: transport, + dialer: dialer, trustedSatellites: trusted, }, nil } @@ -100,14 +99,13 @@ func (pool *Pool) GetSignee(ctx context.Context, id storj.NodeID) (_ signing.Sig // FetchPeerIdentity dials the url and fetches the identity. func (pool *Pool) FetchPeerIdentity(ctx context.Context, url storj.NodeURL) (_ *identity.PeerIdentity, err error) { - identity, err := pool.transport.FetchPeerIdentity(ctx, &pb.Node{ - Id: url.ID, - Address: &pb.NodeAddress{ - Transport: pb.NodeTransport_TCP_TLS_GRPC, - Address: url.Address, - }, - }) - return identity, Error.Wrap(err) + conn, err := pool.dialer.DialAddressID(ctx, url.Address, url.ID) + if err != nil { + return nil, err + } + defer func() { err = errs.Combine(err, conn.Close()) }() + + return conn.PeerIdentity() } // GetSatellites returns a slice containing all trusted satellites diff --git a/uplink/ecclient/client.go b/uplink/ecclient/client.go index 222ab3151..ff756b2b0 100644 --- a/uplink/ecclient/client.go +++ b/uplink/ecclient/client.go @@ -19,8 +19,8 @@ import ( "storj.io/storj/internal/sync2" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/ranger" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/uplink/eestream" "storj.io/storj/uplink/piecestore" ) @@ -39,16 +39,16 @@ type dialPiecestoreFunc func(context.Context, *pb.Node) (*piecestore.Client, err type ecClient struct { log *zap.Logger - transport transport.Client + dialer rpc.Dialer memoryLimit int forceErrorDetection bool } // NewClient from the given identity and max buffer memory -func NewClient(log *zap.Logger, tc transport.Client, memoryLimit int) Client { +func NewClient(log *zap.Logger, dialer rpc.Dialer, memoryLimit int) Client { return &ecClient{ log: log, - transport: tc, + dialer: dialer, memoryLimit: memoryLimit, } } @@ -60,7 +60,7 @@ func (ec *ecClient) WithForceErrorDetection(force bool) Client { func (ec *ecClient) dialPiecestore(ctx context.Context, n *pb.Node) (*piecestore.Client, error) { logger := ec.log.Named(n.Id.String()) - return piecestore.Dial(ctx, ec.transport, n, logger, piecestore.DefaultConfig) + return piecestore.Dial(ctx, ec.dialer, n, logger, piecestore.DefaultConfig) } func (ec *ecClient) Put(ctx context.Context, limits []*pb.AddressedOrderLimit, privateKey storj.PiecePrivateKey, rs eestream.RedundancyStrategy, data io.Reader, expiration time.Time) (successfulNodes []*pb.Node, successfulHashes []*pb.PieceHash, err error) { diff --git a/uplink/ecclient/client_planet_test.go b/uplink/ecclient/client_planet_test.go index 0989221e1..f205db937 100644 --- a/uplink/ecclient/client_planet_test.go +++ b/uplink/ecclient/client_planet_test.go @@ -44,7 +44,7 @@ func TestECClient(t *testing.T) { planet.Start(ctx) - ec := ecclient.NewClient(planet.Uplinks[0].Log.Named("ecclient"), planet.Uplinks[0].Transport, 0) + ec := ecclient.NewClient(planet.Uplinks[0].Log.Named("ecclient"), planet.Uplinks[0].Dialer, 0) k := storageNodes / 2 n := storageNodes diff --git a/uplink/metainfo/client.go b/uplink/metainfo/client.go index 8430eab6d..1f7aa9d92 100644 --- a/uplink/metainfo/client.go +++ b/uplink/metainfo/client.go @@ -10,15 +10,14 @@ import ( "github.com/skyrings/skyring-common/tools/uuid" "github.com/zeebo/errs" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "gopkg.in/spacemonkeygo/monkit.v2" + "storj.io/storj/internal/errs2" "storj.io/storj/pkg/macaroon" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc" + "storj.io/storj/pkg/rpc/rpcstatus" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" "storj.io/storj/storage" ) @@ -31,8 +30,8 @@ var ( // Client creates a grpcClient type Client struct { - client pb.MetainfoClient - conn *grpc.ClientConn + conn *rpc.Conn + client rpc.MetainfoClient apiKeyRaw []byte } @@ -44,7 +43,7 @@ type ListItem struct { } // New used as a public function -func New(client pb.MetainfoClient, apiKey *macaroon.APIKey) *Client { +func New(client rpc.MetainfoClient, apiKey *macaroon.APIKey) *Client { return &Client{ client: client, apiKeyRaw: apiKey.SerializeRaw(), @@ -52,15 +51,15 @@ func New(client pb.MetainfoClient, apiKey *macaroon.APIKey) *Client { } // Dial dials to metainfo endpoint with the specified api key. -func Dial(ctx context.Context, tc transport.Client, address string, apiKey *macaroon.APIKey) (*Client, error) { - conn, err := tc.DialAddress(ctx, address) +func Dial(ctx context.Context, dialer rpc.Dialer, address string, apiKey *macaroon.APIKey) (*Client, error) { + conn, err := dialer.DialAddressInsecure(ctx, address) if err != nil { return nil, Error.Wrap(err) } return &Client{ - client: pb.NewMetainfoClient(conn), conn: conn, + client: conn.MetainfoClient(), apiKeyRaw: apiKey.SerializeRaw(), }, nil } @@ -129,7 +128,7 @@ func (client *Client) SegmentInfo(ctx context.Context, bucket string, path storj Segment: segmentIndex, }) if err != nil { - if status.Code(err) == codes.NotFound { + if errs2.IsRPC(err, rpcstatus.NotFound) { return nil, storage.ErrKeyNotFound.Wrap(err) } return nil, Error.Wrap(err) @@ -149,7 +148,7 @@ func (client *Client) ReadSegment(ctx context.Context, bucket string, path storj Segment: segmentIndex, }) if err != nil { - if status.Code(err) == codes.NotFound { + if errs2.IsRPC(err, rpcstatus.NotFound) { return nil, nil, piecePrivateKey, storage.ErrKeyNotFound.Wrap(err) } return nil, nil, piecePrivateKey, Error.Wrap(err) @@ -187,7 +186,7 @@ func (client *Client) DeleteSegment(ctx context.Context, bucket string, path sto Segment: segmentIndex, }) if err != nil { - if status.Code(err) == codes.NotFound { + if errs2.IsRPC(err, rpcstatus.NotFound) { return nil, piecePrivateKey, storage.ErrKeyNotFound.Wrap(err) } return nil, piecePrivateKey, Error.Wrap(err) @@ -368,7 +367,7 @@ func (client *Client) GetBucket(ctx context.Context, params GetBucketParams) (re resp, err := client.client.GetBucket(ctx, params.toRequest(client.header())) if err != nil { - if status.Code(err) == codes.NotFound { + if errs2.IsRPC(err, rpcstatus.NotFound) { return storj.Bucket{}, storj.ErrBucketNotFound.Wrap(err) } return storj.Bucket{}, Error.Wrap(err) @@ -407,7 +406,7 @@ func (client *Client) DeleteBucket(ctx context.Context, params DeleteBucketParam defer mon.Task()(&ctx)(&err) _, err = client.client.DeleteBucket(ctx, params.toRequest(client.header())) if err != nil { - if status.Code(err) == codes.NotFound { + if errs2.IsRPC(err, rpcstatus.NotFound) { return storj.ErrBucketNotFound.Wrap(err) } return Error.Wrap(err) @@ -713,7 +712,7 @@ func (client *Client) GetObject(ctx context.Context, params GetObjectParams) (_ response, err := client.client.GetObject(ctx, params.toRequest(client.header())) if err != nil { - if status.Code(err) == codes.NotFound { + if errs2.IsRPC(err, rpcstatus.NotFound) { return storj.ObjectInfo{}, storj.ErrObjectNotFound.Wrap(err) } return storj.ObjectInfo{}, Error.Wrap(err) @@ -765,7 +764,7 @@ func (client *Client) BeginDeleteObject(ctx context.Context, params BeginDeleteO response, err := client.client.BeginDeleteObject(ctx, params.toRequest(client.header())) if err != nil { - if status.Code(err) == codes.NotFound { + if errs2.IsRPC(err, rpcstatus.NotFound) { return storj.StreamID{}, storj.ErrObjectNotFound.Wrap(err) } return storj.StreamID{}, Error.Wrap(err) @@ -1236,7 +1235,7 @@ func (client *Client) ListSegmentsNew(ctx context.Context, params ListSegmentsPa response, err := client.client.ListSegments(ctx, params.toRequest(client.header())) if err != nil { - if status.Code(err) == codes.NotFound { + if errs2.IsRPC(err, rpcstatus.NotFound) { return []storj.SegmentListItem{}, false, storj.ErrObjectNotFound.Wrap(err) } return []storj.SegmentListItem{}, false, Error.Wrap(err) diff --git a/uplink/metainfo/kvmetainfo/buckets_test.go b/uplink/metainfo/kvmetainfo/buckets_test.go index 0d80a5ee1..65bf6eb56 100644 --- a/uplink/metainfo/kvmetainfo/buckets_test.go +++ b/uplink/metainfo/kvmetainfo/buckets_test.go @@ -256,7 +256,7 @@ func newMetainfoParts(planet *testplanet.Planet) (*kvmetainfo.DB, streams.Store, } // TODO(leak): call metainfo.Close somehow - ec := ecclient.NewClient(planet.Uplinks[0].Log.Named("ecclient"), planet.Uplinks[0].Transport, 0) + ec := ecclient.NewClient(planet.Uplinks[0].Log.Named("ecclient"), planet.Uplinks[0].Dialer, 0) fc, err := infectious.NewFEC(2, 4) if err != nil { return nil, nil, err diff --git a/uplink/piecestore/client.go b/uplink/piecestore/client.go index 3e85f44f1..91e5685b2 100644 --- a/uplink/piecestore/client.go +++ b/uplink/piecestore/client.go @@ -9,12 +9,11 @@ import ( "github.com/zeebo/errs" "go.uber.org/zap" - "google.golang.org/grpc" "storj.io/storj/internal/memory" "storj.io/storj/pkg/pb" + "storj.io/storj/pkg/rpc" "storj.io/storj/pkg/storj" - "storj.io/storj/pkg/transport" ) // Error is the default error class for piecestore client. @@ -41,21 +40,21 @@ var DefaultConfig = Config{ // Client implements uploading, downloading and deleting content from a piecestore. type Client struct { log *zap.Logger - client pb.PiecestoreClient - conn *grpc.ClientConn + client rpc.PiecestoreClient + conn *rpc.Conn config Config } // Dial dials the target piecestore endpoint. -func Dial(ctx context.Context, transport transport.Client, target *pb.Node, log *zap.Logger, config Config) (*Client, error) { - conn, err := transport.DialNode(ctx, target) +func Dial(ctx context.Context, dialer rpc.Dialer, target *pb.Node, log *zap.Logger, config Config) (*Client, error) { + conn, err := dialer.DialNode(ctx, target) if err != nil { return nil, Error.Wrap(err) } return &Client{ log: log, - client: pb.NewPiecestoreClient(conn), + client: conn.PiecestoreClient(), conn: conn, config: config, }, nil diff --git a/uplink/piecestore/download.go b/uplink/piecestore/download.go index c97fd33a9..9fc5891b1 100644 --- a/uplink/piecestore/download.go +++ b/uplink/piecestore/download.go @@ -32,7 +32,7 @@ type Download struct { limit *pb.OrderLimit privateKey storj.PiecePrivateKey peer *identity.PeerIdentity - stream pb.Piecestore_DownloadClient + stream downloadStream ctx context.Context read int64 // how much data we have read so far @@ -53,22 +53,26 @@ type Download struct { closingError error } +type downloadStream interface { + CloseSend() error + Send(*pb.PieceDownloadRequest) error + Recv() (*pb.PieceDownloadResponse, error) +} + // Download starts a new download using the specified order limit at the specified offset and size. func (client *Client) Download(ctx context.Context, limit *pb.OrderLimit, piecePrivateKey storj.PiecePrivateKey, offset, size int64) (_ Downloader, err error) { defer mon.Task()(&ctx)(&err) + peer, err := client.conn.PeerIdentity() + if err != nil { + return nil, ErrInternal.Wrap(err) + } + stream, err := client.client.Download(ctx) if err != nil { return nil, err } - peer, err := identity.PeerIdentityFromContext(stream.Context()) - if err != nil { - closeErr := stream.CloseSend() - _, recvErr := stream.Recv() - return nil, ErrInternal.Wrap(errs.Combine(err, ignoreEOF(closeErr), ignoreEOF(recvErr))) - } - err = stream.Send(&pb.PieceDownloadRequest{ Limit: limit, Chunk: &pb.PieceDownloadRequest_Chunk{ diff --git a/uplink/piecestore/upload.go b/uplink/piecestore/upload.go index b2a4f7219..8767ca5a0 100644 --- a/uplink/piecestore/upload.go +++ b/uplink/piecestore/upload.go @@ -36,7 +36,7 @@ type Upload struct { limit *pb.OrderLimit privateKey storj.PiecePrivateKey peer *identity.PeerIdentity - stream pb.Piecestore_UploadClient + stream uploadStream ctx context.Context hash hash.Hash // TODO: use concrete implementation @@ -48,20 +48,27 @@ type Upload struct { sendError error } +type uploadStream interface { + Context() context.Context + CloseSend() error + Send(*pb.PieceUploadRequest) error + CloseAndRecv() (*pb.PieceUploadResponse, error) +} + // Upload initiates an upload to the storage node. func (client *Client) Upload(ctx context.Context, limit *pb.OrderLimit, piecePrivateKey storj.PiecePrivateKey) (_ Uploader, err error) { defer mon.Task()(&ctx, "node: "+limit.StorageNodeId.String()[0:8])(&err) + peer, err := client.conn.PeerIdentity() + if err != nil { + return nil, ErrInternal.Wrap(err) + } + stream, err := client.client.Upload(ctx) if err != nil { return nil, err } - peer, err := identity.PeerIdentityFromContext(stream.Context()) - if err != nil { - return nil, ErrInternal.Wrap(err) - } - err = stream.Send(&pb.PieceUploadRequest{ Limit: limit, }) diff --git a/uplink/piecestore/verification.go b/uplink/piecestore/verification.go index f31e904d1..503f786b3 100644 --- a/uplink/piecestore/verification.go +++ b/uplink/piecestore/verification.go @@ -30,14 +30,14 @@ func (client *Client) VerifyPieceHash(ctx context.Context, peer *identity.PeerId return ErrProtocol.New("invalid arguments") } if limit.PieceId != hash.PieceId { - return ErrProtocol.New("piece id changed") // TODO: report grpc status bad message + return ErrProtocol.New("piece id changed") // TODO: report rpc status bad message } if !bytes.Equal(hash.Hash, expectedHash) { - return ErrVerifyUntrusted.New("hashes don't match") // TODO: report grpc status bad message + return ErrVerifyUntrusted.New("hashes don't match") // TODO: report rpc status bad message } if err := signing.VerifyPieceHashSignature(ctx, signing.SigneeFromPeerIdentity(peer), hash); err != nil { - return ErrVerifyUntrusted.New("invalid hash signature: %v", err) // TODO: report grpc status bad message + return ErrVerifyUntrusted.New("invalid hash signature: %v", err) // TODO: report rpc status bad message } return nil diff --git a/uplink/storage/streams/store_test.go b/uplink/storage/streams/store_test.go index 305dc5686..d2731b8f2 100644 --- a/uplink/storage/streams/store_test.go +++ b/uplink/storage/streams/store_test.go @@ -275,7 +275,7 @@ func storeTestSetup(t *testing.T, ctx *testcontext.Context, planet *testplanet.P metainfo, err := planet.Uplinks[0].DialMetainfo(context.Background(), planet.Satellites[0], TestAPIKey) require.NoError(t, err) - ec := ecclient.NewClient(planet.Uplinks[0].Log.Named("ecclient"), planet.Uplinks[0].Transport, 0) + ec := ecclient.NewClient(planet.Uplinks[0].Log.Named("ecclient"), planet.Uplinks[0].Dialer, 0) cfg := planet.Uplinks[0].GetConfig(planet.Satellites[0]) rs, err := eestream.NewRedundancyStrategyFromStorj(cfg.GetRedundancyScheme())