private/context2: add WithoutCancellation

Change-Id: I38557c16f41b8983886f256353cc6afb7634d9e6
This commit is contained in:
Egon Elbre 2020-01-15 10:14:10 +02:00
parent 19d318ea9d
commit 08f63614be
3 changed files with 80 additions and 5 deletions

View File

@ -0,0 +1,49 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
// Package context2 contains utilities for contexts.
package context2
import (
"context"
"fmt"
"time"
)
// WithoutCancellation returns a context that does not propagate Done message
// down to children. However, Values are propagated.
func WithoutCancellation(ctx context.Context) context.Context {
return noCancelContext{ctx}
}
type noCancelContext struct {
ctx context.Context
}
// Deadline returns the time when work done on behalf of this context
// should be canceled.
func (noCancelContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
// Done returns empty channel.
func (noCancelContext) Done() <-chan struct{} {
return nil
}
// Err always returns nil
func (noCancelContext) Err() error {
return nil
}
// String returns string.
func (ctx noCancelContext) String() string {
return fmt.Sprintf("no cancel (%s)", ctx.ctx)
}
// Value returns the value associated with this context for key, or nil
// if no value is associated with key. Successive calls to Value with
// the same key returns the same result.
func (ctx noCancelContext) Value(key interface{}) interface{} {
return ctx.ctx.Value(key)
}

View File

@ -0,0 +1,26 @@
// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package context2_test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"storj.io/common/testcontext"
"storj.io/storj/private/context2"
)
func TestWithoutCancellation(t *testing.T) {
ctx := testcontext.New(t)
defer ctx.Cleanup()
parent, cancel := context.WithCancel(ctx)
cancel()
without := context2.WithoutCancellation(parent)
require.Equal(t, error(nil), without.Err())
require.Equal(t, (<-chan struct{})(nil), without.Done())
}

View File

@ -25,6 +25,7 @@ import (
"storj.io/common/signing" "storj.io/common/signing"
"storj.io/common/storj" "storj.io/common/storj"
"storj.io/common/sync2" "storj.io/common/sync2"
"storj.io/storj/private/context2"
"storj.io/storj/storagenode/bandwidth" "storj.io/storj/storagenode/bandwidth"
"storj.io/storj/storagenode/monitor" "storj.io/storj/storagenode/monitor"
"storj.io/storj/storagenode/orders" "storj.io/storj/storagenode/orders"
@ -707,9 +708,8 @@ func (endpoint *Endpoint) doDownload(stream downloadStream) (err error) {
// saveOrder saves the order with all necessary information. It assumes it has been already verified. // saveOrder saves the order with all necessary information. It assumes it has been already verified.
func (endpoint *Endpoint) saveOrder(ctx context.Context, limit *pb.OrderLimit, order *pb.Order) { func (endpoint *Endpoint) saveOrder(ctx context.Context, limit *pb.OrderLimit, order *pb.Order) {
// intentionally using background context to ensure that we always save the order, // We always want to save order to the database to be able to settle.
// even when the client cancels the request. ctx = context2.WithoutCancellation(ctx)
alwaysctx := context.Background()
var err error var err error
defer mon.Task()(&ctx)(&err) defer mon.Task()(&ctx)(&err)
@ -718,14 +718,14 @@ func (endpoint *Endpoint) saveOrder(ctx context.Context, limit *pb.OrderLimit, o
if order == nil || order.Amount <= 0 { if order == nil || order.Amount <= 0 {
return return
} }
err = endpoint.orders.Enqueue(alwaysctx, &orders.Info{ err = endpoint.orders.Enqueue(ctx, &orders.Info{
Limit: limit, Limit: limit,
Order: order, Order: order,
}) })
if err != nil { if err != nil {
endpoint.log.Error("failed to add order", zap.Error(err)) endpoint.log.Error("failed to add order", zap.Error(err))
} else { } else {
err = endpoint.usage.Add(alwaysctx, limit.SatelliteId, limit.Action, order.Amount, time.Now()) err = endpoint.usage.Add(ctx, limit.SatelliteId, limit.Action, order.Amount, time.Now())
if err != nil { if err != nil {
endpoint.log.Error("failed to add bandwidth usage", zap.Error(err)) endpoint.log.Error("failed to add bandwidth usage", zap.Error(err))
} }