// Copyright (C) 2019 Storj Labs, Inc. // See LICENSE for copying information. package bwagreement import ( "context" "time" "github.com/zeebo/errs" "go.uber.org/zap" monkit "gopkg.in/spacemonkeygo/monkit.v2" "storj.io/storj/pkg/auth" "storj.io/storj/pkg/identity" "storj.io/storj/pkg/pb" "storj.io/storj/pkg/storj" ) var ( // Error the default bwagreement errs class Error = errs.Class("bwagreement error") mon = monkit.Package() ) // Config is a configuration struct that is everything you need to start an // agreement receiver responsibility type Config struct { } // DB stores bandwidth agreements. type DB interface { // CreateAgreement adds a new bandwidth agreement. CreateAgreement(context.Context, *pb.RenterBandwidthAllocation) error // GetAgreements gets all bandwidth agreements. GetAgreements(context.Context) ([]Agreement, error) // GetAgreementsSince gets all bandwidth agreements since specific time. GetAgreementsSince(context.Context, time.Time) ([]Agreement, error) } // Server is an implementation of the pb.BandwidthServer interface type Server struct { db DB NodeID storj.NodeID logger *zap.Logger } // Agreement is a struct that contains a bandwidth agreement and the associated signature type Agreement struct { Agreement pb.RenterBandwidthAllocation CreatedAt time.Time } // NewServer creates instance of Server func NewServer(db DB, logger *zap.Logger, nodeID storj.NodeID) *Server { // TODO: reorder arguments, rename logger -> log return &Server{db: db, logger: logger, NodeID: nodeID} } // Close closes resources func (s *Server) Close() error { return nil } // BandwidthAgreements receives and stores bandwidth agreements from storage nodes func (s *Server) BandwidthAgreements(ctx context.Context, rba *pb.RenterBandwidthAllocation) (reply *pb.AgreementsSummary, err error) { defer mon.Task()(&ctx)(&err) s.logger.Debug("Received Agreement...") reply = &pb.AgreementsSummary{ Status: pb.AgreementsSummary_REJECTED, } pba := rba.PayerAllocation //verify message content pi, err := identity.PeerIdentityFromContext(ctx) if err != nil || rba.StorageNodeId != pi.ID { return reply, auth.ErrBadID.New("Storage Node ID: %s vs %s", rba.StorageNodeId, pi.ID) } //todo: use whitelist for uplinks? if pba.SatelliteId != s.NodeID { return reply, pb.ErrPayer.New("Satellite ID: %s vs %s", pba.SatelliteId, s.NodeID) } exp := time.Unix(pba.GetExpirationUnixSec(), 0).UTC() if exp.Before(time.Now().UTC()) { return reply, pb.ErrPayer.Wrap(auth.ErrExpired.New("%v vs %v", exp, time.Now().UTC())) } //verify message crypto if err := auth.VerifyMsg(rba, pba.UplinkId); err != nil { return reply, pb.ErrRenter.Wrap(err) } if err := auth.VerifyMsg(&pba, pba.SatelliteId); err != nil { return reply, pb.ErrPayer.Wrap(err) } //save and return rersults err = s.db.CreateAgreement(ctx, rba) if err != nil { //todo: better classify transport errors (AgreementsSummary_FAIL) vs logical (AgreementsSummary_REJECTED) return reply, pb.ErrPayer.Wrap(auth.ErrSerial.Wrap(err)) } reply.Status = pb.AgreementsSummary_OK s.logger.Debug("Stored Agreement...") return reply, nil }