diff --git a/Dockerfile.jenkins b/Dockerfile.jenkins index c73a99327..2e36f8cce 100644 --- a/Dockerfile.jenkins +++ b/Dockerfile.jenkins @@ -5,8 +5,7 @@ FROM golang:1.12.7 RUN curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - RUN echo "deb http://apt.postgresql.org/pub/repos/apt/ stretch-pgdg main" | tee /etc/apt/sources.list.d/pgdg.list -RUN apt-get update -RUN apt-get install -y -qq postgresql-11 unzip +RUN apt-get update && apt-get install -y -qq postgresql-11 redis-server unzip RUN rm /etc/postgresql/11/main/pg_hba.conf; \ echo 'local all all trust' >> /etc/postgresql/11/main/pg_hba.conf; \ diff --git a/storage/boltdb/client.go b/storage/boltdb/client.go index c58ae988c..a1b62cd7a 100644 --- a/storage/boltdb/client.go +++ b/storage/boltdb/client.go @@ -122,7 +122,7 @@ func (client *Client) view(fn func(*bolt.Bucket) error) error { })) } -// Put adds a key/value to boltDB in a batch, where boltDB commits the batch to to disk every +// Put adds a key/value to boltDB in a batch, where boltDB commits the batch to disk every // 1000 operations or 10ms, whichever is first. The MaxBatchDelay are using default settings. // Ref: https://github.com/boltdb/bolt/blob/master/db.go#L160 // Note: when using this method, check if it need to be executed asynchronously @@ -346,3 +346,36 @@ func (cursor backward) SkipPrefix(prefix storage.Key) (key, value []byte) { func (cursor backward) Advance() (key, value []byte) { return cursor.Prev() } + +// CompareAndSwap atomically compares and swaps oldValue with newValue +func (client *Client) CompareAndSwap(ctx context.Context, key storage.Key, oldValue, newValue storage.Value) (err error) { + defer mon.Task()(&ctx)(&err) + if key.IsZero() { + return storage.ErrEmptyKey.New("") + } + + return client.update(func(bucket *bolt.Bucket) error { + data := bucket.Get([]byte(key)) + if len(data) == 0 { + if oldValue != nil { + return storage.ErrKeyNotFound.New(key.String()) + } + + if newValue == nil { + return nil + } + + return Error.Wrap(bucket.Put(key, newValue)) + } + + if !bytes.Equal(storage.Value(data), oldValue) { + return storage.ErrValueChanged.New(key.String()) + } + + if newValue == nil { + return Error.Wrap(bucket.Delete(key)) + } + + return Error.Wrap(bucket.Put(key, newValue)) + }) +} diff --git a/storage/common.go b/storage/common.go index 902b4558f..afdb70d8a 100644 --- a/storage/common.go +++ b/storage/common.go @@ -17,12 +17,15 @@ var mon = monkit.Package() // Delimiter separates nested paths in storage const Delimiter = '/' -//ErrKeyNotFound used When something doesn't exist +//ErrKeyNotFound used when something doesn't exist var ErrKeyNotFound = errs.Class("key not found") -// ErrEmptyKey is returned when an empty key is used in Put +// ErrEmptyKey is returned when an empty key is used in Put or in CompareAndSwap var ErrEmptyKey = errs.Class("empty key") +// ErrValueChanged is returned when the current value of the key does not match the oldValue in CompareAndSwap +var ErrValueChanged = errs.Class("value changed") + // ErrEmptyQueue is returned when attempting to Dequeue from an empty queue var ErrEmptyQueue = errs.Class("empty queue") @@ -68,6 +71,8 @@ type KeyValueStore interface { List(ctx context.Context, start Key, limit int) (Keys, error) // Iterate iterates over items based on opts Iterate(ctx context.Context, opts IterateOptions, fn func(context.Context, Iterator) error) error + // CompareAndSwap atomically compares and swaps oldValue with newValue + CompareAndSwap(ctx context.Context, key Key, oldValue, newValue Value) error // Close closes the store Close() error } diff --git a/storage/postgreskv/client.go b/storage/postgreskv/client.go index 322613262..2df8dea94 100644 --- a/storage/postgreskv/client.go +++ b/storage/postgreskv/client.go @@ -316,3 +316,93 @@ func (client *Client) Iterate(ctx context.Context, opts storage.IterateOptions, return fn(ctx, opi) } + +// CompareAndSwap atomically compares and swaps oldValue with newValue +func (client *Client) CompareAndSwap(ctx context.Context, key storage.Key, oldValue, newValue storage.Value) (err error) { + defer mon.Task()(&ctx)(&err) + return client.CompareAndSwapPath(ctx, storage.Key(defaultBucket), key, oldValue, newValue) +} + +// CompareAndSwapPath atomically compares and swaps oldValue with newValue in the given bucket +func (client *Client) CompareAndSwapPath(ctx context.Context, bucket, key storage.Key, oldValue, newValue storage.Value) (err error) { + defer mon.Task()(&ctx)(&err) + if key.IsZero() { + return storage.ErrEmptyKey.New("") + } + + if oldValue == nil && newValue == nil { + q := "SELECT metadata FROM pathdata WHERE bucket = $1::BYTEA AND fullpath = $2::BYTEA" + row := client.pgConn.QueryRow(q, []byte(bucket), []byte(key)) + var val []byte + err = row.Scan(&val) + if err == sql.ErrNoRows { + return nil + } + if err != nil { + return Error.Wrap(err) + } + return storage.ErrValueChanged.New(key.String()) + } + + if oldValue == nil { + q := ` + INSERT INTO pathdata (bucket, fullpath, metadata) VALUES ($1::BYTEA, $2::BYTEA, $3::BYTEA) + ON CONFLICT DO NOTHING + RETURNING 1 + ` + row := client.pgConn.QueryRow(q, []byte(bucket), []byte(key), []byte(newValue)) + var val []byte + err = row.Scan(&val) + if err == sql.ErrNoRows { + return storage.ErrValueChanged.New(key.String()) + } + return Error.Wrap(err) + } + + var row *sql.Row + if newValue == nil { + q := ` + WITH matching_key AS ( + SELECT * FROM pathdata WHERE bucket = $1::BYTEA AND fullpath = $2::BYTEA + ), updated AS ( + DELETE FROM pathdata + USING matching_key mk + WHERE pathdata.metadata = $3::BYTEA + AND pathdata.bucket = mk.bucket + AND pathdata.fullpath = mk.fullpath + RETURNING 1 + ) + SELECT EXISTS(SELECT 1 FROM matching_key) AS key_present, EXISTS(SELECT 1 FROM updated) AS value_updated + ` + row = client.pgConn.QueryRow(q, []byte(bucket), []byte(key), []byte(oldValue)) + } else { + q := ` + WITH matching_key AS ( + SELECT * FROM pathdata WHERE bucket = $1::BYTEA AND fullpath = $2::BYTEA + ), updated AS ( + UPDATE pathdata + SET metadata = $4::BYTEA + FROM matching_key mk + WHERE pathdata.metadata = $3::BYTEA + AND pathdata.bucket = mk.bucket + AND pathdata.fullpath = mk.fullpath + RETURNING 1 + ) + SELECT EXISTS(SELECT 1 FROM matching_key) AS key_present, EXISTS(SELECT 1 FROM updated) AS value_updated; + ` + row = client.pgConn.QueryRow(q, []byte(bucket), []byte(key), []byte(oldValue), []byte(newValue)) + } + + var keyPresent, valueUpdated bool + err = row.Scan(&keyPresent, &valueUpdated) + if err != nil { + return Error.Wrap(err) + } + if !keyPresent { + return storage.ErrKeyNotFound.New(key.String()) + } + if !valueUpdated { + return storage.ErrValueChanged.New(key.String()) + } + return nil +} diff --git a/storage/redis/client.go b/storage/redis/client.go index 0c816388f..b7a6f219e 100644 --- a/storage/redis/client.go +++ b/storage/redis/client.go @@ -4,6 +4,7 @@ package redis import ( + "bytes" "context" "net/url" "sort" @@ -80,15 +81,7 @@ func (client *Client) Get(ctx context.Context, key storage.Key) (_ storage.Value if key.IsZero() { return nil, storage.ErrEmptyKey.New("") } - - value, err := client.db.Get(string(key)).Bytes() - if err == redis.Nil { - return nil, storage.ErrKeyNotFound.New(key.String()) - } - if err != nil { - return nil, Error.New("get error: %v", err) - } - return value, nil + return get(ctx, client.db, key) } // Put adds a value to the provided key in redis, returning an error on failure. @@ -97,12 +90,7 @@ func (client *Client) Put(ctx context.Context, key storage.Key, value storage.Va if key.IsZero() { return storage.ErrEmptyKey.New("") } - - err = client.db.Set(key.String(), []byte(value), client.TTL).Err() - if err != nil { - return Error.New("put error: %v", err) - } - return nil + return put(ctx, client.db, key, value, client.TTL) } // List returns either a list of keys for which boltdb has values or an error. @@ -117,12 +105,7 @@ func (client *Client) Delete(ctx context.Context, key storage.Key) (err error) { if key.IsZero() { return storage.ErrEmptyKey.New("") } - - err = client.db.Del(key.String()).Err() - if err != nil { - return Error.New("delete error: %v", err) - } - return nil + return delete(ctx, client.db, key) } // Close closes a redis client @@ -229,3 +212,82 @@ func (client *Client) allPrefixedItems(prefix, first, last storage.Key) (storage return all, nil } + +// CompareAndSwap atomically compares and swaps oldValue with newValue +func (client *Client) CompareAndSwap(ctx context.Context, key storage.Key, oldValue, newValue storage.Value) (err error) { + defer mon.Task()(&ctx)(&err) + if key.IsZero() { + return storage.ErrEmptyKey.New("") + } + + txf := func(tx *redis.Tx) error { + value, err := get(ctx, tx, key) + if storage.ErrKeyNotFound.Has(err) { + if oldValue != nil { + return storage.ErrKeyNotFound.New(key.String()) + } + + if newValue == nil { + return nil + } + + // runs only if the watched keys remain unchanged + _, err = tx.Pipelined(func(pipe redis.Pipeliner) error { + return put(ctx, pipe, key, newValue, client.TTL) + }) + } + if err != nil { + return err + } + + if !bytes.Equal(value, oldValue) { + return storage.ErrValueChanged.New(key.String()) + } + + // runs only if the watched keys remain unchanged + _, err = tx.Pipelined(func(pipe redis.Pipeliner) error { + if newValue == nil { + return delete(ctx, pipe, key) + } + return put(ctx, pipe, key, newValue, client.TTL) + }) + + return err + } + + err = client.db.Watch(txf, key.String()) + if err == redis.TxFailedErr { + return storage.ErrValueChanged.New(key.String()) + } + return Error.Wrap(err) +} + +func get(ctx context.Context, cmdable redis.Cmdable, key storage.Key) (_ storage.Value, err error) { + defer mon.Task()(&ctx)(&err) + value, err := cmdable.Get(string(key)).Bytes() + if err == redis.Nil { + return nil, storage.ErrKeyNotFound.New(key.String()) + } + if err != nil && err != redis.TxFailedErr { + return nil, Error.New("get error: %v", err) + } + return value, errs.Wrap(err) +} + +func put(ctx context.Context, cmdable redis.Cmdable, key storage.Key, value storage.Value, ttl time.Duration) (err error) { + defer mon.Task()(&ctx)(&err) + err = cmdable.Set(key.String(), []byte(value), ttl).Err() + if err != nil && err != redis.TxFailedErr { + return Error.New("put error: %v", err) + } + return errs.Wrap(err) +} + +func delete(ctx context.Context, cmdable redis.Cmdable, key storage.Key) (err error) { + defer mon.Task()(&ctx)(&err) + err = cmdable.Del(key.String()).Err() + if err != nil && err != redis.TxFailedErr { + return Error.New("delete error: %v", err) + } + return errs.Wrap(err) +} diff --git a/storage/storelogger/logger.go b/storage/storelogger/logger.go index f9e35bf70..3172dcec6 100644 --- a/storage/storelogger/logger.go +++ b/storage/storelogger/logger.go @@ -97,6 +97,15 @@ func (store *Logger) Close() error { return store.store.Close() } +// CompareAndSwap atomically compares and swaps oldValue with newValue +func (store *Logger) CompareAndSwap(ctx context.Context, key storage.Key, oldValue, newValue storage.Value) (err error) { + defer mon.Task()(&ctx)(&err) + store.log.Debug("CompareAndSwap", zap.ByteString("key", key), + zap.Int("old value length", len(oldValue)), zap.Int("new value length", len(newValue)), + zap.Binary("truncated old value", truncate(oldValue)), zap.Binary("truncated new value", truncate(newValue))) + return store.store.CompareAndSwap(ctx, key, oldValue, newValue) +} + func truncate(v storage.Value) (t []byte) { if len(v)-1 < 10 { t = []byte(v) diff --git a/storage/teststore/store.go b/storage/teststore/store.go index 2afb9dfad..bb57c3190 100644 --- a/storage/teststore/store.go +++ b/storage/teststore/store.go @@ -26,14 +26,15 @@ type Client struct { ForceError int CallCount struct { - Get int - Put int - List int - GetAll int - ReverseList int - Delete int - Close int - Iterate int + Get int + Put int + List int + GetAll int + ReverseList int + Delete int + Close int + Iterate int + CompareAndSwap int } version int @@ -89,13 +90,7 @@ func (store *Client) Put(ctx context.Context, key storage.Key, value storage.Val return nil } - store.Items = append(store.Items, storage.ListItem{}) - copy(store.Items[keyIndex+1:], store.Items[keyIndex:]) - store.Items[keyIndex] = storage.ListItem{ - Key: storage.CloneKey(key), - Value: storage.CloneValue(value), - } - + store.put(keyIndex, key, value) return nil } @@ -169,8 +164,7 @@ func (store *Client) Delete(ctx context.Context, key storage.Key) (err error) { return storage.ErrKeyNotFound.New(key.String()) } - copy(store.Items[keyIndex:], store.Items[keyIndex+1:]) - store.Items = store.Items[:len(store.Items)-1] + store.delete(keyIndex) return nil } @@ -430,3 +424,61 @@ func (cursor *cursor) prev() (*storage.ListItem, bool) { cursor.nextIndex-- return item, true } + +// CompareAndSwap atomically compares and swaps oldValue with newValue +func (store *Client) CompareAndSwap(ctx context.Context, key storage.Key, oldValue, newValue storage.Value) (err error) { + defer mon.Task()(&ctx)(&err) + defer store.locked()() + + store.version++ + store.CallCount.CompareAndSwap++ + if store.forcedError() { + return errInternal + } + + if key.IsZero() { + return storage.ErrEmptyKey.New("") + } + + keyIndex, found := store.indexOf(key) + if !found { + if oldValue != nil { + return storage.ErrKeyNotFound.New(key.String()) + } + + if newValue == nil { + return nil + } + + store.put(keyIndex, key, newValue) + return nil + } + + kv := &store.Items[keyIndex] + if !bytes.Equal(kv.Value, oldValue) { + return storage.ErrValueChanged.New(key.String()) + } + + if newValue == nil { + store.delete(keyIndex) + return nil + } + + kv.Value = storage.CloneValue(newValue) + + return nil +} + +func (store *Client) put(keyIndex int, key storage.Key, value storage.Value) { + store.Items = append(store.Items, storage.ListItem{}) + copy(store.Items[keyIndex+1:], store.Items[keyIndex:]) + store.Items[keyIndex] = storage.ListItem{ + Key: storage.CloneKey(key), + Value: storage.CloneValue(value), + } +} + +func (store *Client) delete(keyIndex int) { + copy(store.Items[keyIndex:], store.Items[keyIndex+1:]) + store.Items = store.Items[:len(store.Items)-1] +} diff --git a/storage/testsuite/test.go b/storage/testsuite/test.go index aa5b9a686..2d85c65a0 100644 --- a/storage/testsuite/test.go +++ b/storage/testsuite/test.go @@ -4,10 +4,17 @@ package testsuite import ( + "bytes" + "encoding/gob" + "fmt" "strconv" "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "storj.io/storj/storage" ) @@ -85,4 +92,169 @@ func testConstraints(t *testing.T, store storage.KeyValueStore) { t.Fatalf("List LookupLimit+1 shouldn't fail: %v / got %d", err, len(keys)) } }) + + t.Run("CompareAndSwap Empty Key", func(t *testing.T) { + var key storage.Key + var val storage.Value + + err := store.CompareAndSwap(ctx, key, val, val) + require.Error(t, err, "putting empty key should fail") + }) + + t.Run("CompareAndSwap Empty Old Value", func(t *testing.T) { + key := storage.Key("test-key") + val := storage.Value("test-value") + defer func() { _ = store.Delete(ctx, key) }() + + err := store.CompareAndSwap(ctx, key, nil, val) + require.NoError(t, err, "failed to update %q: %v -> %v: %+v", key, nil, val, err) + + value, err := store.Get(ctx, key) + require.NoError(t, err, "failed to get %q = %v: %+v", key, val, err) + require.Equal(t, value, val, "invalid value for %q = %v: got %v", key, val, value) + }) + + t.Run("CompareAndSwap Empty New Value", func(t *testing.T) { + key := storage.Key("test-key") + val := storage.Value("test-value") + defer func() { _ = store.Delete(ctx, key) }() + + err := store.Put(ctx, key, val) + require.NoError(t, err, "failed to put %q = %v: %+v", key, val, err) + + err = store.CompareAndSwap(ctx, key, val, nil) + require.NoError(t, err, "failed to update %q: %v -> %v: %+v", key, val, nil, err) + + value, err := store.Get(ctx, key) + require.Error(t, err, "got deleted value %q = %v", key, value) + }) + + t.Run("CompareAndSwap Empty Both Empty Values", func(t *testing.T) { + key := storage.Key("test-key") + + err := store.CompareAndSwap(ctx, key, nil, nil) + require.NoError(t, err, "failed to update %q: %v -> %v: %+v", key, nil, nil, err) + + value, err := store.Get(ctx, key) + require.Error(t, err, "got unexpected value %q = %v", key, value) + }) + + t.Run("CompareAndSwap Missing Key", func(t *testing.T) { + for i, tt := range []struct { + old, new storage.Value + }{ + {storage.Value("old-value"), nil}, + {storage.Value("old-value"), storage.Value("new-value")}, + } { + errTag := fmt.Sprintf("%d. %+v", i, tt) + key := storage.Key("test-key") + + err := store.CompareAndSwap(ctx, key, tt.old, tt.new) + assert.True(t, storage.ErrKeyNotFound.Has(err), "%s: unexpected error: %+v", errTag, err) + } + }) + + t.Run("CompareAndSwap Value Changed", func(t *testing.T) { + for i, tt := range []struct { + old, new storage.Value + }{ + {nil, nil}, + {nil, storage.Value("new-value")}, + {storage.Value("old-value"), nil}, + {storage.Value("old-value"), storage.Value("new-value")}, + } { + errTag := fmt.Sprintf("%d. %+v", i, tt) + key := storage.Key("test-key") + val := storage.Value("test-value") + defer func() { _ = store.Delete(ctx, key) }() + + err := store.Put(ctx, key, val) + require.NoError(t, err, errTag) + + err = store.CompareAndSwap(ctx, key, tt.old, tt.new) + assert.True(t, storage.ErrValueChanged.Has(err), "%s: unexpected error: %+v", errTag, err) + } + }) + + t.Run("CompareAndSwap Concurrent", func(t *testing.T) { + const count = 100 + + key := storage.Key("test-key") + defer func() { _ = store.Delete(ctx, key) }() + + // Add concurrently all numbers from 1 to `count` in a set under test-key + var group errgroup.Group + for i := 0; i < count; i++ { + i := i + group.Go(func() error { + for { + set := make(map[int]bool) + + oldValue, err := store.Get(ctx, key) + if !storage.ErrKeyNotFound.Has(err) { + if err != nil { + return err + } + + set, err = decodeSet(oldValue) + if err != nil { + return err + } + } + + set[i] = true + newValue, err := encodeSet(set) + if err != nil { + return err + } + + err = store.CompareAndSwap(ctx, key, oldValue, storage.Value(newValue)) + if storage.ErrValueChanged.Has(err) { + // Another goroutine was faster. Make a new attempt. + continue + } + + return err + } + }) + } + err := group.Wait() + require.NoError(t, err) + + // Check that all numbers were added in the set + value, err := store.Get(ctx, key) + require.NoError(t, err) + + set, err := decodeSet(value) + require.NoError(t, err) + + for i := 0; i < count; i++ { + assert.Contains(t, set, i) + } + }) +} + +func encodeSet(set map[int]bool) ([]byte, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + + err := enc.Encode(set) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func decodeSet(b []byte) (map[int]bool, error) { + buf := bytes.NewBuffer(b) + dec := gob.NewDecoder(buf) + + var set map[int]bool + err := dec.Decode(&set) + if err != nil { + return nil, err + } + + return set, nil } diff --git a/storage/testsuite/test_crud.go b/storage/testsuite/test_crud.go index 50e197e2e..dd2e6f744 100644 --- a/storage/testsuite/test_crud.go +++ b/storage/testsuite/test_crud.go @@ -70,9 +70,9 @@ func testCRUD(t *testing.T, store storage.KeyValueStore) { t.Run("Update", func(t *testing.T) { for i, item := range items { next := items[(i+1)%len(items)] - err := store.Put(ctx, item.Key, next.Value) + err := store.CompareAndSwap(ctx, item.Key, item.Value, next.Value) if err != nil { - t.Fatalf("failed to update %q = %v: %v", item.Key, next.Value, err) + t.Fatalf("failed to update %q: %v -> %v: %v", item.Key, item.Value, next.Value, err) } } diff --git a/storage/testsuite/test_parallel.go b/storage/testsuite/test_parallel.go index e07c65e52..3700e8683 100644 --- a/storage/testsuite/test_parallel.go +++ b/storage/testsuite/test_parallel.go @@ -53,7 +53,7 @@ func testParallel(t *testing.T, store storage.KeyValueStore) { // Update value nextValue := storage.Value(string(item.Value) + "X") - err = store.Put(ctx, item.Key, nextValue) + err = store.CompareAndSwap(ctx, item.Key, item.Value, nextValue) if err != nil { t.Fatalf("failed to update %q = %v: %v", item.Key, nextValue, err) }