CompareAndSwap in KeyValueStore (#2602)
This commit is contained in:
parent
b24e60a33f
commit
0e1cb7bfb8
@ -5,8 +5,7 @@ FROM golang:1.12.7
|
||||
RUN curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
|
||||
RUN echo "deb http://apt.postgresql.org/pub/repos/apt/ stretch-pgdg main" | tee /etc/apt/sources.list.d/pgdg.list
|
||||
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y -qq postgresql-11 unzip
|
||||
RUN apt-get update && apt-get install -y -qq postgresql-11 redis-server unzip
|
||||
|
||||
RUN rm /etc/postgresql/11/main/pg_hba.conf; \
|
||||
echo 'local all all trust' >> /etc/postgresql/11/main/pg_hba.conf; \
|
||||
|
@ -122,7 +122,7 @@ func (client *Client) view(fn func(*bolt.Bucket) error) error {
|
||||
}))
|
||||
}
|
||||
|
||||
// Put adds a key/value to boltDB in a batch, where boltDB commits the batch to to disk every
|
||||
// Put adds a key/value to boltDB in a batch, where boltDB commits the batch to disk every
|
||||
// 1000 operations or 10ms, whichever is first. The MaxBatchDelay are using default settings.
|
||||
// Ref: https://github.com/boltdb/bolt/blob/master/db.go#L160
|
||||
// Note: when using this method, check if it need to be executed asynchronously
|
||||
@ -346,3 +346,36 @@ func (cursor backward) SkipPrefix(prefix storage.Key) (key, value []byte) {
|
||||
func (cursor backward) Advance() (key, value []byte) {
|
||||
return cursor.Prev()
|
||||
}
|
||||
|
||||
// CompareAndSwap atomically compares and swaps oldValue with newValue
|
||||
func (client *Client) CompareAndSwap(ctx context.Context, key storage.Key, oldValue, newValue storage.Value) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
if key.IsZero() {
|
||||
return storage.ErrEmptyKey.New("")
|
||||
}
|
||||
|
||||
return client.update(func(bucket *bolt.Bucket) error {
|
||||
data := bucket.Get([]byte(key))
|
||||
if len(data) == 0 {
|
||||
if oldValue != nil {
|
||||
return storage.ErrKeyNotFound.New(key.String())
|
||||
}
|
||||
|
||||
if newValue == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return Error.Wrap(bucket.Put(key, newValue))
|
||||
}
|
||||
|
||||
if !bytes.Equal(storage.Value(data), oldValue) {
|
||||
return storage.ErrValueChanged.New(key.String())
|
||||
}
|
||||
|
||||
if newValue == nil {
|
||||
return Error.Wrap(bucket.Delete(key))
|
||||
}
|
||||
|
||||
return Error.Wrap(bucket.Put(key, newValue))
|
||||
})
|
||||
}
|
||||
|
@ -17,12 +17,15 @@ var mon = monkit.Package()
|
||||
// Delimiter separates nested paths in storage
|
||||
const Delimiter = '/'
|
||||
|
||||
//ErrKeyNotFound used When something doesn't exist
|
||||
//ErrKeyNotFound used when something doesn't exist
|
||||
var ErrKeyNotFound = errs.Class("key not found")
|
||||
|
||||
// ErrEmptyKey is returned when an empty key is used in Put
|
||||
// ErrEmptyKey is returned when an empty key is used in Put or in CompareAndSwap
|
||||
var ErrEmptyKey = errs.Class("empty key")
|
||||
|
||||
// ErrValueChanged is returned when the current value of the key does not match the oldValue in CompareAndSwap
|
||||
var ErrValueChanged = errs.Class("value changed")
|
||||
|
||||
// ErrEmptyQueue is returned when attempting to Dequeue from an empty queue
|
||||
var ErrEmptyQueue = errs.Class("empty queue")
|
||||
|
||||
@ -68,6 +71,8 @@ type KeyValueStore interface {
|
||||
List(ctx context.Context, start Key, limit int) (Keys, error)
|
||||
// Iterate iterates over items based on opts
|
||||
Iterate(ctx context.Context, opts IterateOptions, fn func(context.Context, Iterator) error) error
|
||||
// CompareAndSwap atomically compares and swaps oldValue with newValue
|
||||
CompareAndSwap(ctx context.Context, key Key, oldValue, newValue Value) error
|
||||
// Close closes the store
|
||||
Close() error
|
||||
}
|
||||
|
@ -316,3 +316,93 @@ func (client *Client) Iterate(ctx context.Context, opts storage.IterateOptions,
|
||||
|
||||
return fn(ctx, opi)
|
||||
}
|
||||
|
||||
// CompareAndSwap atomically compares and swaps oldValue with newValue
|
||||
func (client *Client) CompareAndSwap(ctx context.Context, key storage.Key, oldValue, newValue storage.Value) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
return client.CompareAndSwapPath(ctx, storage.Key(defaultBucket), key, oldValue, newValue)
|
||||
}
|
||||
|
||||
// CompareAndSwapPath atomically compares and swaps oldValue with newValue in the given bucket
|
||||
func (client *Client) CompareAndSwapPath(ctx context.Context, bucket, key storage.Key, oldValue, newValue storage.Value) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
if key.IsZero() {
|
||||
return storage.ErrEmptyKey.New("")
|
||||
}
|
||||
|
||||
if oldValue == nil && newValue == nil {
|
||||
q := "SELECT metadata FROM pathdata WHERE bucket = $1::BYTEA AND fullpath = $2::BYTEA"
|
||||
row := client.pgConn.QueryRow(q, []byte(bucket), []byte(key))
|
||||
var val []byte
|
||||
err = row.Scan(&val)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
return storage.ErrValueChanged.New(key.String())
|
||||
}
|
||||
|
||||
if oldValue == nil {
|
||||
q := `
|
||||
INSERT INTO pathdata (bucket, fullpath, metadata) VALUES ($1::BYTEA, $2::BYTEA, $3::BYTEA)
|
||||
ON CONFLICT DO NOTHING
|
||||
RETURNING 1
|
||||
`
|
||||
row := client.pgConn.QueryRow(q, []byte(bucket), []byte(key), []byte(newValue))
|
||||
var val []byte
|
||||
err = row.Scan(&val)
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.ErrValueChanged.New(key.String())
|
||||
}
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
var row *sql.Row
|
||||
if newValue == nil {
|
||||
q := `
|
||||
WITH matching_key AS (
|
||||
SELECT * FROM pathdata WHERE bucket = $1::BYTEA AND fullpath = $2::BYTEA
|
||||
), updated AS (
|
||||
DELETE FROM pathdata
|
||||
USING matching_key mk
|
||||
WHERE pathdata.metadata = $3::BYTEA
|
||||
AND pathdata.bucket = mk.bucket
|
||||
AND pathdata.fullpath = mk.fullpath
|
||||
RETURNING 1
|
||||
)
|
||||
SELECT EXISTS(SELECT 1 FROM matching_key) AS key_present, EXISTS(SELECT 1 FROM updated) AS value_updated
|
||||
`
|
||||
row = client.pgConn.QueryRow(q, []byte(bucket), []byte(key), []byte(oldValue))
|
||||
} else {
|
||||
q := `
|
||||
WITH matching_key AS (
|
||||
SELECT * FROM pathdata WHERE bucket = $1::BYTEA AND fullpath = $2::BYTEA
|
||||
), updated AS (
|
||||
UPDATE pathdata
|
||||
SET metadata = $4::BYTEA
|
||||
FROM matching_key mk
|
||||
WHERE pathdata.metadata = $3::BYTEA
|
||||
AND pathdata.bucket = mk.bucket
|
||||
AND pathdata.fullpath = mk.fullpath
|
||||
RETURNING 1
|
||||
)
|
||||
SELECT EXISTS(SELECT 1 FROM matching_key) AS key_present, EXISTS(SELECT 1 FROM updated) AS value_updated;
|
||||
`
|
||||
row = client.pgConn.QueryRow(q, []byte(bucket), []byte(key), []byte(oldValue), []byte(newValue))
|
||||
}
|
||||
|
||||
var keyPresent, valueUpdated bool
|
||||
err = row.Scan(&keyPresent, &valueUpdated)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
if !keyPresent {
|
||||
return storage.ErrKeyNotFound.New(key.String())
|
||||
}
|
||||
if !valueUpdated {
|
||||
return storage.ErrValueChanged.New(key.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/url"
|
||||
"sort"
|
||||
@ -80,15 +81,7 @@ func (client *Client) Get(ctx context.Context, key storage.Key) (_ storage.Value
|
||||
if key.IsZero() {
|
||||
return nil, storage.ErrEmptyKey.New("")
|
||||
}
|
||||
|
||||
value, err := client.db.Get(string(key)).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, storage.ErrKeyNotFound.New(key.String())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, Error.New("get error: %v", err)
|
||||
}
|
||||
return value, nil
|
||||
return get(ctx, client.db, key)
|
||||
}
|
||||
|
||||
// Put adds a value to the provided key in redis, returning an error on failure.
|
||||
@ -97,12 +90,7 @@ func (client *Client) Put(ctx context.Context, key storage.Key, value storage.Va
|
||||
if key.IsZero() {
|
||||
return storage.ErrEmptyKey.New("")
|
||||
}
|
||||
|
||||
err = client.db.Set(key.String(), []byte(value), client.TTL).Err()
|
||||
if err != nil {
|
||||
return Error.New("put error: %v", err)
|
||||
}
|
||||
return nil
|
||||
return put(ctx, client.db, key, value, client.TTL)
|
||||
}
|
||||
|
||||
// List returns either a list of keys for which boltdb has values or an error.
|
||||
@ -117,12 +105,7 @@ func (client *Client) Delete(ctx context.Context, key storage.Key) (err error) {
|
||||
if key.IsZero() {
|
||||
return storage.ErrEmptyKey.New("")
|
||||
}
|
||||
|
||||
err = client.db.Del(key.String()).Err()
|
||||
if err != nil {
|
||||
return Error.New("delete error: %v", err)
|
||||
}
|
||||
return nil
|
||||
return delete(ctx, client.db, key)
|
||||
}
|
||||
|
||||
// Close closes a redis client
|
||||
@ -229,3 +212,82 @@ func (client *Client) allPrefixedItems(prefix, first, last storage.Key) (storage
|
||||
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// CompareAndSwap atomically compares and swaps oldValue with newValue
|
||||
func (client *Client) CompareAndSwap(ctx context.Context, key storage.Key, oldValue, newValue storage.Value) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
if key.IsZero() {
|
||||
return storage.ErrEmptyKey.New("")
|
||||
}
|
||||
|
||||
txf := func(tx *redis.Tx) error {
|
||||
value, err := get(ctx, tx, key)
|
||||
if storage.ErrKeyNotFound.Has(err) {
|
||||
if oldValue != nil {
|
||||
return storage.ErrKeyNotFound.New(key.String())
|
||||
}
|
||||
|
||||
if newValue == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// runs only if the watched keys remain unchanged
|
||||
_, err = tx.Pipelined(func(pipe redis.Pipeliner) error {
|
||||
return put(ctx, pipe, key, newValue, client.TTL)
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !bytes.Equal(value, oldValue) {
|
||||
return storage.ErrValueChanged.New(key.String())
|
||||
}
|
||||
|
||||
// runs only if the watched keys remain unchanged
|
||||
_, err = tx.Pipelined(func(pipe redis.Pipeliner) error {
|
||||
if newValue == nil {
|
||||
return delete(ctx, pipe, key)
|
||||
}
|
||||
return put(ctx, pipe, key, newValue, client.TTL)
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
err = client.db.Watch(txf, key.String())
|
||||
if err == redis.TxFailedErr {
|
||||
return storage.ErrValueChanged.New(key.String())
|
||||
}
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
func get(ctx context.Context, cmdable redis.Cmdable, key storage.Key) (_ storage.Value, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
value, err := cmdable.Get(string(key)).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, storage.ErrKeyNotFound.New(key.String())
|
||||
}
|
||||
if err != nil && err != redis.TxFailedErr {
|
||||
return nil, Error.New("get error: %v", err)
|
||||
}
|
||||
return value, errs.Wrap(err)
|
||||
}
|
||||
|
||||
func put(ctx context.Context, cmdable redis.Cmdable, key storage.Key, value storage.Value, ttl time.Duration) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
err = cmdable.Set(key.String(), []byte(value), ttl).Err()
|
||||
if err != nil && err != redis.TxFailedErr {
|
||||
return Error.New("put error: %v", err)
|
||||
}
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
|
||||
func delete(ctx context.Context, cmdable redis.Cmdable, key storage.Key) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
err = cmdable.Del(key.String()).Err()
|
||||
if err != nil && err != redis.TxFailedErr {
|
||||
return Error.New("delete error: %v", err)
|
||||
}
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
|
@ -97,6 +97,15 @@ func (store *Logger) Close() error {
|
||||
return store.store.Close()
|
||||
}
|
||||
|
||||
// CompareAndSwap atomically compares and swaps oldValue with newValue
|
||||
func (store *Logger) CompareAndSwap(ctx context.Context, key storage.Key, oldValue, newValue storage.Value) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
store.log.Debug("CompareAndSwap", zap.ByteString("key", key),
|
||||
zap.Int("old value length", len(oldValue)), zap.Int("new value length", len(newValue)),
|
||||
zap.Binary("truncated old value", truncate(oldValue)), zap.Binary("truncated new value", truncate(newValue)))
|
||||
return store.store.CompareAndSwap(ctx, key, oldValue, newValue)
|
||||
}
|
||||
|
||||
func truncate(v storage.Value) (t []byte) {
|
||||
if len(v)-1 < 10 {
|
||||
t = []byte(v)
|
||||
|
@ -26,14 +26,15 @@ type Client struct {
|
||||
ForceError int
|
||||
|
||||
CallCount struct {
|
||||
Get int
|
||||
Put int
|
||||
List int
|
||||
GetAll int
|
||||
ReverseList int
|
||||
Delete int
|
||||
Close int
|
||||
Iterate int
|
||||
Get int
|
||||
Put int
|
||||
List int
|
||||
GetAll int
|
||||
ReverseList int
|
||||
Delete int
|
||||
Close int
|
||||
Iterate int
|
||||
CompareAndSwap int
|
||||
}
|
||||
|
||||
version int
|
||||
@ -89,13 +90,7 @@ func (store *Client) Put(ctx context.Context, key storage.Key, value storage.Val
|
||||
return nil
|
||||
}
|
||||
|
||||
store.Items = append(store.Items, storage.ListItem{})
|
||||
copy(store.Items[keyIndex+1:], store.Items[keyIndex:])
|
||||
store.Items[keyIndex] = storage.ListItem{
|
||||
Key: storage.CloneKey(key),
|
||||
Value: storage.CloneValue(value),
|
||||
}
|
||||
|
||||
store.put(keyIndex, key, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -169,8 +164,7 @@ func (store *Client) Delete(ctx context.Context, key storage.Key) (err error) {
|
||||
return storage.ErrKeyNotFound.New(key.String())
|
||||
}
|
||||
|
||||
copy(store.Items[keyIndex:], store.Items[keyIndex+1:])
|
||||
store.Items = store.Items[:len(store.Items)-1]
|
||||
store.delete(keyIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -430,3 +424,61 @@ func (cursor *cursor) prev() (*storage.ListItem, bool) {
|
||||
cursor.nextIndex--
|
||||
return item, true
|
||||
}
|
||||
|
||||
// CompareAndSwap atomically compares and swaps oldValue with newValue
|
||||
func (store *Client) CompareAndSwap(ctx context.Context, key storage.Key, oldValue, newValue storage.Value) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
defer store.locked()()
|
||||
|
||||
store.version++
|
||||
store.CallCount.CompareAndSwap++
|
||||
if store.forcedError() {
|
||||
return errInternal
|
||||
}
|
||||
|
||||
if key.IsZero() {
|
||||
return storage.ErrEmptyKey.New("")
|
||||
}
|
||||
|
||||
keyIndex, found := store.indexOf(key)
|
||||
if !found {
|
||||
if oldValue != nil {
|
||||
return storage.ErrKeyNotFound.New(key.String())
|
||||
}
|
||||
|
||||
if newValue == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
store.put(keyIndex, key, newValue)
|
||||
return nil
|
||||
}
|
||||
|
||||
kv := &store.Items[keyIndex]
|
||||
if !bytes.Equal(kv.Value, oldValue) {
|
||||
return storage.ErrValueChanged.New(key.String())
|
||||
}
|
||||
|
||||
if newValue == nil {
|
||||
store.delete(keyIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
kv.Value = storage.CloneValue(newValue)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *Client) put(keyIndex int, key storage.Key, value storage.Value) {
|
||||
store.Items = append(store.Items, storage.ListItem{})
|
||||
copy(store.Items[keyIndex+1:], store.Items[keyIndex:])
|
||||
store.Items[keyIndex] = storage.ListItem{
|
||||
Key: storage.CloneKey(key),
|
||||
Value: storage.CloneValue(value),
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Client) delete(keyIndex int) {
|
||||
copy(store.Items[keyIndex:], store.Items[keyIndex+1:])
|
||||
store.Items = store.Items[:len(store.Items)-1]
|
||||
}
|
||||
|
@ -4,10 +4,17 @@
|
||||
package testsuite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"storj.io/storj/storage"
|
||||
)
|
||||
|
||||
@ -85,4 +92,169 @@ func testConstraints(t *testing.T, store storage.KeyValueStore) {
|
||||
t.Fatalf("List LookupLimit+1 shouldn't fail: %v / got %d", err, len(keys))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CompareAndSwap Empty Key", func(t *testing.T) {
|
||||
var key storage.Key
|
||||
var val storage.Value
|
||||
|
||||
err := store.CompareAndSwap(ctx, key, val, val)
|
||||
require.Error(t, err, "putting empty key should fail")
|
||||
})
|
||||
|
||||
t.Run("CompareAndSwap Empty Old Value", func(t *testing.T) {
|
||||
key := storage.Key("test-key")
|
||||
val := storage.Value("test-value")
|
||||
defer func() { _ = store.Delete(ctx, key) }()
|
||||
|
||||
err := store.CompareAndSwap(ctx, key, nil, val)
|
||||
require.NoError(t, err, "failed to update %q: %v -> %v: %+v", key, nil, val, err)
|
||||
|
||||
value, err := store.Get(ctx, key)
|
||||
require.NoError(t, err, "failed to get %q = %v: %+v", key, val, err)
|
||||
require.Equal(t, value, val, "invalid value for %q = %v: got %v", key, val, value)
|
||||
})
|
||||
|
||||
t.Run("CompareAndSwap Empty New Value", func(t *testing.T) {
|
||||
key := storage.Key("test-key")
|
||||
val := storage.Value("test-value")
|
||||
defer func() { _ = store.Delete(ctx, key) }()
|
||||
|
||||
err := store.Put(ctx, key, val)
|
||||
require.NoError(t, err, "failed to put %q = %v: %+v", key, val, err)
|
||||
|
||||
err = store.CompareAndSwap(ctx, key, val, nil)
|
||||
require.NoError(t, err, "failed to update %q: %v -> %v: %+v", key, val, nil, err)
|
||||
|
||||
value, err := store.Get(ctx, key)
|
||||
require.Error(t, err, "got deleted value %q = %v", key, value)
|
||||
})
|
||||
|
||||
t.Run("CompareAndSwap Empty Both Empty Values", func(t *testing.T) {
|
||||
key := storage.Key("test-key")
|
||||
|
||||
err := store.CompareAndSwap(ctx, key, nil, nil)
|
||||
require.NoError(t, err, "failed to update %q: %v -> %v: %+v", key, nil, nil, err)
|
||||
|
||||
value, err := store.Get(ctx, key)
|
||||
require.Error(t, err, "got unexpected value %q = %v", key, value)
|
||||
})
|
||||
|
||||
t.Run("CompareAndSwap Missing Key", func(t *testing.T) {
|
||||
for i, tt := range []struct {
|
||||
old, new storage.Value
|
||||
}{
|
||||
{storage.Value("old-value"), nil},
|
||||
{storage.Value("old-value"), storage.Value("new-value")},
|
||||
} {
|
||||
errTag := fmt.Sprintf("%d. %+v", i, tt)
|
||||
key := storage.Key("test-key")
|
||||
|
||||
err := store.CompareAndSwap(ctx, key, tt.old, tt.new)
|
||||
assert.True(t, storage.ErrKeyNotFound.Has(err), "%s: unexpected error: %+v", errTag, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CompareAndSwap Value Changed", func(t *testing.T) {
|
||||
for i, tt := range []struct {
|
||||
old, new storage.Value
|
||||
}{
|
||||
{nil, nil},
|
||||
{nil, storage.Value("new-value")},
|
||||
{storage.Value("old-value"), nil},
|
||||
{storage.Value("old-value"), storage.Value("new-value")},
|
||||
} {
|
||||
errTag := fmt.Sprintf("%d. %+v", i, tt)
|
||||
key := storage.Key("test-key")
|
||||
val := storage.Value("test-value")
|
||||
defer func() { _ = store.Delete(ctx, key) }()
|
||||
|
||||
err := store.Put(ctx, key, val)
|
||||
require.NoError(t, err, errTag)
|
||||
|
||||
err = store.CompareAndSwap(ctx, key, tt.old, tt.new)
|
||||
assert.True(t, storage.ErrValueChanged.Has(err), "%s: unexpected error: %+v", errTag, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CompareAndSwap Concurrent", func(t *testing.T) {
|
||||
const count = 100
|
||||
|
||||
key := storage.Key("test-key")
|
||||
defer func() { _ = store.Delete(ctx, key) }()
|
||||
|
||||
// Add concurrently all numbers from 1 to `count` in a set under test-key
|
||||
var group errgroup.Group
|
||||
for i := 0; i < count; i++ {
|
||||
i := i
|
||||
group.Go(func() error {
|
||||
for {
|
||||
set := make(map[int]bool)
|
||||
|
||||
oldValue, err := store.Get(ctx, key)
|
||||
if !storage.ErrKeyNotFound.Has(err) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
set, err = decodeSet(oldValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
set[i] = true
|
||||
newValue, err := encodeSet(set)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = store.CompareAndSwap(ctx, key, oldValue, storage.Value(newValue))
|
||||
if storage.ErrValueChanged.Has(err) {
|
||||
// Another goroutine was faster. Make a new attempt.
|
||||
continue
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
})
|
||||
}
|
||||
err := group.Wait()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that all numbers were added in the set
|
||||
value, err := store.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
set, err := decodeSet(value)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
assert.Contains(t, set, i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func encodeSet(set map[int]bool) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
|
||||
err := enc.Encode(set)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func decodeSet(b []byte) (map[int]bool, error) {
|
||||
buf := bytes.NewBuffer(b)
|
||||
dec := gob.NewDecoder(buf)
|
||||
|
||||
var set map[int]bool
|
||||
err := dec.Decode(&set)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
@ -70,9 +70,9 @@ func testCRUD(t *testing.T, store storage.KeyValueStore) {
|
||||
t.Run("Update", func(t *testing.T) {
|
||||
for i, item := range items {
|
||||
next := items[(i+1)%len(items)]
|
||||
err := store.Put(ctx, item.Key, next.Value)
|
||||
err := store.CompareAndSwap(ctx, item.Key, item.Value, next.Value)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to update %q = %v: %v", item.Key, next.Value, err)
|
||||
t.Fatalf("failed to update %q: %v -> %v: %v", item.Key, item.Value, next.Value, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ func testParallel(t *testing.T, store storage.KeyValueStore) {
|
||||
|
||||
// Update value
|
||||
nextValue := storage.Value(string(item.Value) + "X")
|
||||
err = store.Put(ctx, item.Key, nextValue)
|
||||
err = store.CompareAndSwap(ctx, item.Key, item.Value, nextValue)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to update %q = %v: %v", item.Key, nextValue, err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user