121 lines
3.2 KiB
Go
121 lines
3.2 KiB
Go
|
// Copyright (C) 2018 Storj Labs, Inc.
|
||
|
// See LICENSE for copying information.
|
||
|
|
||
|
package satellitedb
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/zeebo/errs"
|
||
|
|
||
|
"storj.io/storj/pkg/identity"
|
||
|
"storj.io/storj/pkg/storj"
|
||
|
dbx "storj.io/storj/satellite/satellitedb/dbx"
|
||
|
)
|
||
|
|
||
|
type peerIdentities struct {
|
||
|
db *dbx.DB
|
||
|
}
|
||
|
|
||
|
// Set adds a peer identity entry
|
||
|
func (idents *peerIdentities) Set(ctx context.Context, nodeID storj.NodeID, ident *identity.PeerIdentity) (err error) {
|
||
|
defer mon.Task()(&ctx)(&err)
|
||
|
|
||
|
if ident == nil {
|
||
|
return Error.New("identitiy is nil")
|
||
|
}
|
||
|
|
||
|
tx, err := idents.db.Open(ctx)
|
||
|
if err != nil {
|
||
|
return Error.Wrap(err)
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
if err == nil {
|
||
|
err = tx.Commit()
|
||
|
} else {
|
||
|
err = errs.Combine(err, tx.Rollback())
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
serial, err := tx.Get_PeerIdentity_LeafSerialNumber_By_NodeId(ctx, dbx.PeerIdentity_NodeId(nodeID.Bytes()))
|
||
|
if serial == nil || err != nil {
|
||
|
if serial == nil || err == sql.ErrNoRows {
|
||
|
_, err = tx.Create_PeerIdentity(ctx,
|
||
|
dbx.PeerIdentity_NodeId(nodeID.Bytes()),
|
||
|
dbx.PeerIdentity_LeafSerialNumber(ident.Leaf.SerialNumber.Bytes()),
|
||
|
dbx.PeerIdentity_Chain(identity.EncodePeerIdentity(ident)),
|
||
|
)
|
||
|
return Error.Wrap(err)
|
||
|
}
|
||
|
return Error.Wrap(err)
|
||
|
}
|
||
|
if !bytes.Equal(serial.LeafSerialNumber, ident.Leaf.SerialNumber.Bytes()) {
|
||
|
_, err = tx.Update_PeerIdentity_By_NodeId(ctx,
|
||
|
dbx.PeerIdentity_NodeId(nodeID.Bytes()),
|
||
|
dbx.PeerIdentity_Update_Fields{
|
||
|
LeafSerialNumber: dbx.PeerIdentity_LeafSerialNumber(ident.Leaf.SerialNumber.Bytes()),
|
||
|
Chain: dbx.PeerIdentity_Chain(identity.EncodePeerIdentity(ident)),
|
||
|
},
|
||
|
)
|
||
|
}
|
||
|
|
||
|
return Error.Wrap(err)
|
||
|
}
|
||
|
|
||
|
// Get gets the peer identity based on the certificate's nodeID
|
||
|
func (idents *peerIdentities) Get(ctx context.Context, nodeID storj.NodeID) (_ *identity.PeerIdentity, err error) {
|
||
|
defer mon.Task()(&ctx)(&err)
|
||
|
dbxIdent, err := idents.db.Get_PeerIdentity_By_NodeId(ctx, dbx.PeerIdentity_NodeId(nodeID.Bytes()))
|
||
|
if err != nil {
|
||
|
return nil, Error.Wrap(err)
|
||
|
}
|
||
|
if dbxIdent == nil {
|
||
|
return nil, Error.New("missing node id: %v", nodeID)
|
||
|
}
|
||
|
|
||
|
ident, err := identity.DecodePeerIdentity(ctx, dbxIdent.Chain)
|
||
|
return ident, Error.Wrap(err)
|
||
|
}
|
||
|
|
||
|
// BatchGet gets the peer idenities based on the certificate's nodeID
|
||
|
func (idents *peerIdentities) BatchGet(ctx context.Context, nodeIDs storj.NodeIDList) (peerIdents []*identity.PeerIdentity, err error) {
|
||
|
defer mon.Task()(&ctx)(&err)
|
||
|
if len(nodeIDs) == 0 {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
args := make([]interface{}, 0, nodeIDs.Len())
|
||
|
for _, nodeID := range nodeIDs {
|
||
|
args = append(args, nodeID)
|
||
|
}
|
||
|
|
||
|
// TODO: optimize using arrays like overlay
|
||
|
|
||
|
rows, err := idents.db.Query(idents.db.Rebind(`
|
||
|
SELECT chain FROM peer_identities WHERE node_id IN (?`+strings.Repeat(", ?", len(nodeIDs)-1)+`)`), args...)
|
||
|
if err != nil {
|
||
|
return nil, Error.Wrap(err)
|
||
|
}
|
||
|
defer func() {
|
||
|
err = errs.Combine(err, rows.Close())
|
||
|
}()
|
||
|
|
||
|
for rows.Next() {
|
||
|
var peerChain []byte
|
||
|
err := rows.Scan(&peerChain)
|
||
|
if err != nil {
|
||
|
return nil, Error.Wrap(err)
|
||
|
}
|
||
|
ident, err := identity.DecodePeerIdentity(ctx, peerChain)
|
||
|
if err != nil {
|
||
|
return nil, Error.Wrap(err)
|
||
|
}
|
||
|
peerIdents = append(peerIdents, ident)
|
||
|
}
|
||
|
return peerIdents, nil
|
||
|
}
|