2020-06-02 14:51:33 +01:00
|
|
|
// Copyright (C) 2020 Storj Labs, Inc.
|
|
|
|
// See LICENSE for copying information.
|
|
|
|
|
|
|
|
package orders
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/hex"
|
|
|
|
"fmt"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
"github.com/zeebo/errs"
|
2020-07-08 12:02:41 +01:00
|
|
|
"golang.org/x/crypto/nacl/secretbox"
|
2020-06-02 14:51:33 +01:00
|
|
|
|
2020-07-24 18:13:15 +01:00
|
|
|
"storj.io/common/pb"
|
2020-06-02 14:51:33 +01:00
|
|
|
"storj.io/common/storj"
|
|
|
|
)
|
|
|
|
|
|
|
|
// ErrEncryptionKey is error class used for keys.
|
|
|
|
var ErrEncryptionKey = errs.Class("order encryption key")
|
|
|
|
|
|
|
|
// EncryptionKeyID is used to identify an encryption key.
|
|
|
|
type EncryptionKeyID [8]byte
|
|
|
|
|
|
|
|
// IsZero returns whether the key contains no data.
|
|
|
|
func (key EncryptionKeyID) IsZero() bool { return key == EncryptionKeyID{} }
|
|
|
|
|
|
|
|
// EncryptionKeys contains a collection of keys.
|
|
|
|
//
|
|
|
|
// Can be used as a flag.
|
|
|
|
type EncryptionKeys struct {
|
|
|
|
Default EncryptionKey
|
|
|
|
List []EncryptionKey
|
|
|
|
KeyByID map[EncryptionKeyID]storj.Key
|
|
|
|
}
|
|
|
|
|
2020-11-20 19:16:31 +00:00
|
|
|
// NewEncryptionKeys creates a new EncrytpionKeys object with the provided keys.
|
|
|
|
func NewEncryptionKeys(keys ...EncryptionKey) (*EncryptionKeys, error) {
|
|
|
|
var ekeys EncryptionKeys
|
|
|
|
for _, key := range keys {
|
|
|
|
if err := ekeys.Add(key); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return &ekeys, nil
|
|
|
|
}
|
|
|
|
|
2020-06-02 14:51:33 +01:00
|
|
|
// EncryptionKey contains an identifier and an encryption key that is used to
|
|
|
|
// encrypt transient metadata in orders.
|
|
|
|
//
|
|
|
|
// Can be used as a flag.
|
|
|
|
type EncryptionKey struct {
|
|
|
|
ID EncryptionKeyID
|
|
|
|
Key storj.Key
|
|
|
|
}
|
|
|
|
|
2020-07-08 12:02:41 +01:00
|
|
|
// When this fails to compile, then `serialToNonce` should be adjusted accordingly.
|
|
|
|
var _ = ([16]byte)(storj.SerialNumber{})
|
|
|
|
|
|
|
|
func serialToNonce(serial storj.SerialNumber) (x [24]byte) {
|
|
|
|
copy(x[:], serial[:])
|
|
|
|
return x
|
|
|
|
}
|
|
|
|
|
|
|
|
// Encrypt encrypts data and nonce using the key.
|
|
|
|
func (key *EncryptionKey) Encrypt(plaintext []byte, nonce storj.SerialNumber) []byte {
|
|
|
|
out := make([]byte, 0, len(plaintext)+secretbox.Overhead)
|
|
|
|
n := serialToNonce(nonce)
|
|
|
|
k := ([32]byte)(key.Key)
|
|
|
|
return secretbox.Seal(out, plaintext, &n, &k)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Decrypt decrypts data and nonce using the key.
|
|
|
|
func (key *EncryptionKey) Decrypt(ciphertext []byte, nonce storj.SerialNumber) ([]byte, error) {
|
|
|
|
out := make([]byte, 0, len(ciphertext)-secretbox.Overhead)
|
|
|
|
n := serialToNonce(nonce)
|
|
|
|
k := ([32]byte)(key.Key)
|
|
|
|
dec, ok := secretbox.Open(out, ciphertext, &n, &k)
|
|
|
|
if !ok {
|
|
|
|
return nil, ErrEncryptionKey.New("unable to decrypt")
|
|
|
|
}
|
|
|
|
return dec, nil
|
|
|
|
}
|
|
|
|
|
2020-07-24 18:13:15 +01:00
|
|
|
// EncryptMetadata encrypts order limit metadata.
|
|
|
|
func (key *EncryptionKey) EncryptMetadata(serial storj.SerialNumber, metadata *pb.OrderLimitMetadata) ([]byte, error) {
|
|
|
|
marshaled, err := pb.Marshal(metadata)
|
|
|
|
if err != nil {
|
|
|
|
return nil, ErrEncryptionKey.Wrap(err)
|
|
|
|
}
|
|
|
|
return key.Encrypt(marshaled, serial), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// DecryptMetadata decrypts order limit metadata.
|
|
|
|
func (key *EncryptionKey) DecryptMetadata(serial storj.SerialNumber, encrypted []byte) (*pb.OrderLimitMetadata, error) {
|
|
|
|
decrypted, err := key.Decrypt(encrypted, serial)
|
|
|
|
if err != nil {
|
|
|
|
return nil, ErrEncryptionKey.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
metadata := &pb.OrderLimitMetadata{}
|
|
|
|
err = pb.Unmarshal(decrypted, metadata)
|
|
|
|
if err != nil {
|
|
|
|
return nil, ErrEncryptionKey.Wrap(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return metadata, nil
|
|
|
|
}
|
|
|
|
|
2020-06-02 14:51:33 +01:00
|
|
|
// IsZero returns whether they key contains some data.
|
|
|
|
func (key *EncryptionKey) IsZero() bool {
|
|
|
|
return key.ID.IsZero() || key.Key.IsZero()
|
|
|
|
}
|
|
|
|
|
|
|
|
// Type implements pflag.Value.
|
|
|
|
func (EncryptionKey) Type() string { return "orders.EncryptionKey" }
|
|
|
|
|
|
|
|
// String is required for pflag.Value.
|
|
|
|
func (key *EncryptionKey) String() string {
|
|
|
|
return hex.EncodeToString(key.ID[:]) + "=" + hex.EncodeToString(key.Key[:])
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set sets the value from an hex encoded string "hex(id)=hex(key)".
|
|
|
|
func (key *EncryptionKey) Set(s string) error {
|
|
|
|
tokens := strings.SplitN(s, "=", 2)
|
|
|
|
if len(tokens) != 2 {
|
|
|
|
return ErrEncryptionKey.New("invalid definition %q", s)
|
|
|
|
}
|
|
|
|
|
|
|
|
err := setHexEncodedArray(key.ID[:], tokens[0])
|
|
|
|
if err != nil {
|
|
|
|
return ErrEncryptionKey.New("invalid id %q: %v", tokens[0], err)
|
|
|
|
}
|
|
|
|
|
|
|
|
err = setHexEncodedArray(key.Key[:], tokens[1])
|
|
|
|
if err != nil {
|
|
|
|
return ErrEncryptionKey.New("invalid key %q: %v", tokens[1], err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if key.ID.IsZero() || key.Key.IsZero() {
|
|
|
|
return ErrEncryptionKey.New("neither identifier or key can be zero")
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Type implements pflag.Value.
|
|
|
|
func (EncryptionKeys) Type() string { return "orders.EncryptionKeys" }
|
|
|
|
|
|
|
|
// Set adds the values from a comma delimited hex encoded strings "hex(id1)=hex(key1),hex(id2)=hex(key2)".
|
|
|
|
func (keys *EncryptionKeys) Set(s string) error {
|
2020-08-27 15:30:04 +01:00
|
|
|
if s == "" {
|
|
|
|
return nil
|
|
|
|
}
|
2020-06-02 14:51:33 +01:00
|
|
|
|
2020-11-23 19:24:22 +00:00
|
|
|
keys.Clear()
|
|
|
|
|
2020-06-02 14:51:33 +01:00
|
|
|
for _, x := range strings.Split(s, ",") {
|
|
|
|
x = strings.TrimSpace(x)
|
|
|
|
var ekey EncryptionKey
|
|
|
|
if err := ekey.Set(x); err != nil {
|
2020-11-23 19:24:22 +00:00
|
|
|
return ErrEncryptionKey.New("invalid keys %q: %w", s, err)
|
2020-06-02 14:51:33 +01:00
|
|
|
}
|
2020-11-20 19:16:31 +00:00
|
|
|
if err := keys.Add(ekey); err != nil {
|
|
|
|
return err
|
2020-06-02 14:51:33 +01:00
|
|
|
}
|
2020-11-20 19:16:31 +00:00
|
|
|
}
|
2020-06-02 14:51:33 +01:00
|
|
|
|
2020-11-20 19:16:31 +00:00
|
|
|
return nil
|
|
|
|
}
|
2020-06-02 14:51:33 +01:00
|
|
|
|
2020-11-20 19:16:31 +00:00
|
|
|
// Add adds an encryption key to EncryptionsKeys object.
|
|
|
|
func (keys *EncryptionKeys) Add(ekey EncryptionKey) error {
|
|
|
|
if keys.KeyByID == nil {
|
|
|
|
keys.KeyByID = map[EncryptionKeyID]storj.Key{}
|
|
|
|
}
|
|
|
|
if ekey.IsZero() {
|
|
|
|
return ErrEncryptionKey.New("key is zero")
|
|
|
|
}
|
|
|
|
|
|
|
|
if keys.Default.IsZero() {
|
|
|
|
keys.Default = ekey
|
|
|
|
}
|
2020-06-02 14:51:33 +01:00
|
|
|
|
2020-11-20 19:16:31 +00:00
|
|
|
if _, exists := keys.KeyByID[ekey.ID]; exists {
|
2020-11-23 19:24:22 +00:00
|
|
|
return ErrEncryptionKey.New("duplicate key identifier %q", ekey.String())
|
2020-06-02 14:51:33 +01:00
|
|
|
}
|
|
|
|
|
2020-11-20 19:16:31 +00:00
|
|
|
keys.List = append(keys.List, ekey)
|
|
|
|
keys.KeyByID[ekey.ID] = ekey.Key
|
2020-06-02 14:51:33 +01:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2020-11-23 19:24:22 +00:00
|
|
|
// Clear removes all keys.
|
|
|
|
func (keys *EncryptionKeys) Clear() {
|
|
|
|
keys.Default = EncryptionKey{}
|
|
|
|
keys.List = nil
|
|
|
|
keys.KeyByID = map[EncryptionKeyID]storj.Key{}
|
|
|
|
}
|
|
|
|
|
2020-06-02 14:51:33 +01:00
|
|
|
// String is required for pflag.Value.
|
|
|
|
func (keys *EncryptionKeys) String() string {
|
|
|
|
var s strings.Builder
|
|
|
|
if keys.Default.IsZero() {
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
|
|
|
s.WriteString(keys.Default.String())
|
|
|
|
for _, key := range keys.List {
|
|
|
|
if key.ID == keys.Default.ID {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
s.WriteString(",")
|
|
|
|
s.WriteString(key.String())
|
|
|
|
}
|
|
|
|
|
|
|
|
return s.String()
|
|
|
|
}
|
|
|
|
|
|
|
|
// setHexEncodedArray sets dst bytes to hex decoded s, verify that the result matches dst.
|
|
|
|
func setHexEncodedArray(dst []byte, s string) error {
|
|
|
|
s = strings.TrimSpace(s)
|
|
|
|
if len(s) != len(dst)*2 {
|
|
|
|
return fmt.Errorf("wrong hex length %d, expected %d", len(s), len(dst)*2)
|
|
|
|
}
|
|
|
|
|
|
|
|
bytes, err := hex.DecodeString(s)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
copy(dst, bytes)
|
|
|
|
return nil
|
|
|
|
}
|