storj/satellite/orders/encryptionkey.go

235 lines
5.9 KiB
Go
Raw Normal View History

// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package orders
import (
"encoding/hex"
"fmt"
"strings"
"github.com/zeebo/errs"
"golang.org/x/crypto/nacl/secretbox"
"storj.io/common/pb"
"storj.io/common/storj"
)
// ErrEncryptionKey is error class used for keys.
var ErrEncryptionKey = errs.Class("order encryption key")
// EncryptionKeyID is used to identify an encryption key.
type EncryptionKeyID [8]byte
// IsZero returns whether the key contains no data.
func (key EncryptionKeyID) IsZero() bool { return key == EncryptionKeyID{} }
// EncryptionKeys contains a collection of keys.
//
// Can be used as a flag.
type EncryptionKeys struct {
Default EncryptionKey
List []EncryptionKey
KeyByID map[EncryptionKeyID]storj.Key
}
// NewEncryptionKeys creates a new EncrytpionKeys object with the provided keys.
func NewEncryptionKeys(keys ...EncryptionKey) (*EncryptionKeys, error) {
var ekeys EncryptionKeys
for _, key := range keys {
if err := ekeys.Add(key); err != nil {
return nil, err
}
}
return &ekeys, nil
}
// EncryptionKey contains an identifier and an encryption key that is used to
// encrypt transient metadata in orders.
//
// Can be used as a flag.
type EncryptionKey struct {
ID EncryptionKeyID
Key storj.Key
}
// When this fails to compile, then `serialToNonce` should be adjusted accordingly.
var _ = ([16]byte)(storj.SerialNumber{})
func serialToNonce(serial storj.SerialNumber) (x [24]byte) {
copy(x[:], serial[:])
return x
}
// Encrypt encrypts data and nonce using the key.
func (key *EncryptionKey) Encrypt(plaintext []byte, nonce storj.SerialNumber) []byte {
out := make([]byte, 0, len(plaintext)+secretbox.Overhead)
n := serialToNonce(nonce)
k := ([32]byte)(key.Key)
return secretbox.Seal(out, plaintext, &n, &k)
}
// Decrypt decrypts data and nonce using the key.
func (key *EncryptionKey) Decrypt(ciphertext []byte, nonce storj.SerialNumber) ([]byte, error) {
out := make([]byte, 0, len(ciphertext)-secretbox.Overhead)
n := serialToNonce(nonce)
k := ([32]byte)(key.Key)
dec, ok := secretbox.Open(out, ciphertext, &n, &k)
if !ok {
return nil, ErrEncryptionKey.New("unable to decrypt")
}
return dec, nil
}
// EncryptMetadata encrypts order limit metadata.
func (key *EncryptionKey) EncryptMetadata(serial storj.SerialNumber, metadata *pb.OrderLimitMetadata) ([]byte, error) {
marshaled, err := pb.Marshal(metadata)
if err != nil {
return nil, ErrEncryptionKey.Wrap(err)
}
return key.Encrypt(marshaled, serial), nil
}
// DecryptMetadata decrypts order limit metadata.
func (key *EncryptionKey) DecryptMetadata(serial storj.SerialNumber, encrypted []byte) (*pb.OrderLimitMetadata, error) {
decrypted, err := key.Decrypt(encrypted, serial)
if err != nil {
return nil, ErrEncryptionKey.Wrap(err)
}
metadata := &pb.OrderLimitMetadata{}
err = pb.Unmarshal(decrypted, metadata)
if err != nil {
return nil, ErrEncryptionKey.Wrap(err)
}
return metadata, nil
}
// IsZero returns whether they key contains some data.
func (key *EncryptionKey) IsZero() bool {
return key.ID.IsZero() || key.Key.IsZero()
}
// Type implements pflag.Value.
func (EncryptionKey) Type() string { return "orders.EncryptionKey" }
// String is required for pflag.Value.
func (key *EncryptionKey) String() string {
return hex.EncodeToString(key.ID[:]) + "=" + hex.EncodeToString(key.Key[:])
}
// Set sets the value from an hex encoded string "hex(id)=hex(key)".
func (key *EncryptionKey) Set(s string) error {
tokens := strings.SplitN(s, "=", 2)
if len(tokens) != 2 {
return ErrEncryptionKey.New("invalid definition %q", s)
}
err := setHexEncodedArray(key.ID[:], tokens[0])
if err != nil {
return ErrEncryptionKey.New("invalid id %q: %v", tokens[0], err)
}
err = setHexEncodedArray(key.Key[:], tokens[1])
if err != nil {
return ErrEncryptionKey.New("invalid key %q: %v", tokens[1], err)
}
if key.ID.IsZero() || key.Key.IsZero() {
return ErrEncryptionKey.New("neither identifier or key can be zero")
}
return nil
}
// Type implements pflag.Value.
func (EncryptionKeys) Type() string { return "orders.EncryptionKeys" }
// Set adds the values from a comma delimited hex encoded strings "hex(id1)=hex(key1),hex(id2)=hex(key2)".
func (keys *EncryptionKeys) Set(s string) error {
if s == "" {
return nil
}
keys.Clear()
for _, x := range strings.Split(s, ",") {
x = strings.TrimSpace(x)
var ekey EncryptionKey
if err := ekey.Set(x); err != nil {
return ErrEncryptionKey.New("invalid keys %q: %w", s, err)
}
if err := keys.Add(ekey); err != nil {
return err
}
}
return nil
}
// Add adds an encryption key to EncryptionsKeys object.
func (keys *EncryptionKeys) Add(ekey EncryptionKey) error {
if keys.KeyByID == nil {
keys.KeyByID = map[EncryptionKeyID]storj.Key{}
}
if ekey.IsZero() {
return ErrEncryptionKey.New("key is zero")
}
if keys.Default.IsZero() {
keys.Default = ekey
}
if _, exists := keys.KeyByID[ekey.ID]; exists {
return ErrEncryptionKey.New("duplicate key identifier %q", ekey.String())
}
keys.List = append(keys.List, ekey)
keys.KeyByID[ekey.ID] = ekey.Key
return nil
}
// Clear removes all keys.
func (keys *EncryptionKeys) Clear() {
keys.Default = EncryptionKey{}
keys.List = nil
keys.KeyByID = map[EncryptionKeyID]storj.Key{}
}
// String is required for pflag.Value.
func (keys *EncryptionKeys) String() string {
var s strings.Builder
if keys.Default.IsZero() {
return ""
}
s.WriteString(keys.Default.String())
for _, key := range keys.List {
if key.ID == keys.Default.ID {
continue
}
s.WriteString(",")
s.WriteString(key.String())
}
return s.String()
}
// setHexEncodedArray sets dst bytes to hex decoded s, verify that the result matches dst.
func setHexEncodedArray(dst []byte, s string) error {
s = strings.TrimSpace(s)
if len(s) != len(dst)*2 {
return fmt.Errorf("wrong hex length %d, expected %d", len(s), len(dst)*2)
}
bytes, err := hex.DecodeString(s)
if err != nil {
return err
}
copy(dst, bytes)
return nil
}