storj/cmd/uplink/ulfs/handle_std_test.go

102 lines
1.9 KiB
Go
Raw Normal View History

// Copyright (C) 2022 Storj Labs, Inc.
// See LICENSE for copying information.
package ulfs
import (
"bytes"
"errors"
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
"storj.io/common/testcontext"
"storj.io/common/testrand"
)
type writeThrottle struct {
entered chan struct{}
release chan error
}
type throttledWriter struct {
writex int64
write []writeThrottle
data bytes.Buffer
}
func newThrottledWriter(maxWrites int) *throttledWriter {
tw := &throttledWriter{
writex: 0,
write: make([]writeThrottle, maxWrites),
}
for i := range tw.write {
tw.write[i] = writeThrottle{
entered: make(chan struct{}),
release: make(chan error, 1),
}
}
return tw
}
func (tw *throttledWriter) Write(data []byte) (n int, _ error) {
index := atomic.AddInt64(&tw.writex, 1) - 1
close(tw.write[index].entered)
forceErr := <-tw.write[index].release
n, writeErr := tw.data.Write(data)
if writeErr != nil {
return n, writeErr
}
return n, forceErr
}
func TestStdMultiWriteAbort(t *testing.T) {
ctx := testcontext.New(t)
stdout := newThrottledWriter(2)
multi := newStdMultiWriteHandle(stdout)
head := testrand.Bytes(256)
tail := testrand.Bytes(256)
part1, err := multi.NextPart(ctx, 256)
require.NoError(t, err)
ctx.Go(func() error {
defer func() { _ = part1.Abort() }()
_, err := part1.Write(head)
if err == nil {
return errors.New("expected an error")
}
return nil
})
part2, err := multi.NextPart(ctx, 256)
require.NoError(t, err)
ctx.Go(func() error {
defer func() { _ = part2.Commit() }()
// wait for the above part to enter write first
<-stdout.write[0].entered
_, err := part2.Write(tail)
if err == nil {
return errors.New("expected an error")
}
return nil
})
// wait until we enter both writes
<-stdout.write[0].entered
stdout.write[0].release <- errors.New("fail 0")
stdout.write[1].release <- nil
ctx.Wait()
require.Equal(t, head, stdout.data.Bytes())
}