private/context2: add WithoutCancellation
Change-Id: I38557c16f41b8983886f256353cc6afb7634d9e6
This commit is contained in:
parent
19d318ea9d
commit
08f63614be
49
private/context2/nocancel.go
Normal file
49
private/context2/nocancel.go
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
// Package context2 contains utilities for contexts.
|
||||
package context2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WithoutCancellation returns a context that does not propagate Done message
|
||||
// down to children. However, Values are propagated.
|
||||
func WithoutCancellation(ctx context.Context) context.Context {
|
||||
return noCancelContext{ctx}
|
||||
}
|
||||
|
||||
type noCancelContext struct {
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Deadline returns the time when work done on behalf of this context
|
||||
// should be canceled.
|
||||
func (noCancelContext) Deadline() (deadline time.Time, ok bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// Done returns empty channel.
|
||||
func (noCancelContext) Done() <-chan struct{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Err always returns nil
|
||||
func (noCancelContext) Err() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns string.
|
||||
func (ctx noCancelContext) String() string {
|
||||
return fmt.Sprintf("no cancel (%s)", ctx.ctx)
|
||||
}
|
||||
|
||||
// Value returns the value associated with this context for key, or nil
|
||||
// if no value is associated with key. Successive calls to Value with
|
||||
// the same key returns the same result.
|
||||
func (ctx noCancelContext) Value(key interface{}) interface{} {
|
||||
return ctx.ctx.Value(key)
|
||||
}
|
26
private/context2/nocancel_test.go
Normal file
26
private/context2/nocancel_test.go
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package context2_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/storj/private/context2"
|
||||
)
|
||||
|
||||
func TestWithoutCancellation(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
defer ctx.Cleanup()
|
||||
|
||||
parent, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
|
||||
without := context2.WithoutCancellation(parent)
|
||||
require.Equal(t, error(nil), without.Err())
|
||||
require.Equal(t, (<-chan struct{})(nil), without.Done())
|
||||
}
|
@ -25,6 +25,7 @@ import (
|
||||
"storj.io/common/signing"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/sync2"
|
||||
"storj.io/storj/private/context2"
|
||||
"storj.io/storj/storagenode/bandwidth"
|
||||
"storj.io/storj/storagenode/monitor"
|
||||
"storj.io/storj/storagenode/orders"
|
||||
@ -707,9 +708,8 @@ func (endpoint *Endpoint) doDownload(stream downloadStream) (err error) {
|
||||
|
||||
// saveOrder saves the order with all necessary information. It assumes it has been already verified.
|
||||
func (endpoint *Endpoint) saveOrder(ctx context.Context, limit *pb.OrderLimit, order *pb.Order) {
|
||||
// intentionally using background context to ensure that we always save the order,
|
||||
// even when the client cancels the request.
|
||||
alwaysctx := context.Background()
|
||||
// We always want to save order to the database to be able to settle.
|
||||
ctx = context2.WithoutCancellation(ctx)
|
||||
|
||||
var err error
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
@ -718,14 +718,14 @@ func (endpoint *Endpoint) saveOrder(ctx context.Context, limit *pb.OrderLimit, o
|
||||
if order == nil || order.Amount <= 0 {
|
||||
return
|
||||
}
|
||||
err = endpoint.orders.Enqueue(alwaysctx, &orders.Info{
|
||||
err = endpoint.orders.Enqueue(ctx, &orders.Info{
|
||||
Limit: limit,
|
||||
Order: order,
|
||||
})
|
||||
if err != nil {
|
||||
endpoint.log.Error("failed to add order", zap.Error(err))
|
||||
} else {
|
||||
err = endpoint.usage.Add(alwaysctx, limit.SatelliteId, limit.Action, order.Amount, time.Now())
|
||||
err = endpoint.usage.Add(ctx, limit.SatelliteId, limit.Action, order.Amount, time.Now())
|
||||
if err != nil {
|
||||
endpoint.log.Error("failed to add bandwidth usage", zap.Error(err))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user