Compare commits
77 Commits
main
...
gui-prebui
Author | SHA1 | Date | |
---|---|---|---|
977a27dde2 | |||
|
703fd437fe | ||
|
bf0f3b829f | ||
|
7f499e44a6 | ||
|
af93d2090b | ||
|
7059e10dfc | ||
|
4cb85186b2 | ||
|
9a871bf3bc | ||
|
afae5b578e | ||
|
5317135416 | ||
|
7cc873a62a | ||
|
31bb6d54c7 | ||
|
23631dc8bb | ||
|
b1e7d70a86 | ||
|
583ad54d86 | ||
|
2ee0195eba | ||
|
0303920da7 | ||
|
df9a6e968e | ||
|
abe1463a73 | ||
|
c96c83e805 | ||
|
0f4371e84c | ||
|
0a8115b149 | ||
|
47a4d4986d | ||
|
5272fd8497 | ||
|
95761908b5 | ||
|
5234727886 | ||
|
5a1c3f7f19 | ||
|
4ee647a951 | ||
|
e8fcdc10a4 | ||
|
99128ab551 | ||
|
062ca285a0 | ||
|
465941b345 | ||
|
7e03ccfa46 | ||
|
9370bc4580 | ||
|
1f92e7acda | ||
|
a9d979e4d7 | ||
|
4108aa72ba | ||
|
4e876fbdba | ||
|
bd4d57c604 | ||
|
c79d1b0d2f | ||
|
fbda13c752 | ||
|
0f9a0ba9cd | ||
|
73d65fce9a | ||
|
1d62dc63f5 | ||
|
05f30740f5 | ||
|
97a89c3476 | ||
|
e0b5476e78 | ||
|
074457fa4e | ||
|
5fc6eaab17 | ||
|
70cdca5d3c | ||
|
8b4387a498 | ||
|
ced8657caa | ||
|
ece0cc5785 | ||
|
a85c080509 | ||
|
a4d68b9b7e | ||
|
ddf1f1c340 | ||
|
e3d2f09988 | ||
|
f819b6a210 | ||
|
1525324384 | ||
|
2c3464081f | ||
|
6a3802de4f | ||
|
a740f96f75 | ||
|
7ac2031cac | ||
|
21c1e66a85 | ||
|
f2cd7b0928 | ||
|
500b6244f8 | ||
|
1851d103f9 | ||
|
032546219c | ||
|
1173877167 | ||
|
cb41c51692 | ||
|
d38b8fa2c4 | ||
|
20a47034a5 | ||
|
01e33e7753 | ||
|
8482b37c14 | ||
|
f131047f1a | ||
|
8d8f6734de | ||
|
c006126d54 |
@ -12,13 +12,27 @@ FROM debian:buster-slim as ca-cert
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates
|
||||
RUN update-ca-certificates
|
||||
|
||||
# Install storj-up helper (for local/dev runs)
|
||||
FROM --platform=$TARGETPLATFORM golang:1.19 AS storjup
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
--mount=type=cache,target=/go/pkg/mod \
|
||||
go install storj.io/storj-up@latest
|
||||
|
||||
# Install dlv (for local/dev runs)
|
||||
FROM --platform=$TARGETPLATFORM golang:1.19 AS dlv
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
--mount=type=cache,target=/go/pkg/mod \
|
||||
go install github.com/go-delve/delve/cmd/dlv@latest
|
||||
|
||||
FROM ${DOCKER_ARCH:-amd64}/debian:buster-slim
|
||||
ARG TAG
|
||||
ARG GOARCH
|
||||
ENV GOARCH ${GOARCH}
|
||||
ENV CONF_PATH=/root/.local/share/storj/satellite \
|
||||
STORJ_CONSOLE_STATIC_DIR=/app \
|
||||
STORJ_MAIL_TEMPLATE_PATH=/app/static/emails \
|
||||
STORJ_CONSOLE_ADDRESS=0.0.0.0:10100
|
||||
ENV PATH=$PATH:/app
|
||||
EXPOSE 7777
|
||||
EXPOSE 10100
|
||||
WORKDIR /app
|
||||
@ -30,5 +44,9 @@ COPY release/${TAG}/wasm/wasm_exec.js /app/static/wasm/
|
||||
COPY release/${TAG}/wasm/access.wasm.br /app/static/wasm/
|
||||
COPY release/${TAG}/wasm/wasm_exec.js.br /app/static/wasm/
|
||||
COPY release/${TAG}/satellite_linux_${GOARCH:-amd64} /app/satellite
|
||||
COPY --from=storjup /go/bin/storj-up /usr/local/bin/storj-up
|
||||
COPY --from=dlv /go/bin/dlv /usr/local/bin/dlv
|
||||
# test identities for quick-start
|
||||
COPY --from=img.dev.storj.io/storjup/base:20230607-1 /var/lib/storj/identities /var/lib/storj/identities
|
||||
COPY cmd/satellite/entrypoint /entrypoint
|
||||
ENTRYPOINT ["/entrypoint"]
|
||||
|
@ -1,6 +1,7 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
## production helpers
|
||||
SETUP_PARAMS=""
|
||||
|
||||
if [ -n "${IDENTITY_ADDR:-}" ]; then
|
||||
@ -21,6 +22,10 @@ if [ "${SATELLITE_API:-}" = "true" ]; then
|
||||
exec ./satellite run api $RUN_PARAMS "$@"
|
||||
fi
|
||||
|
||||
if [ "${SATELLITE_UI:-}" = "true" ]; then
|
||||
exec ./satellite run ui $RUN_PARAMS "$@"
|
||||
fi
|
||||
|
||||
if [ "${SATELLITE_GC:-}" = "true" ]; then
|
||||
exec ./satellite run garbage-collection $RUN_PARAMS "$@"
|
||||
fi
|
||||
@ -37,4 +42,63 @@ if [ "${SATELLITE_AUDITOR:-}" = "true" ]; then
|
||||
exec ./satellite run auditor $RUN_PARAMS "$@"
|
||||
fi
|
||||
|
||||
exec ./satellite run $RUN_PARAMS "$@"
|
||||
## storj-up helpers
|
||||
if [ "${STORJUP_ROLE:-""}" ]; then
|
||||
|
||||
if [ "${STORJ_IDENTITY_DIR:-""}" ]; then
|
||||
#Generate identity if missing
|
||||
if [ ! -f "$STORJ_IDENTITY_DIR/identity.key" ]; then
|
||||
if [ "$STORJ_USE_PREDEFINED_IDENTITY" ]; then
|
||||
# use predictable, pre-generated identity
|
||||
mkdir -p $(dirname $STORJ_IDENTITY_DIR)
|
||||
cp -r /var/lib/storj/identities/$STORJ_USE_PREDEFINED_IDENTITY $STORJ_IDENTITY_DIR
|
||||
else
|
||||
identity --identity-dir $STORJ_IDENTITY_DIR --difficulty 8 create .
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "${STORJ_WAIT_FOR_DB:-""}" ]; then
|
||||
storj-up util wait-for-port cockroach:26257
|
||||
storj-up util wait-for-port redis:6379
|
||||
fi
|
||||
|
||||
if [ "${STORJUP_ROLE:-""}" == "satellite-api" ]; then
|
||||
mkdir -p /var/lib/storj/.local
|
||||
|
||||
#only migrate first time
|
||||
if [ ! -f "/var/lib/storj/.local/migrated" ]; then
|
||||
satellite run migration --identity-dir $STORJ_IDENTITY_DIR
|
||||
touch /var/lib/storj/.local/migrated
|
||||
fi
|
||||
fi
|
||||
|
||||
# default config generated without arguments is misleading
|
||||
rm /root/.local/share/storj/satellite/config.yaml
|
||||
|
||||
mkdir -p /var/lib/storj/.local/share/storj/satellite || true
|
||||
|
||||
if [ "${GO_DLV:-""}" ]; then
|
||||
echo "Starting with go dlv"
|
||||
|
||||
#absolute file path is required
|
||||
CMD=$(which $1)
|
||||
shift
|
||||
/usr/local/bin/dlv --listen=:2345 --headless=true --api-version=2 --accept-multiclient exec --check-go-version=false -- $CMD "$@"
|
||||
exit $?
|
||||
fi
|
||||
fi
|
||||
|
||||
# for backward compatibility reason, we use argument as command, only if it's an executable (and use it as satellite flags oterwise)
|
||||
set +eo nounset
|
||||
which "$1" > /dev/null
|
||||
VALID_EXECUTABLE=$?
|
||||
set -eo nounset
|
||||
|
||||
if [ $VALID_EXECUTABLE -eq 0 ]; then
|
||||
# this is a full command (what storj-up uses)
|
||||
exec "$@"
|
||||
else
|
||||
# legacy, run-only parameters
|
||||
exec ./satellite run $RUN_PARAMS "$@"
|
||||
fi
|
||||
|
@ -40,7 +40,7 @@ import (
|
||||
"storj.io/storj/satellite/accounting/live"
|
||||
"storj.io/storj/satellite/compensation"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/payments/stripe"
|
||||
"storj.io/storj/satellite/satellitedb"
|
||||
)
|
||||
@ -100,6 +100,11 @@ var (
|
||||
Short: "Run the satellite API",
|
||||
RunE: cmdAPIRun,
|
||||
}
|
||||
runUICmd = &cobra.Command{
|
||||
Use: "ui",
|
||||
Short: "Run the satellite UI",
|
||||
RunE: cmdUIRun,
|
||||
}
|
||||
runRepairerCmd = &cobra.Command{
|
||||
Use: "repair",
|
||||
Short: "Run the repair service",
|
||||
@ -255,12 +260,19 @@ var (
|
||||
Long: "Finalizes all draft stripe invoices known to satellite's stripe account.",
|
||||
RunE: cmdFinalizeCustomerInvoices,
|
||||
}
|
||||
payCustomerInvoicesCmd = &cobra.Command{
|
||||
payInvoicesWithTokenCmd = &cobra.Command{
|
||||
Use: "pay-customer-invoices",
|
||||
Short: "pay open finalized invoices for customer",
|
||||
Long: "attempts payment on any open finalized invoices for a specific user.",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: cmdPayCustomerInvoices,
|
||||
}
|
||||
payAllInvoicesCmd = &cobra.Command{
|
||||
Use: "pay-invoices",
|
||||
Short: "pay finalized invoices",
|
||||
Long: "attempts payment on all open finalized invoices according to subscriptions settings.",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: cmdPayCustomerInvoices,
|
||||
RunE: cmdPayAllInvoices,
|
||||
}
|
||||
stripeCustomerCmd = &cobra.Command{
|
||||
Use: "ensure-stripe-customer",
|
||||
@ -366,6 +378,7 @@ func init() {
|
||||
rootCmd.AddCommand(runCmd)
|
||||
runCmd.AddCommand(runMigrationCmd)
|
||||
runCmd.AddCommand(runAPICmd)
|
||||
runCmd.AddCommand(runUICmd)
|
||||
runCmd.AddCommand(runAdminCmd)
|
||||
runCmd.AddCommand(runRepairerCmd)
|
||||
runCmd.AddCommand(runAuditorCmd)
|
||||
@ -398,12 +411,14 @@ func init() {
|
||||
billingCmd.AddCommand(createCustomerInvoicesCmd)
|
||||
billingCmd.AddCommand(generateCustomerInvoicesCmd)
|
||||
billingCmd.AddCommand(finalizeCustomerInvoicesCmd)
|
||||
billingCmd.AddCommand(payCustomerInvoicesCmd)
|
||||
billingCmd.AddCommand(payInvoicesWithTokenCmd)
|
||||
billingCmd.AddCommand(payAllInvoicesCmd)
|
||||
billingCmd.AddCommand(stripeCustomerCmd)
|
||||
consistencyCmd.AddCommand(consistencyGECleanupCmd)
|
||||
process.Bind(runCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(runMigrationCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(runAPICmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(runUICmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(runAdminCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(runRepairerCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(runAuditorCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
@ -432,7 +447,8 @@ func init() {
|
||||
process.Bind(createCustomerInvoicesCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(generateCustomerInvoicesCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(finalizeCustomerInvoicesCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(payCustomerInvoicesCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(payInvoicesWithTokenCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(payAllInvoicesCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(stripeCustomerCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(consistencyGECleanupCmd, &consistencyGECleanupCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
process.Bind(fixLastNetsCmd, &runCfg, defaults, cfgstruct.ConfDir(confDir), cfgstruct.IdentityDir(identityDir))
|
||||
@ -862,6 +878,18 @@ func cmdFinalizeCustomerInvoices(cmd *cobra.Command, args []string) (err error)
|
||||
func cmdPayCustomerInvoices(cmd *cobra.Command, args []string) (err error) {
|
||||
ctx, _ := process.Ctx(cmd)
|
||||
|
||||
return runBillingCmd(ctx, func(ctx context.Context, payments *stripe.Service, _ satellite.DB) error {
|
||||
err := payments.InvoiceApplyCustomerTokenBalance(ctx, args[0])
|
||||
if err != nil {
|
||||
return errs.New("error applying native token payments to invoice for customer: %v", err)
|
||||
}
|
||||
return payments.PayCustomerInvoices(ctx, args[0])
|
||||
})
|
||||
}
|
||||
|
||||
func cmdPayAllInvoices(cmd *cobra.Command, args []string) (err error) {
|
||||
ctx, _ := process.Ctx(cmd)
|
||||
|
||||
periodStart, err := parseYearMonth(args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
@ -932,7 +960,7 @@ func cmdRestoreTrash(cmd *cobra.Command, args []string) error {
|
||||
successes := new(int64)
|
||||
failures := new(int64)
|
||||
|
||||
undelete := func(node *overlay.SelectedNode) {
|
||||
undelete := func(node *nodeselection.SelectedNode) {
|
||||
log.Info("starting restore trash", zap.String("Node ID", node.ID.String()))
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
@ -966,9 +994,9 @@ func cmdRestoreTrash(cmd *cobra.Command, args []string) error {
|
||||
log.Info("successful restore trash", zap.String("Node ID", node.ID.String()))
|
||||
}
|
||||
|
||||
var nodes []*overlay.SelectedNode
|
||||
var nodes []*nodeselection.SelectedNode
|
||||
if len(args) == 0 {
|
||||
err = db.OverlayCache().IterateAllContactedNodes(ctx, func(ctx context.Context, node *overlay.SelectedNode) error {
|
||||
err = db.OverlayCache().IterateAllContactedNodes(ctx, func(ctx context.Context, node *nodeselection.SelectedNode) error {
|
||||
nodes = append(nodes, node)
|
||||
return nil
|
||||
})
|
||||
@ -985,7 +1013,7 @@ func cmdRestoreTrash(cmd *cobra.Command, args []string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nodes = append(nodes, &overlay.SelectedNode{
|
||||
nodes = append(nodes, &nodeselection.SelectedNode{
|
||||
ID: dossier.Id,
|
||||
Address: dossier.Address,
|
||||
LastNet: dossier.LastNet,
|
||||
|
@ -94,7 +94,7 @@ func cmdRepairSegment(cmd *cobra.Command, args []string) (err error) {
|
||||
|
||||
dialer := rpc.NewDefaultDialer(tlsOptions)
|
||||
|
||||
overlay, err := overlay.NewService(log.Named("overlay"), db.OverlayCache(), db.NodeEvents(), config.Console.ExternalAddress, config.Console.SatelliteName, config.Overlay)
|
||||
overlayService, err := overlay.NewService(log.Named("overlay"), db.OverlayCache(), db.NodeEvents(), config.Placement.CreateFilters, config.Console.ExternalAddress, config.Console.SatelliteName, config.Overlay)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -102,8 +102,9 @@ func cmdRepairSegment(cmd *cobra.Command, args []string) (err error) {
|
||||
orders, err := orders.NewService(
|
||||
log.Named("orders"),
|
||||
signing.SignerFromFullIdentity(identity),
|
||||
overlay,
|
||||
overlayService,
|
||||
orders.NewNoopDB(),
|
||||
config.Placement.CreateFilters,
|
||||
config.Orders,
|
||||
)
|
||||
if err != nil {
|
||||
@ -122,9 +123,10 @@ func cmdRepairSegment(cmd *cobra.Command, args []string) (err error) {
|
||||
log.Named("segment-repair"),
|
||||
metabaseDB,
|
||||
orders,
|
||||
overlay,
|
||||
overlayService,
|
||||
nil, // TODO add noop version
|
||||
ecRepairer,
|
||||
config.Placement.CreateFilters,
|
||||
config.Checker.RepairOverrides,
|
||||
config.Repairer,
|
||||
)
|
||||
@ -132,7 +134,7 @@ func cmdRepairSegment(cmd *cobra.Command, args []string) (err error) {
|
||||
// TODO reorganize to avoid using peer.
|
||||
|
||||
peer := &satellite.Repairer{}
|
||||
peer.Overlay = overlay
|
||||
peer.Overlay = overlayService
|
||||
peer.Orders.Service = orders
|
||||
peer.EcRepairer = ecRepairer
|
||||
peer.SegmentRepairer = segmentRepairer
|
||||
|
47
cmd/satellite/ui.go
Normal file
47
cmd/satellite/ui.go
Normal file
@ -0,0 +1,47 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/zeebo/errs"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"storj.io/private/process"
|
||||
"storj.io/storj/satellite"
|
||||
)
|
||||
|
||||
func cmdUIRun(cmd *cobra.Command, args []string) (err error) {
|
||||
ctx, _ := process.Ctx(cmd)
|
||||
log := zap.L()
|
||||
|
||||
runCfg.Debug.Address = *process.DebugAddrFlag
|
||||
|
||||
identity, err := runCfg.Identity.Load()
|
||||
if err != nil {
|
||||
log.Error("Failed to load identity.", zap.Error(err))
|
||||
return errs.New("Failed to load identity: %+v", err)
|
||||
}
|
||||
|
||||
satAddr := runCfg.Config.Contact.ExternalAddress
|
||||
if satAddr == "" {
|
||||
return errs.New("cannot run satellite ui if contact.external-address is not set")
|
||||
}
|
||||
apiAddress := runCfg.Config.Console.ExternalAddress
|
||||
if apiAddress == "" {
|
||||
apiAddress = runCfg.Config.Console.Address
|
||||
}
|
||||
peer, err := satellite.NewUI(log, identity, &runCfg.Config, process.AtomicLevel(cmd), satAddr, apiAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := process.InitMetricsWithHostname(ctx, log, nil); err != nil {
|
||||
log.Warn("Failed to initialize telemetry batcher on satellite api", zap.Error(err))
|
||||
}
|
||||
|
||||
runError := peer.Run(ctx)
|
||||
closeError := peer.Close()
|
||||
return errs.Combine(runError, closeError)
|
||||
}
|
248
cmd/tools/migrate-segment-copies/main.go
Normal file
248
cmd/tools/migrate-segment-copies/main.go
Normal file
@ -0,0 +1,248 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/spacemonkeygo/monkit/v3"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/zeebo/errs"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/private/cfgstruct"
|
||||
"storj.io/private/dbutil/pgutil"
|
||||
"storj.io/private/process"
|
||||
"storj.io/private/tagsql"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
)
|
||||
|
||||
var mon = monkit.Package()
|
||||
|
||||
var (
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "migrate-segment-copies",
|
||||
Short: "migrate-segment-copies",
|
||||
}
|
||||
|
||||
runCmd = &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "run migrate-segment-copies",
|
||||
RunE: run,
|
||||
}
|
||||
|
||||
config Config
|
||||
)
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
cfgstruct.Bind(pflag.CommandLine, &config)
|
||||
}
|
||||
|
||||
// Config defines configuration for migration.
|
||||
type Config struct {
|
||||
MetabaseDB string `help:"connection URL for metabaseDB"`
|
||||
BatchSize int `help:"number of entries from segment_copies processed at once" default:"2000"`
|
||||
SegmentCopiesBackup string `help:"cvs file where segment copies entries will be backup"`
|
||||
}
|
||||
|
||||
// VerifyFlags verifies whether the values provided are valid.
|
||||
func (config *Config) VerifyFlags() error {
|
||||
var errlist errs.Group
|
||||
if config.MetabaseDB == "" {
|
||||
errlist.Add(errors.New("flag '--metabasedb' is not set"))
|
||||
}
|
||||
return errlist.Err()
|
||||
}
|
||||
|
||||
func run(cmd *cobra.Command, args []string) error {
|
||||
if err := config.VerifyFlags(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, _ := process.Ctx(cmd)
|
||||
log := zap.L()
|
||||
return Migrate(ctx, log, config)
|
||||
}
|
||||
|
||||
func main() {
|
||||
process.Exec(rootCmd)
|
||||
}
|
||||
|
||||
// Migrate starts segment copies migration.
|
||||
func Migrate(ctx context.Context, log *zap.Logger, config Config) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
db, err := metabase.Open(ctx, log, config.MetabaseDB, metabase.Config{})
|
||||
if err != nil {
|
||||
return errs.New("unable to connect %q: %w", config.MetabaseDB, err)
|
||||
}
|
||||
defer func() {
|
||||
err = errs.Combine(err, db.Close())
|
||||
}()
|
||||
|
||||
return MigrateSegments(ctx, log, db, config)
|
||||
}
|
||||
|
||||
// MigrateSegments updates segment copies with proper metadata (pieces and placment).
|
||||
func MigrateSegments(ctx context.Context, log *zap.Logger, metabaseDB *metabase.DB, config Config) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
var backupCSV *csv.Writer
|
||||
if config.SegmentCopiesBackup != "" {
|
||||
f, err := os.Create(config.SegmentCopiesBackup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = errs.Combine(err, f.Close())
|
||||
}()
|
||||
|
||||
backupCSV = csv.NewWriter(f)
|
||||
|
||||
defer backupCSV.Flush()
|
||||
|
||||
if err := backupCSV.Write([]string{"stream_id", "ancestor_stream_id"}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
db := metabaseDB.UnderlyingTagSQL()
|
||||
|
||||
var streamIDCursor uuid.UUID
|
||||
ancestorStreamIDs := []uuid.UUID{}
|
||||
streamIDs := []uuid.UUID{}
|
||||
processed := 0
|
||||
|
||||
// what we are doing here:
|
||||
// * read batch of entries from segment_copies table
|
||||
// * read ancestors (original) segments metadata from segments table
|
||||
// * update segment copies with missing metadata, one by one
|
||||
// * delete entries from segment_copies table
|
||||
for {
|
||||
log.Info("Processed entries", zap.Int("processed", processed))
|
||||
|
||||
ancestorStreamIDs = ancestorStreamIDs[:0]
|
||||
streamIDs = streamIDs[:0]
|
||||
|
||||
idsMap := map[uuid.UUID][]uuid.UUID{}
|
||||
err := withRows(db.QueryContext(ctx, `
|
||||
SELECT stream_id, ancestor_stream_id FROM segment_copies WHERE stream_id > $1 ORDER BY stream_id LIMIT $2
|
||||
`, streamIDCursor, config.BatchSize))(func(rows tagsql.Rows) error {
|
||||
for rows.Next() {
|
||||
var streamID, ancestorStreamID uuid.UUID
|
||||
err := rows.Scan(&streamID, &ancestorStreamID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
streamIDCursor = streamID
|
||||
ancestorStreamIDs = append(ancestorStreamIDs, ancestorStreamID)
|
||||
streamIDs = append(streamIDs, streamID)
|
||||
|
||||
idsMap[ancestorStreamID] = append(idsMap[ancestorStreamID], streamID)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
type Update struct {
|
||||
StreamID uuid.UUID
|
||||
AncestorStreamID uuid.UUID
|
||||
Position int64
|
||||
RemoteAliasPieces []byte
|
||||
RootPieceID []byte
|
||||
RepairedAt *time.Time
|
||||
Placement int64
|
||||
}
|
||||
|
||||
updates := []Update{}
|
||||
err = withRows(db.QueryContext(ctx, `
|
||||
SELECT stream_id, position, remote_alias_pieces, root_piece_id, repaired_at, placement FROM segments WHERE stream_id = ANY($1::BYTEA[])
|
||||
`, pgutil.UUIDArray(ancestorStreamIDs)))(func(rows tagsql.Rows) error {
|
||||
for rows.Next() {
|
||||
var ancestorStreamID uuid.UUID
|
||||
var position int64
|
||||
var remoteAliasPieces, rootPieceID []byte
|
||||
var repairedAt *time.Time
|
||||
var placement int64
|
||||
err := rows.Scan(&ancestorStreamID, &position, &remoteAliasPieces, &rootPieceID, &repairedAt, &placement)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
streamIDs, ok := idsMap[ancestorStreamID]
|
||||
if !ok {
|
||||
return errs.New("unable to map ancestor stream id: %s", ancestorStreamID)
|
||||
}
|
||||
|
||||
for _, streamID := range streamIDs {
|
||||
updates = append(updates, Update{
|
||||
StreamID: streamID,
|
||||
AncestorStreamID: ancestorStreamID,
|
||||
Position: position,
|
||||
RemoteAliasPieces: remoteAliasPieces,
|
||||
RootPieceID: rootPieceID,
|
||||
RepairedAt: repairedAt,
|
||||
Placement: placement,
|
||||
})
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, update := range updates {
|
||||
_, err := db.ExecContext(ctx, `
|
||||
UPDATE segments SET
|
||||
remote_alias_pieces = $3,
|
||||
root_piece_id = $4,
|
||||
repaired_at = $5,
|
||||
placement = $6
|
||||
WHERE (stream_id, position) = ($1, $2)
|
||||
`, update.StreamID, update.Position, update.RemoteAliasPieces, update.RootPieceID, update.RepairedAt, update.Placement)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if backupCSV != nil {
|
||||
if err := backupCSV.Write([]string{update.StreamID.String(), update.AncestorStreamID.String()}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if backupCSV != nil {
|
||||
backupCSV.Flush()
|
||||
}
|
||||
|
||||
processed += len(streamIDs)
|
||||
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func withRows(rows tagsql.Rows, err error) func(func(tagsql.Rows) error) error {
|
||||
return func(callback func(tagsql.Rows) error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := callback(rows)
|
||||
return errs.Combine(rows.Err(), rows.Close(), err)
|
||||
}
|
||||
}
|
324
cmd/tools/migrate-segment-copies/main_test.go
Normal file
324
cmd/tools/migrate-segment-copies/main_test.go
Normal file
@ -0,0 +1,324 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package main_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
|
||||
"storj.io/common/memory"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
"storj.io/common/uuid"
|
||||
cmd "storj.io/storj/cmd/tools/migrate-segment-copies"
|
||||
"storj.io/storj/private/testplanet"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/metabase/metabasetest"
|
||||
)
|
||||
|
||||
func TestMigrateSingleCopy(t *testing.T) {
|
||||
metabasetest.Run(t, func(ctx *testcontext.Context, t *testing.T, metabaseDB *metabase.DB) {
|
||||
obj := metabasetest.RandObjectStream()
|
||||
|
||||
expectedPieces := metabase.Pieces{
|
||||
{Number: 1, StorageNode: testrand.NodeID()},
|
||||
{Number: 3, StorageNode: testrand.NodeID()},
|
||||
}
|
||||
|
||||
object, _ := metabasetest.CreateTestObject{
|
||||
CreateSegment: func(object metabase.Object, index int) metabase.Segment {
|
||||
metabasetest.CommitSegment{
|
||||
Opts: metabase.CommitSegment{
|
||||
ObjectStream: obj,
|
||||
Position: metabase.SegmentPosition{Part: 0, Index: uint32(index)},
|
||||
RootPieceID: testrand.PieceID(),
|
||||
|
||||
Pieces: expectedPieces,
|
||||
|
||||
EncryptedKey: []byte{3},
|
||||
EncryptedKeyNonce: []byte{4},
|
||||
EncryptedETag: []byte{5},
|
||||
|
||||
EncryptedSize: 1024,
|
||||
PlainSize: 512,
|
||||
PlainOffset: 0,
|
||||
Redundancy: metabasetest.DefaultRedundancy,
|
||||
Placement: storj.EEA,
|
||||
},
|
||||
}.Check(ctx, t, metabaseDB)
|
||||
|
||||
return metabase.Segment{}
|
||||
},
|
||||
}.Run(ctx, t, metabaseDB, obj, 50)
|
||||
|
||||
copyObject, _, _ := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: object,
|
||||
}.Run(ctx, t, metabaseDB, false)
|
||||
|
||||
segments, err := metabaseDB.TestingAllSegments(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, segment := range segments {
|
||||
if segment.StreamID == copyObject.StreamID {
|
||||
require.Len(t, segment.Pieces, 0)
|
||||
require.Equal(t, storj.EveryCountry, segment.Placement)
|
||||
}
|
||||
}
|
||||
|
||||
require.NotZero(t, numberOfSegmentCopies(t, ctx, metabaseDB))
|
||||
|
||||
err = cmd.MigrateSegments(ctx, zaptest.NewLogger(t), metabaseDB, cmd.Config{
|
||||
BatchSize: 3,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
segments, err = metabaseDB.TestingAllSegments(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, segment := range segments {
|
||||
require.Equal(t, expectedPieces, segment.Pieces)
|
||||
require.Equal(t, storj.EEA, segment.Placement)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateManyCopies(t *testing.T) {
|
||||
metabasetest.Run(t, func(ctx *testcontext.Context, t *testing.T, metabaseDB *metabase.DB) {
|
||||
obj := metabasetest.RandObjectStream()
|
||||
|
||||
expectedPieces := metabase.Pieces{
|
||||
{Number: 1, StorageNode: testrand.NodeID()},
|
||||
{Number: 3, StorageNode: testrand.NodeID()},
|
||||
}
|
||||
|
||||
object, _ := metabasetest.CreateTestObject{
|
||||
CreateSegment: func(object metabase.Object, index int) metabase.Segment {
|
||||
metabasetest.CommitSegment{
|
||||
Opts: metabase.CommitSegment{
|
||||
ObjectStream: obj,
|
||||
Position: metabase.SegmentPosition{Part: 0, Index: uint32(index)},
|
||||
RootPieceID: testrand.PieceID(),
|
||||
|
||||
Pieces: expectedPieces,
|
||||
|
||||
EncryptedKey: []byte{3},
|
||||
EncryptedKeyNonce: []byte{4},
|
||||
EncryptedETag: []byte{5},
|
||||
|
||||
EncryptedSize: 1024,
|
||||
PlainSize: 512,
|
||||
PlainOffset: 0,
|
||||
Redundancy: metabasetest.DefaultRedundancy,
|
||||
Placement: storj.EEA,
|
||||
},
|
||||
}.Check(ctx, t, metabaseDB)
|
||||
|
||||
return metabase.Segment{}
|
||||
},
|
||||
}.Run(ctx, t, metabaseDB, obj, 20)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
copyObject, _, _ := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: object,
|
||||
}.Run(ctx, t, metabaseDB, false)
|
||||
|
||||
segments, err := metabaseDB.TestingAllSegments(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, segment := range segments {
|
||||
if segment.StreamID == copyObject.StreamID {
|
||||
require.Len(t, segment.Pieces, 0)
|
||||
require.Equal(t, storj.EveryCountry, segment.Placement)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
require.NotZero(t, numberOfSegmentCopies(t, ctx, metabaseDB))
|
||||
|
||||
err := cmd.MigrateSegments(ctx, zaptest.NewLogger(t), metabaseDB, cmd.Config{
|
||||
BatchSize: 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
segments, err := metabaseDB.TestingAllSegments(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, segment := range segments {
|
||||
require.Equal(t, expectedPieces, segment.Pieces)
|
||||
require.Equal(t, storj.EEA, segment.Placement)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateDifferentSegment(t *testing.T) {
|
||||
metabasetest.Run(t, func(ctx *testcontext.Context, t *testing.T, metabaseDB *metabase.DB) {
|
||||
type Segment struct {
|
||||
StreamID uuid.UUID
|
||||
Position int64
|
||||
}
|
||||
|
||||
expectedResults := map[Segment]metabase.Pieces{}
|
||||
createData := func(numberOfObjecsts int, pieces metabase.Pieces) {
|
||||
for i := 0; i < numberOfObjecsts; i++ {
|
||||
numberOfSegments := 3
|
||||
obj := metabasetest.RandObjectStream()
|
||||
object, _ := metabasetest.CreateTestObject{
|
||||
CreateSegment: func(object metabase.Object, index int) metabase.Segment {
|
||||
metabasetest.CommitSegment{
|
||||
Opts: metabase.CommitSegment{
|
||||
ObjectStream: obj,
|
||||
Position: metabase.SegmentPosition{Part: 0, Index: uint32(index)},
|
||||
RootPieceID: testrand.PieceID(),
|
||||
|
||||
Pieces: pieces,
|
||||
|
||||
EncryptedKey: []byte{3},
|
||||
EncryptedKeyNonce: []byte{4},
|
||||
EncryptedETag: []byte{5},
|
||||
|
||||
EncryptedSize: 1024,
|
||||
PlainSize: 512,
|
||||
PlainOffset: 0,
|
||||
Redundancy: metabasetest.DefaultRedundancy,
|
||||
Placement: storj.EEA,
|
||||
},
|
||||
}.Check(ctx, t, metabaseDB)
|
||||
|
||||
return metabase.Segment{}
|
||||
},
|
||||
}.Run(ctx, t, metabaseDB, obj, 3)
|
||||
for n := 0; n < numberOfSegments; n++ {
|
||||
expectedResults[Segment{
|
||||
StreamID: object.StreamID,
|
||||
Position: int64(n),
|
||||
}] = pieces
|
||||
}
|
||||
|
||||
copyObject, _, _ := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: object,
|
||||
}.Run(ctx, t, metabaseDB, false)
|
||||
|
||||
for n := 0; n < numberOfSegments; n++ {
|
||||
expectedResults[Segment{
|
||||
StreamID: copyObject.StreamID,
|
||||
Position: int64(n),
|
||||
}] = pieces
|
||||
|
||||
segments, err := metabaseDB.TestingAllSegments(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, segment := range segments {
|
||||
if segment.StreamID == copyObject.StreamID {
|
||||
require.Len(t, segment.Pieces, 0)
|
||||
require.Equal(t, storj.EveryCountry, segment.Placement)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expectedPieces := metabase.Pieces{
|
||||
{Number: 1, StorageNode: testrand.NodeID()},
|
||||
{Number: 3, StorageNode: testrand.NodeID()},
|
||||
}
|
||||
createData(5, expectedPieces)
|
||||
|
||||
expectedPieces = metabase.Pieces{
|
||||
{Number: 2, StorageNode: testrand.NodeID()},
|
||||
{Number: 4, StorageNode: testrand.NodeID()},
|
||||
}
|
||||
createData(5, expectedPieces)
|
||||
|
||||
require.NotZero(t, numberOfSegmentCopies(t, ctx, metabaseDB))
|
||||
|
||||
err := cmd.MigrateSegments(ctx, zaptest.NewLogger(t), metabaseDB, cmd.Config{
|
||||
BatchSize: 7,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
segments, err := metabaseDB.TestingAllSegments(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(expectedResults), len(segments))
|
||||
for _, segment := range segments {
|
||||
pieces := expectedResults[Segment{
|
||||
StreamID: segment.StreamID,
|
||||
Position: int64(segment.Position.Encode()),
|
||||
}]
|
||||
require.Equal(t, pieces, segment.Pieces)
|
||||
require.Equal(t, storj.EEA, segment.Placement)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func numberOfSegmentCopies(t *testing.T, ctx *testcontext.Context, metabaseDB *metabase.DB) int {
|
||||
var count int
|
||||
err := metabaseDB.UnderlyingTagSQL().QueryRow(ctx, "SELECT count(1) FROM segment_copies").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
return count
|
||||
}
|
||||
|
||||
func TestMigrateEndToEnd(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 4, UplinkCount: 1,
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
expectedData := testrand.Bytes(10 * memory.KiB)
|
||||
err := planet.Uplinks[0].Upload(ctx, planet.Satellites[0], "test", "object", expectedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
project, err := planet.Uplinks[0].OpenProject(ctx, planet.Satellites[0])
|
||||
require.NoError(t, err)
|
||||
defer ctx.Check(project.Close)
|
||||
|
||||
_, err = project.CopyObject(ctx, "test", "object", "test", "object-copy", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := planet.Uplinks[0].Download(ctx, planet.Satellites[0], "test", "object-copy")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedData, data)
|
||||
|
||||
err = cmd.MigrateSegments(ctx, zaptest.NewLogger(t), planet.Satellites[0].Metabase.DB, cmd.Config{
|
||||
BatchSize: 1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err = planet.Uplinks[0].Download(ctx, planet.Satellites[0], "test", "object-copy")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedData, data)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateBackupCSV(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 4, UplinkCount: 1,
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
expectedData := testrand.Bytes(10 * memory.KiB)
|
||||
err := planet.Uplinks[0].Upload(ctx, planet.Satellites[0], "test", "object", expectedData)
|
||||
require.NoError(t, err)
|
||||
|
||||
project, err := planet.Uplinks[0].OpenProject(ctx, planet.Satellites[0])
|
||||
require.NoError(t, err)
|
||||
defer ctx.Check(project.Close)
|
||||
|
||||
_, err = project.CopyObject(ctx, "test", "object", "test", "object-copy", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := planet.Uplinks[0].Download(ctx, planet.Satellites[0], "test", "object-copy")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedData, data)
|
||||
|
||||
backupFile := ctx.File("backupcsv")
|
||||
err = cmd.MigrateSegments(ctx, zaptest.NewLogger(t), planet.Satellites[0].Metabase.DB, cmd.Config{
|
||||
BatchSize: 1,
|
||||
SegmentCopiesBackup: backupFile,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err = planet.Uplinks[0].Download(ctx, planet.Satellites[0], "test", "object-copy")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedData, data)
|
||||
|
||||
fileByes, err := os.ReadFile(backupFile)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, fileByes)
|
||||
})
|
||||
}
|
@ -203,12 +203,12 @@ func verifySegments(cmd *cobra.Command, args []string) error {
|
||||
dialer := rpc.NewDefaultDialer(tlsOptions)
|
||||
|
||||
// setup dependencies for verification
|
||||
overlay, err := overlay.NewService(log.Named("overlay"), db.OverlayCache(), db.NodeEvents(), "", "", satelliteCfg.Overlay)
|
||||
overlayService, err := overlay.NewService(log.Named("overlay"), db.OverlayCache(), db.NodeEvents(), overlay.NewPlacementRules().CreateFilters, "", "", satelliteCfg.Overlay)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
||||
ordersService, err := orders.NewService(log.Named("orders"), signing.SignerFromFullIdentity(identity), overlay, orders.NewNoopDB(), satelliteCfg.Orders)
|
||||
ordersService, err := orders.NewService(log.Named("orders"), signing.SignerFromFullIdentity(identity), overlayService, orders.NewNoopDB(), overlay.NewPlacementRules().CreateFilters, satelliteCfg.Orders)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
@ -243,7 +243,7 @@ func verifySegments(cmd *cobra.Command, args []string) error {
|
||||
|
||||
// setup verifier
|
||||
verifier := NewVerifier(log.Named("verifier"), dialer, ordersService, verifyConfig)
|
||||
service, err := NewService(log.Named("service"), metabaseDB, verifier, overlay, serviceConfig)
|
||||
service, err := NewService(log.Named("service"), metabaseDB, verifier, overlayService, serviceConfig)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/private/process"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
"storj.io/storj/satellite/satellitedb"
|
||||
)
|
||||
@ -78,7 +79,7 @@ type NodeCheckConfig struct {
|
||||
|
||||
// NodeCheckOverlayDB contains dependencies from overlay that are needed for the processing.
|
||||
type NodeCheckOverlayDB interface {
|
||||
IterateAllContactedNodes(context.Context, func(context.Context, *overlay.SelectedNode) error) error
|
||||
IterateAllContactedNodes(context.Context, func(context.Context, *nodeselection.SelectedNode) error) error
|
||||
IterateAllNodeDossiers(context.Context, func(context.Context, *overlay.NodeDossier) error) error
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"storj.io/common/uuid"
|
||||
"storj.io/storj/satellite/audit"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
)
|
||||
|
||||
@ -46,7 +47,7 @@ type Verifier interface {
|
||||
type Overlay interface {
|
||||
// Get looks up the node by nodeID
|
||||
Get(ctx context.Context, nodeID storj.NodeID) (*overlay.NodeDossier, error)
|
||||
SelectAllStorageNodesDownload(ctx context.Context, onlineWindow time.Duration, asOf overlay.AsOfSystemTimeConfig) ([]*overlay.SelectedNode, error)
|
||||
SelectAllStorageNodesDownload(ctx context.Context, onlineWindow time.Duration, asOf overlay.AsOfSystemTimeConfig) ([]*nodeselection.SelectedNode, error)
|
||||
}
|
||||
|
||||
// SegmentWriter allows writing segments to some output.
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
segmentverify "storj.io/storj/cmd/tools/segment-verify"
|
||||
"storj.io/storj/private/testplanet"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
)
|
||||
|
||||
@ -344,10 +345,10 @@ func (db *metabaseMock) Get(ctx context.Context, nodeID storj.NodeID) (*overlay.
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *metabaseMock) SelectAllStorageNodesDownload(ctx context.Context, onlineWindow time.Duration, asOf overlay.AsOfSystemTimeConfig) ([]*overlay.SelectedNode, error) {
|
||||
var xs []*overlay.SelectedNode
|
||||
func (db *metabaseMock) SelectAllStorageNodesDownload(ctx context.Context, onlineWindow time.Duration, asOf overlay.AsOfSystemTimeConfig) ([]*nodeselection.SelectedNode, error) {
|
||||
var xs []*nodeselection.SelectedNode
|
||||
for nodeID := range db.nodeIDToAlias {
|
||||
xs = append(xs, &overlay.SelectedNode{
|
||||
xs = append(xs, &nodeselection.SelectedNode{
|
||||
ID: nodeID,
|
||||
Address: &pb.NodeAddress{
|
||||
Address: fmt.Sprintf("nodeid:%v", nodeID),
|
||||
|
186
cmd/tools/tag-signer/main.go
Normal file
186
cmd/tools/tag-signer/main.go
Normal file
@ -0,0 +1,186 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
"storj.io/common/identity"
|
||||
"storj.io/common/nodetag"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/signing"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/private/process"
|
||||
)
|
||||
|
||||
var (
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "tag-signer",
|
||||
Short: "Sign key=value pairs with identity",
|
||||
Long: "Node tags are arbitrary key value pairs signed by an authority. If the public key is configured on " +
|
||||
"Satellite side, Satellite will check the signatures and save the tags, which can be used (for example)" +
|
||||
" during node selection. Storagenodes can be configured to send encoded node tags to the Satellite. " +
|
||||
"This utility helps creating/managing the values of this specific configuration value, which is encoded by default.",
|
||||
}
|
||||
|
||||
signCmd = &cobra.Command{
|
||||
Use: "sign <key=value> <key2=value> ...",
|
||||
Short: "Create signed tagset",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, _ := process.Ctx(cmd)
|
||||
encoded, err := signTags(ctx, config, args)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println(encoded)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
inspectCmd = &cobra.Command{
|
||||
Use: "inspect <encoded string>",
|
||||
Short: "Print out the details from an encoded node set",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, _ := process.Ctx(cmd)
|
||||
return inspect(ctx, args[0])
|
||||
},
|
||||
}
|
||||
|
||||
config Config
|
||||
)
|
||||
|
||||
// Config contains configuration required for signing.
|
||||
type Config struct {
|
||||
IdentityDir string `help:"location if the identity files" path:"true"`
|
||||
NodeID string `help:"the ID of the node, which will used this tag "`
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(signCmd)
|
||||
rootCmd.AddCommand(inspectCmd)
|
||||
process.Bind(signCmd, &config)
|
||||
}
|
||||
|
||||
func signTags(ctx context.Context, cfg Config, tagPairs []string) (string, error) {
|
||||
|
||||
if cfg.IdentityDir == "" {
|
||||
return "", errs.New("Please specify the identity, used as a signer with --identity-dir")
|
||||
}
|
||||
|
||||
if cfg.NodeID == "" {
|
||||
return "", errs.New("Please specify the --node-id")
|
||||
}
|
||||
|
||||
identityConfig := identity.Config{
|
||||
CertPath: filepath.Join(cfg.IdentityDir, "identity.cert"),
|
||||
KeyPath: filepath.Join(cfg.IdentityDir, "identity.key"),
|
||||
}
|
||||
|
||||
fullIdentity, err := identityConfig.Load()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signer := signing.SignerFromFullIdentity(fullIdentity)
|
||||
|
||||
nodeID, err := storj.NodeIDFromString(cfg.NodeID)
|
||||
if err != nil {
|
||||
return "", errs.New("Wrong NodeID format: %v", err)
|
||||
}
|
||||
tagSet := &pb.NodeTagSet{
|
||||
NodeId: nodeID.Bytes(),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
for _, tag := range tagPairs {
|
||||
tag = strings.TrimSpace(tag)
|
||||
if len(tag) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(tag, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", errs.New("tags should be in KEY=VALUE format, but it was %s", tag)
|
||||
}
|
||||
tagSet.Tags = append(tagSet.Tags, &pb.Tag{
|
||||
Name: parts[0],
|
||||
Value: []byte(parts[1]),
|
||||
})
|
||||
}
|
||||
|
||||
signedMessage, err := nodetag.Sign(ctx, tagSet, signer)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
all := &pb.SignedNodeTagSets{
|
||||
Tags: []*pb.SignedNodeTagSet{
|
||||
signedMessage,
|
||||
},
|
||||
}
|
||||
|
||||
raw, err := proto.Marshal(all)
|
||||
if err != nil {
|
||||
return "", errs.Wrap(err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(raw), nil
|
||||
}
|
||||
|
||||
func inspect(ctx context.Context, s string) error {
|
||||
raw, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return errs.New("Input is not in base64 format")
|
||||
}
|
||||
|
||||
sets := &pb.SignedNodeTagSets{}
|
||||
err = proto.Unmarshal(raw, sets)
|
||||
if err != nil {
|
||||
return errs.New("Input is not a protobuf encoded *pb.SignedNodeTagSets message")
|
||||
}
|
||||
|
||||
for _, msg := range sets.Tags {
|
||||
|
||||
signerNodeID, err := storj.NodeIDFromBytes(msg.SignerNodeId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("Signer: ", signerNodeID.String())
|
||||
fmt.Println("Signature: ", hex.EncodeToString(msg.Signature))
|
||||
|
||||
tags := &pb.NodeTagSet{}
|
||||
err = proto.Unmarshal(msg.SerializedTag, tags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nodeID, err := storj.NodeIDFromBytes(tags.NodeId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("Timestamp: ", time.Unix(tags.Timestamp, 0).Format(time.RFC3339))
|
||||
fmt.Println("NodeID: ", nodeID.String())
|
||||
fmt.Println("Tags:")
|
||||
for _, tag := range tags.Tags {
|
||||
fmt.Printf(" %s=%s\n", tag.Name, string(tag.Value))
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
process.Exec(rootCmd)
|
||||
}
|
273
docs/blueprints/certified-nodes.md
Normal file
273
docs/blueprints/certified-nodes.md
Normal file
@ -0,0 +1,273 @@
|
||||
# Node and operator certification
|
||||
|
||||
## Abstract
|
||||
|
||||
This is a proposal for a small feature and service that allows for nodes and
|
||||
operators to have signed tags of certain kinds for use in project-specific or
|
||||
Satellite-specific node selection.
|
||||
|
||||
## Background/context
|
||||
|
||||
We have a couple of ongoing needs:
|
||||
|
||||
* 1099 KYC
|
||||
* Private storage node networks
|
||||
* SOC2/HIPAA/etc node certification
|
||||
* Voting and operator signaling
|
||||
|
||||
### 1099 KYC
|
||||
|
||||
The United States has a rule that if node operators earn more than $600/year,
|
||||
we need to file a 1099 for each of them. Our current way of dealing with this
|
||||
is manual and time consuming, and so it would be nice to automate it.
|
||||
|
||||
Ultimately, we should be able to automatically:
|
||||
|
||||
1) keep track of which nodes are run by operators under or over the $600
|
||||
threshold.
|
||||
2) keep track of if an automated KYC service has signed off that we have the
|
||||
necessary information to file a 1099.
|
||||
3) automatically suspend nodes that have earned more than $600 but have not
|
||||
provided legally required information.
|
||||
|
||||
### Private storage node networks
|
||||
|
||||
We have seen growing interest from customers that want to bring their own
|
||||
hard drives, or be extremely choosy about the nodes they are willing to work
|
||||
with. The current way we are solving this is spinning up private Satellites
|
||||
that are configured to only work with the nodes those customers provide, but
|
||||
it would be better if we didn't have to start custom Satellites for this.
|
||||
|
||||
Instead, it would be nice to have a per-project configuration on an existing
|
||||
Satellite that allowed that project to specify a specific subset of verified
|
||||
or validated nodes, e.g., Project A should be able to say only nodes from
|
||||
node providers B and C should be selected. Symmetrically, Nodes from providers
|
||||
B and C may only want to accept data from certain projects, like Project A.
|
||||
|
||||
When nodes from providers B and C are added to the Satellite, they should be
|
||||
able to provide a provider-specific signature, and requirements about
|
||||
customer-specific requirements, if any.
|
||||
|
||||
### SOC2/HIPAA/etc node certification
|
||||
|
||||
This is actually just a slightly different shape of the private storage node
|
||||
network problem, but instead of being provider-specific, it is property
|
||||
specific.
|
||||
|
||||
Perhaps Project D has a compliance requirement. They can only store data
|
||||
on nodes that meet specific requirements.
|
||||
|
||||
Node operators E and F are willing to conform and attest to these compliance
|
||||
requirements, but don't know about project D. It would be nice if Node
|
||||
operators E and F could navigate to a compliance portal and see a list of
|
||||
potential compliance attestations available. For possible compliance
|
||||
attestations, node operators could sign agreements for these, and then receive
|
||||
a verified signature that shows their selected compliance options.
|
||||
|
||||
Then, Project D's node selection process would filter by nodes that had been
|
||||
approved for the necessary compliance requirements.
|
||||
|
||||
### Voting and operator signaling
|
||||
|
||||
As Satellite operators ourselves, we are currently engaged in a discussion about
|
||||
pricing changes with storage node operators. Future Satellite operators may find
|
||||
themselves in similar situations. It would be nice if storage node operators
|
||||
could indicate votes for values. This would potentially be more representative
|
||||
of network sentiment than posts on a forum.
|
||||
|
||||
Note that this isn't a transparent voting scheme, where other voters can see
|
||||
the votes made, so this may not be a great voting solution in general.
|
||||
|
||||
## Design and implementation
|
||||
|
||||
I believe there are two basic building blocks that solves all of the above
|
||||
issues:
|
||||
|
||||
* Signed node tags (with potential values)
|
||||
* A document signing service
|
||||
|
||||
### Signed node tags
|
||||
|
||||
The network representation:
|
||||
|
||||
```
|
||||
message Tag {
|
||||
// Note that there is a signal flat namespace of all names per
|
||||
// signer node id. Signers should be careful to make sure that
|
||||
// there are no name collisions. For self-signed content-hash
|
||||
// based values, the name should have the prefix of the content
|
||||
// hash.
|
||||
string name = 1;
|
||||
bytes value = 2; // optional, representation dependent on name.
|
||||
}
|
||||
|
||||
message TagSet {
|
||||
// must always be set. this is the node the signer is signing for.
|
||||
bytes node_id = 1;
|
||||
|
||||
repeated Tag tags = 2;
|
||||
|
||||
// must always be set. this makes sure the signature is signing the
|
||||
// timestamp inside.
|
||||
int64 timestamp = 3;
|
||||
}
|
||||
|
||||
message SignedTagSet {
|
||||
// this is the seralized form of TagSet, serialized so that
|
||||
// the signature process has something stable to work with.
|
||||
bytes serialized_tag = 1;
|
||||
|
||||
// this is who signed (could be self signed, could be well known).
|
||||
bytes signer_node_id = 3;
|
||||
bytes signature = 4;
|
||||
}
|
||||
|
||||
message SignedTagSets {
|
||||
repeated SignedTagSet tags = 1;
|
||||
}
|
||||
```
|
||||
|
||||
Note that every tag is signing a name/value pair (value optional) against
|
||||
a specific node id.
|
||||
|
||||
Note also that names are only unique within the namespace of a given signer.
|
||||
|
||||
The database representation on the Satellite. N.B.: nothing should be entered
|
||||
into this database without validation:
|
||||
|
||||
```
|
||||
model signed_tags (
|
||||
field node_id blob
|
||||
field name text
|
||||
field value blob
|
||||
field timestamp int64
|
||||
field signer_node_id blob
|
||||
)
|
||||
```
|
||||
|
||||
The "signer_node_id" is worth more explanation. Every signer should have a
|
||||
stable node id. Satellites and storage nodes already have one, but any other
|
||||
service that validates node tags would also need one.
|
||||
In particular, the document signing service (below) would have its own unique
|
||||
node id for signing tags, whereas for voting-style tags or tags based on a
|
||||
content-addressed identifier (e.g. a hash of a document), the nodes would
|
||||
self-sign.
|
||||
|
||||
### Document signing service
|
||||
|
||||
We would start a small web service, where users can log in and sign and fill
|
||||
out documents. This web service would then create a unique activation code
|
||||
that storage node operators could run on their storage nodes for activation and
|
||||
signing. They could run `storagenode activate <code>` and then the node would
|
||||
reach out to the signing service and get a `SignedTag` related to that node
|
||||
given the information the user provided. The node could then present these
|
||||
to the satellite.
|
||||
|
||||
Ultimately, the document signing service will require a separate design doc,
|
||||
but here are some considerations for it:
|
||||
|
||||
Activation codes must expire shortly. Even Netflix has two hours of validity
|
||||
for their service code - for a significantly less critical use case. What would
|
||||
be a usable validity time for our use case? 15 minutes? 1 hour? Should we make
|
||||
it configurable?
|
||||
|
||||
We want to still keep usability in mind for a SNO who needs to activate 500
|
||||
nodes.
|
||||
|
||||
It would be even better if the SNO could force invalidating the activation code
|
||||
when they are done with it.
|
||||
|
||||
As activation codes expire, the SNO should be able to generate a new activation
|
||||
code if they want to associate a new node to an already signed document.
|
||||
|
||||
It should be hard to brute-force activation codes. They shouldn't be simple
|
||||
numbers (4-digit or 6-digit) but something as complex as UUID.
|
||||
|
||||
It's also possible that SNO uses some signature mechanism during signing service
|
||||
authentication, and the same signature is used for activation. If the same
|
||||
signature mechanism is used during activation then no token is necessary.
|
||||
|
||||
### Update node selection
|
||||
|
||||
Once the above two building blocks exist, many problems become much more easily
|
||||
solvable.
|
||||
|
||||
We would want to extend node selection to be able to do queries,
|
||||
given project-specific configuration, based on these signed_tag values.
|
||||
|
||||
Because node selection mostly happens in memory from cached node table data,
|
||||
it should be easy to add some denormalized data for certain selected cases,
|
||||
such as:
|
||||
|
||||
* Document hashes nodes have self signed.
|
||||
* Approval states based on well known third party signer nodes (a KYC service).
|
||||
|
||||
Once these fields exist, then node selection can happen as before, filtering
|
||||
for the appropriate value given project settings.
|
||||
|
||||
## How these building blocks work for the example use cases
|
||||
|
||||
### 1099 KYC
|
||||
|
||||
The document signing service would have a KYC (Know Your Customer) form. Once
|
||||
filled out, the document signing service would make a `TagSet` that includes all
|
||||
of the answers to the KYC questions, for the given node id, signed by the
|
||||
document signing service's node id.
|
||||
|
||||
The node would hang on to this `SignedTagSet` and submit it along with others
|
||||
in a `SignedTagSets` to Satellites occasionally (maybe once a month during
|
||||
node CheckIn).
|
||||
|
||||
### Private storage node networks
|
||||
|
||||
Storage node provisioning would provide nodes with a signed `SignedTagSet`
|
||||
from a provisioning service that had its own node id. Then a private Satellite
|
||||
could be configured to require that all nodes present a `SignedTagSet` signed
|
||||
by the configured provisioning service that has that node's id in it.
|
||||
|
||||
Notably - this functionality could also be solved by the older waitlist node
|
||||
identity signing certificate process, but we are slowly removing what remains
|
||||
of that feature over time.
|
||||
|
||||
This functionality could also be solved by setting the Satellite's minimum
|
||||
allowable node id difficulty to the maximum possible difficulty, thus preventing
|
||||
any automatic node registration, and manually inserting node ids into the
|
||||
database. This is what we are currently doing for private network trials, but
|
||||
if `SignedTagSet`s existed, that would be easier.
|
||||
|
||||
### SOC2/HIPAA/etc node certification
|
||||
|
||||
For any type of document that doesn't require any third party service
|
||||
(such as government id validation, etc), the document and its fields can be
|
||||
filled out and self signed by the node, along with a content hash of the
|
||||
document in question.
|
||||
|
||||
The node would create a `TagSet`, where one field is the hash of the legal
|
||||
document that was agreed upon, and the remaining fields (with names prefixed
|
||||
by the document's content hash) would be form fields
|
||||
that the node operator filled in and ascribed to the document. Then, the
|
||||
`TagSet` would be signed by the node itself. The cryptographic nature of the
|
||||
content hash inside the `TagSet` would validate what the node operator had
|
||||
agreed to.
|
||||
|
||||
### Voting and operator signaling
|
||||
|
||||
Node operators could self sign additional `Tag`s inside of a miscellaneous
|
||||
`TagSet`, including `Tag`s such as
|
||||
|
||||
```
|
||||
"storage-node-vote-20230611-network-change": "yes"
|
||||
```
|
||||
|
||||
Or similar.
|
||||
|
||||
## Open problems
|
||||
|
||||
* Revocation? - `TagSets` have a timestamp inside that must be filled out. In
|
||||
The future, certain tags could have an expiry or updated values or similar.
|
||||
|
||||
## Other options
|
||||
|
||||
## Wrapup
|
||||
|
||||
## Related work
|
25
docs/testplan/project-cowbell-testplan.md
Normal file
25
docs/testplan/project-cowbell-testplan.md
Normal file
@ -0,0 +1,25 @@
|
||||
# Mini Cowbell Testplan
|
||||
|
||||
|
||||
|
||||
## Background
|
||||
We want to deploy the entire Storj stack on environments that have kubernetes running on 5 NUCs.
|
||||
|
||||
|
||||
|
||||
## Pre-condition
|
||||
Configuration for satellites that only have 5 node and the recommended RS scheme is [2,3,4,4] where:
|
||||
- 2 is the number of required pieces to reconstitute the segment.
|
||||
- 3 is the repair threshold, i.e. if a segment remains with only 3 healthy pieces, it will be repaired.
|
||||
- 4 is the success threshold, i.e. the number of pieces required for a successful upload or repair.
|
||||
- 4 is the number of total erasure-coded pieces that will be generated.
|
||||
|
||||
|
||||
| Test Scenario | Test Case | Description | Comments |
|
||||
|---------------|--------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Upload | Upload with all nodes online | Every file is uploaded to 4 nodes with 2x expansion factor. So one node has no files. | Happy path scenario |
|
||||
| | Upload with one node offline | If one of five nodes fails and goes offline, 80% of the stored data will lose one erasure-coded piece. The health status of these segments will be reduced from 4 pieces to 3 pieces and will mark these segments for repair. overlay.node.online-window: 4h0m0s -> for about 4 hours the node will still be selected for uploads) | Uploads will continue uninterrupted if the client uses the new refactored upload path. This improved upload logic will request the satellite for a new node if the satellite selects the offline node for the upload, unaware it is already offline. If the client uses the old upload logic, uploads may fail if the satellite selects the offline node (20% chance). When the satellite detects the offline node, all uploads will be successful. |
|
||||
| Download | Download with one node offline | If one of five nodes fails and goes offline, 80% of the stored data will lose one erasure-coded piece. The health status of these segments will be reduced from 4 pieces to 3 pieces and will mark these segments for repair. overlay.node.online-window: 4h0m0s -> for about 4 hours the node will still be selected for downloads) | |
|
||||
| Repair | Repair with 2 nodes disqualified | Disqualify 2 nodes so the repair download are still possible but there is no node available for an upload, shouldn't consume download bandwidth and error out early. Only spend download bandwidth when there is at least one node available for an upload | If two nodes go offline, there are remaining pieces in the worst case, which cannot be repaired and is a de facto data loss if the offline nodes are damaged. |
|
||||
| Audit | | Audits can't identify corrupted pieces with just the minimum number of pieces. Reputation should not increase. Audits should be able to identify corrupted pieces with minumum + 1 pieces. Reputation should decrease. | |
|
||||
| Upgrades | Nodes restart for upgrades | No more than a single node goes offline for maintenance. Otherwise, normal operation of the network cannot be ensured. | Occasionally, nodes may need to restart due to software updates. This brings the node offline for some period of time |
|
5
go.mod
5
go.mod
@ -22,6 +22,7 @@ require (
|
||||
github.com/jackc/pgx/v5 v5.3.1
|
||||
github.com/jtolds/monkit-hw/v2 v2.0.0-20191108235325-141a0da276b3
|
||||
github.com/jtolio/eventkit v0.0.0-20230607152326-4668f79ff72d
|
||||
github.com/jtolio/mito v0.0.0-20230523171229-d78ef06bb77b
|
||||
github.com/jtolio/noiseconn v0.0.0-20230301220541-88105e6c8ac6
|
||||
github.com/loov/hrtime v1.0.3
|
||||
github.com/mattn/go-sqlite3 v1.14.12
|
||||
@ -60,10 +61,10 @@ require (
|
||||
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e
|
||||
gopkg.in/segmentio/analytics-go.v3 v3.1.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
storj.io/common v0.0.0-20230602145716-d6ea82d58b3d
|
||||
storj.io/common v0.0.0-20230719104100-cb5eec2edc30
|
||||
storj.io/drpc v0.0.33
|
||||
storj.io/monkit-jaeger v0.0.0-20220915074555-d100d7589f41
|
||||
storj.io/private v0.0.0-20230627140631-807a2f00d0e1
|
||||
storj.io/private v0.0.0-20230703113355-ccd4db5ae659
|
||||
storj.io/uplink v1.10.1-0.20230626081029-035890d408c2
|
||||
)
|
||||
|
||||
|
10
go.sum
10
go.sum
@ -324,6 +324,8 @@ github.com/jtolds/tracetagger/v2 v2.0.0-rc5 h1:SriMFVtftPsQmG+0xaABotz9HnoKoo1QM
|
||||
github.com/jtolds/tracetagger/v2 v2.0.0-rc5/go.mod h1:61Fh+XhbBONy+RsqkA+xTtmaFbEVL040m9FAF/hTrjQ=
|
||||
github.com/jtolio/eventkit v0.0.0-20230607152326-4668f79ff72d h1:MAGZUXA8MLSA5oJT1Gua3nLSyTYF2uvBgM4Sfs5+jts=
|
||||
github.com/jtolio/eventkit v0.0.0-20230607152326-4668f79ff72d/go.mod h1:PXFUrknJu7TkBNyL8t7XWDPtDFFLFrNQQAdsXv9YfJE=
|
||||
github.com/jtolio/mito v0.0.0-20230523171229-d78ef06bb77b h1:HKvXTXZTeUHXRibg2ilZlkGSQP6A3cs0zXrBd4xMi6M=
|
||||
github.com/jtolio/mito v0.0.0-20230523171229-d78ef06bb77b/go.mod h1:Mrym6OnPMkBKvN8/uXSkyhFSh6ndKKYE+Q4kxCfQ4V0=
|
||||
github.com/jtolio/noiseconn v0.0.0-20230301220541-88105e6c8ac6 h1:iVMQyk78uOpX/UKjEbzyBdptXgEz6jwGwo7kM9IQ+3U=
|
||||
github.com/jtolio/noiseconn v0.0.0-20230301220541-88105e6c8ac6/go.mod h1:MEkhEPFwP3yudWO0lj6vfYpLIB+3eIcuIW+e0AZzUQk=
|
||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||
@ -1013,8 +1015,8 @@ rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8
|
||||
sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck=
|
||||
sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0=
|
||||
storj.io/common v0.0.0-20220719163320-cd2ef8e1b9b0/go.mod h1:mCYV6Ud5+cdbuaxdPD5Zht/HYaIn0sffnnws9ErkrMQ=
|
||||
storj.io/common v0.0.0-20230602145716-d6ea82d58b3d h1:AXdJxmg4Jqdz1nmogSrImKOHAU+bn8JCy8lHYnTwP0Y=
|
||||
storj.io/common v0.0.0-20230602145716-d6ea82d58b3d/go.mod h1:zu2L8WdpvfIBrCbBTgPsz4qhHSArYSiDgRcV1RLlIF8=
|
||||
storj.io/common v0.0.0-20230719104100-cb5eec2edc30 h1:xso8DyZExwYO2SFV0C/vt7unT/Vg3jQV2mtESiVEpUY=
|
||||
storj.io/common v0.0.0-20230719104100-cb5eec2edc30/go.mod h1:zu2L8WdpvfIBrCbBTgPsz4qhHSArYSiDgRcV1RLlIF8=
|
||||
storj.io/drpc v0.0.32/go.mod h1:6rcOyR/QQkSTX/9L5ZGtlZaE2PtXTTZl8d+ulSeeYEg=
|
||||
storj.io/drpc v0.0.33 h1:yCGZ26r66ZdMP0IcTYsj7WDAUIIjzXk6DJhbhvt9FHI=
|
||||
storj.io/drpc v0.0.33/go.mod h1:vR804UNzhBa49NOJ6HeLjd2H3MakC1j5Gv8bsOQT6N4=
|
||||
@ -1022,7 +1024,7 @@ storj.io/monkit-jaeger v0.0.0-20220915074555-d100d7589f41 h1:SVuEocEhZfFc13J1Aml
|
||||
storj.io/monkit-jaeger v0.0.0-20220915074555-d100d7589f41/go.mod h1:iK+dmHZZXQlW7ahKdNSOo+raMk5BDL2wbD62FIeXLWs=
|
||||
storj.io/picobuf v0.0.1 h1:ekEvxSQCbEjTVIi/qxj2za13SJyfRE37yE30IBkZeT0=
|
||||
storj.io/picobuf v0.0.1/go.mod h1:7ZTAMs6VesgTHbbhFU79oQ9hDaJ+MD4uoFQZ1P4SEz0=
|
||||
storj.io/private v0.0.0-20230627140631-807a2f00d0e1 h1:O2+Xjq8H4TKad2cnhvjitK3BtwkGtJ2TfRCHOIN8e7w=
|
||||
storj.io/private v0.0.0-20230627140631-807a2f00d0e1/go.mod h1:mfdHEaAcTARpd4/Hc6N5uxwB1ZG3jtPdVlle57xzQxQ=
|
||||
storj.io/private v0.0.0-20230703113355-ccd4db5ae659 h1:J72VWwbpllfolJoCsjVMr3YnscUUOQAruzFTsivqIqY=
|
||||
storj.io/private v0.0.0-20230703113355-ccd4db5ae659/go.mod h1:mfdHEaAcTARpd4/Hc6N5uxwB1ZG3jtPdVlle57xzQxQ=
|
||||
storj.io/uplink v1.10.1-0.20230626081029-035890d408c2 h1:XnJR9egrqvAqx5oCRu2b13ubK0iu0qTX12EAa6lAPhg=
|
||||
storj.io/uplink v1.10.1-0.20230626081029-035890d408c2/go.mod h1:cDlpDWGJykXfYE7NtO1EeArGFy12K5Xj8pV8ufpUCKE=
|
||||
|
@ -69,7 +69,9 @@ type DiskSpace struct {
|
||||
Allocated int64 `json:"allocated"`
|
||||
Used int64 `json:"usedPieces"`
|
||||
Trash int64 `json:"usedTrash"`
|
||||
Free int64 `json:"free"`
|
||||
// Free is the actual amount of free space on the whole disk, not just allocated disk space, in bytes.
|
||||
Free int64 `json:"free"`
|
||||
// Available is the amount of free space on the allocated disk space, in bytes.
|
||||
Available int64 `json:"available"`
|
||||
Overused int64 `json:"overused"`
|
||||
}
|
||||
|
@ -27,7 +27,9 @@ message DiskSpaceResponse {
|
||||
int64 allocated = 1;
|
||||
int64 used_pieces = 2;
|
||||
int64 used_trash = 3;
|
||||
// Free is the actual amount of free space on the whole disk, not just allocated disk space, in bytes.
|
||||
int64 free = 4;
|
||||
// Available is the amount of free space on the allocated disk space, in bytes.
|
||||
int64 available = 5;
|
||||
int64 overused = 6;
|
||||
}
|
||||
|
@ -66,10 +66,10 @@ type Satellite struct {
|
||||
|
||||
Core *satellite.Core
|
||||
API *satellite.API
|
||||
UI *satellite.UI
|
||||
Repairer *satellite.Repairer
|
||||
Auditor *satellite.Auditor
|
||||
Admin *satellite.Admin
|
||||
GC *satellite.GarbageCollection
|
||||
GCBF *satellite.GarbageCollectionBF
|
||||
RangedLoop *satellite.RangedLoop
|
||||
|
||||
@ -173,12 +173,17 @@ type Satellite struct {
|
||||
Service *mailservice.Service
|
||||
}
|
||||
|
||||
Console struct {
|
||||
ConsoleBackend struct {
|
||||
Listener net.Listener
|
||||
Service *console.Service
|
||||
Endpoint *consoleweb.Server
|
||||
}
|
||||
|
||||
ConsoleFrontend struct {
|
||||
Listener net.Listener
|
||||
Endpoint *consoleweb.Server
|
||||
}
|
||||
|
||||
NodeStats struct {
|
||||
Endpoint *nodestats.Endpoint
|
||||
}
|
||||
@ -285,7 +290,6 @@ func (system *Satellite) Close() error {
|
||||
system.Repairer.Close(),
|
||||
system.Auditor.Close(),
|
||||
system.Admin.Close(),
|
||||
system.GC.Close(),
|
||||
system.GCBF.Close(),
|
||||
)
|
||||
}
|
||||
@ -300,6 +304,11 @@ func (system *Satellite) Run(ctx context.Context) (err error) {
|
||||
group.Go(func() error {
|
||||
return errs2.IgnoreCanceled(system.API.Run(ctx))
|
||||
})
|
||||
if system.UI != nil {
|
||||
group.Go(func() error {
|
||||
return errs2.IgnoreCanceled(system.UI.Run(ctx))
|
||||
})
|
||||
}
|
||||
group.Go(func() error {
|
||||
return errs2.IgnoreCanceled(system.Repairer.Run(ctx))
|
||||
})
|
||||
@ -309,9 +318,6 @@ func (system *Satellite) Run(ctx context.Context) (err error) {
|
||||
group.Go(func() error {
|
||||
return errs2.IgnoreCanceled(system.Admin.Run(ctx))
|
||||
})
|
||||
group.Go(func() error {
|
||||
return errs2.IgnoreCanceled(system.GC.Run(ctx))
|
||||
})
|
||||
group.Go(func() error {
|
||||
return errs2.IgnoreCanceled(system.GCBF.Run(ctx))
|
||||
})
|
||||
@ -524,6 +530,15 @@ func (planet *Planet) newSatellite(ctx context.Context, prefix string, index int
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
// only run if front-end endpoints on console back-end server are disabled.
|
||||
var ui *satellite.UI
|
||||
if !config.Console.FrontendEnable {
|
||||
ui, err = planet.newUI(ctx, index, identity, config, api.ExternalAddress, api.Console.Listener.Addr().String())
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
adminPeer, err := planet.newAdmin(ctx, index, identity, db, metabaseDB, config, versionInfo)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
@ -539,11 +554,6 @@ func (planet *Planet) newSatellite(ctx context.Context, prefix string, index int
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
gcPeer, err := planet.newGarbageCollection(ctx, index, identity, db, metabaseDB, config, versionInfo)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
|
||||
gcBFPeer, err := planet.newGarbageCollectionBF(ctx, index, db, metabaseDB, config, versionInfo)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
@ -558,23 +568,23 @@ func (planet *Planet) newSatellite(ctx context.Context, prefix string, index int
|
||||
peer.Mail.EmailReminders.TestSetLinkAddress("http://" + api.Console.Listener.Addr().String() + "/")
|
||||
}
|
||||
|
||||
return createNewSystem(prefix, log, config, peer, api, repairerPeer, auditorPeer, adminPeer, gcPeer, gcBFPeer, rangedLoopPeer), nil
|
||||
return createNewSystem(prefix, log, config, peer, api, ui, repairerPeer, auditorPeer, adminPeer, gcBFPeer, rangedLoopPeer), nil
|
||||
}
|
||||
|
||||
// createNewSystem makes a new Satellite System and exposes the same interface from
|
||||
// before we split out the API. In the short term this will help keep all the tests passing
|
||||
// without much modification needed. However long term, we probably want to rework this
|
||||
// so it represents how the satellite will run when it is made up of many processes.
|
||||
func createNewSystem(name string, log *zap.Logger, config satellite.Config, peer *satellite.Core, api *satellite.API, repairerPeer *satellite.Repairer, auditorPeer *satellite.Auditor, adminPeer *satellite.Admin, gcPeer *satellite.GarbageCollection, gcBFPeer *satellite.GarbageCollectionBF, rangedLoopPeer *satellite.RangedLoop) *Satellite {
|
||||
func createNewSystem(name string, log *zap.Logger, config satellite.Config, peer *satellite.Core, api *satellite.API, ui *satellite.UI, repairerPeer *satellite.Repairer, auditorPeer *satellite.Auditor, adminPeer *satellite.Admin, gcBFPeer *satellite.GarbageCollectionBF, rangedLoopPeer *satellite.RangedLoop) *Satellite {
|
||||
system := &Satellite{
|
||||
Name: name,
|
||||
Config: config,
|
||||
Core: peer,
|
||||
API: api,
|
||||
UI: ui,
|
||||
Repairer: repairerPeer,
|
||||
Auditor: auditorPeer,
|
||||
Admin: adminPeer,
|
||||
GC: gcPeer,
|
||||
GCBF: gcBFPeer,
|
||||
RangedLoop: rangedLoopPeer,
|
||||
}
|
||||
@ -622,7 +632,7 @@ func createNewSystem(name string, log *zap.Logger, config satellite.Config, peer
|
||||
system.Audit.Reporter = auditorPeer.Audit.Reporter
|
||||
system.Audit.ContainmentSyncChore = peer.Audit.ContainmentSyncChore
|
||||
|
||||
system.GarbageCollection.Sender = gcPeer.GarbageCollection.Sender
|
||||
system.GarbageCollection.Sender = peer.GarbageCollection.Sender
|
||||
|
||||
system.ExpiredDeletion.Chore = peer.ExpiredDeletion.Chore
|
||||
system.ZombieDeletion.Chore = peer.ZombieDeletion.Chore
|
||||
@ -666,6 +676,15 @@ func (planet *Planet) newAPI(ctx context.Context, index int, identity *identity.
|
||||
return satellite.NewAPI(log, identity, db, metabaseDB, revocationDB, liveAccounting, rollupsWriteCache, &config, versionInfo, nil)
|
||||
}
|
||||
|
||||
func (planet *Planet) newUI(ctx context.Context, index int, identity *identity.FullIdentity, config satellite.Config, satelliteAddr, consoleAPIAddr string) (_ *satellite.UI, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
prefix := "satellite-ui" + strconv.Itoa(index)
|
||||
log := planet.log.Named(prefix)
|
||||
|
||||
return satellite.NewUI(log, identity, &config, nil, satelliteAddr, consoleAPIAddr)
|
||||
}
|
||||
|
||||
func (planet *Planet) newAdmin(ctx context.Context, index int, identity *identity.FullIdentity, db satellite.DB, metabaseDB *metabase.DB, config satellite.Config, versionInfo version.Info) (_ *satellite.Admin, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
@ -713,20 +732,6 @@ func (cache rollupsWriteCacheCloser) Close() error {
|
||||
return cache.RollupsWriteCache.CloseAndFlush(context.TODO())
|
||||
}
|
||||
|
||||
func (planet *Planet) newGarbageCollection(ctx context.Context, index int, identity *identity.FullIdentity, db satellite.DB, metabaseDB *metabase.DB, config satellite.Config, versionInfo version.Info) (_ *satellite.GarbageCollection, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
prefix := "satellite-gc" + strconv.Itoa(index)
|
||||
log := planet.log.Named(prefix)
|
||||
|
||||
revocationDB, err := revocation.OpenDBFromCfg(ctx, config.Server.Config)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
planet.databases = append(planet.databases, revocationDB)
|
||||
return satellite.NewGarbageCollection(log, identity, db, metabaseDB, revocationDB, versionInfo, &config, nil)
|
||||
}
|
||||
|
||||
func (planet *Planet) newGarbageCollectionBF(ctx context.Context, index int, db satellite.DB, metabaseDB *metabase.DB, config satellite.Config, versionInfo version.Info) (_ *satellite.GarbageCollectionBF, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
@ -746,7 +751,6 @@ func (planet *Planet) newRangedLoop(ctx context.Context, index int, db satellite
|
||||
|
||||
prefix := "satellite-ranged-loop" + strconv.Itoa(index)
|
||||
log := planet.log.Named(prefix)
|
||||
|
||||
return satellite.NewRangedLoop(log, db, metabaseDB, &config, nil)
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"storj.io/common/peertls/tlsopts"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/private/debug"
|
||||
"storj.io/storj/cmd/storagenode/internalcmd"
|
||||
"storj.io/storj/private/revocation"
|
||||
"storj.io/storj/private/server"
|
||||
"storj.io/storj/storagenode"
|
||||
@ -215,6 +216,10 @@ func (planet *Planet) newStorageNode(ctx context.Context, prefix string, index,
|
||||
MinDownloadTimeout: 2 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
// enable the lazy filewalker
|
||||
config.Pieces.EnableLazyFilewalker = true
|
||||
|
||||
if planet.config.Reconfigure.StorageNode != nil {
|
||||
planet.config.Reconfigure.StorageNode(index, &config)
|
||||
}
|
||||
@ -275,6 +280,21 @@ func (planet *Planet) newStorageNode(ctx context.Context, prefix string, index,
|
||||
return nil, errs.New("error while trying to issue new api key: %v", err)
|
||||
}
|
||||
|
||||
{
|
||||
// set up the used space lazyfilewalker filewalker
|
||||
cmd := internalcmd.NewUsedSpaceFilewalkerCmd()
|
||||
cmd.Logger = log.Named("used-space-filewalker")
|
||||
cmd.Ctx = ctx
|
||||
peer.Storage2.LazyFileWalker.TestingSetUsedSpaceCmd(cmd)
|
||||
}
|
||||
{
|
||||
// set up the GC lazyfilewalker filewalker
|
||||
cmd := internalcmd.NewGCFilewalkerCmd()
|
||||
cmd.Logger = log.Named("gc-filewalker")
|
||||
cmd.Ctx = ctx
|
||||
peer.Storage2.LazyFileWalker.TestingSetGCCmd(cmd)
|
||||
}
|
||||
|
||||
return &StorageNode{
|
||||
Name: prefix,
|
||||
Config: config,
|
||||
|
@ -105,9 +105,9 @@ func TestDownloadWithSomeNodesOffline(t *testing.T) {
|
||||
}
|
||||
|
||||
// confirm that we marked the correct number of storage nodes as offline
|
||||
nodes, err := satellite.Overlay.Service.Reliable(ctx)
|
||||
online, _, err := satellite.Overlay.Service.Reliable(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, len(planet.StorageNodes)-toKill)
|
||||
require.Len(t, online, len(planet.StorageNodes)-toKill)
|
||||
|
||||
// we should be able to download data without any of the original nodes
|
||||
newData, err := ul.Download(ctx, satellite, "testbucket", "test/path")
|
||||
|
@ -6,16 +6,16 @@ package version
|
||||
import _ "unsafe" // needed for go:linkname
|
||||
|
||||
//go:linkname buildTimestamp storj.io/private/version.buildTimestamp
|
||||
var buildTimestamp string
|
||||
var buildTimestamp string = "1690910649"
|
||||
|
||||
//go:linkname buildCommitHash storj.io/private/version.buildCommitHash
|
||||
var buildCommitHash string
|
||||
var buildCommitHash string = "bf0f3b829f699bc5fc7029c4acf747e7857e13d8"
|
||||
|
||||
//go:linkname buildVersion storj.io/private/version.buildVersion
|
||||
var buildVersion string
|
||||
var buildVersion string = "v1.84.2"
|
||||
|
||||
//go:linkname buildRelease storj.io/private/version.buildRelease
|
||||
var buildRelease string
|
||||
var buildRelease string = "true"
|
||||
|
||||
// ensure that linter understands that the variables are being used.
|
||||
func init() { use(buildTimestamp, buildCommitHash, buildVersion, buildRelease) }
|
||||
|
@ -4,24 +4,34 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"storj.io/common/http/requestid"
|
||||
)
|
||||
|
||||
// ServeJSONError writes a JSON error to the response output stream.
|
||||
func ServeJSONError(log *zap.Logger, w http.ResponseWriter, status int, err error) {
|
||||
ServeCustomJSONError(log, w, status, err, err.Error())
|
||||
func ServeJSONError(ctx context.Context, log *zap.Logger, w http.ResponseWriter, status int, err error) {
|
||||
ServeCustomJSONError(ctx, log, w, status, err, err.Error())
|
||||
}
|
||||
|
||||
// ServeCustomJSONError writes a JSON error with a custom message to the response output stream.
|
||||
func ServeCustomJSONError(log *zap.Logger, w http.ResponseWriter, status int, err error, msg string) {
|
||||
func ServeCustomJSONError(ctx context.Context, log *zap.Logger, w http.ResponseWriter, status int, err error, msg string) {
|
||||
fields := []zap.Field{
|
||||
zap.Int("code", status),
|
||||
zap.String("message", msg),
|
||||
zap.Error(err),
|
||||
}
|
||||
|
||||
if requestID := requestid.FromContext(ctx); requestID != "" {
|
||||
fields = append(fields, zap.String("requestID", requestID))
|
||||
msg += fmt.Sprintf(" (request id: %s)", requestID)
|
||||
}
|
||||
|
||||
switch status {
|
||||
case http.StatusNoContent:
|
||||
return
|
||||
|
@ -87,12 +87,12 @@ func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key, err := rl.keyFunc(r)
|
||||
if err != nil {
|
||||
ServeCustomJSONError(rl.log, w, http.StatusInternalServerError, err, internalServerErrMsg)
|
||||
ServeCustomJSONError(r.Context(), rl.log, w, http.StatusInternalServerError, err, internalServerErrMsg)
|
||||
return
|
||||
}
|
||||
limit := rl.getUserLimit(key)
|
||||
if !limit.Allow() {
|
||||
ServeJSONError(rl.log, w, http.StatusTooManyRequests, errs.New(rateLimitErrMsg))
|
||||
ServeJSONError(r.Context(), rl.log, w, http.StatusTooManyRequests, errs.New(rateLimitErrMsg))
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
|
@ -219,6 +219,8 @@ type ProjectAccounting interface {
|
||||
GetProjectSettledBandwidthTotal(ctx context.Context, projectID uuid.UUID, from time.Time) (_ int64, err error)
|
||||
// GetProjectBandwidth returns project allocated bandwidth for the specified year, month and day.
|
||||
GetProjectBandwidth(ctx context.Context, projectID uuid.UUID, year int, month time.Month, day int, asOfSystemInterval time.Duration) (int64, error)
|
||||
// GetProjectSettledBandwidth returns the used settled bandwidth for the specified year and month.
|
||||
GetProjectSettledBandwidth(ctx context.Context, projectID uuid.UUID, year int, month time.Month, asOfSystemInterval time.Duration) (int64, error)
|
||||
// GetProjectDailyBandwidth returns bandwidth (allocated and settled) for the specified day.
|
||||
GetProjectDailyBandwidth(ctx context.Context, projectID uuid.UUID, year int, month time.Month, day int) (int64, int64, int64, error)
|
||||
// DeleteProjectBandwidthBefore deletes project bandwidth rollups before the given time
|
||||
|
@ -218,6 +218,17 @@ func (usage *Service) GetProjectBandwidthTotals(ctx context.Context, projectID u
|
||||
return total, ErrProjectUsage.Wrap(err)
|
||||
}
|
||||
|
||||
// GetProjectSettledBandwidth returns total amount of settled bandwidth used for past 30 days.
|
||||
func (usage *Service) GetProjectSettledBandwidth(ctx context.Context, projectID uuid.UUID) (_ int64, err error) {
|
||||
defer mon.Task()(&ctx, projectID)(&err)
|
||||
|
||||
// from the beginning of the current month
|
||||
year, month, _ := usage.nowFn().Date()
|
||||
|
||||
total, err := usage.projectAccountingDB.GetProjectSettledBandwidth(ctx, projectID, year, month, usage.asOfSystemInterval)
|
||||
return total, ErrProjectUsage.Wrap(err)
|
||||
}
|
||||
|
||||
// GetProjectSegmentTotals returns total amount of allocated segments used for past 30 days.
|
||||
func (usage *Service) GetProjectSegmentTotals(ctx context.Context, projectID uuid.UUID) (total int64, err error) {
|
||||
defer mon.Task()(&ctx, projectID)(&err)
|
||||
|
@ -182,7 +182,8 @@ func TestProjectSegmentLimit(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
data := testrand.Bytes(160 * memory.KiB)
|
||||
// tally self-corrects live accounting, however, it may cause things to be temporarily off by a few segments.
|
||||
planet.Satellites[0].Accounting.Tally.Loop.Pause()
|
||||
|
||||
// set limit manually to 10 segments
|
||||
accountingDB := planet.Satellites[0].DB.ProjectAccounting()
|
||||
@ -190,6 +191,7 @@ func TestProjectSegmentLimit(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// successful upload
|
||||
data := testrand.Bytes(160 * memory.KiB)
|
||||
err = planet.Uplinks[0].Upload(ctx, planet.Satellites[0], "testbucket", "test/path/0", data)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -203,14 +205,17 @@ func TestProjectSegmentLimit(t *testing.T) {
|
||||
|
||||
func TestProjectSegmentLimitInline(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, UplinkCount: 1}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
data := testrand.Bytes(1 * memory.KiB)
|
||||
SatelliteCount: 1, UplinkCount: 1,
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
// tally self-corrects live accounting, however, it may cause things to be temporarily off by a few segments.
|
||||
planet.Satellites[0].Accounting.Tally.Loop.Pause()
|
||||
|
||||
// set limit manually to 10 segments
|
||||
accountingDB := planet.Satellites[0].DB.ProjectAccounting()
|
||||
err := accountingDB.UpdateProjectSegmentLimit(ctx, planet.Uplinks[0].Projects[0].ID, 10)
|
||||
require.NoError(t, err)
|
||||
|
||||
data := testrand.Bytes(1 * memory.KiB)
|
||||
for i := 0; i < 10; i++ {
|
||||
// successful upload
|
||||
err = planet.Uplinks[0].Upload(ctx, planet.Satellites[0], "testbucket", "test/path/"+strconv.Itoa(i), data)
|
||||
@ -260,14 +265,17 @@ func TestProjectBandwidthLimitWithoutCache(t *testing.T) {
|
||||
|
||||
func TestProjectSegmentLimitMultipartUpload(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, UplinkCount: 1}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
data := testrand.Bytes(1 * memory.KiB)
|
||||
SatelliteCount: 1, UplinkCount: 1,
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
// tally self-corrects live accounting, however, it may cause things to be temporarily off by a few segments.
|
||||
planet.Satellites[0].Accounting.Tally.Loop.Pause()
|
||||
|
||||
// set limit manually to 10 segments
|
||||
accountingDB := planet.Satellites[0].DB.ProjectAccounting()
|
||||
err := accountingDB.UpdateProjectSegmentLimit(ctx, planet.Uplinks[0].Projects[0].ID, 4)
|
||||
require.NoError(t, err)
|
||||
|
||||
data := testrand.Bytes(1 * memory.KiB)
|
||||
for i := 0; i < 4; i++ {
|
||||
// successful upload
|
||||
err = planet.Uplinks[0].Upload(ctx, planet.Satellites[0], "testbucket", "test/path/"+strconv.Itoa(i), data)
|
||||
|
@ -151,6 +151,7 @@ func NewServer(log *zap.Logger, listener net.Listener, db DB, buckets *buckets.S
|
||||
limitUpdateAPI.HandleFunc("/users/{useremail}/limits", server.updateLimits).Methods("PUT")
|
||||
limitUpdateAPI.HandleFunc("/users/{useremail}/freeze", server.freezeUser).Methods("PUT")
|
||||
limitUpdateAPI.HandleFunc("/users/{useremail}/freeze", server.unfreezeUser).Methods("DELETE")
|
||||
limitUpdateAPI.HandleFunc("/users/{useremail}/warning", server.unWarnUser).Methods("DELETE")
|
||||
limitUpdateAPI.HandleFunc("/projects/{project}/limit", server.getProjectLimit).Methods("GET")
|
||||
limitUpdateAPI.HandleFunc("/projects/{project}/limit", server.putProjectLimit).Methods("PUT", "POST")
|
||||
|
||||
|
@ -249,7 +249,7 @@ export class Admin {
|
||||
desc: 'Get the API keys of a specific project',
|
||||
params: [['Project ID', new InputText('text', true)]],
|
||||
func: async (projectId: string): Promise<Record<string, unknown>> => {
|
||||
return this.fetch('GET', `projects/${projectId}/apiKeys`);
|
||||
return this.fetch('GET', `projects/${projectId}/apikeys`);
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -464,6 +464,14 @@ Blank fields will not be updated.`,
|
||||
func: async (email: string): Promise<null> => {
|
||||
return this.fetch('DELETE', `users/${email}/freeze`) as Promise<null>;
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'unwarn user',
|
||||
desc: "Remove a user's warning status",
|
||||
params: [['email', new InputText('email', true)]],
|
||||
func: async (email: string): Promise<null> => {
|
||||
return this.fetch('DELETE', `users/${email}/warning`) as Promise<null>;
|
||||
}
|
||||
}
|
||||
],
|
||||
rest_api_keys: [
|
||||
|
@ -630,6 +630,35 @@ func (server *Server) unfreezeUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) unWarnUser(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userEmail, ok := vars["useremail"]
|
||||
if !ok {
|
||||
sendJSONError(w, "user-email missing", "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := server.db.Console().Users().GetByEmail(ctx, userEmail)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
sendJSONError(w, fmt.Sprintf("user with email %q does not exist", userEmail),
|
||||
"", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
sendJSONError(w, "failed to get user details",
|
||||
err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err = server.freezeAccounts.UnWarnUser(ctx, u.ID); err != nil {
|
||||
sendJSONError(w, "failed to unwarn user",
|
||||
err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
|
@ -428,6 +428,43 @@ func TestFreezeUnfreezeUser(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestWarnUnwarnUser(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1,
|
||||
StorageNodeCount: 0,
|
||||
UplinkCount: 1,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
Satellite: func(_ *zap.Logger, _ int, config *satellite.Config) {
|
||||
config.Admin.Address = "127.0.0.1:0"
|
||||
},
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
address := planet.Satellites[0].Admin.Admin.Listener.Addr()
|
||||
user, err := planet.Satellites[0].DB.Console().Users().Get(ctx, planet.Uplinks[0].Projects[0].Owner.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = planet.Satellites[0].Admin.FreezeAccounts.Service.WarnUser(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
freeze, warning, err := planet.Satellites[0].DB.Console().AccountFreezeEvents().GetAll(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, freeze)
|
||||
require.NotNil(t, warning)
|
||||
|
||||
link := fmt.Sprintf("http://"+address.String()+"/api/users/%s/warning", user.Email)
|
||||
body := assertReq(ctx, t, link, http.MethodDelete, "", http.StatusOK, "", planet.Satellites[0].Config.Console.AuthToken)
|
||||
require.Len(t, body, 0)
|
||||
|
||||
freeze, warning, err = planet.Satellites[0].DB.Console().AccountFreezeEvents().GetAll(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, freeze)
|
||||
require.Nil(t, warning)
|
||||
|
||||
body = assertReq(ctx, t, link, http.MethodDelete, "", http.StatusInternalServerError, "", planet.Satellites[0].Config.Console.AuthToken)
|
||||
require.Contains(t, string(body), "user is not warned")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserDelete(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1,
|
||||
|
@ -84,6 +84,7 @@ const (
|
||||
eventAccountUnwarned = "Account Unwarned"
|
||||
eventAccountFreezeWarning = "Account Freeze Warning"
|
||||
eventUnpaidLargeInvoice = "Large Invoice Unpaid"
|
||||
eventUnpaidStorjscanInvoice = "Storjscan Invoice Unpaid"
|
||||
eventExpiredCreditNeedsRemoval = "Expired Credit Needs Removal"
|
||||
eventExpiredCreditRemoved = "Expired Credit Removed"
|
||||
eventProjectInvitationAccepted = "Project Invitation Accepted"
|
||||
@ -122,6 +123,9 @@ type FreezeTracker interface {
|
||||
|
||||
// TrackLargeUnpaidInvoice sends an event to Segment indicating that a user has not paid a large invoice.
|
||||
TrackLargeUnpaidInvoice(invID string, userID uuid.UUID, email string)
|
||||
|
||||
// TrackStorjscanUnpaidInvoice sends an event to Segment indicating that a user has not paid an invoice, but has storjscan transaction history.
|
||||
TrackStorjscanUnpaidInvoice(invID string, userID uuid.UUID, email string)
|
||||
}
|
||||
|
||||
// Service for sending analytics.
|
||||
@ -418,6 +422,23 @@ func (service *Service) TrackLargeUnpaidInvoice(invID string, userID uuid.UUID,
|
||||
})
|
||||
}
|
||||
|
||||
// TrackStorjscanUnpaidInvoice sends an event to Segment indicating that a user has not paid an invoice, but has storjscan transaction history.
|
||||
func (service *Service) TrackStorjscanUnpaidInvoice(invID string, userID uuid.UUID, email string) {
|
||||
if !service.config.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
props := segment.NewProperties()
|
||||
props.Set("email", email)
|
||||
props.Set("invoice", invID)
|
||||
|
||||
service.enqueueMessage(segment.Track{
|
||||
UserId: userID.String(),
|
||||
Event: service.satelliteName + " " + eventUnpaidStorjscanInvoice,
|
||||
Properties: props,
|
||||
})
|
||||
}
|
||||
|
||||
// TrackAccessGrantCreated sends an "Access Grant Created" event to Segment.
|
||||
func (service *Service) TrackAccessGrantCreated(userID uuid.UUID, email string) {
|
||||
if !service.config.Enabled {
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"storj.io/common/identity"
|
||||
"storj.io/common/nodetag"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/peertls/extensions"
|
||||
"storj.io/common/peertls/tlsopts"
|
||||
@ -281,7 +282,7 @@ func NewAPI(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
{ // setup overlay
|
||||
peer.Overlay.DB = peer.DB.OverlayCache()
|
||||
|
||||
peer.Overlay.Service, err = overlay.NewService(peer.Log.Named("overlay"), peer.Overlay.DB, peer.DB.NodeEvents(), config.Console.ExternalAddress, config.Console.SatelliteName, config.Overlay)
|
||||
peer.Overlay.Service, err = overlay.NewService(peer.Log.Named("overlay"), peer.Overlay.DB, peer.DB.NodeEvents(), config.Placement.CreateFilters, config.Console.ExternalAddress, config.Console.SatelliteName, config.Overlay)
|
||||
if err != nil {
|
||||
return nil, errs.Combine(err, peer.Close())
|
||||
}
|
||||
@ -325,7 +326,12 @@ func NewAPI(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
Type: pb.NodeType_SATELLITE,
|
||||
Version: *pbVersion,
|
||||
}
|
||||
peer.Contact.Service = contact.NewService(peer.Log.Named("contact:service"), self, peer.Overlay.Service, peer.DB.PeerIdentities(), peer.Dialer, config.Contact)
|
||||
|
||||
var authority nodetag.Authority
|
||||
peerIdentity := full.PeerIdentity()
|
||||
authority = append(authority, signing.SigneeFromPeerIdentity(peerIdentity))
|
||||
|
||||
peer.Contact.Service = contact.NewService(peer.Log.Named("contact:service"), self, peer.Overlay.Service, peer.DB.PeerIdentities(), peer.Dialer, authority, config.Contact)
|
||||
peer.Contact.Endpoint = contact.NewEndpoint(peer.Log.Named("contact:endpoint"), peer.Contact.Service)
|
||||
if err := pb.DRPCRegisterNode(peer.Server.DRPC(), peer.Contact.Endpoint); err != nil {
|
||||
return nil, errs.Combine(err, peer.Close())
|
||||
@ -381,6 +387,7 @@ func NewAPI(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
signing.SignerFromFullIdentity(peer.Identity),
|
||||
peer.Overlay.Service,
|
||||
peer.Orders.DB,
|
||||
config.Placement.CreateFilters,
|
||||
config.Orders,
|
||||
)
|
||||
if err != nil {
|
||||
@ -540,7 +547,9 @@ func NewAPI(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
peer.Payments.StorjscanService = storjscan.NewService(log.Named("storjscan-service"),
|
||||
peer.DB.Wallets(),
|
||||
peer.DB.StorjscanPayments(),
|
||||
peer.Payments.StorjscanClient)
|
||||
peer.Payments.StorjscanClient,
|
||||
pc.Storjscan.Confirmations,
|
||||
pc.BonusRate)
|
||||
if err != nil {
|
||||
return nil, errs.Combine(err, peer.Close())
|
||||
}
|
||||
@ -603,6 +612,7 @@ func NewAPI(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
accountFreezeService,
|
||||
peer.Console.Listener,
|
||||
config.Payments.StripeCoinPayments.StripePublicKey,
|
||||
config.Payments.Storjscan.Confirmations,
|
||||
peer.URL(),
|
||||
config.Payments.PackagePlans,
|
||||
)
|
||||
|
@ -141,7 +141,7 @@ func NewAuditor(log *zap.Logger, full *identity.FullIdentity,
|
||||
|
||||
{ // setup overlay
|
||||
var err error
|
||||
peer.Overlay, err = overlay.NewService(log.Named("overlay"), overlayCache, nodeEvents, config.Console.ExternalAddress, config.Console.SatelliteName, config.Overlay)
|
||||
peer.Overlay, err = overlay.NewService(log.Named("overlay"), overlayCache, nodeEvents, config.Placement.CreateFilters, config.Console.ExternalAddress, config.Console.SatelliteName, config.Overlay)
|
||||
if err != nil {
|
||||
return nil, errs.Combine(err, peer.Close())
|
||||
}
|
||||
@ -183,6 +183,7 @@ func NewAuditor(log *zap.Logger, full *identity.FullIdentity,
|
||||
// PUT and GET actions which are not used by
|
||||
// auditor so we can set noop implementation.
|
||||
orders.NewNoopDB(),
|
||||
config.Placement.CreateFilters,
|
||||
config.Orders,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -48,6 +48,7 @@ type FrontendConfig struct {
|
||||
PricingPackagesEnabled bool `json:"pricingPackagesEnabled"`
|
||||
NewUploadModalEnabled bool `json:"newUploadModalEnabled"`
|
||||
GalleryViewEnabled bool `json:"galleryViewEnabled"`
|
||||
NeededTransactionConfirmations int `json:"neededTransactionConfirmations"`
|
||||
}
|
||||
|
||||
// Satellites is a configuration value that contains a list of satellite names and addresses.
|
||||
|
@ -41,13 +41,13 @@ func (a *ABTesting) GetABValues(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
user, err := console.GetUser(ctx)
|
||||
if err != nil {
|
||||
web.ServeJSONError(a.log, w, http.StatusUnauthorized, err)
|
||||
web.ServeJSONError(ctx, a.log, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
values, err := a.service.GetABValues(ctx, *user)
|
||||
if err != nil {
|
||||
web.ServeJSONError(a.log, w, http.StatusInternalServerError, err)
|
||||
web.ServeJSONError(ctx, a.log, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@ -66,13 +66,13 @@ func (a *ABTesting) SendHit(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
action := mux.Vars(r)["action"]
|
||||
if action == "" {
|
||||
web.ServeJSONError(a.log, w, http.StatusBadRequest, errs.New("parameter 'action' can't be empty"))
|
||||
web.ServeJSONError(ctx, a.log, w, http.StatusBadRequest, errs.New("parameter 'action' can't be empty"))
|
||||
return
|
||||
}
|
||||
|
||||
user, err := console.GetUser(ctx)
|
||||
if err != nil {
|
||||
web.ServeJSONError(a.log, w, http.StatusUnauthorized, err)
|
||||
web.ServeJSONError(ctx, a.log, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
package consoleapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -54,17 +55,17 @@ func (a *Analytics) EventTriggered(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
a.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
var et eventTriggeredBody
|
||||
err = json.Unmarshal(body, &et)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
a.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
user, err := console.GetUser(ctx)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
a.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -86,17 +87,17 @@ func (a *Analytics) PageEventTriggered(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
a.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
var pv pageVisitBody
|
||||
err = json.Unmarshal(body, &pv)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
a.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
user, err := console.GetUser(ctx)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
a.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -106,6 +107,6 @@ func (a *Analytics) PageEventTriggered(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// serveJSONError writes JSON error to response output stream.
|
||||
func (a *Analytics) serveJSONError(w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(a.log, w, status, err)
|
||||
func (a *Analytics) serveJSONError(ctx context.Context, w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(ctx, a.log, w, status, err)
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
package consoleapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
@ -42,24 +43,24 @@ func (keys *APIKeys) GetAllAPIKeyNames(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
projectIDString := r.URL.Query().Get("projectID")
|
||||
if projectIDString == "" {
|
||||
keys.serveJSONError(w, http.StatusBadRequest, errs.New("Project ID was not provided."))
|
||||
keys.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("Project ID was not provided."))
|
||||
return
|
||||
}
|
||||
|
||||
projectID, err := uuid.FromString(projectIDString)
|
||||
if err != nil {
|
||||
keys.serveJSONError(w, http.StatusBadRequest, err)
|
||||
keys.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyNames, err := keys.service.GetAllAPIKeyNamesByProjectID(ctx, projectID)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
keys.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
keys.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
keys.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
keys.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -81,7 +82,7 @@ func (keys *APIKeys) DeleteByNameAndProjectID(w http.ResponseWriter, r *http.Req
|
||||
publicIDString := r.URL.Query().Get("publicID")
|
||||
|
||||
if name == "" {
|
||||
keys.serveJSONError(w, http.StatusBadRequest, err)
|
||||
keys.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -89,38 +90,38 @@ func (keys *APIKeys) DeleteByNameAndProjectID(w http.ResponseWriter, r *http.Req
|
||||
if projectIDString != "" {
|
||||
projectID, err = uuid.FromString(projectIDString)
|
||||
if err != nil {
|
||||
keys.serveJSONError(w, http.StatusBadRequest, err)
|
||||
keys.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
} else if publicIDString != "" {
|
||||
projectID, err = uuid.FromString(publicIDString)
|
||||
if err != nil {
|
||||
keys.serveJSONError(w, http.StatusBadRequest, err)
|
||||
keys.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
keys.serveJSONError(w, http.StatusBadRequest, errs.New("Project ID was not provided."))
|
||||
keys.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("Project ID was not provided."))
|
||||
return
|
||||
}
|
||||
|
||||
err = keys.service.DeleteAPIKeyByNameAndProjectID(ctx, name, projectID)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
keys.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
keys.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
if console.ErrNoAPIKey.Has(err) {
|
||||
keys.serveJSONError(w, http.StatusNoContent, err)
|
||||
keys.serveJSONError(ctx, w, http.StatusNoContent, err)
|
||||
return
|
||||
}
|
||||
|
||||
keys.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
keys.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// serveJSONError writes JSON error to response output stream.
|
||||
func (keys *APIKeys) serveJSONError(w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(keys.log, w, status, err)
|
||||
func (keys *APIKeys) serveJSONError(ctx context.Context, w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(ctx, keys.log, w, status, err)
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
package consoleapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
@ -53,7 +54,7 @@ type Auth struct {
|
||||
}
|
||||
|
||||
// NewAuth is a constructor for api auth controller.
|
||||
func NewAuth(log *zap.Logger, service *console.Service, accountFreezeService *console.AccountFreezeService, mailService *mailservice.Service, cookieAuth *consolewebauth.CookieAuth, analytics *analytics.Service, satelliteName string, externalAddress string, letUsKnowURL string, termsAndConditionsURL string, contactInfoURL string, generalRequestURL string) *Auth {
|
||||
func NewAuth(log *zap.Logger, service *console.Service, accountFreezeService *console.AccountFreezeService, mailService *mailservice.Service, cookieAuth *consolewebauth.CookieAuth, analytics *analytics.Service, satelliteName, externalAddress, letUsKnowURL, termsAndConditionsURL, contactInfoURL, generalRequestURL string) *Auth {
|
||||
return &Auth{
|
||||
log: log,
|
||||
ExternalAddress: externalAddress,
|
||||
@ -82,24 +83,24 @@ func (a *Auth) Token(w http.ResponseWriter, r *http.Request) {
|
||||
tokenRequest := console.AuthUser{}
|
||||
err = json.NewDecoder(r.Body).Decode(&tokenRequest)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
tokenRequest.UserAgent = r.UserAgent()
|
||||
tokenRequest.IP, err = web.GetRequestIP(r)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := a.service.Token(ctx, tokenRequest)
|
||||
if err != nil {
|
||||
if console.ErrMFAMissing.Has(err) {
|
||||
web.ServeCustomJSONError(a.log, w, http.StatusOK, err, a.getUserErrorMessage(err))
|
||||
web.ServeCustomJSONError(ctx, a.log, w, http.StatusOK, err, a.getUserErrorMessage(err))
|
||||
} else {
|
||||
a.log.Info("Error authenticating token request", zap.String("email", tokenRequest.Email), zap.Error(ErrAuthAPI.Wrap(err)))
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -126,7 +127,7 @@ func (a *Auth) TokenByAPIKey(w http.ResponseWriter, r *http.Request) {
|
||||
authToken := r.Header.Get("Authorization")
|
||||
if !(strings.HasPrefix(authToken, "Bearer ")) {
|
||||
a.log.Info("authorization key format is incorrect. Should be 'Bearer <key>'")
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -135,14 +136,14 @@ func (a *Auth) TokenByAPIKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAgent := r.UserAgent()
|
||||
ip, err := web.GetRequestIP(r)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo, err := a.service.TokenByAPIKey(ctx, userAgent, ip, apiKey)
|
||||
if err != nil {
|
||||
a.log.Info("Error authenticating token request", zap.Error(ErrAuthAPI.Wrap(err)))
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -184,13 +185,13 @@ func (a *Auth) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
sessionID, err := a.getSessionID(r)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = a.service.DeleteSession(ctx, sessionID)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -225,7 +226,7 @@ func (a *Auth) Register(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewDecoder(r.Body).Decode(®isterData)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -234,23 +235,23 @@ func (a *Auth) Register(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
isValidEmail := utils.ValidateEmail(registerData.Email)
|
||||
if !isValidEmail {
|
||||
a.serveJSONError(w, console.ErrValidation.Wrap(errs.New("Invalid email.")))
|
||||
a.serveJSONError(ctx, w, console.ErrValidation.Wrap(errs.New("Invalid email.")))
|
||||
return
|
||||
}
|
||||
|
||||
if len([]rune(registerData.Partner)) > 100 {
|
||||
a.serveJSONError(w, console.ErrValidation.Wrap(errs.New("Partner must be less than or equal to 100 characters")))
|
||||
a.serveJSONError(ctx, w, console.ErrValidation.Wrap(errs.New("Partner must be less than or equal to 100 characters")))
|
||||
return
|
||||
}
|
||||
|
||||
if len([]rune(registerData.SignupPromoCode)) > 100 {
|
||||
a.serveJSONError(w, console.ErrValidation.Wrap(errs.New("Promo code must be less than or equal to 100 characters")))
|
||||
a.serveJSONError(ctx, w, console.ErrValidation.Wrap(errs.New("Promo code must be less than or equal to 100 characters")))
|
||||
return
|
||||
}
|
||||
|
||||
verified, unverified, err := a.service.GetUserByEmailWithUnverified(ctx, registerData.Email)
|
||||
if err != nil && !console.ErrEmailNotFound.Has(err) {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -279,7 +280,7 @@ func (a *Auth) Register(w http.ResponseWriter, r *http.Request) {
|
||||
} else {
|
||||
secret, err := console.RegistrationSecretFromBase64(registerData.SecretInput)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -289,7 +290,7 @@ func (a *Auth) Register(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ip, err := web.GetRequestIP(r)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -312,7 +313,7 @@ func (a *Auth) Register(w http.ResponseWriter, r *http.Request) {
|
||||
secret,
|
||||
)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -351,7 +352,7 @@ func (a *Auth) Register(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
token, err := a.service.GenerateActivationToken(ctx, user.ID, user.Email)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -390,13 +391,13 @@ func (a *Auth) GetFreezeStatus(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userID, err := a.service.GetUserID(ctx)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
freeze, warning, err := a.accountFreezeService.GetAll(ctx, userID)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -424,12 +425,12 @@ func (a *Auth) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewDecoder(r.Body).Decode(&updatedInfo)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = a.service.UpdateAccount(ctx, updatedInfo.FullName, updatedInfo.ShortName); err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -440,27 +441,29 @@ func (a *Auth) GetAccount(w http.ResponseWriter, r *http.Request) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
var user struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
FullName string `json:"fullName"`
|
||||
ShortName string `json:"shortName"`
|
||||
Email string `json:"email"`
|
||||
Partner string `json:"partner"`
|
||||
ProjectLimit int `json:"projectLimit"`
|
||||
ProjectStorageLimit int64 `json:"projectStorageLimit"`
|
||||
IsProfessional bool `json:"isProfessional"`
|
||||
Position string `json:"position"`
|
||||
CompanyName string `json:"companyName"`
|
||||
EmployeeCount string `json:"employeeCount"`
|
||||
HaveSalesContact bool `json:"haveSalesContact"`
|
||||
PaidTier bool `json:"paidTier"`
|
||||
MFAEnabled bool `json:"isMFAEnabled"`
|
||||
MFARecoveryCodeCount int `json:"mfaRecoveryCodeCount"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
FullName string `json:"fullName"`
|
||||
ShortName string `json:"shortName"`
|
||||
Email string `json:"email"`
|
||||
Partner string `json:"partner"`
|
||||
ProjectLimit int `json:"projectLimit"`
|
||||
ProjectStorageLimit int64 `json:"projectStorageLimit"`
|
||||
ProjectBandwidthLimit int64 `json:"projectBandwidthLimit"`
|
||||
ProjectSegmentLimit int64 `json:"projectSegmentLimit"`
|
||||
IsProfessional bool `json:"isProfessional"`
|
||||
Position string `json:"position"`
|
||||
CompanyName string `json:"companyName"`
|
||||
EmployeeCount string `json:"employeeCount"`
|
||||
HaveSalesContact bool `json:"haveSalesContact"`
|
||||
PaidTier bool `json:"paidTier"`
|
||||
MFAEnabled bool `json:"isMFAEnabled"`
|
||||
MFARecoveryCodeCount int `json:"mfaRecoveryCodeCount"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
consoleUser, err := console.GetUser(ctx)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -473,6 +476,8 @@ func (a *Auth) GetAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
user.ProjectLimit = consoleUser.ProjectLimit
|
||||
user.ProjectStorageLimit = consoleUser.ProjectStorageLimit
|
||||
user.ProjectBandwidthLimit = consoleUser.ProjectBandwidthLimit
|
||||
user.ProjectSegmentLimit = consoleUser.ProjectSegmentLimit
|
||||
user.IsProfessional = consoleUser.IsProfessional
|
||||
user.CompanyName = consoleUser.CompanyName
|
||||
user.Position = consoleUser.Position
|
||||
@ -497,7 +502,7 @@ func (a *Auth) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
defer mon.Task()(&ctx)(&errNotImplemented)
|
||||
|
||||
// We do not want to allow account deletion via API currently.
|
||||
a.serveJSONError(w, errNotImplemented)
|
||||
a.serveJSONError(ctx, w, errNotImplemented)
|
||||
}
|
||||
|
||||
// ChangeEmail auth user, changes users email for a new one.
|
||||
@ -512,13 +517,13 @@ func (a *Auth) ChangeEmail(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewDecoder(r.Body).Decode(&emailChange)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = a.service.ChangeEmail(ctx, emailChange.NewEmail)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -536,13 +541,13 @@ func (a *Auth) ChangePassword(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewDecoder(r.Body).Decode(&passwordChange)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = a.service.ChangePassword(ctx, passwordChange.CurrentPassword, passwordChange.NewPassword)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -560,23 +565,23 @@ func (a *Auth) ForgotPassword(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewDecoder(r.Body).Decode(&forgotPassword)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
ip, err := web.GetRequestIP(r)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
valid, err := a.service.VerifyForgotPasswordCaptcha(ctx, forgotPassword.CaptchaResponse, ip)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
a.serveJSONError(w, console.ErrCaptcha.New("captcha validation unsuccessful"))
|
||||
a.serveJSONError(ctx, w, console.ErrCaptcha.New("captcha validation unsuccessful"))
|
||||
return
|
||||
}
|
||||
|
||||
@ -608,7 +613,7 @@ func (a *Auth) ForgotPassword(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
recoveryToken, err := a.service.GeneratePasswordRecoveryToken(ctx, user.ID)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -659,7 +664,7 @@ func (a *Auth) ResendEmail(w http.ResponseWriter, r *http.Request) {
|
||||
if verified != nil {
|
||||
recoveryToken, err := a.service.GeneratePasswordRecoveryToken(ctx, verified.ID)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -688,7 +693,7 @@ func (a *Auth) ResendEmail(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
token, err := a.service.GenerateActivationToken(ctx, user.ID, user.Email)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -719,31 +724,31 @@ func (a *Auth) EnableUserMFA(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
err = json.NewDecoder(r.Body).Decode(&data)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = a.service.EnableUserMFA(ctx, data.Passcode, time.Now())
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := a.getSessionID(r)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
consoleUser, err := console.GetUser(ctx)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = a.service.DeleteAllSessionsByUserIDExcept(ctx, consoleUser.ID, sessionID)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -760,31 +765,31 @@ func (a *Auth) DisableUserMFA(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
err = json.NewDecoder(r.Body).Decode(&data)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = a.service.DisableUserMFA(ctx, data.Passcode, time.Now(), data.RecoveryCode)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
sessionID, err := a.getSessionID(r)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
consoleUser, err := console.GetUser(ctx)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = a.service.DeleteAllSessionsByUserIDExcept(ctx, consoleUser.ID, sessionID)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -797,7 +802,7 @@ func (a *Auth) GenerateMFASecretKey(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
key, err := a.service.ResetMFASecretKey(ctx)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -817,7 +822,7 @@ func (a *Auth) GenerateMFARecoveryCodes(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
codes, err := a.service.ResetMFARecoveryCodes(ctx)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -844,7 +849,7 @@ func (a *Auth) ResetPassword(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewDecoder(r.Body).Decode(&resetPassword)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
}
|
||||
|
||||
err = a.service.ResetPassword(ctx, resetPassword.RecoveryToken, resetPassword.NewPassword, resetPassword.MFAPasscode, resetPassword.MFARecoveryCode, time.Now())
|
||||
@ -882,7 +887,7 @@ func (a *Auth) ResetPassword(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
} else {
|
||||
a.cookieAuth.RemoveTokenCookie(w)
|
||||
}
|
||||
@ -896,19 +901,19 @@ func (a *Auth) RefreshSession(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
tokenInfo, err := a.cookieAuth.GetToken(r)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.FromBytes(tokenInfo.Token.Payload)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
tokenInfo.ExpiresAt, err = a.service.RefreshSession(ctx, id)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -929,7 +934,7 @@ func (a *Auth) GetUserSettings(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
settings, err := a.service.GetUserSettings(ctx)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -954,7 +959,7 @@ func (a *Auth) SetOnboardingStatus(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewDecoder(r.Body).Decode(&updateInfo)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -964,7 +969,7 @@ func (a *Auth) SetOnboardingStatus(w http.ResponseWriter, r *http.Request) {
|
||||
OnboardingStep: updateInfo.OnboardingStep,
|
||||
})
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -985,7 +990,7 @@ func (a *Auth) SetUserSettings(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewDecoder(r.Body).Decode(&updateInfo)
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -1006,7 +1011,7 @@ func (a *Auth) SetUserSettings(w http.ResponseWriter, r *http.Request) {
|
||||
SessionDuration: newDuration,
|
||||
})
|
||||
if err != nil {
|
||||
a.serveJSONError(w, err)
|
||||
a.serveJSONError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -1018,9 +1023,9 @@ func (a *Auth) SetUserSettings(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// serveJSONError writes JSON error to response output stream.
|
||||
func (a *Auth) serveJSONError(w http.ResponseWriter, err error) {
|
||||
func (a *Auth) serveJSONError(ctx context.Context, w http.ResponseWriter, err error) {
|
||||
status := a.getStatusCode(err)
|
||||
web.ServeCustomJSONError(a.log, w, status, err, a.getUserErrorMessage(err))
|
||||
web.ServeCustomJSONError(ctx, a.log, w, status, err, a.getUserErrorMessage(err))
|
||||
}
|
||||
|
||||
// getStatusCode returns http.StatusCode depends on console error class.
|
||||
|
@ -4,6 +4,7 @@
|
||||
package consoleapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
@ -49,28 +50,28 @@ func (b *Buckets) AllBucketNames(w http.ResponseWriter, r *http.Request) {
|
||||
if projectIDString != "" {
|
||||
projectID, err = uuid.FromString(projectIDString)
|
||||
if err != nil {
|
||||
b.serveJSONError(w, http.StatusBadRequest, err)
|
||||
b.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
} else if publicIDString != "" {
|
||||
projectID, err = uuid.FromString(publicIDString)
|
||||
if err != nil {
|
||||
b.serveJSONError(w, http.StatusBadRequest, err)
|
||||
b.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
b.serveJSONError(w, http.StatusBadRequest, errs.New("Project ID was not provided."))
|
||||
b.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("Project ID was not provided."))
|
||||
return
|
||||
}
|
||||
|
||||
bucketNames, err := b.service.GetAllBucketNames(ctx, projectID)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
b.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
b.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
b.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
b.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -81,6 +82,6 @@ func (b *Buckets) AllBucketNames(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// serveJSONError writes JSON error to response output stream.
|
||||
func (b *Buckets) serveJSONError(w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(b.log, w, status, err)
|
||||
func (b *Buckets) serveJSONError(ctx context.Context, w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(ctx, b.log, w, status, err)
|
||||
}
|
||||
|
@ -58,11 +58,11 @@ func (p *Payments) SetupAccount(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -83,11 +83,11 @@ func (p *Payments) AccountBalance(w http.ResponseWriter, r *http.Request) {
|
||||
balance, err := p.service.Payments().AccountBalance(ctx)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -112,12 +112,12 @@ func (p *Payments) ProjectsCharges(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
sinceStamp, err := strconv.ParseInt(r.URL.Query().Get("from"), 10, 64)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
beforeStamp, err := strconv.ParseInt(r.URL.Query().Get("to"), 10, 64)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -127,11 +127,11 @@ func (p *Payments) ProjectsCharges(w http.ResponseWriter, r *http.Request) {
|
||||
charges, err := p.service.Payments().ProjectsCharges(ctx, since, before)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -155,8 +155,8 @@ func (p *Payments) ProjectsCharges(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// triggerAttemptPaymentIfFrozenOrWarned checks if the account is frozen and if frozen, will trigger attempt to pay outstanding invoices.
|
||||
func (p *Payments) triggerAttemptPaymentIfFrozenOrWarned(ctx context.Context) (err error) {
|
||||
// triggerAttemptPayment attempts payment and unfreezes/unwarn user if needed.
|
||||
func (p *Payments) triggerAttemptPayment(ctx context.Context) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
userID, err := p.service.GetUserID(ctx)
|
||||
@ -169,12 +169,11 @@ func (p *Payments) triggerAttemptPaymentIfFrozenOrWarned(ctx context.Context) (e
|
||||
return err
|
||||
}
|
||||
|
||||
if freeze != nil || warning != nil {
|
||||
err = p.service.Payments().AttemptPayOverdueInvoices(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.service.Payments().AttemptPayOverdueInvoices(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if freeze != nil {
|
||||
err = p.accountFreezeService.UnfreezeUser(ctx, userID)
|
||||
if err != nil {
|
||||
@ -197,7 +196,7 @@ func (p *Payments) AddCreditCard(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -206,17 +205,17 @@ func (p *Payments) AddCreditCard(w http.ResponseWriter, r *http.Request) {
|
||||
_, err = p.service.Payments().AddCreditCard(ctx, token)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = p.triggerAttemptPaymentIfFrozenOrWarned(ctx)
|
||||
err = p.triggerAttemptPayment(ctx)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -232,11 +231,11 @@ func (p *Payments) ListCreditCards(w http.ResponseWriter, r *http.Request) {
|
||||
cards, err := p.service.Payments().ListCreditCards(ctx)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -259,24 +258,24 @@ func (p *Payments) MakeCreditCardDefault(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
cardID, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = p.service.Payments().MakeCreditCardDefault(ctx, string(cardID))
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = p.triggerAttemptPaymentIfFrozenOrWarned(ctx)
|
||||
err = p.triggerAttemptPayment(ctx)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -291,18 +290,18 @@ func (p *Payments) RemoveCreditCard(w http.ResponseWriter, r *http.Request) {
|
||||
cardID := vars["cardId"]
|
||||
|
||||
if cardID == "" {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = p.service.Payments().RemoveCreditCard(ctx, cardID)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -318,11 +317,11 @@ func (p *Payments) BillingHistory(w http.ResponseWriter, r *http.Request) {
|
||||
billingHistory, err := p.service.Payments().BillingHistory(ctx)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -345,7 +344,7 @@ func (p *Payments) ApplyCouponCode(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
couponCode := string(bodyBytes)
|
||||
@ -358,7 +357,7 @@ func (p *Payments) ApplyCouponCode(w http.ResponseWriter, r *http.Request) {
|
||||
} else if payments.ErrCouponConflict.Has(err) {
|
||||
status = http.StatusConflict
|
||||
}
|
||||
p.serveJSONError(w, status, err)
|
||||
p.serveJSONError(ctx, w, status, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -378,11 +377,11 @@ func (p *Payments) GetCoupon(w http.ResponseWriter, r *http.Request) {
|
||||
coupon, err := p.service.Payments().GetCoupon(ctx)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -402,15 +401,15 @@ func (p *Payments) GetWallet(w http.ResponseWriter, r *http.Request) {
|
||||
walletInfo, err := p.service.Payments().GetWallet(ctx)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
if errs.Is(err, billing.ErrNoWallet) {
|
||||
p.serveJSONError(w, http.StatusNotFound, err)
|
||||
p.serveJSONError(ctx, w, http.StatusNotFound, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -430,11 +429,11 @@ func (p *Payments) ClaimWallet(w http.ResponseWriter, r *http.Request) {
|
||||
walletInfo, err := p.service.Payments().ClaimWallet(ctx)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -454,11 +453,11 @@ func (p *Payments) WalletPayments(w http.ResponseWriter, r *http.Request) {
|
||||
walletPayments, err := p.service.Payments().WalletPayments(ctx)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -467,6 +466,30 @@ func (p *Payments) WalletPayments(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// WalletPaymentsWithConfirmations returns with the list of storjscan transactions (including confirmations count) for user`s wallet.
|
||||
func (p *Payments) WalletPaymentsWithConfirmations(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
var err error
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
walletPayments, err := p.service.Payments().WalletPaymentsWithConfirmations(ctx)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(w).Encode(walletPayments); err != nil {
|
||||
p.log.Error("failed to encode wallet payments with confirmations", zap.Error(ErrPaymentsAPI.Wrap(err)))
|
||||
}
|
||||
}
|
||||
|
||||
// GetProjectUsagePriceModel returns the project usage price model for the user.
|
||||
func (p *Payments) GetProjectUsagePriceModel(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
@ -477,7 +500,7 @@ func (p *Payments) GetProjectUsagePriceModel(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
user, err := console.GetUser(ctx)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -496,7 +519,7 @@ func (p *Payments) PurchasePackage(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -504,13 +527,13 @@ func (p *Payments) PurchasePackage(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
u, err := console.GetUser(ctx)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
pkg, err := p.packagePlans.Get(u.UserAgent)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusNotFound, err)
|
||||
p.serveJSONError(ctx, w, http.StatusNotFound, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -518,9 +541,9 @@ func (p *Payments) PurchasePackage(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
switch {
|
||||
case console.ErrUnauthorized.Has(err):
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
default:
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -529,19 +552,19 @@ func (p *Payments) PurchasePackage(w http.ResponseWriter, r *http.Request) {
|
||||
err = p.service.Payments().UpdatePackage(ctx, description, time.Now())
|
||||
if err != nil {
|
||||
if !console.ErrAlreadyHasPackage.Has(err) {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = p.service.Payments().Purchase(ctx, pkg.Price, description, card.ID)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = p.service.Payments().ApplyCredit(ctx, pkg.Credit, description); err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -554,7 +577,7 @@ func (p *Payments) PackageAvailable(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
u, err := console.GetUser(ctx)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -567,6 +590,6 @@ func (p *Payments) PackageAvailable(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// serveJSONError writes JSON error to response output stream.
|
||||
func (p *Payments) serveJSONError(w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(p.log, w, status, err)
|
||||
func (p *Payments) serveJSONError(ctx context.Context, w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(ctx, p.log, w, status, err)
|
||||
}
|
||||
|
@ -4,9 +4,11 @@
|
||||
package consoleapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
@ -42,18 +44,18 @@ func (p *Projects) GetSalt(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
idParam, ok := mux.Vars(r)["id"]
|
||||
if !ok {
|
||||
p.serveJSONError(w, http.StatusBadRequest, errs.New("missing id route param"))
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("missing id route param"))
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.FromString(idParam)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
salt, err := p.service.GetSalt(ctx, id)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
p.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -61,7 +63,7 @@ func (p *Projects) GetSalt(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewEncoder(w).Encode(b64SaltString)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,12 +75,12 @@ func (p *Projects) InviteUsers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
idParam, ok := mux.Vars(r)["id"]
|
||||
if !ok {
|
||||
p.serveJSONError(w, http.StatusBadRequest, errs.New("missing project id route param"))
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("missing project id route param"))
|
||||
return
|
||||
}
|
||||
id, err := uuid.FromString(idParam)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
var data struct {
|
||||
@ -87,13 +89,17 @@ func (p *Projects) InviteUsers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewDecoder(r.Body).Decode(&data)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
for i, email := range data.Emails {
|
||||
data.Emails[i] = strings.TrimSpace(email)
|
||||
}
|
||||
|
||||
_, err = p.service.InviteProjectMembers(ctx, id, data.Emails)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -104,28 +110,28 @@ func (p *Projects) GetInviteLink(w http.ResponseWriter, r *http.Request) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
idParam, ok := mux.Vars(r)["id"]
|
||||
if !ok {
|
||||
p.serveJSONError(w, http.StatusBadRequest, errs.New("missing project id route param"))
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("missing project id route param"))
|
||||
return
|
||||
}
|
||||
id, err := uuid.FromString(idParam)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
email := r.URL.Query().Get("email")
|
||||
if email == "" {
|
||||
p.serveJSONError(w, http.StatusBadRequest, errs.New("missing email query param"))
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("missing email query param"))
|
||||
return
|
||||
}
|
||||
|
||||
link, err := p.service.GetInviteLink(ctx, id, email)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(link)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,7 +145,7 @@ func (p *Projects) GetUserInvitations(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
invites, err := p.service.GetUserProjectInvitations(ctx)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -156,7 +162,7 @@ func (p *Projects) GetUserInvitations(w http.ResponseWriter, r *http.Request) {
|
||||
for _, invite := range invites {
|
||||
proj, err := p.service.GetProjectNoAuth(ctx, invite.ProjectID)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -170,7 +176,7 @@ func (p *Projects) GetUserInvitations(w http.ResponseWriter, r *http.Request) {
|
||||
if invite.InviterID != nil {
|
||||
inviter, err := p.service.GetUser(ctx, *invite.InviterID)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
respInvite.InviterEmail = inviter.Email
|
||||
@ -181,7 +187,7 @@ func (p *Projects) GetUserInvitations(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
p.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -195,13 +201,13 @@ func (p *Projects) RespondToInvitation(w http.ResponseWriter, r *http.Request) {
|
||||
var idParam string
|
||||
|
||||
if idParam, ok = mux.Vars(r)["id"]; !ok {
|
||||
p.serveJSONError(w, http.StatusBadRequest, errs.New("missing project id route param"))
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("missing project id route param"))
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.FromString(idParam)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
@ -210,7 +216,7 @@ func (p *Projects) RespondToInvitation(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = json.NewDecoder(r.Body).Decode(&payload)
|
||||
if err != nil {
|
||||
p.serveJSONError(w, http.StatusBadRequest, err)
|
||||
p.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -225,11 +231,11 @@ func (p *Projects) RespondToInvitation(w http.ResponseWriter, r *http.Request) {
|
||||
case console.ErrValidation.Has(err):
|
||||
status = http.StatusBadRequest
|
||||
}
|
||||
p.serveJSONError(w, status, err)
|
||||
p.serveJSONError(ctx, w, status, err)
|
||||
}
|
||||
}
|
||||
|
||||
// serveJSONError writes JSON error to response output stream.
|
||||
func (p *Projects) serveJSONError(w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(p.log, w, status, err)
|
||||
func (p *Projects) serveJSONError(ctx context.Context, w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(ctx, p.log, w, status, err)
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
package consoleapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@ -50,13 +51,13 @@ func (ul *UsageLimits) ProjectUsageLimits(w http.ResponseWriter, r *http.Request
|
||||
var idParam string
|
||||
|
||||
if idParam, ok = mux.Vars(r)["id"]; !ok {
|
||||
ul.serveJSONError(w, http.StatusBadRequest, errs.New("missing project id route param"))
|
||||
ul.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("missing project id route param"))
|
||||
return
|
||||
}
|
||||
|
||||
projectID, err := uuid.FromString(idParam)
|
||||
if err != nil {
|
||||
ul.serveJSONError(w, http.StatusBadRequest, errs.New("invalid project id: %v", err))
|
||||
ul.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("invalid project id: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
@ -64,13 +65,13 @@ func (ul *UsageLimits) ProjectUsageLimits(w http.ResponseWriter, r *http.Request
|
||||
if err != nil {
|
||||
switch {
|
||||
case console.ErrUnauthorized.Has(err):
|
||||
ul.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
ul.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
case accounting.ErrInvalidArgument.Has(err):
|
||||
ul.serveJSONError(w, http.StatusBadRequest, err)
|
||||
ul.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
default:
|
||||
ul.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
ul.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -90,11 +91,11 @@ func (ul *UsageLimits) TotalUsageLimits(w http.ResponseWriter, r *http.Request)
|
||||
usageLimits, err := ul.service.GetTotalUsageLimits(ctx)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
ul.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
ul.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
ul.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
ul.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -114,23 +115,23 @@ func (ul *UsageLimits) DailyUsage(w http.ResponseWriter, r *http.Request) {
|
||||
var idParam string
|
||||
|
||||
if idParam, ok = mux.Vars(r)["id"]; !ok {
|
||||
ul.serveJSONError(w, http.StatusBadRequest, errs.New("missing project id route param"))
|
||||
ul.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("missing project id route param"))
|
||||
return
|
||||
}
|
||||
projectID, err := uuid.FromString(idParam)
|
||||
if err != nil {
|
||||
ul.serveJSONError(w, http.StatusBadRequest, errs.New("invalid project id: %v", err))
|
||||
ul.serveJSONError(ctx, w, http.StatusBadRequest, errs.New("invalid project id: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
sinceStamp, err := strconv.ParseInt(r.URL.Query().Get("from"), 10, 64)
|
||||
if err != nil {
|
||||
ul.serveJSONError(w, http.StatusBadRequest, err)
|
||||
ul.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
beforeStamp, err := strconv.ParseInt(r.URL.Query().Get("to"), 10, 64)
|
||||
if err != nil {
|
||||
ul.serveJSONError(w, http.StatusBadRequest, err)
|
||||
ul.serveJSONError(ctx, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -140,11 +141,11 @@ func (ul *UsageLimits) DailyUsage(w http.ResponseWriter, r *http.Request) {
|
||||
dailyUsage, err := ul.service.GetDailyProjectUsage(ctx, projectID, since, before)
|
||||
if err != nil {
|
||||
if console.ErrUnauthorized.Has(err) {
|
||||
ul.serveJSONError(w, http.StatusUnauthorized, err)
|
||||
ul.serveJSONError(ctx, w, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
|
||||
ul.serveJSONError(w, http.StatusInternalServerError, err)
|
||||
ul.serveJSONError(ctx, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -155,6 +156,6 @@ func (ul *UsageLimits) DailyUsage(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// serveJSONError writes JSON error to response output stream.
|
||||
func (ul *UsageLimits) serveJSONError(w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(ul.log, w, status, err)
|
||||
func (ul *UsageLimits) serveJSONError(ctx context.Context, w http.ResponseWriter, status int, err error) {
|
||||
web.ServeJSONError(ctx, ul.log, w, status, err)
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -32,6 +33,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"storj.io/common/errs2"
|
||||
"storj.io/common/http/requestid"
|
||||
"storj.io/common/memory"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/storj/private/web"
|
||||
@ -62,10 +64,14 @@ var (
|
||||
|
||||
// Config contains configuration for console web server.
|
||||
type Config struct {
|
||||
Address string `help:"server address of the graphql api gateway and frontend app" devDefault:"127.0.0.1:0" releaseDefault:":10100"`
|
||||
StaticDir string `help:"path to static resources" default:""`
|
||||
Watch bool `help:"whether to load templates on each request" default:"false" devDefault:"true"`
|
||||
ExternalAddress string `help:"external endpoint of the satellite if hosted" default:""`
|
||||
Address string `help:"server address of the graphql api gateway and frontend app" devDefault:"127.0.0.1:0" releaseDefault:":10100"`
|
||||
FrontendAddress string `help:"server address of the front-end app" devDefault:"127.0.0.1:0" releaseDefault:":10200"`
|
||||
ExternalAddress string `help:"external endpoint of the satellite if hosted" default:""`
|
||||
FrontendEnable bool `help:"feature flag to toggle whether console back-end server should also serve front-end endpoints" default:"true"`
|
||||
BackendReverseProxy string `help:"the target URL of console back-end reverse proxy for local development when running a UI server" default:""`
|
||||
|
||||
StaticDir string `help:"path to static resources" default:""`
|
||||
Watch bool `help:"whether to load templates on each request" default:"false" devDefault:"true"`
|
||||
|
||||
AuthToken string `help:"auth token needed for access to registration token creation endpoint" default:"" testDefault:"very-secret-token"`
|
||||
AuthTokenSecret string `help:"secret used to sign auth tokens" releaseDefault:"" devDefault:"my-suppa-secret-key"`
|
||||
@ -138,7 +144,8 @@ type Server struct {
|
||||
userIDRateLimiter *web.RateLimiter
|
||||
nodeURL storj.NodeURL
|
||||
|
||||
stripePublicKey string
|
||||
stripePublicKey string
|
||||
neededTokenPaymentConfirmations int
|
||||
|
||||
packagePlans paymentsconfig.PackagePlans
|
||||
|
||||
@ -204,23 +211,24 @@ func (a *apiAuth) RemoveAuthCookie(w http.ResponseWriter) {
|
||||
}
|
||||
|
||||
// NewServer creates new instance of console server.
|
||||
func NewServer(logger *zap.Logger, config Config, service *console.Service, oidcService *oidc.Service, mailService *mailservice.Service, analytics *analytics.Service, abTesting *abtesting.Service, accountFreezeService *console.AccountFreezeService, listener net.Listener, stripePublicKey string, nodeURL storj.NodeURL, packagePlans paymentsconfig.PackagePlans) *Server {
|
||||
func NewServer(logger *zap.Logger, config Config, service *console.Service, oidcService *oidc.Service, mailService *mailservice.Service, analytics *analytics.Service, abTesting *abtesting.Service, accountFreezeService *console.AccountFreezeService, listener net.Listener, stripePublicKey string, neededTokenPaymentConfirmations int, nodeURL storj.NodeURL, packagePlans paymentsconfig.PackagePlans) *Server {
|
||||
server := Server{
|
||||
log: logger,
|
||||
config: config,
|
||||
listener: listener,
|
||||
service: service,
|
||||
mailService: mailService,
|
||||
analytics: analytics,
|
||||
abTesting: abTesting,
|
||||
stripePublicKey: stripePublicKey,
|
||||
ipRateLimiter: web.NewIPRateLimiter(config.RateLimit, logger),
|
||||
userIDRateLimiter: NewUserIDRateLimiter(config.RateLimit, logger),
|
||||
nodeURL: nodeURL,
|
||||
packagePlans: packagePlans,
|
||||
log: logger,
|
||||
config: config,
|
||||
listener: listener,
|
||||
service: service,
|
||||
mailService: mailService,
|
||||
analytics: analytics,
|
||||
abTesting: abTesting,
|
||||
stripePublicKey: stripePublicKey,
|
||||
neededTokenPaymentConfirmations: neededTokenPaymentConfirmations,
|
||||
ipRateLimiter: web.NewIPRateLimiter(config.RateLimit, logger),
|
||||
userIDRateLimiter: NewUserIDRateLimiter(config.RateLimit, logger),
|
||||
nodeURL: nodeURL,
|
||||
packagePlans: packagePlans,
|
||||
}
|
||||
|
||||
logger.Debug("Starting Satellite UI.", zap.Stringer("Address", server.listener.Addr()))
|
||||
logger.Debug("Starting Satellite Console server.", zap.Stringer("Address", server.listener.Addr()))
|
||||
|
||||
server.cookieAuth = consolewebauth.NewCookieAuth(consolewebauth.CookieSettings{
|
||||
Name: "_tokenKey",
|
||||
@ -245,6 +253,8 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc
|
||||
// the earliest in the HTTP chain.
|
||||
router.Use(newTraceRequestMiddleware(logger, router))
|
||||
|
||||
router.Use(requestid.AddToContext)
|
||||
|
||||
// limit body size
|
||||
router.Use(newBodyLimiterMiddleware(logger.Named("body-limiter-middleware"), config.BodySizeLimit))
|
||||
|
||||
@ -324,6 +334,7 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc
|
||||
paymentsRouter.HandleFunc("/wallet", paymentController.GetWallet).Methods(http.MethodGet, http.MethodOptions)
|
||||
paymentsRouter.HandleFunc("/wallet", paymentController.ClaimWallet).Methods(http.MethodPost, http.MethodOptions)
|
||||
paymentsRouter.HandleFunc("/wallet/payments", paymentController.WalletPayments).Methods(http.MethodGet, http.MethodOptions)
|
||||
paymentsRouter.HandleFunc("/wallet/payments-with-confirmations", paymentController.WalletPaymentsWithConfirmations).Methods(http.MethodGet, http.MethodOptions)
|
||||
paymentsRouter.HandleFunc("/billing-history", paymentController.BillingHistory).Methods(http.MethodGet, http.MethodOptions)
|
||||
paymentsRouter.Handle("/coupon/apply", server.userIDRateLimiter.Limit(http.HandlerFunc(paymentController.ApplyCouponCode))).Methods(http.MethodPatch, http.MethodOptions)
|
||||
paymentsRouter.HandleFunc("/coupon", paymentController.GetCoupon).Methods(http.MethodGet, http.MethodOptions)
|
||||
@ -353,30 +364,26 @@ func NewServer(logger *zap.Logger, config Config, service *console.Service, oidc
|
||||
analyticsRouter.HandleFunc("/event", analyticsController.EventTriggered).Methods(http.MethodPost, http.MethodOptions)
|
||||
analyticsRouter.HandleFunc("/page", analyticsController.PageEventTriggered).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
if server.config.StaticDir != "" {
|
||||
oidc := oidc.NewEndpoint(
|
||||
server.nodeURL, server.config.ExternalAddress,
|
||||
logger, oidcService, service,
|
||||
server.config.OauthCodeExpiry, server.config.OauthAccessTokenExpiry, server.config.OauthRefreshTokenExpiry,
|
||||
)
|
||||
oidc := oidc.NewEndpoint(
|
||||
server.nodeURL, server.config.ExternalAddress,
|
||||
logger, oidcService, service,
|
||||
server.config.OauthCodeExpiry, server.config.OauthAccessTokenExpiry, server.config.OauthRefreshTokenExpiry,
|
||||
)
|
||||
|
||||
router.HandleFunc("/.well-known/openid-configuration", oidc.WellKnownConfiguration)
|
||||
router.Handle("/oauth/v2/authorize", server.withAuth(http.HandlerFunc(oidc.AuthorizeUser))).Methods(http.MethodPost)
|
||||
router.Handle("/oauth/v2/tokens", server.ipRateLimiter.Limit(http.HandlerFunc(oidc.Tokens))).Methods(http.MethodPost)
|
||||
router.Handle("/oauth/v2/userinfo", server.ipRateLimiter.Limit(http.HandlerFunc(oidc.UserInfo))).Methods(http.MethodGet)
|
||||
router.Handle("/oauth/v2/clients/{id}", server.withAuth(http.HandlerFunc(oidc.GetClient))).Methods(http.MethodGet)
|
||||
router.HandleFunc("/.well-known/openid-configuration", oidc.WellKnownConfiguration)
|
||||
router.Handle("/oauth/v2/authorize", server.withAuth(http.HandlerFunc(oidc.AuthorizeUser))).Methods(http.MethodPost)
|
||||
router.Handle("/oauth/v2/tokens", server.ipRateLimiter.Limit(http.HandlerFunc(oidc.Tokens))).Methods(http.MethodPost)
|
||||
router.Handle("/oauth/v2/userinfo", server.ipRateLimiter.Limit(http.HandlerFunc(oidc.UserInfo))).Methods(http.MethodGet)
|
||||
router.Handle("/oauth/v2/clients/{id}", server.withAuth(http.HandlerFunc(oidc.GetClient))).Methods(http.MethodGet)
|
||||
|
||||
router.HandleFunc("/invited", server.handleInvited)
|
||||
router.HandleFunc("/activation", server.accountActivationHandler)
|
||||
router.HandleFunc("/cancel-password-recovery", server.cancelPasswordRecoveryHandler)
|
||||
|
||||
if server.config.StaticDir != "" && server.config.FrontendEnable {
|
||||
fs := http.FileServer(http.Dir(server.config.StaticDir))
|
||||
router.PathPrefix("/static/").Handler(server.withCORS(server.brotliMiddleware(http.StripPrefix("/static", fs))))
|
||||
|
||||
router.HandleFunc("/invited", server.handleInvited)
|
||||
|
||||
// These paths previously required a trailing slash, so we support both forms for now
|
||||
slashRouter := router.NewRoute().Subrouter()
|
||||
slashRouter.StrictSlash(true)
|
||||
slashRouter.HandleFunc("/activation", server.accountActivationHandler)
|
||||
slashRouter.HandleFunc("/cancel-password-recovery", server.cancelPasswordRecoveryHandler)
|
||||
|
||||
if server.config.UseVuetifyProject {
|
||||
router.PathPrefix("/vuetifypoc").Handler(server.withCORS(http.HandlerFunc(server.vuetifyAppHandler)))
|
||||
}
|
||||
@ -427,6 +434,100 @@ func (server *Server) Run(ctx context.Context) (err error) {
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// NewFrontendServer creates new instance of console front-end server.
|
||||
// NB: The return type is currently consoleweb.Server, but it does not contain all the dependencies.
|
||||
// It should only be used with RunFrontEnd and Close. We plan on moving this to its own type, but
|
||||
// right now since we have a feature flag to allow the backend server to continue serving the frontend, it
|
||||
// makes it easier if they are the same type.
|
||||
func NewFrontendServer(logger *zap.Logger, config Config, listener net.Listener, nodeURL storj.NodeURL, stripePublicKey string) (server *Server, err error) {
|
||||
server = &Server{
|
||||
log: logger,
|
||||
config: config,
|
||||
listener: listener,
|
||||
nodeURL: nodeURL,
|
||||
stripePublicKey: stripePublicKey,
|
||||
}
|
||||
|
||||
logger.Debug("Starting Satellite UI server.", zap.Stringer("Address", server.listener.Addr()))
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// N.B. This middleware has to be the first one because it has to be called
|
||||
// the earliest in the HTTP chain.
|
||||
router.Use(newTraceRequestMiddleware(logger, router))
|
||||
|
||||
// in local development, proxy certain requests to the console back-end server
|
||||
if config.BackendReverseProxy != "" {
|
||||
target, err := url.Parse(config.BackendReverseProxy)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
logger.Debug("Reverse proxy targeting", zap.String("address", config.BackendReverseProxy))
|
||||
|
||||
router.PathPrefix("/api").Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
router.PathPrefix("/oauth").Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
}))
|
||||
router.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
})
|
||||
router.HandleFunc("/invited", func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
})
|
||||
router.HandleFunc("/activation", func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
})
|
||||
router.HandleFunc("/cancel-password-recovery", func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
})
|
||||
router.HandleFunc("/registrationToken/", func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
})
|
||||
router.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
fs := http.FileServer(http.Dir(server.config.StaticDir))
|
||||
|
||||
router.HandleFunc("/robots.txt", server.seoHandler)
|
||||
router.PathPrefix("/static/").Handler(server.brotliMiddleware(http.StripPrefix("/static", fs)))
|
||||
router.HandleFunc("/config", server.frontendConfigHandler)
|
||||
if server.config.UseVuetifyProject {
|
||||
router.PathPrefix("/vuetifypoc").Handler(http.HandlerFunc(server.vuetifyAppHandler))
|
||||
}
|
||||
router.PathPrefix("/").Handler(http.HandlerFunc(server.appHandler))
|
||||
server.server = http.Server{
|
||||
Handler: server.withRequest(router),
|
||||
MaxHeaderBytes: ContentLengthLimit.Int(),
|
||||
}
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// RunFrontend starts the server that runs the webapp.
|
||||
func (server *Server) RunFrontend(ctx context.Context) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
var group errgroup.Group
|
||||
group.Go(func() error {
|
||||
<-ctx.Done()
|
||||
return server.server.Shutdown(context.Background())
|
||||
})
|
||||
group.Go(func() error {
|
||||
defer cancel()
|
||||
err := server.server.Serve(server.listener)
|
||||
if errs2.IsCanceled(err) || errors.Is(err, http.ErrServerClosed) {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
})
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
// Close closes server and underlying listener.
|
||||
func (server *Server) Close() error {
|
||||
return server.server.Close()
|
||||
@ -550,7 +651,7 @@ func (server *Server) withAuth(handler http.Handler) http.Handler {
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
web.ServeJSONError(server.log, w, http.StatusUnauthorized, console.ErrUnauthorized.Wrap(err))
|
||||
web.ServeJSONError(ctx, server.log, w, http.StatusUnauthorized, console.ErrUnauthorized.Wrap(err))
|
||||
server.cookieAuth.RemoveTokenCookie(w)
|
||||
}
|
||||
}()
|
||||
@ -620,6 +721,7 @@ func (server *Server) frontendConfigHandler(w http.ResponseWriter, r *http.Reque
|
||||
PricingPackagesEnabled: server.config.PricingPackagesEnabled,
|
||||
NewUploadModalEnabled: server.config.NewUploadModalEnabled,
|
||||
GalleryViewEnabled: server.config.GalleryViewEnabled,
|
||||
NeededTransactionConfirmations: server.neededTokenPaymentConfirmations,
|
||||
}
|
||||
|
||||
err := json.NewEncoder(w).Encode(&cfg)
|
||||
@ -783,7 +885,7 @@ func (server *Server) handleInvited(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if user != nil {
|
||||
http.Redirect(w, r, loginLink+"?email="+user.Email, http.StatusTemporaryRedirect)
|
||||
http.Redirect(w, r, loginLink+"?email="+url.QueryEscape(user.Email), http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
|
||||
@ -829,6 +931,10 @@ func (server *Server) graphqlHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
jsonError.Error = err.Error()
|
||||
|
||||
if requestID := requestid.FromContext(ctx); requestID != "" {
|
||||
jsonError.Error += fmt.Sprintf(" (request id: %s)", requestID)
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(jsonError); err != nil {
|
||||
server.log.Error("error graphql error", zap.Error(err))
|
||||
}
|
||||
@ -893,6 +999,10 @@ func (server *Server) graphqlHandler(w http.ResponseWriter, r *http.Request) {
|
||||
jsonError.Errors = append(jsonError.Errors, err.Message)
|
||||
}
|
||||
|
||||
if requestID := requestid.FromContext(ctx); requestID != "" {
|
||||
jsonError.Errors = append(jsonError.Errors, fmt.Sprintf("request id: %s", requestID))
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(jsonError); err != nil {
|
||||
server.log.Error("error graphql error", zap.Error(err))
|
||||
}
|
||||
@ -1124,7 +1234,7 @@ func newBodyLimiterMiddleware(log *zap.Logger, limit memory.Size) mux.Middleware
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength > limit.Int64() {
|
||||
web.ServeJSONError(log, w, http.StatusRequestEntityTooLarge, errs.New("Request body is too large"))
|
||||
web.ServeJSONError(r.Context(), log, w, http.StatusRequestEntityTooLarge, errs.New("Request body is too large"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@ package consoleweb_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
@ -140,14 +141,14 @@ func TestInvitedRouting(t *testing.T) {
|
||||
params := "email=invited%40mail.test&inviter=Project+Owner&inviter_email=owner%40mail.test&project=Test+Project"
|
||||
checkInvitedRedirect("Invited - Nonexistent user", baseURL+"signup?"+params, token)
|
||||
|
||||
invitedUser, err := sat.AddUser(ctx, console.CreateUser{
|
||||
_, err = sat.AddUser(ctx, console.CreateUser{
|
||||
FullName: "Invited User",
|
||||
Email: invitedEmail,
|
||||
}, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// valid invite should redirect to login page with email.
|
||||
checkInvitedRedirect("Invited - User invited", loginURL+"?email="+invitedUser.Email, token)
|
||||
checkInvitedRedirect("Invited - User invited", loginURL+"?email=invited%40mail.test", token)
|
||||
})
|
||||
}
|
||||
|
||||
@ -219,3 +220,56 @@ func TestUserIDRateLimiter(t *testing.T) {
|
||||
require.Equal(t, http.StatusTooManyRequests, applyCouponStatus(firstToken))
|
||||
})
|
||||
}
|
||||
|
||||
func TestConsoleBackendWithDisabledFrontEnd(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
|
||||
config.Console.FrontendEnable = false
|
||||
config.Console.UseVuetifyProject = true
|
||||
},
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
apiAddr := planet.Satellites[0].API.Console.Listener.Addr().String()
|
||||
uiAddr := planet.Satellites[0].UI.Console.Listener.Addr().String()
|
||||
|
||||
testEndpoint(ctx, t, apiAddr, "/", http.StatusNotFound)
|
||||
testEndpoint(ctx, t, apiAddr, "/vuetifypoc", http.StatusNotFound)
|
||||
testEndpoint(ctx, t, apiAddr, "/static/", http.StatusNotFound)
|
||||
|
||||
testEndpoint(ctx, t, uiAddr, "/", http.StatusOK)
|
||||
testEndpoint(ctx, t, uiAddr, "/vuetifypoc", http.StatusOK)
|
||||
testEndpoint(ctx, t, uiAddr, "/static/", http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConsoleBackendWithEnabledFrontEnd(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
|
||||
config.Console.UseVuetifyProject = true
|
||||
},
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
apiAddr := planet.Satellites[0].API.Console.Listener.Addr().String()
|
||||
|
||||
testEndpoint(ctx, t, apiAddr, "/", http.StatusOK)
|
||||
testEndpoint(ctx, t, apiAddr, "/vuetifypoc", http.StatusOK)
|
||||
testEndpoint(ctx, t, apiAddr, "/static/", http.StatusOK)
|
||||
})
|
||||
}
|
||||
|
||||
func testEndpoint(ctx context.Context, t *testing.T, addr, endpoint string, expectedStatus int) {
|
||||
client := http.Client{}
|
||||
url := "http://" + addr + endpoint
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, expectedStatus, result.StatusCode)
|
||||
require.NoError(t, result.Body.Close())
|
||||
}
|
||||
|
89
satellite/console/observerupgradeuser.go
Normal file
89
satellite/console/observerupgradeuser.go
Normal file
@ -0,0 +1,89 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package console
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"storj.io/common/memory"
|
||||
"storj.io/storj/satellite/payments/billing"
|
||||
)
|
||||
|
||||
var _ billing.Observer = (*UpgradeUserObserver)(nil)
|
||||
|
||||
// UpgradeUserObserver used to upgrade user if their balance is more than $10 after confirmed token transaction.
|
||||
type UpgradeUserObserver struct {
|
||||
consoleDB DB
|
||||
transactionsDB billing.TransactionsDB
|
||||
usageLimitsConfig UsageLimitsConfig
|
||||
userBalanceForUpgrade int64
|
||||
}
|
||||
|
||||
// NewUpgradeUserObserver creates new observer instance.
|
||||
func NewUpgradeUserObserver(consoleDB DB, transactionsDB billing.TransactionsDB, usageLimitsConfig UsageLimitsConfig, userBalanceForUpgrade int64) *UpgradeUserObserver {
|
||||
return &UpgradeUserObserver{
|
||||
consoleDB: consoleDB,
|
||||
transactionsDB: transactionsDB,
|
||||
usageLimitsConfig: usageLimitsConfig,
|
||||
userBalanceForUpgrade: userBalanceForUpgrade,
|
||||
}
|
||||
}
|
||||
|
||||
// Process puts user into the paid tier and converts projects to upgraded limits.
|
||||
func (o *UpgradeUserObserver) Process(ctx context.Context, transaction billing.Transaction) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
user, err := o.consoleDB.Users().Get(ctx, transaction.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.PaidTier {
|
||||
return nil
|
||||
}
|
||||
|
||||
balance, err := o.transactionsDB.GetBalance(ctx, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check if user's balance is less than needed amount for upgrade.
|
||||
if balance.BaseUnits() < o.userBalanceForUpgrade {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = o.consoleDB.Users().UpdatePaidTier(ctx, user.ID, true,
|
||||
o.usageLimitsConfig.Bandwidth.Paid,
|
||||
o.usageLimitsConfig.Storage.Paid,
|
||||
o.usageLimitsConfig.Segment.Paid,
|
||||
o.usageLimitsConfig.Project.Paid,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
projects, err := o.consoleDB.Projects().GetOwn(ctx, user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, project := range projects {
|
||||
if project.StorageLimit == nil || *project.StorageLimit < o.usageLimitsConfig.Storage.Paid {
|
||||
project.StorageLimit = new(memory.Size)
|
||||
*project.StorageLimit = o.usageLimitsConfig.Storage.Paid
|
||||
}
|
||||
if project.BandwidthLimit == nil || *project.BandwidthLimit < o.usageLimitsConfig.Bandwidth.Paid {
|
||||
project.BandwidthLimit = new(memory.Size)
|
||||
*project.BandwidthLimit = o.usageLimitsConfig.Bandwidth.Paid
|
||||
}
|
||||
if project.SegmentLimit == nil || *project.SegmentLimit < o.usageLimitsConfig.Segment.Paid {
|
||||
*project.SegmentLimit = o.usageLimitsConfig.Segment.Paid
|
||||
}
|
||||
err = o.consoleDB.Projects().Update(ctx, &project)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -24,6 +24,7 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"storj.io/common/currency"
|
||||
"storj.io/common/http/requestid"
|
||||
"storj.io/common/macaroon"
|
||||
"storj.io/common/memory"
|
||||
"storj.io/common/uuid"
|
||||
@ -75,7 +76,7 @@ const (
|
||||
projInviteInvalidErrMsg = "The invitation has expired or is invalid"
|
||||
projInviteAlreadyMemberErrMsg = "You are already a member of the project"
|
||||
projInviteResponseInvalidErrMsg = "Invalid project member invitation response"
|
||||
projInviteActiveErrMsg = "The invitation for '%s' has not expired yet"
|
||||
projInviteExistsErrMsg = "An active invitation for '%s' already exists"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -143,8 +144,8 @@ var (
|
||||
// or has expired.
|
||||
ErrProjectInviteInvalid = errs.Class("invalid project invitation")
|
||||
|
||||
// ErrProjectInviteActive occurs when trying to reinvite a user whose invitation hasn't expired yet.
|
||||
ErrProjectInviteActive = errs.Class("project invitation active")
|
||||
// ErrAlreadyInvited occurs when trying to invite a user who already has an unexpired invitation.
|
||||
ErrAlreadyInvited = errs.Class("user is already invited")
|
||||
)
|
||||
|
||||
// Service is handling accounts related logic.
|
||||
@ -193,6 +194,7 @@ type Config struct {
|
||||
LoginAttemptsWithoutPenalty int `help:"number of times user can try to login without penalty" default:"3"`
|
||||
FailedLoginPenalty float64 `help:"incremental duration of penalty for failed login attempts in minutes" default:"2.0"`
|
||||
ProjectInvitationExpiration time.Duration `help:"duration that project member invitations are valid for" default:"168h"`
|
||||
UserBalanceForUpgrade int64 `help:"amount of base units of US micro dollars needed to upgrade user's tier status" default:"10000000"`
|
||||
UsageLimits UsageLimitsConfig
|
||||
Captcha CaptchaConfig
|
||||
Session SessionConfig
|
||||
@ -302,6 +304,10 @@ func (s *Service) auditLog(ctx context.Context, operation string, userID *uuid.U
|
||||
if email != "" {
|
||||
fields = append(fields, zap.String("email", email))
|
||||
}
|
||||
if requestID := requestid.FromContext(ctx); requestID != "" {
|
||||
fields = append(fields, zap.String("requestID", requestID))
|
||||
}
|
||||
|
||||
fields = append(fields, fields...)
|
||||
s.auditLogger.Info("console activity", fields...)
|
||||
}
|
||||
@ -2725,7 +2731,7 @@ func (s *Service) GetProjectUsageLimits(ctx context.Context, projectID uuid.UUID
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
prUsageLimits, err := s.getProjectUsageLimits(ctx, isMember.project.ID)
|
||||
prUsageLimits, err := s.getProjectUsageLimits(ctx, isMember.project.ID, true)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
@ -2767,7 +2773,7 @@ func (s *Service) GetTotalUsageLimits(ctx context.Context) (_ *ProjectUsageLimit
|
||||
var totalBandwidthUsed int64
|
||||
|
||||
for _, pr := range projects {
|
||||
prUsageLimits, err := s.getProjectUsageLimits(ctx, pr.ID)
|
||||
prUsageLimits, err := s.getProjectUsageLimits(ctx, pr.ID, false)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
@ -2786,7 +2792,7 @@ func (s *Service) GetTotalUsageLimits(ctx context.Context) (_ *ProjectUsageLimit
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) getProjectUsageLimits(ctx context.Context, projectID uuid.UUID) (_ *ProjectUsageLimits, err error) {
|
||||
func (s *Service) getProjectUsageLimits(ctx context.Context, projectID uuid.UUID, onlySettledBandwidth bool) (_ *ProjectUsageLimits, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
storageLimit, err := s.projectUsage.GetProjectStorageLimit(ctx, projectID)
|
||||
@ -2806,10 +2812,17 @@ func (s *Service) getProjectUsageLimits(ctx context.Context, projectID uuid.UUID
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bandwidthUsed, err := s.projectUsage.GetProjectBandwidthTotals(ctx, projectID)
|
||||
|
||||
var bandwidthUsed int64
|
||||
if onlySettledBandwidth {
|
||||
bandwidthUsed, err = s.projectUsage.GetProjectSettledBandwidth(ctx, projectID)
|
||||
} else {
|
||||
bandwidthUsed, err = s.projectUsage.GetProjectBandwidthTotals(ctx, projectID)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
segmentUsed, err := s.projectUsage.GetProjectSegmentTotals(ctx, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -2923,7 +2936,7 @@ func (s *Service) checkProjectLimit(ctx context.Context, userID uuid.UUID) (curr
|
||||
return 0, Error.Wrap(err)
|
||||
}
|
||||
|
||||
projects, err := s.GetUsersProjects(ctx)
|
||||
projects, err := s.store.Projects().GetOwn(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, Error.Wrap(err)
|
||||
}
|
||||
@ -3078,6 +3091,12 @@ func EtherscanURL(tx string) string {
|
||||
// ErrWalletNotClaimed shows that no address is claimed by the user.
|
||||
var ErrWalletNotClaimed = errs.Class("wallet is not claimed")
|
||||
|
||||
// TestSwapDepositWallets replaces the existing handler for deposit wallets with
|
||||
// the one specified for use in testing.
|
||||
func (payment Payments) TestSwapDepositWallets(dw payments.DepositWallets) {
|
||||
payment.service.depositWallets = dw
|
||||
}
|
||||
|
||||
// ClaimWallet requests a new wallet for the users to be used for payments. If wallet is already claimed,
|
||||
// it will return with the info without error.
|
||||
func (payment Payments) ClaimWallet(ctx context.Context) (_ WalletInfo, err error) {
|
||||
@ -3198,6 +3217,27 @@ func (payment Payments) WalletPayments(ctx context.Context) (_ WalletPayments, e
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WalletPaymentsWithConfirmations returns with all the native blockchain payments (including pending) for a user's wallet.
|
||||
func (payment Payments) WalletPaymentsWithConfirmations(ctx context.Context) (paymentsWithConfirmations []payments.WalletPaymentWithConfirmations, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
user, err := GetUser(ctx)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
address, err := payment.service.depositWallets.Get(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
paymentsWithConfirmations, err = payment.service.depositWallets.PaymentsWithConfirmations(ctx, address)
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Purchase makes a purchase of `price` amount with description of `desc` and payment method with id of `paymentMethodID`.
|
||||
// If a paid invoice with the same description exists, then we assume this is a retried request and don't create and pay
|
||||
// another invoice.
|
||||
@ -3563,7 +3603,6 @@ func (s *Service) RespondToProjectInvitation(ctx context.Context, projectID uuid
|
||||
|
||||
// InviteProjectMembers invites users by email to given project.
|
||||
// If an invitation already exists and has expired, it will be replaced and the user will be sent a new email.
|
||||
// Email addresses not belonging to a user are ignored.
|
||||
// projectID here may be project.PublicID or project.ID.
|
||||
func (s *Service) InviteProjectMembers(ctx context.Context, projectID uuid.UUID, emails []string) (invites []ProjectInvitation, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
@ -3581,6 +3620,14 @@ func (s *Service) InviteProjectMembers(ctx context.Context, projectID uuid.UUID,
|
||||
var users []*User
|
||||
var newUserEmails []string
|
||||
for _, email := range emails {
|
||||
invite, err := s.store.ProjectInvitations().Get(ctx, projectID, email)
|
||||
if err != nil && !errs.Is(err, sql.ErrNoRows) {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
if invite != nil && !s.IsProjectInvitationExpired(invite) {
|
||||
return nil, ErrAlreadyInvited.New(projInviteExistsErrMsg, email)
|
||||
}
|
||||
|
||||
invitedUser, err := s.store.Users().GetByEmail(ctx, email)
|
||||
if err == nil {
|
||||
_, err = s.isProjectMember(ctx, invitedUser.ID, projectID)
|
||||
@ -3589,14 +3636,6 @@ func (s *Service) InviteProjectMembers(ctx context.Context, projectID uuid.UUID,
|
||||
} else if err == nil {
|
||||
return nil, ErrAlreadyMember.New("%s is already a member", email)
|
||||
}
|
||||
|
||||
invite, err := s.store.ProjectInvitations().Get(ctx, projectID, email)
|
||||
if err != nil && !errs.Is(err, sql.ErrNoRows) {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
if invite != nil && !s.IsProjectInvitationExpired(invite) {
|
||||
return nil, ErrProjectInviteActive.New(projInviteActiveErrMsg, invitedUser.Email)
|
||||
}
|
||||
users = append(users, invitedUser)
|
||||
} else if errs.Is(err, sql.ErrNoRows) {
|
||||
newUserEmails = append(newUserEmails, email)
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
"storj.io/common/currency"
|
||||
"storj.io/common/macaroon"
|
||||
"storj.io/common/memory"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
@ -434,6 +435,20 @@ func TestService(t *testing.T) {
|
||||
require.Equal(t, updatedBandwidthLimit.Int64(), limits1.BandwidthLimit)
|
||||
require.Equal(t, updatedStorageLimit.Int64(), limits2.StorageLimit)
|
||||
require.Equal(t, updatedBandwidthLimit.Int64(), limits2.BandwidthLimit)
|
||||
|
||||
bucket := "testbucket1"
|
||||
err = planet.Uplinks[1].CreateBucket(ctx, sat, bucket)
|
||||
require.NoError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
err = sat.DB.Orders().UpdateBucketBandwidthAllocation(ctx, up2Proj.ID, []byte(bucket), pb.PieceAction_GET, 1000, startOfMonth)
|
||||
require.NoError(t, err)
|
||||
|
||||
limits2, err = service.GetProjectUsageLimits(userCtx2, up2Proj.PublicID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, limits2)
|
||||
require.Equal(t, int64(0), limits2.BandwidthUsed)
|
||||
})
|
||||
|
||||
t.Run("ChangeEmail", func(t *testing.T) {
|
||||
@ -1687,6 +1702,86 @@ func TestPaymentsWalletPayments(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
type mockDepositWallets struct {
|
||||
address blockchain.Address
|
||||
payments []payments.WalletPaymentWithConfirmations
|
||||
}
|
||||
|
||||
func (dw mockDepositWallets) Claim(_ context.Context, _ uuid.UUID) (blockchain.Address, error) {
|
||||
return dw.address, nil
|
||||
}
|
||||
|
||||
func (dw mockDepositWallets) Get(_ context.Context, _ uuid.UUID) (blockchain.Address, error) {
|
||||
return dw.address, nil
|
||||
}
|
||||
|
||||
func (dw mockDepositWallets) Payments(_ context.Context, _ blockchain.Address, _ int, _ int64) (p []payments.WalletPayment, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (dw mockDepositWallets) PaymentsWithConfirmations(_ context.Context, _ blockchain.Address) ([]payments.WalletPaymentWithConfirmations, error) {
|
||||
return dw.payments, nil
|
||||
}
|
||||
|
||||
func TestWalletPaymentsWithConfirmations(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
sat := planet.Satellites[0]
|
||||
service := sat.API.Console.Service
|
||||
paymentsService := service.Payments()
|
||||
|
||||
user, err := sat.AddUser(ctx, console.CreateUser{
|
||||
FullName: "Test User",
|
||||
Email: "test@mail.test",
|
||||
Password: "example",
|
||||
}, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
now := time.Now()
|
||||
wallet := blockchaintest.NewAddress()
|
||||
|
||||
var expected []payments.WalletPaymentWithConfirmations
|
||||
for i := 0; i < 3; i++ {
|
||||
expected = append(expected, payments.WalletPaymentWithConfirmations{
|
||||
From: blockchaintest.NewAddress().Hex(),
|
||||
To: wallet.Hex(),
|
||||
TokenValue: currency.AmountFromBaseUnits(int64(i), currency.StorjToken).AsDecimal(),
|
||||
USDValue: currency.AmountFromBaseUnits(int64(i), currency.USDollarsMicro).AsDecimal(),
|
||||
Status: payments.PaymentStatusConfirmed,
|
||||
BlockHash: blockchaintest.NewHash().Hex(),
|
||||
BlockNumber: int64(i),
|
||||
Transaction: blockchaintest.NewHash().Hex(),
|
||||
LogIndex: i,
|
||||
Timestamp: now,
|
||||
Confirmations: int64(i),
|
||||
BonusTokens: decimal.NewFromInt(int64(i)),
|
||||
})
|
||||
}
|
||||
|
||||
paymentsService.TestSwapDepositWallets(mockDepositWallets{address: wallet, payments: expected})
|
||||
|
||||
reqCtx := console.WithUser(ctx, user)
|
||||
|
||||
walletPayments, err := paymentsService.WalletPaymentsWithConfirmations(reqCtx)
|
||||
require.NoError(t, err)
|
||||
require.NotZero(t, len(walletPayments))
|
||||
|
||||
for i, wp := range walletPayments {
|
||||
require.Equal(t, expected[i].From, wp.From)
|
||||
require.Equal(t, expected[i].To, wp.To)
|
||||
require.Equal(t, expected[i].TokenValue, wp.TokenValue)
|
||||
require.Equal(t, expected[i].USDValue, wp.USDValue)
|
||||
require.Equal(t, expected[i].Status, wp.Status)
|
||||
require.Equal(t, expected[i].BlockHash, wp.BlockHash)
|
||||
require.Equal(t, expected[i].BlockNumber, wp.BlockNumber)
|
||||
require.Equal(t, expected[i].Transaction, wp.Transaction)
|
||||
require.Equal(t, expected[i].LogIndex, wp.LogIndex)
|
||||
require.Equal(t, expected[i].Timestamp, wp.Timestamp)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPaymentsPurchase(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 0,
|
||||
@ -2049,7 +2144,7 @@ func TestProjectInvitations(t *testing.T) {
|
||||
|
||||
// resending an active invitation should fail.
|
||||
invites, err = service.InviteProjectMembers(ctx2, project.ID, []string{user3.Email})
|
||||
require.True(t, console.ErrProjectInviteActive.Has(err))
|
||||
require.True(t, console.ErrAlreadyInvited.Has(err))
|
||||
require.Empty(t, invites)
|
||||
|
||||
// expire the invitation.
|
||||
|
@ -7,17 +7,22 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/common/identity/testidentity"
|
||||
"storj.io/common/nodetag"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/rpc/rpcpeer"
|
||||
"storj.io/common/rpc/rpcstatus"
|
||||
"storj.io/common/signing"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/storj/private/testplanet"
|
||||
"storj.io/storj/storagenode"
|
||||
"storj.io/storj/storagenode/contact"
|
||||
)
|
||||
|
||||
func TestSatelliteContactEndpoint(t *testing.T) {
|
||||
@ -177,3 +182,143 @@ func TestSatellitePingMe_Failure(t *testing.T) {
|
||||
require.Nil(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSatelliteContactEndpoint_WithNodeTags(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 1, UplinkCount: 0,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
StorageNode: func(index int, config *storagenode.Config) {
|
||||
config.Server.DisableQUIC = true
|
||||
config.Contact.Tags = contact.SignedTags(pb.SignedNodeTagSets{
|
||||
Tags: []*pb.SignedNodeTagSet{},
|
||||
})
|
||||
},
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
nodeInfo := planet.StorageNodes[0].Contact.Service.Local()
|
||||
ident := planet.StorageNodes[0].Identity
|
||||
|
||||
peer := rpcpeer.Peer{
|
||||
Addr: &net.TCPAddr{
|
||||
IP: net.ParseIP(nodeInfo.Address),
|
||||
Port: 5,
|
||||
},
|
||||
State: tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{ident.Leaf, ident.CA},
|
||||
},
|
||||
}
|
||||
|
||||
unsignedTags := &pb.NodeTagSet{
|
||||
NodeId: ident.ID.Bytes(),
|
||||
Tags: []*pb.Tag{
|
||||
{
|
||||
Name: "soc",
|
||||
Value: []byte{1},
|
||||
},
|
||||
{
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
signedTags, err := nodetag.Sign(ctx, unsignedTags, signing.SignerFromFullIdentity(planet.Satellites[0].Identity))
|
||||
require.NoError(t, err)
|
||||
|
||||
peerCtx := rpcpeer.NewContext(ctx, &peer)
|
||||
resp, err := planet.Satellites[0].Contact.Endpoint.CheckIn(peerCtx, &pb.CheckInRequest{
|
||||
Address: nodeInfo.Address,
|
||||
Version: &nodeInfo.Version,
|
||||
Capacity: &nodeInfo.Capacity,
|
||||
Operator: &nodeInfo.Operator,
|
||||
DebounceLimit: 3,
|
||||
Features: 0xf,
|
||||
SignedTags: &pb.SignedNodeTagSets{
|
||||
Tags: []*pb.SignedNodeTagSet{
|
||||
signedTags,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
tags, err := planet.Satellites[0].DB.OverlayCache().GetNodeTags(ctx, ident.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tags, 2)
|
||||
sort.Slice(tags, func(i, j int) bool {
|
||||
return tags[i].Name < tags[j].Name
|
||||
})
|
||||
require.Equal(t, "foo", tags[0].Name)
|
||||
require.Equal(t, "bar", string(tags[0].Value))
|
||||
|
||||
require.Equal(t, "soc", tags[1].Name)
|
||||
require.Equal(t, []byte{1}, tags[1].Value)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestSatelliteContactEndpoint_WithWrongNodeTags(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 1, UplinkCount: 0,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
StorageNode: func(index int, config *storagenode.Config) {
|
||||
config.Server.DisableQUIC = true
|
||||
config.Contact.Tags = contact.SignedTags(pb.SignedNodeTagSets{
|
||||
Tags: []*pb.SignedNodeTagSet{},
|
||||
})
|
||||
},
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
nodeInfo := planet.StorageNodes[0].Contact.Service.Local()
|
||||
ident := planet.StorageNodes[0].Identity
|
||||
|
||||
peer := rpcpeer.Peer{
|
||||
Addr: &net.TCPAddr{
|
||||
IP: net.ParseIP(nodeInfo.Address),
|
||||
Port: 5,
|
||||
},
|
||||
State: tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{ident.Leaf, ident.CA},
|
||||
},
|
||||
}
|
||||
|
||||
wrongNodeID := testidentity.MustPregeneratedIdentity(99, storj.LatestIDVersion()).ID
|
||||
unsignedTags := &pb.NodeTagSet{
|
||||
NodeId: wrongNodeID.Bytes(),
|
||||
Tags: []*pb.Tag{
|
||||
{
|
||||
Name: "soc",
|
||||
Value: []byte{1},
|
||||
},
|
||||
{
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
signedTags, err := nodetag.Sign(ctx, unsignedTags, signing.SignerFromFullIdentity(planet.Satellites[0].Identity))
|
||||
require.NoError(t, err)
|
||||
|
||||
peerCtx := rpcpeer.NewContext(ctx, &peer)
|
||||
resp, err := planet.Satellites[0].Contact.Endpoint.CheckIn(peerCtx, &pb.CheckInRequest{
|
||||
Address: nodeInfo.Address,
|
||||
Version: &nodeInfo.Version,
|
||||
Capacity: &nodeInfo.Capacity,
|
||||
Operator: &nodeInfo.Operator,
|
||||
DebounceLimit: 3,
|
||||
Features: 0xf,
|
||||
SignedTags: &pb.SignedNodeTagSets{
|
||||
Tags: []*pb.SignedNodeTagSet{
|
||||
signedTags,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
tags, err := planet.Satellites[0].DB.OverlayCache().GetNodeTags(ctx, ident.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tags, 0)
|
||||
})
|
||||
}
|
||||
|
@ -114,6 +114,10 @@ func (endpoint *Endpoint) CheckIn(ctx context.Context, req *pb.CheckInRequest) (
|
||||
req.Operator.WalletFeatures = nil
|
||||
}
|
||||
}
|
||||
err = endpoint.service.processNodeTags(ctx, nodeID, req.SignedTags)
|
||||
if err != nil {
|
||||
endpoint.log.Info("failed to update node tags", zap.String("node address", req.Address), zap.Stringer("Node ID", nodeID), zap.Error(err))
|
||||
}
|
||||
|
||||
nodeInfo := overlay.NodeCheckInInfo{
|
||||
NodeID: peerID.ID,
|
||||
|
@ -12,11 +12,13 @@ import (
|
||||
"github.com/zeebo/errs"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"storj.io/common/nodetag"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/rpc"
|
||||
"storj.io/common/rpc/quic"
|
||||
"storj.io/common/rpc/rpcstatus"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
)
|
||||
|
||||
@ -49,19 +51,22 @@ type Service struct {
|
||||
timeout time.Duration
|
||||
idLimiter *RateLimiter
|
||||
allowPrivateIP bool
|
||||
|
||||
nodeTagAuthority nodetag.Authority
|
||||
}
|
||||
|
||||
// NewService creates a new contact service.
|
||||
func NewService(log *zap.Logger, self *overlay.NodeDossier, overlay *overlay.Service, peerIDs overlay.PeerIdentities, dialer rpc.Dialer, config Config) *Service {
|
||||
func NewService(log *zap.Logger, self *overlay.NodeDossier, overlay *overlay.Service, peerIDs overlay.PeerIdentities, dialer rpc.Dialer, authority nodetag.Authority, config Config) *Service {
|
||||
return &Service{
|
||||
log: log,
|
||||
self: self,
|
||||
overlay: overlay,
|
||||
peerIDs: peerIDs,
|
||||
dialer: dialer,
|
||||
timeout: config.Timeout,
|
||||
idLimiter: NewRateLimiter(config.RateLimitInterval, config.RateLimitBurst, config.RateLimitCacheSize),
|
||||
allowPrivateIP: config.AllowPrivateIP,
|
||||
log: log,
|
||||
self: self,
|
||||
overlay: overlay,
|
||||
peerIDs: peerIDs,
|
||||
dialer: dialer,
|
||||
timeout: config.Timeout,
|
||||
idLimiter: NewRateLimiter(config.RateLimitInterval, config.RateLimitBurst, config.RateLimitCacheSize),
|
||||
allowPrivateIP: config.AllowPrivateIP,
|
||||
nodeTagAuthority: authority,
|
||||
}
|
||||
}
|
||||
|
||||
@ -151,3 +156,56 @@ func (service *Service) pingNodeQUIC(ctx context.Context, nodeurl storj.NodeURL)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service *Service) processNodeTags(ctx context.Context, nodeID storj.NodeID, req *pb.SignedNodeTagSets) error {
|
||||
if req != nil {
|
||||
tags := nodeselection.NodeTags{}
|
||||
for _, t := range req.Tags {
|
||||
verifiedTags, signerID, err := verifyTags(ctx, service.nodeTagAuthority, nodeID, t)
|
||||
if err != nil {
|
||||
service.log.Info("Failed to verify tags.", zap.Error(err), zap.Stringer("NodeID", nodeID))
|
||||
continue
|
||||
}
|
||||
|
||||
ts := time.Unix(verifiedTags.Timestamp, 0)
|
||||
for _, vt := range verifiedTags.Tags {
|
||||
tags = append(tags, nodeselection.NodeTag{
|
||||
NodeID: nodeID,
|
||||
Name: vt.Name,
|
||||
Value: vt.Value,
|
||||
SignedAt: ts,
|
||||
Signer: signerID,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(tags) > 0 {
|
||||
err := service.overlay.UpdateNodeTags(ctx, tags)
|
||||
if err != nil {
|
||||
return Error.Wrap(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyTags(ctx context.Context, authority nodetag.Authority, nodeID storj.NodeID, t *pb.SignedNodeTagSet) (*pb.NodeTagSet, storj.NodeID, error) {
|
||||
signerID, err := storj.NodeIDFromBytes(t.SignerNodeId)
|
||||
if err != nil {
|
||||
return nil, signerID, errs.New("failed to parse signerNodeID from verifiedTags: '%x', %s", t.SignerNodeId, err.Error())
|
||||
}
|
||||
|
||||
verifiedTags, err := authority.Verify(ctx, t)
|
||||
if err != nil {
|
||||
return nil, signerID, errs.New("received node tags with wrong/unknown signature: '%x', %s", t.Signature, err.Error())
|
||||
}
|
||||
|
||||
signedNodeID, err := storj.NodeIDFromBytes(verifiedTags.NodeId)
|
||||
if err != nil {
|
||||
return nil, signerID, errs.New("failed to parse nodeID from verifiedTags: '%x', %s", verifiedTags.NodeId, err.Error())
|
||||
}
|
||||
|
||||
if signedNodeID != nodeID {
|
||||
return nil, signerID, errs.New("the tag is signed for a different node. Expected NodeID: '%s', Received NodeID: '%s'", nodeID, signedNodeID)
|
||||
}
|
||||
return verifiedTags, signerID, nil
|
||||
}
|
||||
|
149
satellite/contact/service_test.go
Normal file
149
satellite/contact/service_test.go
Normal file
@ -0,0 +1,149 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package contact
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/common/identity/testidentity"
|
||||
"storj.io/common/nodetag"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/signing"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/testcontext"
|
||||
)
|
||||
|
||||
func TestVerifyTags(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
snIdentity := testidentity.MustPregeneratedIdentity(0, storj.LatestIDVersion())
|
||||
signerIdentity := testidentity.MustPregeneratedIdentity(1, storj.LatestIDVersion())
|
||||
signer := signing.SignerFromFullIdentity(signerIdentity)
|
||||
authority := nodetag.Authority{
|
||||
signing.SignerFromFullIdentity(signerIdentity),
|
||||
}
|
||||
t.Run("ok tags", func(t *testing.T) {
|
||||
tags, err := nodetag.Sign(ctx, &pb.NodeTagSet{
|
||||
NodeId: snIdentity.ID.Bytes(),
|
||||
Tags: []*pb.Tag{
|
||||
{
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}, signer)
|
||||
require.NoError(t, err)
|
||||
|
||||
verifiedTags, signerID, err := verifyTags(ctx, authority, snIdentity.ID, tags)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, signerIdentity.ID, signerID)
|
||||
require.Len(t, verifiedTags.Tags, 1)
|
||||
require.Equal(t, "foo", verifiedTags.Tags[0].Name)
|
||||
require.Equal(t, []byte("bar"), verifiedTags.Tags[0].Value)
|
||||
})
|
||||
|
||||
t.Run("wrong signer ID", func(t *testing.T) {
|
||||
tags, err := nodetag.Sign(ctx, &pb.NodeTagSet{
|
||||
NodeId: snIdentity.ID.Bytes(),
|
||||
Tags: []*pb.Tag{
|
||||
{
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}, signer)
|
||||
require.NoError(t, err)
|
||||
tags.SignerNodeId = []byte{1, 2, 3, 4}
|
||||
|
||||
_, _, err = verifyTags(ctx, authority, snIdentity.ID, tags)
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "01020304")
|
||||
require.ErrorContains(t, err, "failed to parse signerNodeID")
|
||||
|
||||
})
|
||||
|
||||
t.Run("wrong signature", func(t *testing.T) {
|
||||
tags, err := nodetag.Sign(ctx, &pb.NodeTagSet{
|
||||
NodeId: snIdentity.ID.Bytes(),
|
||||
Tags: []*pb.Tag{
|
||||
{
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}, signer)
|
||||
require.NoError(t, err)
|
||||
tags.Signature = []byte{4, 3, 2, 1}
|
||||
|
||||
_, _, err = verifyTags(ctx, authority, snIdentity.ID, tags)
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "04030201")
|
||||
require.ErrorContains(t, err, "wrong/unknown signature")
|
||||
})
|
||||
|
||||
t.Run("unknown signer", func(t *testing.T) {
|
||||
otherSignerIdentity := testidentity.MustPregeneratedIdentity(2, storj.LatestIDVersion())
|
||||
otherSigner := signing.SignerFromFullIdentity(otherSignerIdentity)
|
||||
|
||||
tags, err := nodetag.Sign(ctx, &pb.NodeTagSet{
|
||||
NodeId: snIdentity.ID.Bytes(),
|
||||
Tags: []*pb.Tag{
|
||||
{
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}, otherSigner)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = verifyTags(ctx, authority, snIdentity.ID, tags)
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "wrong/unknown signature")
|
||||
})
|
||||
|
||||
t.Run("signed for different node", func(t *testing.T) {
|
||||
otherNodeID := testidentity.MustPregeneratedIdentity(3, storj.LatestIDVersion()).ID
|
||||
tags, err := nodetag.Sign(ctx, &pb.NodeTagSet{
|
||||
NodeId: otherNodeID.Bytes(),
|
||||
Tags: []*pb.Tag{
|
||||
{
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}, signer)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = verifyTags(ctx, authority, snIdentity.ID, tags)
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, snIdentity.ID.String())
|
||||
require.ErrorContains(t, err, "the tag is signed for a different node")
|
||||
})
|
||||
|
||||
t.Run("wrong NodeID", func(t *testing.T) {
|
||||
tags, err := nodetag.Sign(ctx, &pb.NodeTagSet{
|
||||
NodeId: []byte{4, 4, 4},
|
||||
Tags: []*pb.Tag{
|
||||
{
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}, signer)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = verifyTags(ctx, authority, snIdentity.ID, tags)
|
||||
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "040404")
|
||||
require.ErrorContains(t, err, "failed to parse nodeID")
|
||||
})
|
||||
|
||||
}
|
@ -34,6 +34,7 @@ import (
|
||||
"storj.io/storj/satellite/console/consoleauth"
|
||||
"storj.io/storj/satellite/console/dbcleanup"
|
||||
"storj.io/storj/satellite/console/emailreminders"
|
||||
"storj.io/storj/satellite/gc/sender"
|
||||
"storj.io/storj/satellite/mailservice"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/metabase/zombiedeletion"
|
||||
@ -142,6 +143,10 @@ type Core struct {
|
||||
ConsoleDBCleanup struct {
|
||||
Chore *dbcleanup.Chore
|
||||
}
|
||||
|
||||
GarbageCollection struct {
|
||||
Sender *sender.Service
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a new satellite.
|
||||
@ -244,7 +249,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
|
||||
{ // setup overlay
|
||||
peer.Overlay.DB = peer.DB.OverlayCache()
|
||||
peer.Overlay.Service, err = overlay.NewService(peer.Log.Named("overlay"), peer.Overlay.DB, peer.DB.NodeEvents(), config.Console.ExternalAddress, config.Console.SatelliteName, config.Overlay)
|
||||
peer.Overlay.Service, err = overlay.NewService(peer.Log.Named("overlay"), peer.Overlay.DB, peer.DB.NodeEvents(), config.Placement.CreateFilters, config.Console.ExternalAddress, config.Console.SatelliteName, config.Overlay)
|
||||
if err != nil {
|
||||
return nil, errs.Combine(err, peer.Close())
|
||||
}
|
||||
@ -491,7 +496,9 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
peer.Payments.StorjscanService = storjscan.NewService(log.Named("storjscan-service"),
|
||||
peer.DB.Wallets(),
|
||||
peer.DB.StorjscanPayments(),
|
||||
peer.Payments.StorjscanClient)
|
||||
peer.Payments.StorjscanClient,
|
||||
pc.Storjscan.Confirmations,
|
||||
pc.BonusRate)
|
||||
if err != nil {
|
||||
return nil, errs.Combine(err, peer.Close())
|
||||
}
|
||||
@ -512,6 +519,10 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
debug.Cycle("Payments Storjscan", peer.Payments.StorjscanChore.TransactionCycle),
|
||||
)
|
||||
|
||||
choreObservers := map[billing.ObserverBilling]billing.Observer{
|
||||
billing.ObserverUpgradeUser: console.NewUpgradeUserObserver(peer.DB.Console(), peer.DB.Billing(), config.Console.UsageLimits, config.Console.UserBalanceForUpgrade),
|
||||
}
|
||||
|
||||
peer.Payments.BillingChore = billing.NewChore(
|
||||
peer.Log.Named("payments.billing:chore"),
|
||||
[]billing.PaymentType{peer.Payments.StorjscanService},
|
||||
@ -519,6 +530,7 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
config.Payments.BillingConfig.Interval,
|
||||
config.Payments.BillingConfig.DisableLoop,
|
||||
config.Payments.BonusRate,
|
||||
choreObservers,
|
||||
)
|
||||
peer.Services.Add(lifecycle.Item{
|
||||
Name: "billing:chore",
|
||||
@ -534,6 +546,8 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
peer.DB.StripeCoinPayments(),
|
||||
peer.Payments.Accounts,
|
||||
peer.DB.Console().Users(),
|
||||
peer.DB.Wallets(),
|
||||
peer.DB.StorjscanPayments(),
|
||||
console.NewAccountFreezeService(db.Console().AccountFreezeEvents(), db.Console().Users(), db.Console().Projects(), peer.Analytics.Service),
|
||||
peer.Analytics.Service,
|
||||
config.AccountFreeze,
|
||||
@ -562,6 +576,22 @@ func New(log *zap.Logger, full *identity.FullIdentity, db DB,
|
||||
})
|
||||
}
|
||||
|
||||
{ // setup garbage collection
|
||||
peer.GarbageCollection.Sender = sender.NewService(
|
||||
peer.Log.Named("gc-sender"),
|
||||
config.GarbageCollection,
|
||||
peer.Dialer,
|
||||
peer.Overlay.DB,
|
||||
)
|
||||
|
||||
peer.Services.Add(lifecycle.Item{
|
||||
Name: "gc-sender",
|
||||
Run: peer.GarbageCollection.Sender.Run,
|
||||
})
|
||||
peer.Debug.Server.Panel.Add(
|
||||
debug.Cycle("Garbage Collection", peer.GarbageCollection.Sender.Loop))
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
"storj.io/storj/private/testplanet"
|
||||
"storj.io/storj/satellite"
|
||||
"storj.io/storj/satellite/gc/bloomfilter"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/metabase/rangedloop"
|
||||
@ -299,6 +300,157 @@ func TestGarbageCollectionWithCopies(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// TestGarbageCollectionWithCopies checks that server-side copy elements are not
|
||||
// affecting GC and nothing unexpected was deleted from storage nodes.
|
||||
func TestGarbageCollectionWithCopiesWithDuplicateMetadata(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 4, UplinkCount: 1,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
Satellite: testplanet.Combine(
|
||||
testplanet.ReconfigureRS(2, 3, 4, 4),
|
||||
func(log *zap.Logger, index int, config *satellite.Config) {
|
||||
config.Metainfo.ServerSideCopyDuplicateMetadata = true
|
||||
},
|
||||
),
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
satellite := planet.Satellites[0]
|
||||
|
||||
access := planet.Uplinks[0].Access[planet.Satellites[0].NodeURL().ID]
|
||||
accessString, err := access.Serialize()
|
||||
require.NoError(t, err)
|
||||
|
||||
gcsender := planet.Satellites[0].GarbageCollection.Sender
|
||||
gcsender.Config.AccessGrant = accessString
|
||||
|
||||
// configure filter uploader
|
||||
config := planet.Satellites[0].Config.GarbageCollectionBF
|
||||
config.AccessGrant = accessString
|
||||
|
||||
project, err := planet.Uplinks[0].OpenProject(ctx, satellite)
|
||||
require.NoError(t, err)
|
||||
defer ctx.Check(project.Close)
|
||||
|
||||
allSpaceUsedForPieces := func() (all int64) {
|
||||
for _, node := range planet.StorageNodes {
|
||||
_, piecesContent, _, err := node.Storage2.Store.SpaceUsedTotalAndBySatellite(ctx)
|
||||
require.NoError(t, err)
|
||||
all += piecesContent
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
expectedRemoteData := testrand.Bytes(8 * memory.KiB)
|
||||
expectedInlineData := testrand.Bytes(1 * memory.KiB)
|
||||
|
||||
encryptedSize, err := encryption.CalcEncryptedSize(int64(len(expectedRemoteData)), storj.EncryptionParameters{
|
||||
CipherSuite: storj.EncAESGCM,
|
||||
BlockSize: 29 * 256 * memory.B.Int32(), // hardcoded value from uplink
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
redundancyStrategy, err := planet.Satellites[0].Config.Metainfo.RS.RedundancyStrategy()
|
||||
require.NoError(t, err)
|
||||
|
||||
pieceSize := eestream.CalcPieceSize(encryptedSize, redundancyStrategy.ErasureScheme)
|
||||
singleRemoteUsed := pieceSize * int64(len(planet.StorageNodes))
|
||||
totalUsedByNodes := 2 * singleRemoteUsed // two remote objects
|
||||
|
||||
require.NoError(t, planet.Uplinks[0].Upload(ctx, satellite, "testbucket", "remote", expectedRemoteData))
|
||||
require.NoError(t, planet.Uplinks[0].Upload(ctx, satellite, "testbucket", "inline", expectedInlineData))
|
||||
require.NoError(t, planet.Uplinks[0].Upload(ctx, satellite, "testbucket", "remote-no-copy", expectedRemoteData))
|
||||
|
||||
_, err = project.CopyObject(ctx, "testbucket", "remote", "testbucket", "remote-copy", nil)
|
||||
require.NoError(t, err)
|
||||
_, err = project.CopyObject(ctx, "testbucket", "inline", "testbucket", "inline-copy", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, planet.WaitForStorageNodeEndpoints(ctx))
|
||||
|
||||
afterTotalUsedByNodes := allSpaceUsedForPieces()
|
||||
require.Equal(t, totalUsedByNodes, afterTotalUsedByNodes)
|
||||
|
||||
// Wait for bloom filter observer to finish
|
||||
rangedloopConfig := planet.Satellites[0].Config.RangedLoop
|
||||
|
||||
observer := bloomfilter.NewObserver(zaptest.NewLogger(t), config, planet.Satellites[0].Overlay.DB)
|
||||
segments := rangedloop.NewMetabaseRangeSplitter(planet.Satellites[0].Metabase.DB, rangedloopConfig.AsOfSystemInterval, rangedloopConfig.BatchSize)
|
||||
rangedLoop := rangedloop.NewService(zap.NewNop(), planet.Satellites[0].Config.RangedLoop, segments,
|
||||
[]rangedloop.Observer{observer})
|
||||
|
||||
_, err = rangedLoop.RunOnce(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// send to storagenode
|
||||
err = gcsender.RunOnce(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, node := range planet.StorageNodes {
|
||||
node.Storage2.RetainService.TestWaitUntilEmpty()
|
||||
}
|
||||
|
||||
// we should see all space used by all objects
|
||||
afterTotalUsedByNodes = allSpaceUsedForPieces()
|
||||
require.Equal(t, totalUsedByNodes, afterTotalUsedByNodes)
|
||||
|
||||
for _, toDelete := range []string{
|
||||
// delete ancestors, no change in used space
|
||||
"remote",
|
||||
"inline",
|
||||
// delete object without copy, used space should be decreased
|
||||
"remote-no-copy",
|
||||
} {
|
||||
_, err = project.DeleteObject(ctx, "testbucket", toDelete)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
planet.WaitForStorageNodeDeleters(ctx)
|
||||
|
||||
// run GC
|
||||
_, err = rangedLoop.RunOnce(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// send to storagenode
|
||||
err = gcsender.RunOnce(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, node := range planet.StorageNodes {
|
||||
node.Storage2.RetainService.TestWaitUntilEmpty()
|
||||
}
|
||||
|
||||
// verify that we deleted only pieces for "remote-no-copy" object
|
||||
afterTotalUsedByNodes = allSpaceUsedForPieces()
|
||||
require.Equal(t, totalUsedByNodes, afterTotalUsedByNodes)
|
||||
|
||||
// delete rest of objects to verify that everything will be removed also from SNs
|
||||
for _, toDelete := range []string{
|
||||
"remote-copy",
|
||||
"inline-copy",
|
||||
} {
|
||||
_, err = project.DeleteObject(ctx, "testbucket", toDelete)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
planet.WaitForStorageNodeDeleters(ctx)
|
||||
|
||||
// run GC
|
||||
_, err = rangedLoop.RunOnce(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// send to storagenode
|
||||
err = gcsender.RunOnce(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, node := range planet.StorageNodes {
|
||||
node.Storage2.RetainService.TestWaitUntilEmpty()
|
||||
}
|
||||
|
||||
// verify that nothing more was deleted from storage nodes after GC
|
||||
afterTotalUsedByNodes = allSpaceUsedForPieces()
|
||||
require.EqualValues(t, totalUsedByNodes, afterTotalUsedByNodes)
|
||||
})
|
||||
}
|
||||
|
||||
func getSegment(ctx *testcontext.Context, t *testing.T, satellite *testplanet.Satellite, upl *testplanet.Uplink, bucket, path string) (_ metabase.ObjectLocation, _ metabase.Segment) {
|
||||
access := upl.Access[satellite.ID()]
|
||||
|
||||
|
@ -201,13 +201,21 @@ func (cache *NodeAliasCache) EnsurePiecesToAliases(ctx context.Context, pieces P
|
||||
|
||||
// ConvertAliasesToPieces converts alias pieces to pieces.
|
||||
func (cache *NodeAliasCache) ConvertAliasesToPieces(ctx context.Context, aliasPieces AliasPieces) (_ Pieces, err error) {
|
||||
return cache.convertAliasesToPieces(ctx, aliasPieces, make(Pieces, len(aliasPieces)))
|
||||
}
|
||||
|
||||
// convertAliasesToPieces converts AliasPieces by populating Pieces with converted data.
|
||||
func (cache *NodeAliasCache) convertAliasesToPieces(ctx context.Context, aliasPieces AliasPieces, pieces Pieces) (_ Pieces, err error) {
|
||||
if len(aliasPieces) == 0 {
|
||||
return Pieces{}, nil
|
||||
}
|
||||
|
||||
if len(aliasPieces) != len(pieces) {
|
||||
return Pieces{}, Error.New("aliasPieces and pieces length must be equal")
|
||||
}
|
||||
|
||||
latest := cache.getLatest()
|
||||
|
||||
pieces := make(Pieces, len(aliasPieces))
|
||||
var missing []NodeAlias
|
||||
|
||||
for i, aliasPiece := range aliasPieces {
|
||||
@ -224,13 +232,13 @@ func (cache *NodeAliasCache) ConvertAliasesToPieces(ctx context.Context, aliasPi
|
||||
var err error
|
||||
latest, err = cache.refresh(ctx, nil, missing)
|
||||
if err != nil {
|
||||
return nil, Error.New("failed to refresh node alias db: %w", err)
|
||||
return Pieces{}, Error.New("failed to refresh node alias db: %w", err)
|
||||
}
|
||||
|
||||
for i, aliasPiece := range aliasPieces {
|
||||
node, ok := latest.Node(aliasPiece.Alias)
|
||||
if !ok {
|
||||
return nil, Error.New("aliases missing in database: %v", missing)
|
||||
return Pieces{}, Error.New("aliases missing in database: %v", missing)
|
||||
}
|
||||
pieces[i].Number = aliasPiece.Number
|
||||
pieces[i].StorageNode = node
|
||||
|
@ -52,6 +52,10 @@ type FinishCopyObject struct {
|
||||
|
||||
NewSegmentKeys []EncryptedKeyAndNonce
|
||||
|
||||
// If set, copy the object by duplicating the metadata and
|
||||
// remote_alias_pieces list, rather than using segment_copies.
|
||||
DuplicateMetadata bool
|
||||
|
||||
// VerifyLimits holds a callback by which the caller can interrupt the copy
|
||||
// if it turns out completing the copy would exceed a limit.
|
||||
// It will be called only once.
|
||||
@ -147,47 +151,96 @@ func (db *DB) FinishCopyObject(ctx context.Context, opts FinishCopyObject) (obje
|
||||
plainSizes := make([]int32, sourceObject.SegmentCount)
|
||||
plainOffsets := make([]int64, sourceObject.SegmentCount)
|
||||
inlineDatas := make([][]byte, sourceObject.SegmentCount)
|
||||
placementConstraints := make([]storj.PlacementConstraint, sourceObject.SegmentCount)
|
||||
remoteAliasPiecesLists := make([][]byte, sourceObject.SegmentCount)
|
||||
|
||||
redundancySchemes := make([]int64, sourceObject.SegmentCount)
|
||||
err = withRows(db.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
position,
|
||||
expires_at,
|
||||
root_piece_id,
|
||||
encrypted_size, plain_offset, plain_size,
|
||||
redundancy,
|
||||
inline_data
|
||||
FROM segments
|
||||
WHERE stream_id = $1
|
||||
ORDER BY position ASC
|
||||
LIMIT $2
|
||||
|
||||
if opts.DuplicateMetadata {
|
||||
err = withRows(db.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
position,
|
||||
expires_at,
|
||||
root_piece_id,
|
||||
encrypted_size, plain_offset, plain_size,
|
||||
redundancy,
|
||||
remote_alias_pieces,
|
||||
placement,
|
||||
inline_data
|
||||
FROM segments
|
||||
WHERE stream_id = $1
|
||||
ORDER BY position ASC
|
||||
LIMIT $2
|
||||
`, sourceObject.StreamID, sourceObject.SegmentCount))(func(rows tagsql.Rows) error {
|
||||
index := 0
|
||||
for rows.Next() {
|
||||
err := rows.Scan(
|
||||
&positions[index],
|
||||
&expiresAts[index],
|
||||
&rootPieceIDs[index],
|
||||
&encryptedSizes[index], &plainOffsets[index], &plainSizes[index],
|
||||
&redundancySchemes[index],
|
||||
&inlineDatas[index],
|
||||
)
|
||||
if err != nil {
|
||||
index := 0
|
||||
for rows.Next() {
|
||||
err := rows.Scan(
|
||||
&positions[index],
|
||||
&expiresAts[index],
|
||||
&rootPieceIDs[index],
|
||||
&encryptedSizes[index], &plainOffsets[index], &plainSizes[index],
|
||||
&redundancySchemes[index],
|
||||
&remoteAliasPiecesLists[index],
|
||||
&placementConstraints[index],
|
||||
&inlineDatas[index],
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
index++
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
index++
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if index != int(sourceObject.SegmentCount) {
|
||||
return Error.New("could not load all of the segment information")
|
||||
}
|
||||
|
||||
if index != int(sourceObject.SegmentCount) {
|
||||
return Error.New("could not load all of the segment information")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
} else {
|
||||
err = withRows(db.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
position,
|
||||
expires_at,
|
||||
root_piece_id,
|
||||
encrypted_size, plain_offset, plain_size,
|
||||
redundancy,
|
||||
inline_data
|
||||
FROM segments
|
||||
WHERE stream_id = $1
|
||||
ORDER BY position ASC
|
||||
LIMIT $2
|
||||
`, sourceObject.StreamID, sourceObject.SegmentCount))(func(rows tagsql.Rows) error {
|
||||
index := 0
|
||||
for rows.Next() {
|
||||
err := rows.Scan(
|
||||
&positions[index],
|
||||
&expiresAts[index],
|
||||
&rootPieceIDs[index],
|
||||
&encryptedSizes[index], &plainOffsets[index], &plainSizes[index],
|
||||
&redundancySchemes[index],
|
||||
&inlineDatas[index],
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
index++
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if index != int(sourceObject.SegmentCount) {
|
||||
return Error.New("could not load all of the segment information")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return Error.New("unable to copy object: %w", err)
|
||||
}
|
||||
@ -275,6 +328,7 @@ func (db *DB) FinishCopyObject(ctx context.Context, opts FinishCopyObject) (obje
|
||||
root_piece_id,
|
||||
redundancy,
|
||||
encrypted_size, plain_offset, plain_size,
|
||||
remote_alias_pieces, placement,
|
||||
inline_data
|
||||
) SELECT
|
||||
$1, UNNEST($2::INT8[]), UNNEST($3::timestamptz[]),
|
||||
@ -282,12 +336,14 @@ func (db *DB) FinishCopyObject(ctx context.Context, opts FinishCopyObject) (obje
|
||||
UNNEST($6::BYTEA[]),
|
||||
UNNEST($7::INT8[]),
|
||||
UNNEST($8::INT4[]), UNNEST($9::INT8[]), UNNEST($10::INT4[]),
|
||||
UNNEST($11::BYTEA[])
|
||||
UNNEST($11::BYTEA[]), UNNEST($12::INT2[]),
|
||||
UNNEST($13::BYTEA[])
|
||||
`, opts.NewStreamID, pgutil.Int8Array(newSegments.Positions), pgutil.NullTimestampTZArray(expiresAts),
|
||||
pgutil.ByteaArray(newSegments.EncryptedKeyNonces), pgutil.ByteaArray(newSegments.EncryptedKeys),
|
||||
pgutil.ByteaArray(rootPieceIDs),
|
||||
pgutil.Int8Array(redundancySchemes),
|
||||
pgutil.Int4Array(encryptedSizes), pgutil.Int8Array(plainOffsets), pgutil.Int4Array(plainSizes),
|
||||
pgutil.ByteaArray(remoteAliasPiecesLists), pgutil.PlacementConstraintArray(placementConstraints),
|
||||
pgutil.ByteaArray(inlineDatas),
|
||||
)
|
||||
if err != nil {
|
||||
@ -298,15 +354,17 @@ func (db *DB) FinishCopyObject(ctx context.Context, opts FinishCopyObject) (obje
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
INSERT INTO segment_copies (
|
||||
stream_id, ancestor_stream_id
|
||||
) VALUES (
|
||||
$1, $2
|
||||
)
|
||||
`, opts.NewStreamID, ancestorStreamID)
|
||||
if err != nil {
|
||||
return Error.New("unable to copy object: %w", err)
|
||||
if !opts.DuplicateMetadata {
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
INSERT INTO segment_copies (
|
||||
stream_id, ancestor_stream_id
|
||||
) VALUES (
|
||||
$1, $2
|
||||
)
|
||||
`, opts.NewStreamID, ancestorStreamID)
|
||||
if err != nil {
|
||||
return Error.New("unable to copy object: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -65,11 +65,6 @@ type deletedRemoteSegmentInfo struct {
|
||||
RepairedAt *time.Time
|
||||
}
|
||||
|
||||
// DeleteObjectAnyStatusAllVersions contains arguments necessary for deleting all object versions.
|
||||
type DeleteObjectAnyStatusAllVersions struct {
|
||||
ObjectLocation
|
||||
}
|
||||
|
||||
// DeleteObjectsAllVersions contains arguments necessary for deleting all versions of multiple objects from the same bucket.
|
||||
type DeleteObjectsAllVersions struct {
|
||||
Locations []ObjectLocation
|
||||
@ -566,66 +561,6 @@ func (db *DB) DeletePendingObject(ctx context.Context, opts DeletePendingObject)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteObjectAnyStatusAllVersions deletes all object versions.
|
||||
func (db *DB) DeleteObjectAnyStatusAllVersions(ctx context.Context, opts DeleteObjectAnyStatusAllVersions) (result DeleteObjectResult, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
if db.config.ServerSideCopy {
|
||||
return DeleteObjectResult{}, errs.New("method cannot be used when server-side copy is enabled")
|
||||
}
|
||||
|
||||
if err := opts.Verify(); err != nil {
|
||||
return DeleteObjectResult{}, err
|
||||
}
|
||||
|
||||
err = withRows(db.db.QueryContext(ctx, `
|
||||
WITH deleted_objects AS (
|
||||
DELETE FROM objects
|
||||
WHERE
|
||||
project_id = $1 AND
|
||||
bucket_name = $2 AND
|
||||
object_key = $3
|
||||
RETURNING
|
||||
version, stream_id,
|
||||
created_at, expires_at,
|
||||
status, segment_count,
|
||||
encrypted_metadata_nonce, encrypted_metadata, encrypted_metadata_encrypted_key,
|
||||
total_plain_size, total_encrypted_size, fixed_segment_size,
|
||||
encryption
|
||||
), deleted_segments AS (
|
||||
DELETE FROM segments
|
||||
WHERE segments.stream_id IN (SELECT deleted_objects.stream_id FROM deleted_objects)
|
||||
RETURNING segments.stream_id,segments.root_piece_id, segments.remote_alias_pieces
|
||||
)
|
||||
SELECT
|
||||
deleted_objects.version, deleted_objects.stream_id,
|
||||
deleted_objects.created_at, deleted_objects.expires_at,
|
||||
deleted_objects.status, deleted_objects.segment_count,
|
||||
deleted_objects.encrypted_metadata_nonce, deleted_objects.encrypted_metadata, deleted_objects.encrypted_metadata_encrypted_key,
|
||||
deleted_objects.total_plain_size, deleted_objects.total_encrypted_size, deleted_objects.fixed_segment_size,
|
||||
deleted_objects.encryption,
|
||||
deleted_segments.root_piece_id, deleted_segments.remote_alias_pieces
|
||||
FROM deleted_objects
|
||||
LEFT JOIN deleted_segments ON deleted_objects.stream_id = deleted_segments.stream_id
|
||||
`, opts.ProjectID, []byte(opts.BucketName), opts.ObjectKey))(func(rows tagsql.Rows) error {
|
||||
result.Objects, result.Segments, err = db.scanObjectDeletion(ctx, opts.ObjectLocation, rows)
|
||||
return err
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return DeleteObjectResult{}, err
|
||||
}
|
||||
|
||||
if len(result.Objects) == 0 {
|
||||
return DeleteObjectResult{}, ErrObjectNotFound.Wrap(Error.New("no rows deleted"))
|
||||
}
|
||||
|
||||
mon.Meter("object_delete").Mark(len(result.Objects))
|
||||
mon.Meter("segment_delete").Mark(len(result.Segments))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteObjectsAllVersions deletes all versions of multiple objects from the same bucket.
|
||||
func (db *DB) DeleteObjectsAllVersions(ctx context.Context, opts DeleteObjectsAllVersions) (result DeleteObjectResult, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
@ -321,7 +321,7 @@ func TestDeleteBucketWithCopies(t *testing.T) {
|
||||
metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
CopyObjectStream: ©ObjectStream,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
_, err := db.DeleteBucketObjects(ctx, metabase.DeleteBucketObjects{
|
||||
Bucket: metabase.BucketLocation{
|
||||
@ -362,7 +362,7 @@ func TestDeleteBucketWithCopies(t *testing.T) {
|
||||
copyObj, _, copySegments := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
CopyObjectStream: ©ObjectStream,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
_, err := db.DeleteBucketObjects(ctx, metabase.DeleteBucketObjects{
|
||||
Bucket: metabase.BucketLocation{
|
||||
@ -420,12 +420,78 @@ func TestDeleteBucketWithCopies(t *testing.T) {
|
||||
metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj1,
|
||||
CopyObjectStream: ©ObjectStream1,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
copyObj2, _, copySegments2 := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj2,
|
||||
CopyObjectStream: ©ObjectStream2,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
// done preparing, delete bucket 1
|
||||
_, err := db.DeleteBucketObjects(ctx, metabase.DeleteBucketObjects{
|
||||
Bucket: metabase.BucketLocation{
|
||||
ProjectID: projectID,
|
||||
BucketName: "bucket2",
|
||||
},
|
||||
BatchSize: 2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Prepare for check.
|
||||
// obj1 is the same as before, copyObj2 should now be the original
|
||||
for i := range copySegments2 {
|
||||
copySegments2[i].Pieces = originalSegments2[i].Pieces
|
||||
}
|
||||
|
||||
metabasetest.Verify{
|
||||
Objects: []metabase.RawObject{
|
||||
metabase.RawObject(originalObj1),
|
||||
metabase.RawObject(copyObj2),
|
||||
},
|
||||
Segments: append(copySegments2, metabasetest.SegmentsToRaw(originalSegments1)...),
|
||||
}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("delete bucket which has one ancestor and one copy with duplicate metadata", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
originalObjStream1 := metabasetest.RandObjectStream()
|
||||
originalObjStream1.BucketName = "bucket1"
|
||||
|
||||
projectID := originalObjStream1.ProjectID
|
||||
|
||||
originalObjStream2 := metabasetest.RandObjectStream()
|
||||
originalObjStream2.ProjectID = projectID
|
||||
originalObjStream2.BucketName = "bucket2"
|
||||
|
||||
originalObj1, originalSegments1 := metabasetest.CreateTestObject{
|
||||
CommitObject: &metabase.CommitObject{
|
||||
ObjectStream: originalObjStream1,
|
||||
},
|
||||
}.Run(ctx, t, db, originalObjStream1, byte(numberOfSegments))
|
||||
|
||||
originalObj2, originalSegments2 := metabasetest.CreateTestObject{
|
||||
CommitObject: &metabase.CommitObject{
|
||||
ObjectStream: originalObjStream2,
|
||||
},
|
||||
}.Run(ctx, t, db, originalObjStream2, byte(numberOfSegments))
|
||||
|
||||
copyObjectStream1 := metabasetest.RandObjectStream()
|
||||
copyObjectStream1.ProjectID = projectID
|
||||
copyObjectStream1.BucketName = "bucket2" // copy from bucket 1 to bucket 2
|
||||
|
||||
copyObjectStream2 := metabasetest.RandObjectStream()
|
||||
copyObjectStream2.ProjectID = projectID
|
||||
copyObjectStream2.BucketName = "bucket1" // copy from bucket 2 to bucket 1
|
||||
|
||||
metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj1,
|
||||
CopyObjectStream: ©ObjectStream1,
|
||||
}.Run(ctx, t, db, true)
|
||||
|
||||
copyObj2, _, copySegments2 := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj2,
|
||||
CopyObjectStream: ©ObjectStream2,
|
||||
}.Run(ctx, t, db, true)
|
||||
|
||||
// done preparing, delete bucket 1
|
||||
_, err := db.DeleteBucketObjects(ctx, metabase.DeleteBucketObjects{
|
||||
|
@ -466,205 +466,6 @@ func TestDeleteObjectExactVersion(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteObjectAnyStatusAllVersions(t *testing.T) {
|
||||
metabasetest.RunWithConfig(t, noServerSideCopyConfig, func(ctx *testcontext.Context, t *testing.T, db *metabase.DB) {
|
||||
obj := metabasetest.RandObjectStream()
|
||||
|
||||
location := obj.Location()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
for _, test := range metabasetest.InvalidObjectLocations(location) {
|
||||
test := test
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
metabasetest.DeleteObjectAnyStatusAllVersions{
|
||||
Opts: metabase.DeleteObjectAnyStatusAllVersions{ObjectLocation: test.ObjectLocation},
|
||||
ErrClass: test.ErrClass,
|
||||
ErrText: test.ErrText,
|
||||
}.Check(ctx, t, db)
|
||||
metabasetest.Verify{}.Check(ctx, t, db)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Object missing", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.DeleteObjectAnyStatusAllVersions{
|
||||
Opts: metabase.DeleteObjectAnyStatusAllVersions{ObjectLocation: obj.Location()},
|
||||
ErrClass: &metabase.ErrObjectNotFound,
|
||||
ErrText: "metabase: no rows deleted",
|
||||
}.Check(ctx, t, db)
|
||||
metabasetest.Verify{}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("Delete non existing object version", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.DeleteObjectAnyStatusAllVersions{
|
||||
Opts: metabase.DeleteObjectAnyStatusAllVersions{ObjectLocation: obj.Location()},
|
||||
ErrClass: &metabase.ErrObjectNotFound,
|
||||
ErrText: "metabase: no rows deleted",
|
||||
}.Check(ctx, t, db)
|
||||
metabasetest.Verify{}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("Delete partial object", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.BeginObjectExactVersion{
|
||||
Opts: metabase.BeginObjectExactVersion{
|
||||
ObjectStream: obj,
|
||||
Encryption: metabasetest.DefaultEncryption,
|
||||
},
|
||||
Version: 1,
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.DeleteObjectAnyStatusAllVersions{
|
||||
Opts: metabase.DeleteObjectAnyStatusAllVersions{ObjectLocation: obj.Location()},
|
||||
Result: metabase.DeleteObjectResult{
|
||||
Objects: []metabase.Object{{
|
||||
ObjectStream: obj,
|
||||
CreatedAt: now,
|
||||
Status: metabase.Pending,
|
||||
|
||||
Encryption: metabasetest.DefaultEncryption,
|
||||
}},
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.Verify{}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("Delete object without segments", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
|
||||
encryptedMetadata := testrand.Bytes(1024)
|
||||
encryptedMetadataNonce := testrand.Nonce()
|
||||
encryptedMetadataKey := testrand.Bytes(265)
|
||||
|
||||
object, _ := metabasetest.CreateTestObject{
|
||||
CommitObject: &metabase.CommitObject{
|
||||
ObjectStream: obj,
|
||||
EncryptedMetadataNonce: encryptedMetadataNonce[:],
|
||||
EncryptedMetadata: encryptedMetadata,
|
||||
EncryptedMetadataEncryptedKey: encryptedMetadataKey,
|
||||
},
|
||||
}.Run(ctx, t, db, obj, 0)
|
||||
|
||||
metabasetest.DeleteObjectAnyStatusAllVersions{
|
||||
Opts: metabase.DeleteObjectAnyStatusAllVersions{ObjectLocation: obj.Location()},
|
||||
Result: metabase.DeleteObjectResult{
|
||||
Objects: []metabase.Object{object},
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.Verify{}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("Delete object with segments", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
|
||||
object := metabasetest.CreateObject(ctx, t, db, obj, 2)
|
||||
|
||||
expectedSegmentInfo := metabase.DeletedSegmentInfo{
|
||||
RootPieceID: storj.PieceID{1},
|
||||
Pieces: metabase.Pieces{{Number: 0, StorageNode: storj.NodeID{2}}},
|
||||
}
|
||||
|
||||
metabasetest.DeleteObjectAnyStatusAllVersions{
|
||||
Opts: metabase.DeleteObjectAnyStatusAllVersions{
|
||||
ObjectLocation: location,
|
||||
},
|
||||
Result: metabase.DeleteObjectResult{
|
||||
Objects: []metabase.Object{object},
|
||||
Segments: []metabase.DeletedSegmentInfo{expectedSegmentInfo, expectedSegmentInfo},
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.Verify{}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("Delete object with inline segment", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.BeginObjectExactVersion{
|
||||
Opts: metabase.BeginObjectExactVersion{
|
||||
ObjectStream: obj,
|
||||
Encryption: metabasetest.DefaultEncryption,
|
||||
},
|
||||
Version: obj.Version,
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.CommitInlineSegment{
|
||||
Opts: metabase.CommitInlineSegment{
|
||||
ObjectStream: obj,
|
||||
Position: metabase.SegmentPosition{Part: 0, Index: 0},
|
||||
|
||||
EncryptedKey: testrand.Bytes(32),
|
||||
EncryptedKeyNonce: testrand.Bytes(32),
|
||||
|
||||
InlineData: testrand.Bytes(1024),
|
||||
|
||||
PlainSize: 512,
|
||||
PlainOffset: 0,
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
object := metabasetest.CommitObject{
|
||||
Opts: metabase.CommitObject{
|
||||
ObjectStream: obj,
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.DeleteObjectAnyStatusAllVersions{
|
||||
Opts: metabase.DeleteObjectAnyStatusAllVersions{ObjectLocation: obj.Location()},
|
||||
Result: metabase.DeleteObjectResult{
|
||||
Objects: []metabase.Object{object},
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.Verify{}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("Delete multiple versions of the same object at once", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
|
||||
expected := metabase.DeleteObjectResult{}
|
||||
|
||||
// committed object
|
||||
obj := metabasetest.RandObjectStream()
|
||||
expected.Objects = append(expected.Objects, metabasetest.CreateObject(ctx, t, db, obj, 1))
|
||||
expected.Segments = append(expected.Segments, metabase.DeletedSegmentInfo{
|
||||
RootPieceID: storj.PieceID{1},
|
||||
Pieces: metabase.Pieces{{Number: 0, StorageNode: storj.NodeID{2}}},
|
||||
})
|
||||
|
||||
// pending objects
|
||||
for i := 1; i <= 10; i++ {
|
||||
obj.StreamID = testrand.UUID()
|
||||
obj.Version = metabase.NextVersion
|
||||
|
||||
pendingObject, err := db.BeginObjectNextVersion(ctx, metabase.BeginObjectNextVersion{
|
||||
ObjectStream: obj,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// nil ZombieDeletionDeadline because while deleting we are not returning this value with object metadata
|
||||
pendingObject.ZombieDeletionDeadline = nil
|
||||
expected.Objects = append(expected.Objects, pendingObject)
|
||||
}
|
||||
|
||||
metabasetest.DeleteObjectAnyStatusAllVersions{
|
||||
Opts: metabase.DeleteObjectAnyStatusAllVersions{ObjectLocation: obj.Location()},
|
||||
Result: expected,
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.Verify{}.Check(ctx, t, db)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteObjectsAllVersions(t *testing.T) {
|
||||
metabasetest.RunWithConfig(t, noServerSideCopyConfig, func(ctx *testcontext.Context, t *testing.T, db *metabase.DB) {
|
||||
obj := metabasetest.RandObjectStream()
|
||||
@ -987,7 +788,7 @@ func TestDeleteCopy(t *testing.T) {
|
||||
|
||||
copyObj, _, copySegments := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
var copies []metabase.RawCopy
|
||||
if numberOfSegments > 0 {
|
||||
@ -1042,10 +843,10 @@ func TestDeleteCopy(t *testing.T) {
|
||||
|
||||
copyObject1, _, _ := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
copyObject2, _, copySegments2 := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
metabasetest.DeleteObjectExactVersion{
|
||||
Opts: metabase.DeleteObjectExactVersion{
|
||||
@ -1092,7 +893,7 @@ func TestDeleteCopy(t *testing.T) {
|
||||
|
||||
copyObject, _, copySegments := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
metabasetest.DeleteObjectExactVersion{
|
||||
Opts: metabase.DeleteObjectExactVersion{
|
||||
@ -1134,10 +935,10 @@ func TestDeleteCopy(t *testing.T) {
|
||||
|
||||
copyObject1, _, copySegments1 := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
copyObject2, _, copySegments2 := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
_, err := db.DeleteObjectExactVersion(ctx, metabase.DeleteObjectExactVersion{
|
||||
Version: originalObj.Version,
|
||||
@ -1206,6 +1007,201 @@ func TestDeleteCopy(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteCopyWithDuplicateMetadata(t *testing.T) {
|
||||
metabasetest.Run(t, func(ctx *testcontext.Context, t *testing.T, db *metabase.DB) {
|
||||
for _, numberOfSegments := range []int{0, 1, 3} {
|
||||
t.Run(fmt.Sprintf("%d segments", numberOfSegments), func(t *testing.T) {
|
||||
t.Run("delete copy", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
originalObjStream := metabasetest.RandObjectStream()
|
||||
|
||||
originalObj, originalSegments := metabasetest.CreateTestObject{
|
||||
CommitObject: &metabase.CommitObject{
|
||||
ObjectStream: originalObjStream,
|
||||
EncryptedMetadata: testrand.Bytes(64),
|
||||
EncryptedMetadataNonce: testrand.Nonce().Bytes(),
|
||||
EncryptedMetadataEncryptedKey: testrand.Bytes(265),
|
||||
},
|
||||
}.Run(ctx, t, db, originalObjStream, byte(numberOfSegments))
|
||||
|
||||
copyObj, _, copySegments := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db, true)
|
||||
|
||||
// check that copy went OK
|
||||
metabasetest.Verify{
|
||||
Objects: []metabase.RawObject{
|
||||
metabase.RawObject(originalObj),
|
||||
metabase.RawObject(copyObj),
|
||||
},
|
||||
Segments: append(metabasetest.SegmentsToRaw(originalSegments), copySegments...),
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.DeleteObjectExactVersion{
|
||||
Opts: metabase.DeleteObjectExactVersion{
|
||||
ObjectLocation: copyObj.Location(),
|
||||
Version: copyObj.Version,
|
||||
},
|
||||
Result: metabase.DeleteObjectResult{
|
||||
Objects: []metabase.Object{copyObj},
|
||||
Segments: rawSegmentsToDeletedSegmentInfo(copySegments),
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
// Verify that we are back at the original single object
|
||||
metabasetest.Verify{
|
||||
Objects: []metabase.RawObject{
|
||||
metabase.RawObject(originalObj),
|
||||
},
|
||||
Segments: metabasetest.SegmentsToRaw(originalSegments),
|
||||
}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("delete one of two copies", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
originalObjectStream := metabasetest.RandObjectStream()
|
||||
|
||||
originalObj, originalSegments := metabasetest.CreateTestObject{
|
||||
CommitObject: &metabase.CommitObject{
|
||||
ObjectStream: originalObjectStream,
|
||||
EncryptedMetadata: testrand.Bytes(64),
|
||||
EncryptedMetadataNonce: testrand.Nonce().Bytes(),
|
||||
EncryptedMetadataEncryptedKey: testrand.Bytes(265),
|
||||
},
|
||||
}.Run(ctx, t, db, originalObjectStream, byte(numberOfSegments))
|
||||
|
||||
copyObject1, _, copySegments1 := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db, true)
|
||||
copyObject2, _, copySegments2 := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db, true)
|
||||
|
||||
metabasetest.DeleteObjectExactVersion{
|
||||
Opts: metabase.DeleteObjectExactVersion{
|
||||
ObjectLocation: copyObject1.Location(),
|
||||
Version: copyObject1.Version,
|
||||
},
|
||||
Result: metabase.DeleteObjectResult{
|
||||
Objects: []metabase.Object{copyObject1},
|
||||
Segments: rawSegmentsToDeletedSegmentInfo(copySegments1),
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
// Verify that only one of the copies is deleted
|
||||
metabasetest.Verify{
|
||||
Objects: []metabase.RawObject{
|
||||
metabase.RawObject(originalObj),
|
||||
metabase.RawObject(copyObject2),
|
||||
},
|
||||
Segments: append(metabasetest.SegmentsToRaw(originalSegments), copySegments2...),
|
||||
}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("delete original", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
originalObjectStream := metabasetest.RandObjectStream()
|
||||
|
||||
originalObj, originalSegments := metabasetest.CreateTestObject{
|
||||
CommitObject: &metabase.CommitObject{
|
||||
ObjectStream: originalObjectStream,
|
||||
EncryptedMetadata: testrand.Bytes(64),
|
||||
EncryptedMetadataNonce: testrand.Nonce().Bytes(),
|
||||
EncryptedMetadataEncryptedKey: testrand.Bytes(265),
|
||||
},
|
||||
}.Run(ctx, t, db, originalObjectStream, byte(numberOfSegments))
|
||||
|
||||
copyObject, _, copySegments := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db, true)
|
||||
|
||||
metabasetest.DeleteObjectExactVersion{
|
||||
Opts: metabase.DeleteObjectExactVersion{
|
||||
ObjectLocation: originalObj.Location(),
|
||||
Version: originalObj.Version,
|
||||
},
|
||||
Result: metabase.DeleteObjectResult{
|
||||
Objects: []metabase.Object{originalObj},
|
||||
Segments: rawSegmentsToDeletedSegmentInfo(copySegments),
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
for i := range copySegments {
|
||||
copySegments[i].Pieces = originalSegments[i].Pieces
|
||||
}
|
||||
|
||||
// verify that the copy is left
|
||||
metabasetest.Verify{
|
||||
Objects: []metabase.RawObject{
|
||||
metabase.RawObject(copyObject),
|
||||
},
|
||||
Segments: copySegments,
|
||||
}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("delete original and leave two copies", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
originalObjectStream := metabasetest.RandObjectStream()
|
||||
|
||||
originalObj, originalSegments := metabasetest.CreateTestObject{
|
||||
CommitObject: &metabase.CommitObject{
|
||||
ObjectStream: originalObjectStream,
|
||||
EncryptedMetadata: testrand.Bytes(64),
|
||||
EncryptedMetadataNonce: testrand.Nonce().Bytes(),
|
||||
EncryptedMetadataEncryptedKey: testrand.Bytes(265),
|
||||
},
|
||||
}.Run(ctx, t, db, originalObjectStream, byte(numberOfSegments))
|
||||
|
||||
copyObject1, _, copySegments1 := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db, true)
|
||||
copyObject2, _, copySegments2 := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db, true)
|
||||
|
||||
_, err := db.DeleteObjectExactVersion(ctx, metabase.DeleteObjectExactVersion{
|
||||
Version: originalObj.Version,
|
||||
ObjectLocation: originalObj.Location(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var expectedAncestorStreamID uuid.UUID
|
||||
|
||||
if numberOfSegments > 0 {
|
||||
segments, err := db.TestingAllSegments(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, segments)
|
||||
|
||||
if segments[0].StreamID == copyObject1.StreamID {
|
||||
expectedAncestorStreamID = copyObject1.StreamID
|
||||
} else {
|
||||
expectedAncestorStreamID = copyObject2.StreamID
|
||||
}
|
||||
}
|
||||
|
||||
// set pieces in expected ancestor for verifcation
|
||||
for _, segments := range [][]metabase.RawSegment{copySegments1, copySegments2} {
|
||||
for i := range segments {
|
||||
if segments[i].StreamID == expectedAncestorStreamID {
|
||||
segments[i].Pieces = originalSegments[i].Pieces
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// verify that two functioning copies are left and the original object is gone
|
||||
metabasetest.Verify{
|
||||
Objects: []metabase.RawObject{
|
||||
metabase.RawObject(copyObject1),
|
||||
metabase.RawObject(copyObject2),
|
||||
},
|
||||
Segments: append(copySegments1, copySegments2...),
|
||||
}.Check(ctx, t, db)
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeleteObjectLastCommitted(t *testing.T) {
|
||||
metabasetest.Run(t, func(ctx *testcontext.Context, t *testing.T, db *metabase.DB) {
|
||||
obj := metabasetest.RandObjectStream()
|
||||
@ -1402,3 +1398,12 @@ func TestDeleteObjectLastCommitted(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func rawSegmentsToDeletedSegmentInfo(segments []metabase.RawSegment) []metabase.DeletedSegmentInfo {
|
||||
result := make([]metabase.DeletedSegmentInfo, len(segments))
|
||||
for i := range segments {
|
||||
result[i].RootPieceID = segments[i].RootPieceID
|
||||
result[i].Pieces = segments[i].Pieces
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
@ -345,7 +345,53 @@ func TestGetObjectLastCommitted(t *testing.T) {
|
||||
copiedObj, _, _ := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObject,
|
||||
CopyObjectStream: ©ObjStream,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
metabasetest.DeleteObjectExactVersion{
|
||||
Opts: metabase.DeleteObjectExactVersion{
|
||||
Version: 1,
|
||||
ObjectLocation: obj.Location(),
|
||||
},
|
||||
Result: metabase.DeleteObjectResult{
|
||||
Objects: []metabase.Object{originalObject},
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.GetObjectLastCommitted{
|
||||
Opts: metabase.GetObjectLastCommitted{
|
||||
ObjectLocation: copiedObj.Location(),
|
||||
},
|
||||
Result: copiedObj,
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.Verify{Objects: []metabase.RawObject{
|
||||
{
|
||||
ObjectStream: metabase.ObjectStream{
|
||||
ProjectID: copiedObj.ProjectID,
|
||||
BucketName: copiedObj.BucketName,
|
||||
ObjectKey: copiedObj.ObjectKey,
|
||||
Version: copiedObj.Version,
|
||||
StreamID: copiedObj.StreamID,
|
||||
},
|
||||
CreatedAt: now,
|
||||
Status: metabase.Committed,
|
||||
Encryption: metabasetest.DefaultEncryption,
|
||||
EncryptedMetadata: copiedObj.EncryptedMetadata,
|
||||
EncryptedMetadataNonce: copiedObj.EncryptedMetadataNonce,
|
||||
EncryptedMetadataEncryptedKey: copiedObj.EncryptedMetadataEncryptedKey,
|
||||
},
|
||||
}}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("Get latest copied object version with duplicate metadata", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
copyObjStream := metabasetest.RandObjectStream()
|
||||
originalObject := metabasetest.CreateObject(ctx, t, db, obj, 0)
|
||||
|
||||
copiedObj, _, _ := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObject,
|
||||
CopyObjectStream: ©ObjStream,
|
||||
}.Run(ctx, t, db, true)
|
||||
|
||||
metabasetest.DeleteObjectExactVersion{
|
||||
Opts: metabase.DeleteObjectExactVersion{
|
||||
@ -1114,7 +1160,7 @@ func TestGetLatestObjectLastSegment(t *testing.T) {
|
||||
|
||||
copyObj, _, newSegments := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
metabasetest.GetLatestObjectLastSegment{
|
||||
Opts: metabase.GetLatestObjectLastSegment{
|
||||
@ -1150,6 +1196,54 @@ func TestGetLatestObjectLastSegment(t *testing.T) {
|
||||
}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("Get segment copy with duplicate metadata", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
|
||||
objStream := metabasetest.RandObjectStream()
|
||||
|
||||
originalObj, originalSegments := metabasetest.CreateTestObject{
|
||||
CommitObject: &metabase.CommitObject{
|
||||
ObjectStream: objStream,
|
||||
EncryptedMetadata: testrand.Bytes(64),
|
||||
EncryptedMetadataNonce: testrand.Nonce().Bytes(),
|
||||
EncryptedMetadataEncryptedKey: testrand.Bytes(265),
|
||||
},
|
||||
}.Run(ctx, t, db, objStream, 1)
|
||||
|
||||
copyObj, _, newSegments := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
}.Run(ctx, t, db, true)
|
||||
|
||||
metabasetest.GetLatestObjectLastSegment{
|
||||
Opts: metabase.GetLatestObjectLastSegment{
|
||||
ObjectLocation: originalObj.Location(),
|
||||
},
|
||||
Result: originalSegments[0],
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
copySegmentGet := originalSegments[0]
|
||||
copySegmentGet.StreamID = copyObj.StreamID
|
||||
copySegmentGet.EncryptedETag = nil
|
||||
copySegmentGet.InlineData = []byte{}
|
||||
copySegmentGet.EncryptedKey = newSegments[0].EncryptedKey
|
||||
copySegmentGet.EncryptedKeyNonce = newSegments[0].EncryptedKeyNonce
|
||||
|
||||
metabasetest.GetLatestObjectLastSegment{
|
||||
Opts: metabase.GetLatestObjectLastSegment{
|
||||
ObjectLocation: copyObj.Location(),
|
||||
},
|
||||
Result: copySegmentGet,
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
metabasetest.Verify{
|
||||
Objects: []metabase.RawObject{
|
||||
metabase.RawObject(originalObj),
|
||||
metabase.RawObject(copyObj),
|
||||
},
|
||||
Segments: append(metabasetest.SegmentsToRaw(originalSegments), newSegments...),
|
||||
}.Check(ctx, t, db)
|
||||
})
|
||||
|
||||
t.Run("Get empty inline segment copy", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
|
||||
|
@ -282,7 +282,51 @@ func TestListSegments(t *testing.T) {
|
||||
_, _, copySegments := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObject,
|
||||
CopyObjectStream: ©Stream,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
expectedSegments := []metabase.Segment{}
|
||||
for _, segment := range copySegments {
|
||||
expectedSegments = append(expectedSegments, metabase.Segment(segment))
|
||||
}
|
||||
|
||||
metabasetest.ListSegments{
|
||||
Opts: metabase.ListSegments{
|
||||
StreamID: copyStream.StreamID,
|
||||
},
|
||||
Result: metabase.ListSegmentsResult{
|
||||
Segments: expectedSegments,
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
|
||||
if numberOfSegments > 0 {
|
||||
expectedSegments[0].Pieces = originalSegments[0].Pieces
|
||||
}
|
||||
|
||||
metabasetest.ListSegments{
|
||||
Opts: metabase.ListSegments{
|
||||
StreamID: copyStream.StreamID,
|
||||
UpdateFirstWithAncestor: true,
|
||||
},
|
||||
Result: metabase.ListSegmentsResult{
|
||||
Segments: expectedSegments,
|
||||
},
|
||||
}.Check(ctx, t, db)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("segments from copy with duplicate metadata", func(t *testing.T) {
|
||||
defer metabasetest.DeleteAll{}.Check(ctx, t, db)
|
||||
|
||||
for _, numberOfSegments := range []byte{0, 1, 2, 10} {
|
||||
originalObjectStream := metabasetest.RandObjectStream()
|
||||
originalObject, originalSegments := metabasetest.CreateTestObject{}.
|
||||
Run(ctx, t, db, originalObjectStream, numberOfSegments)
|
||||
|
||||
copyStream := metabasetest.RandObjectStream()
|
||||
_, _, copySegments := metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObject,
|
||||
CopyObjectStream: ©Stream,
|
||||
}.Run(ctx, t, db, true)
|
||||
|
||||
expectedSegments := []metabase.Segment{}
|
||||
for _, segment := range copySegments {
|
||||
|
@ -255,12 +255,15 @@ func (db *DB) IterateLoopSegments(ctx context.Context, opts IterateLoopSegments,
|
||||
return err
|
||||
}
|
||||
|
||||
loopIteratorBatchSizeLimit.Ensure(&opts.BatchSize)
|
||||
|
||||
it := &loopSegmentIterator{
|
||||
db: db,
|
||||
|
||||
asOfSystemTime: opts.AsOfSystemTime,
|
||||
asOfSystemInterval: opts.AsOfSystemInterval,
|
||||
batchSize: opts.BatchSize,
|
||||
batchPieces: make([]Pieces, opts.BatchSize),
|
||||
|
||||
curIndex: 0,
|
||||
cursor: loopSegmentIteratorCursor{
|
||||
@ -277,8 +280,6 @@ func (db *DB) IterateLoopSegments(ctx context.Context, opts IterateLoopSegments,
|
||||
it.cursor.EndStreamID = uuid.Max()
|
||||
}
|
||||
|
||||
loopIteratorBatchSizeLimit.Ensure(&it.batchSize)
|
||||
|
||||
it.curRows, err = it.doNextQuery(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -298,7 +299,10 @@ func (db *DB) IterateLoopSegments(ctx context.Context, opts IterateLoopSegments,
|
||||
type loopSegmentIterator struct {
|
||||
db *DB
|
||||
|
||||
batchSize int
|
||||
batchSize int
|
||||
// batchPieces are reused between result pages to reduce memory consumption
|
||||
batchPieces []Pieces
|
||||
|
||||
asOfSystemTime time.Time
|
||||
asOfSystemInterval time.Duration
|
||||
|
||||
@ -399,7 +403,14 @@ func (it *loopSegmentIterator) scanItem(ctx context.Context, item *LoopSegmentEn
|
||||
return Error.New("failed to scan segments: %w", err)
|
||||
}
|
||||
|
||||
item.Pieces, err = it.db.aliasCache.ConvertAliasesToPieces(ctx, item.AliasPieces)
|
||||
// allocate new Pieces only if existing have not enough capacity
|
||||
if cap(it.batchPieces[it.curIndex]) < len(item.AliasPieces) {
|
||||
it.batchPieces[it.curIndex] = make(Pieces, len(item.AliasPieces))
|
||||
} else {
|
||||
it.batchPieces[it.curIndex] = it.batchPieces[it.curIndex][:len(item.AliasPieces)]
|
||||
}
|
||||
|
||||
item.Pieces, err = it.db.aliasCache.convertAliasesToPieces(ctx, item.AliasPieces, it.batchPieces[it.curIndex])
|
||||
if err != nil {
|
||||
return Error.New("failed to convert aliases to pieces: %w", err)
|
||||
}
|
||||
|
@ -321,7 +321,10 @@ type CreateObjectCopy struct {
|
||||
}
|
||||
|
||||
// Run creates the copy.
|
||||
func (cc CreateObjectCopy) Run(ctx *testcontext.Context, t testing.TB, db *metabase.DB) (copyObj metabase.Object, expectedOriginalSegments []metabase.RawSegment, expectedCopySegments []metabase.RawSegment) {
|
||||
//
|
||||
// The duplicateMetadata argument is a hack and it will be great to get rid of it once
|
||||
// duplicateMetadata is no longer an option.
|
||||
func (cc CreateObjectCopy) Run(ctx *testcontext.Context, t testing.TB, db *metabase.DB, duplicateMetadata bool) (copyObj metabase.Object, expectedOriginalSegments []metabase.RawSegment, expectedCopySegments []metabase.RawSegment) {
|
||||
var copyStream metabase.ObjectStream
|
||||
if cc.CopyObjectStream != nil {
|
||||
copyStream = *cc.CopyObjectStream
|
||||
@ -360,6 +363,11 @@ func (cc CreateObjectCopy) Run(ctx *testcontext.Context, t testing.TB, db *metab
|
||||
} else {
|
||||
expectedCopySegments[i].InlineData = []byte{}
|
||||
}
|
||||
|
||||
if duplicateMetadata {
|
||||
expectedCopySegments[i].Pieces = make(metabase.Pieces, len(expectedOriginalSegments[i].Pieces))
|
||||
copy(expectedCopySegments[i].Pieces, expectedOriginalSegments[i].Pieces)
|
||||
}
|
||||
}
|
||||
|
||||
opts := cc.FinishObject
|
||||
@ -374,6 +382,7 @@ func (cc CreateObjectCopy) Run(ctx *testcontext.Context, t testing.TB, db *metab
|
||||
NewEncryptedMetadataKey: testrand.Bytes(32),
|
||||
}
|
||||
}
|
||||
opts.DuplicateMetadata = duplicateMetadata
|
||||
|
||||
copyObj, err := db.FinishCopyObject(ctx, *opts)
|
||||
require.NoError(t, err)
|
||||
|
@ -454,29 +454,6 @@ func (step DeletePendingObject) Check(ctx *testcontext.Context, t testing.TB, db
|
||||
require.Zero(t, diff)
|
||||
}
|
||||
|
||||
// DeleteObjectAnyStatusAllVersions is for testing metabase.DeleteObjectAnyStatusAllVersions.
|
||||
type DeleteObjectAnyStatusAllVersions struct {
|
||||
Opts metabase.DeleteObjectAnyStatusAllVersions
|
||||
Result metabase.DeleteObjectResult
|
||||
ErrClass *errs.Class
|
||||
ErrText string
|
||||
}
|
||||
|
||||
// Check runs the test.
|
||||
func (step DeleteObjectAnyStatusAllVersions) Check(ctx *testcontext.Context, t testing.TB, db *metabase.DB) {
|
||||
result, err := db.DeleteObjectAnyStatusAllVersions(ctx, step.Opts)
|
||||
checkError(t, err, step.ErrClass, step.ErrText)
|
||||
|
||||
sortObjects(result.Objects)
|
||||
sortObjects(step.Result.Objects)
|
||||
|
||||
sortDeletedSegments(result.Segments)
|
||||
sortDeletedSegments(step.Result.Segments)
|
||||
|
||||
diff := cmp.Diff(step.Result, result, DefaultTimeDiff())
|
||||
require.Zero(t, diff)
|
||||
}
|
||||
|
||||
// DeleteObjectsAllVersions is for testing metabase.DeleteObjectsAllVersions.
|
||||
type DeleteObjectsAllVersions struct {
|
||||
Opts metabase.DeleteObjectsAllVersions
|
||||
|
@ -31,6 +31,7 @@ import (
|
||||
"storj.io/storj/satellite/metabase/rangedloop"
|
||||
"storj.io/storj/satellite/metabase/rangedloop/rangedlooptest"
|
||||
"storj.io/storj/satellite/metrics"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
"storj.io/storj/satellite/repair/checker"
|
||||
)
|
||||
|
||||
@ -425,6 +426,7 @@ func TestAllInOne(t *testing.T) {
|
||||
log.Named("repair:checker"),
|
||||
satellite.DB.RepairQueue(),
|
||||
satellite.Overlay.Service,
|
||||
overlay.NewPlacementRules().CreateFilters,
|
||||
satellite.Config.Checker,
|
||||
),
|
||||
})
|
||||
|
@ -193,7 +193,7 @@ func TestGetStreamPieceCountByNodeID(t *testing.T) {
|
||||
_, _, _ = metabasetest.CreateObjectCopy{
|
||||
OriginalObject: originalObj,
|
||||
CopyObjectStream: ©Stream,
|
||||
}.Run(ctx, t, db)
|
||||
}.Run(ctx, t, db, false)
|
||||
|
||||
metabasetest.GetStreamPieceCountByNodeID{
|
||||
Opts: metabase.GetStreamPieceCountByNodeID{
|
||||
|
@ -4,6 +4,7 @@
|
||||
package metainfo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
@ -24,9 +25,11 @@ import (
|
||||
const MaxUserAgentLength = 500
|
||||
|
||||
// ensureAttribution ensures that the bucketName has the partner information specified by project-level user agent, or header user agent.
|
||||
// If `forceBucketUpdate` is true, then the buckets table will be updated if necessary (needed for bucket creation). Otherwise, it is sufficient
|
||||
// to only ensure the attribution exists in the value attributions db.
|
||||
//
|
||||
// Assumes that the user has permissions sufficient for authenticating.
|
||||
func (endpoint *Endpoint) ensureAttribution(ctx context.Context, header *pb.RequestHeader, keyInfo *console.APIKeyInfo, bucketName, projectUserAgent []byte) (err error) {
|
||||
func (endpoint *Endpoint) ensureAttribution(ctx context.Context, header *pb.RequestHeader, keyInfo *console.APIKeyInfo, bucketName, projectUserAgent []byte, forceBucketUpdate bool) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
if header == nil {
|
||||
@ -36,13 +39,15 @@ func (endpoint *Endpoint) ensureAttribution(ctx context.Context, header *pb.Requ
|
||||
return nil
|
||||
}
|
||||
|
||||
if conncache := drpccache.FromContext(ctx); conncache != nil {
|
||||
cache := conncache.LoadOrCreate(attributionCheckCacheKey{},
|
||||
func() interface{} {
|
||||
return &attributionCheckCache{}
|
||||
}).(*attributionCheckCache)
|
||||
if !cache.needsCheck(string(bucketName)) {
|
||||
return nil
|
||||
if !forceBucketUpdate {
|
||||
if conncache := drpccache.FromContext(ctx); conncache != nil {
|
||||
cache := conncache.LoadOrCreate(attributionCheckCacheKey{},
|
||||
func() interface{} {
|
||||
return &attributionCheckCache{}
|
||||
}).(*attributionCheckCache)
|
||||
if !cache.needsCheck(string(bucketName)) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,7 +67,7 @@ func (endpoint *Endpoint) ensureAttribution(ctx context.Context, header *pb.Requ
|
||||
return err
|
||||
}
|
||||
|
||||
err = endpoint.tryUpdateBucketAttribution(ctx, header, keyInfo.ProjectID, bucketName, userAgent)
|
||||
err = endpoint.tryUpdateBucketAttribution(ctx, header, keyInfo.ProjectID, bucketName, userAgent, forceBucketUpdate)
|
||||
if errs2.IsRPC(err, rpcstatus.NotFound) || errs2.IsRPC(err, rpcstatus.AlreadyExists) {
|
||||
return nil
|
||||
}
|
||||
@ -110,7 +115,7 @@ func TrimUserAgent(userAgent []byte) ([]byte, error) {
|
||||
return userAgent, nil
|
||||
}
|
||||
|
||||
func (endpoint *Endpoint) tryUpdateBucketAttribution(ctx context.Context, header *pb.RequestHeader, projectID uuid.UUID, bucketName []byte, userAgent []byte) (err error) {
|
||||
func (endpoint *Endpoint) tryUpdateBucketAttribution(ctx context.Context, header *pb.RequestHeader, projectID uuid.UUID, bucketName []byte, userAgent []byte, forceBucketUpdate bool) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
if header == nil {
|
||||
@ -118,26 +123,17 @@ func (endpoint *Endpoint) tryUpdateBucketAttribution(ctx context.Context, header
|
||||
}
|
||||
|
||||
// check if attribution is set for given bucket
|
||||
_, err = endpoint.attributions.Get(ctx, projectID, bucketName)
|
||||
attrInfo, err := endpoint.attributions.Get(ctx, projectID, bucketName)
|
||||
if err == nil {
|
||||
// bucket has already an attribution, no need to update
|
||||
return nil
|
||||
}
|
||||
if !attribution.ErrBucketNotAttributed.Has(err) {
|
||||
// try only to set the attribution, when it's missing
|
||||
if !forceBucketUpdate {
|
||||
// bucket has already an attribution, no need to update
|
||||
return nil
|
||||
}
|
||||
} else if !attribution.ErrBucketNotAttributed.Has(err) {
|
||||
endpoint.log.Error("error while getting attribution from DB", zap.Error(err))
|
||||
return rpcstatus.Error(rpcstatus.Internal, err.Error())
|
||||
}
|
||||
|
||||
empty, err := endpoint.isBucketEmpty(ctx, projectID, bucketName)
|
||||
if err != nil {
|
||||
endpoint.log.Error("internal", zap.Error(err))
|
||||
return rpcstatus.Error(rpcstatus.Internal, Error.Wrap(err).Error())
|
||||
}
|
||||
if !empty {
|
||||
return rpcstatus.Errorf(rpcstatus.AlreadyExists, "bucket %q is not empty, Partner %q cannot be attributed", bucketName, userAgent)
|
||||
}
|
||||
|
||||
// checks if bucket exists before updates it or makes a new entry
|
||||
bucket, err := endpoint.buckets.GetBucket(ctx, bucketName, projectID)
|
||||
if err != nil {
|
||||
@ -147,8 +143,36 @@ func (endpoint *Endpoint) tryUpdateBucketAttribution(ctx context.Context, header
|
||||
endpoint.log.Error("error while getting bucket", zap.ByteString("bucketName", bucketName), zap.Error(err))
|
||||
return rpcstatus.Error(rpcstatus.Internal, "unable to set bucket attribution")
|
||||
}
|
||||
if bucket.UserAgent != nil {
|
||||
return rpcstatus.Errorf(rpcstatus.AlreadyExists, "bucket %q already has attribution, Partner %q cannot be attributed", bucketName, userAgent)
|
||||
|
||||
if attrInfo != nil {
|
||||
// bucket user agent and value attributions user agent already set
|
||||
if bytes.Equal(bucket.UserAgent, attrInfo.UserAgent) {
|
||||
return nil
|
||||
}
|
||||
// make sure bucket user_agent matches value_attribution
|
||||
userAgent = attrInfo.UserAgent
|
||||
}
|
||||
|
||||
empty, err := endpoint.isBucketEmpty(ctx, projectID, bucketName)
|
||||
if err != nil {
|
||||
endpoint.log.Error("internal", zap.Error(err))
|
||||
return rpcstatus.Error(rpcstatus.Internal, Error.Wrap(err).Error())
|
||||
}
|
||||
if !empty {
|
||||
return rpcstatus.Errorf(rpcstatus.AlreadyExists, "bucket %q is not empty, Partner %q cannot be attributed", bucketName, userAgent)
|
||||
}
|
||||
|
||||
if attrInfo == nil {
|
||||
// update attribution table
|
||||
_, err = endpoint.attributions.Insert(ctx, &attribution.Info{
|
||||
ProjectID: projectID,
|
||||
BucketName: bucketName,
|
||||
UserAgent: userAgent,
|
||||
})
|
||||
if err != nil {
|
||||
endpoint.log.Error("error while inserting attribution to DB", zap.Error(err))
|
||||
return rpcstatus.Error(rpcstatus.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// update bucket information
|
||||
@ -159,17 +183,6 @@ func (endpoint *Endpoint) tryUpdateBucketAttribution(ctx context.Context, header
|
||||
return rpcstatus.Error(rpcstatus.Internal, "unable to set bucket attribution")
|
||||
}
|
||||
|
||||
// update attribution table
|
||||
_, err = endpoint.attributions.Insert(ctx, &attribution.Info{
|
||||
ProjectID: projectID,
|
||||
BucketName: bucketName,
|
||||
UserAgent: userAgent,
|
||||
})
|
||||
if err != nil {
|
||||
endpoint.log.Error("error while inserting attribution to DB", zap.Error(err))
|
||||
return rpcstatus.Error(rpcstatus.Internal, err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -381,3 +381,146 @@ func TestBucketAttributionConcurrentUpload(t *testing.T) {
|
||||
require.Equal(t, []byte(config.UserAgent), attributionInfo.UserAgent)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAttributionDeletedBucketRecreated(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 1,
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
satellite := planet.Satellites[0]
|
||||
upl := planet.Uplinks[0]
|
||||
proj := upl.Projects[0].ID
|
||||
bucket := "testbucket"
|
||||
ua1 := []byte("minio")
|
||||
ua2 := []byte("not minio")
|
||||
|
||||
require.NoError(t, satellite.DB.Console().Projects().UpdateUserAgent(ctx, proj, ua1))
|
||||
|
||||
require.NoError(t, upl.CreateBucket(ctx, satellite, bucket))
|
||||
b, err := satellite.DB.Buckets().GetBucket(ctx, []byte(bucket), proj)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ua1, b.UserAgent)
|
||||
|
||||
// test recreate with same user agent
|
||||
require.NoError(t, upl.DeleteBucket(ctx, satellite, bucket))
|
||||
require.NoError(t, upl.CreateBucket(ctx, satellite, bucket))
|
||||
b, err = satellite.DB.Buckets().GetBucket(ctx, []byte(bucket), proj)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ua1, b.UserAgent)
|
||||
|
||||
// test recreate with different user agent
|
||||
// should still have original user agent
|
||||
require.NoError(t, upl.DeleteBucket(ctx, satellite, bucket))
|
||||
upl.Config.UserAgent = string(ua2)
|
||||
require.NoError(t, upl.CreateBucket(ctx, satellite, bucket))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ua1, b.UserAgent)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAttributionBeginObject(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 0, UplinkCount: 1,
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
satellite := planet.Satellites[0]
|
||||
upl := planet.Uplinks[0]
|
||||
proj := upl.Projects[0].ID
|
||||
ua := []byte("minio")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
vaAttrBefore, bktAttrBefore, bktAttrAfter bool
|
||||
}{
|
||||
// test for existence of user_agent in buckets table given the different possibilities of preconditions of user_agent
|
||||
// in value_attributions and bucket_metainfos to make sure nothing breaks and outcome is expected.
|
||||
{
|
||||
name: "attribution exists in VA and bucket",
|
||||
vaAttrBefore: true,
|
||||
bktAttrBefore: true,
|
||||
bktAttrAfter: true,
|
||||
},
|
||||
{
|
||||
name: "attribution exists in VA and NOT bucket",
|
||||
vaAttrBefore: true,
|
||||
bktAttrBefore: false,
|
||||
bktAttrAfter: false,
|
||||
},
|
||||
{
|
||||
name: "attribution exists in bucket and NOT VA",
|
||||
vaAttrBefore: false,
|
||||
bktAttrBefore: true,
|
||||
bktAttrAfter: true,
|
||||
},
|
||||
{
|
||||
name: "attribution exists in neither VA nor buckets",
|
||||
vaAttrBefore: false,
|
||||
bktAttrBefore: false,
|
||||
bktAttrAfter: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(*testing.T) {
|
||||
bucketName := fmt.Sprintf("bucket-%d", i)
|
||||
var expectedBktUA []byte
|
||||
var config uplink.Config
|
||||
if tt.bktAttrBefore || tt.vaAttrBefore {
|
||||
config.UserAgent = string(ua)
|
||||
}
|
||||
if tt.bktAttrAfter {
|
||||
expectedBktUA = ua
|
||||
}
|
||||
|
||||
p, err := config.OpenProject(ctx, upl.Access[satellite.ID()])
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = p.CreateBucket(ctx, bucketName)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, p.Close())
|
||||
|
||||
if !tt.bktAttrBefore && tt.vaAttrBefore {
|
||||
// remove user agent from bucket
|
||||
err = satellite.API.DB.Buckets().UpdateUserAgent(ctx, proj, bucketName, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_, err = satellite.API.DB.Attribution().Get(ctx, proj, []byte(bucketName))
|
||||
if !tt.bktAttrBefore && !tt.vaAttrBefore {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
b, err := satellite.API.DB.Buckets().GetBucket(ctx, []byte(bucketName), proj)
|
||||
require.NoError(t, err)
|
||||
if !tt.bktAttrBefore {
|
||||
require.Nil(t, b.UserAgent)
|
||||
} else {
|
||||
require.Equal(t, expectedBktUA, b.UserAgent)
|
||||
}
|
||||
|
||||
config.UserAgent = string(ua)
|
||||
|
||||
p, err = config.OpenProject(ctx, upl.Access[satellite.ID()])
|
||||
require.NoError(t, err)
|
||||
|
||||
upload, err := p.UploadObject(ctx, bucketName, fmt.Sprintf("foobar-%d", i), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = upload.Write([]byte("content"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = upload.Commit()
|
||||
require.NoError(t, err)
|
||||
|
||||
attr, err := satellite.API.DB.Attribution().Get(ctx, proj, []byte(bucketName))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ua, attr.UserAgent)
|
||||
|
||||
b, err = satellite.API.DB.Buckets().GetBucket(ctx, []byte(bucketName), proj)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedBktUA, b.UserAgent)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -141,9 +141,12 @@ type Config struct {
|
||||
RateLimiter RateLimiterConfig `help:"rate limiter configuration"`
|
||||
UploadLimiter UploadLimiterConfig `help:"object upload limiter configuration"`
|
||||
ProjectLimits ProjectLimitConfig `help:"project limit configuration"`
|
||||
|
||||
// TODO remove this flag when server-side copy implementation will be finished
|
||||
ServerSideCopy bool `help:"enable code for server-side copy, deprecated. please leave this to true." default:"true"`
|
||||
ServerSideCopyDisabled bool `help:"disable already enabled server-side copy. this is because once server side copy is enabled, delete code should stay changed, even if you want to disable server side copy" default:"false"`
|
||||
ServerSideCopy bool `help:"enable code for server-side copy, deprecated. please leave this to true." default:"true"`
|
||||
ServerSideCopyDisabled bool `help:"disable already enabled server-side copy. this is because once server side copy is enabled, delete code should stay changed, even if you want to disable server side copy" default:"false"`
|
||||
ServerSideCopyDuplicateMetadata bool `help:"perform server-side copy by duplicating metadata, instead of using segment_copies" default:"false"`
|
||||
|
||||
// TODO remove when we benchmarking are done and decision is made.
|
||||
TestListingQuery bool `default:"false" help:"test the new query for non-recursive listing"`
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ func (endpoint *Endpoint) CreateBucket(ctx context.Context, req *pb.BucketCreate
|
||||
}
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.Name)
|
||||
err = endpoint.validateBucketName(req.Name)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -83,7 +83,7 @@ func (endpoint *Endpoint) CreateBucket(ctx context.Context, req *pb.BucketCreate
|
||||
return nil, rpcstatus.Error(rpcstatus.Internal, err.Error())
|
||||
} else if exists {
|
||||
// When the bucket exists, try to set the attribution.
|
||||
if err := endpoint.ensureAttribution(ctx, req.Header, keyInfo, req.GetName(), nil); err != nil {
|
||||
if err := endpoint.ensureAttribution(ctx, req.Header, keyInfo, req.GetName(), nil, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, rpcstatus.Error(rpcstatus.AlreadyExists, "bucket already exists")
|
||||
@ -119,7 +119,7 @@ func (endpoint *Endpoint) CreateBucket(ctx context.Context, req *pb.BucketCreate
|
||||
}
|
||||
|
||||
// Once we have created the bucket, we can try setting the attribution.
|
||||
if err := endpoint.ensureAttribution(ctx, req.Header, keyInfo, req.GetName(), project.UserAgent); err != nil {
|
||||
if err := endpoint.ensureAttribution(ctx, req.Header, keyInfo, req.GetName(), project.UserAgent, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -180,7 +180,7 @@ func (endpoint *Endpoint) DeleteBucket(ctx context.Context, req *pb.BucketDelete
|
||||
}
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.Name)
|
||||
err = endpoint.validateBucketNameLength(req.Name)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
|
@ -115,18 +115,29 @@ func TestBucketNameValidation(t *testing.T) {
|
||||
"192.168.1.234", "testBUCKET",
|
||||
"test/bucket",
|
||||
"testbucket-64-0123456789012345678901234567890123456789012345abcd",
|
||||
"test\\", "test%",
|
||||
}
|
||||
for _, name := range invalidNames {
|
||||
_, err = metainfoClient.BeginObject(ctx, metaclient.BeginObjectParams{
|
||||
Bucket: []byte(name),
|
||||
EncryptedObjectKey: []byte("123"),
|
||||
})
|
||||
require.Error(t, err, "bucket name: %v", name)
|
||||
|
||||
_, err = metainfoClient.CreateBucket(ctx, metaclient.CreateBucketParams{
|
||||
Name: []byte(name),
|
||||
})
|
||||
require.Error(t, err, "bucket name: %v", name)
|
||||
require.True(t, errs2.IsRPC(err, rpcstatus.InvalidArgument))
|
||||
}
|
||||
|
||||
invalidNames = []string{
|
||||
"", "t", "te",
|
||||
"testbucket-64-0123456789012345678901234567890123456789012345abcd",
|
||||
}
|
||||
for _, name := range invalidNames {
|
||||
// BeginObject validates only bucket name length
|
||||
_, err = metainfoClient.BeginObject(ctx, metaclient.BeginObjectParams{
|
||||
Bucket: []byte(name),
|
||||
EncryptedObjectKey: []byte("123"),
|
||||
})
|
||||
require.Error(t, err, "bucket name: %v", name)
|
||||
require.True(t, errs2.IsRPC(err, rpcstatus.InvalidArgument))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -66,7 +66,8 @@ func (endpoint *Endpoint) BeginObject(ctx context.Context, req *pb.ObjectBeginRe
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, "Invalid expiration time")
|
||||
}
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.Bucket)
|
||||
// we can do just basic name validation because later we are checking bucket in DB
|
||||
err = endpoint.validateBucketNameLength(req.Bucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -95,7 +96,7 @@ func (endpoint *Endpoint) BeginObject(ctx context.Context, req *pb.ObjectBeginRe
|
||||
return nil, rpcstatus.Error(rpcstatus.Internal, err.Error())
|
||||
}
|
||||
|
||||
if err := endpoint.ensureAttribution(ctx, req.Header, keyInfo, req.Bucket, nil); err != nil {
|
||||
if err := endpoint.ensureAttribution(ctx, req.Header, keyInfo, req.Bucket, nil, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -310,7 +311,7 @@ func (endpoint *Endpoint) GetObject(ctx context.Context, req *pb.ObjectGetReques
|
||||
}
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.Bucket)
|
||||
err = endpoint.validateBucketNameLength(req.Bucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -409,7 +410,7 @@ func (endpoint *Endpoint) DownloadObject(ctx context.Context, req *pb.ObjectDown
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.Bucket)
|
||||
err = endpoint.validateBucketNameLength(req.Bucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -804,7 +805,7 @@ func (endpoint *Endpoint) ListObjects(ctx context.Context, req *pb.ObjectListReq
|
||||
}
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.Bucket)
|
||||
err = endpoint.validateBucketNameLength(req.Bucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -944,7 +945,7 @@ func (endpoint *Endpoint) ListPendingObjectStreams(ctx context.Context, req *pb.
|
||||
}
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.Bucket)
|
||||
err = endpoint.validateBucketNameLength(req.Bucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -1057,7 +1058,7 @@ func (endpoint *Endpoint) BeginDeleteObject(ctx context.Context, req *pb.ObjectB
|
||||
}
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.Bucket)
|
||||
err = endpoint.validateBucketNameLength(req.Bucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -1146,7 +1147,7 @@ func (endpoint *Endpoint) GetObjectIPs(ctx context.Context, req *pb.ObjectGetIPs
|
||||
}
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.Bucket)
|
||||
err = endpoint.validateBucketNameLength(req.Bucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -1237,7 +1238,7 @@ func (endpoint *Endpoint) UpdateObjectMetadata(ctx context.Context, req *pb.Obje
|
||||
}
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.Bucket)
|
||||
err = endpoint.validateBucketNameLength(req.Bucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -1499,50 +1500,6 @@ func (endpoint *Endpoint) DeleteCommittedObject(
|
||||
return deletedObjects, nil
|
||||
}
|
||||
|
||||
// DeleteObjectAnyStatus deletes all the pieces of the storage nodes that belongs
|
||||
// to the specified object.
|
||||
//
|
||||
// NOTE: this method is exported for being able to individually test it without
|
||||
// having import cycles.
|
||||
// TODO regarding the above note: exporting for testing is fine, but we should name
|
||||
// it something that will definitely never ever be added to the rpc set in DRPC
|
||||
// definitions. If we ever decide to add an RPC method called "DeleteObjectAnyStatus",
|
||||
// DRPC interface definitions is all that is standing in the way from someone
|
||||
// remotely calling this. We should name this InternalDeleteObjectAnyStatus or
|
||||
// something.
|
||||
func (endpoint *Endpoint) DeleteObjectAnyStatus(ctx context.Context, location metabase.ObjectLocation,
|
||||
) (deletedObjects []*pb.Object, err error) {
|
||||
defer mon.Task()(&ctx, location.ProjectID.String(), location.BucketName, location.ObjectKey)(&err)
|
||||
|
||||
var result metabase.DeleteObjectResult
|
||||
if endpoint.config.ServerSideCopy {
|
||||
result, err = endpoint.metabase.DeleteObjectExactVersion(ctx, metabase.DeleteObjectExactVersion{
|
||||
ObjectLocation: location,
|
||||
Version: metabase.DefaultVersion,
|
||||
})
|
||||
} else {
|
||||
result, err = endpoint.metabase.DeleteObjectAnyStatusAllVersions(ctx, metabase.DeleteObjectAnyStatusAllVersions{
|
||||
ObjectLocation: location,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
deletedObjects, err = endpoint.deleteObjectResultToProto(ctx, result)
|
||||
if err != nil {
|
||||
endpoint.log.Error("failed to convert delete object result",
|
||||
zap.Stringer("project", location.ProjectID),
|
||||
zap.String("bucket", location.BucketName),
|
||||
zap.Binary("object", []byte(location.ObjectKey)),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return deletedObjects, nil
|
||||
}
|
||||
|
||||
// DeletePendingObject deletes all the pieces of the storage nodes that belongs
|
||||
// to the specified pending object.
|
||||
//
|
||||
@ -1615,7 +1572,7 @@ func (endpoint *Endpoint) BeginMoveObject(ctx context.Context, req *pb.ObjectBeg
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
for _, bucket := range [][]byte{req.Bucket, req.NewBucket} {
|
||||
err = endpoint.validateBucket(ctx, bucket)
|
||||
err = endpoint.validateBucketNameLength(bucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -1761,7 +1718,7 @@ func (endpoint *Endpoint) FinishMoveObject(ctx context.Context, req *pb.ObjectFi
|
||||
}
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.NewBucket)
|
||||
err = endpoint.validateBucketNameLength(req.NewBucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -1840,7 +1797,7 @@ func (endpoint *Endpoint) BeginCopyObject(ctx context.Context, req *pb.ObjectBeg
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
for _, bucket := range [][]byte{req.Bucket, req.NewBucket} {
|
||||
err = endpoint.validateBucket(ctx, bucket)
|
||||
err = endpoint.validateBucketNameLength(bucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -1952,7 +1909,7 @@ func (endpoint *Endpoint) FinishCopyObject(ctx context.Context, req *pb.ObjectFi
|
||||
}
|
||||
endpoint.usageTracking(keyInfo, req.Header, fmt.Sprintf("%T", req))
|
||||
|
||||
err = endpoint.validateBucket(ctx, req.NewBucket)
|
||||
err = endpoint.validateBucketNameLength(req.NewBucket)
|
||||
if err != nil {
|
||||
return nil, rpcstatus.Error(rpcstatus.InvalidArgument, err.Error())
|
||||
}
|
||||
@ -1995,6 +1952,7 @@ func (endpoint *Endpoint) FinishCopyObject(ctx context.Context, req *pb.ObjectFi
|
||||
NewEncryptedMetadata: req.NewEncryptedMetadata,
|
||||
NewEncryptedMetadataKeyNonce: req.NewEncryptedMetadataKeyNonce,
|
||||
NewEncryptedMetadataKey: req.NewEncryptedMetadataKey,
|
||||
DuplicateMetadata: endpoint.config.ServerSideCopyDuplicateMetadata,
|
||||
VerifyLimits: func(encryptedObjectSize int64, nSegments int64) error {
|
||||
return endpoint.addStorageUsageUpToLimit(ctx, keyInfo.ProjectID, encryptedObjectSize, nSegments)
|
||||
},
|
||||
|
@ -22,7 +22,9 @@ import (
|
||||
|
||||
"storj.io/common/errs2"
|
||||
"storj.io/common/identity"
|
||||
"storj.io/common/identity/testidentity"
|
||||
"storj.io/common/memory"
|
||||
"storj.io/common/nodetag"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/rpc/rpcstatus"
|
||||
"storj.io/common/signing"
|
||||
@ -37,6 +39,9 @@ import (
|
||||
"storj.io/storj/satellite/internalpb"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/metainfo"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
"storj.io/storj/storagenode"
|
||||
"storj.io/storj/storagenode/contact"
|
||||
"storj.io/uplink"
|
||||
"storj.io/uplink/private/metaclient"
|
||||
"storj.io/uplink/private/object"
|
||||
@ -1585,59 +1590,6 @@ func TestEndpoint_DeletePendingObject(t *testing.T) {
|
||||
testDeleteObject(t, createPendingObject, deletePendingObject)
|
||||
}
|
||||
|
||||
func TestEndpoint_DeleteObjectAnyStatus(t *testing.T) {
|
||||
createCommittedObject := func(ctx context.Context, t *testing.T, planet *testplanet.Planet, bucket, key string, data []byte) {
|
||||
err := planet.Uplinks[0].Upload(ctx, planet.Satellites[0], bucket, key, data)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
deleteCommittedObject := func(ctx context.Context, t *testing.T, planet *testplanet.Planet, bucket, encryptedKey string, streamID uuid.UUID) {
|
||||
projectID := planet.Uplinks[0].Projects[0].ID
|
||||
|
||||
deletedObjects, err := planet.Satellites[0].Metainfo.Endpoint.DeleteObjectAnyStatus(ctx, metabase.ObjectLocation{
|
||||
ProjectID: projectID,
|
||||
BucketName: bucket,
|
||||
ObjectKey: metabase.ObjectKey(encryptedKey),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, deletedObjects, 1)
|
||||
|
||||
}
|
||||
testDeleteObject(t, createCommittedObject, deleteCommittedObject)
|
||||
|
||||
createPendingObject := func(ctx context.Context, t *testing.T, planet *testplanet.Planet, bucket, key string, data []byte) {
|
||||
// TODO This should be replaced by a call to testplanet.Uplink.MultipartUpload when available.
|
||||
project, err := planet.Uplinks[0].OpenProject(ctx, planet.Satellites[0])
|
||||
require.NoError(t, err, "failed to retrieve project")
|
||||
defer func() { require.NoError(t, project.Close()) }()
|
||||
|
||||
_, err = project.EnsureBucket(ctx, bucket)
|
||||
require.NoError(t, err, "failed to create bucket")
|
||||
|
||||
info, err := project.BeginUpload(ctx, bucket, key, &uplink.UploadOptions{})
|
||||
require.NoError(t, err, "failed to start multipart upload")
|
||||
|
||||
upload, err := project.UploadPart(ctx, bucket, key, info.UploadID, 1)
|
||||
require.NoError(t, err, "failed to put object part")
|
||||
_, err = upload.Write(data)
|
||||
require.NoError(t, err, "failed to start multipart upload")
|
||||
require.NoError(t, upload.Commit(), "failed to start multipart upload")
|
||||
}
|
||||
|
||||
deletePendingObject := func(ctx context.Context, t *testing.T, planet *testplanet.Planet, bucket, encryptedKey string, streamID uuid.UUID) {
|
||||
projectID := planet.Uplinks[0].Projects[0].ID
|
||||
|
||||
deletedObjects, err := planet.Satellites[0].Metainfo.Endpoint.DeleteObjectAnyStatus(ctx, metabase.ObjectLocation{
|
||||
ProjectID: projectID,
|
||||
BucketName: bucket,
|
||||
ObjectKey: metabase.ObjectKey(encryptedKey),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, deletedObjects, 1)
|
||||
}
|
||||
|
||||
testDeleteObject(t, createPendingObject, deletePendingObject)
|
||||
}
|
||||
|
||||
func testDeleteObject(t *testing.T,
|
||||
createObject func(ctx context.Context, t *testing.T, planet *testplanet.Planet, bucket, key string, data []byte),
|
||||
deleteObject func(ctx context.Context, t *testing.T, planet *testplanet.Planet, bucket, encryptedKey string, streamID uuid.UUID),
|
||||
@ -2450,3 +2402,100 @@ func TestListUploads(t *testing.T) {
|
||||
require.Equal(t, 1000, items)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPlacements(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
|
||||
satelliteIdentity := signing.SignerFromFullIdentity(testidentity.MustPregeneratedSignedIdentity(0, storj.LatestIDVersion()))
|
||||
|
||||
placementRules := overlay.ConfigurablePlacementRule{}
|
||||
err := placementRules.Set(fmt.Sprintf(`16:tag("%s", "certified","true")`, satelliteIdentity.ID()))
|
||||
require.NoError(t, err)
|
||||
|
||||
testplanet.Run(t,
|
||||
testplanet.Config{
|
||||
SatelliteCount: 1,
|
||||
StorageNodeCount: 12,
|
||||
UplinkCount: 1,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
|
||||
config.Metainfo.RS.Min = 3
|
||||
config.Metainfo.RS.Repair = 4
|
||||
config.Metainfo.RS.Success = 5
|
||||
config.Metainfo.RS.Total = 6
|
||||
config.Metainfo.MaxInlineSegmentSize = 1
|
||||
config.Placement = placementRules
|
||||
},
|
||||
StorageNode: func(index int, config *storagenode.Config) {
|
||||
if index%2 == 0 {
|
||||
tags := &pb.NodeTagSet{
|
||||
NodeId: testidentity.MustPregeneratedSignedIdentity(index+1, storj.LatestIDVersion()).ID.Bytes(),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Tags: []*pb.Tag{
|
||||
{
|
||||
Name: "certified",
|
||||
Value: []byte("true"),
|
||||
},
|
||||
},
|
||||
}
|
||||
signed, err := nodetag.Sign(ctx, tags, satelliteIdentity)
|
||||
require.NoError(t, err)
|
||||
|
||||
config.Contact.Tags = contact.SignedTags(pb.SignedNodeTagSets{
|
||||
Tags: []*pb.SignedNodeTagSet{
|
||||
signed,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
},
|
||||
},
|
||||
},
|
||||
func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
satellite := planet.Satellites[0]
|
||||
buckets := satellite.API.Buckets.Service
|
||||
uplink := planet.Uplinks[0]
|
||||
projectID := uplink.Projects[0].ID
|
||||
|
||||
// create buckets with different placement (placement 16 is configured above)
|
||||
createGeofencedBucket(t, ctx, buckets, projectID, "constrained", 16)
|
||||
|
||||
objectNo := 10
|
||||
for i := 0; i < objectNo; i++ {
|
||||
// upload an object to one of the global buckets
|
||||
err := uplink.Upload(ctx, satellite, "constrained", "testobject"+strconv.Itoa(i), make([]byte, 10240))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
apiKey := planet.Uplinks[0].APIKey[planet.Satellites[0].ID()]
|
||||
metainfoClient, err := uplink.DialMetainfo(ctx, satellite, apiKey)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = metainfoClient.Close()
|
||||
}()
|
||||
|
||||
objects, _, err := metainfoClient.ListObjects(ctx, metaclient.ListObjectsParams{
|
||||
Bucket: []byte("constrained"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, objects, objectNo)
|
||||
|
||||
for _, listedObject := range objects {
|
||||
o, err := metainfoClient.DownloadObject(ctx, metaclient.DownloadObjectParams{
|
||||
Bucket: []byte("constrained"),
|
||||
EncryptedObjectKey: listedObject.EncryptedObjectKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, limit := range o.DownloadedSegments[0].Limits {
|
||||
if limit != nil {
|
||||
// starting from 2 (first identity used for satellite, SN with even index are fine)
|
||||
for i := 2; i < 11; i += 2 {
|
||||
require.NotEqual(t, testidentity.MustPregeneratedSignedIdentity(i, storj.LatestIDVersion()).ID, limit.Limit.StorageNodeId)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
@ -247,9 +247,7 @@ func (endpoint *Endpoint) checkRate(ctx context.Context, projectID uuid.UUID) (e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
func (endpoint *Endpoint) validateBucketNameLength(bucket []byte) (err error) {
|
||||
if len(bucket) == 0 {
|
||||
return Error.Wrap(buckets.ErrNoBucket.New(""))
|
||||
}
|
||||
@ -258,11 +256,19 @@ func (endpoint *Endpoint) validateBucket(ctx context.Context, bucket []byte) (er
|
||||
return Error.New("bucket name must be at least 3 and no more than 63 characters long")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (endpoint *Endpoint) validateBucketName(bucket []byte) error {
|
||||
if err := endpoint.validateBucketNameLength(bucket); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Regexp not used because benchmark shows it will be slower for valid bucket names
|
||||
// https://gist.github.com/mniewrzal/49de3af95f36e63e88fac24f565e444c
|
||||
labels := bytes.Split(bucket, []byte("."))
|
||||
for _, label := range labels {
|
||||
err = validateBucketLabel(label)
|
||||
err := validateBucketLabel(label)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -284,8 +290,8 @@ func validateBucketLabel(label []byte) error {
|
||||
return Error.New("bucket label must start with a lowercase letter or number")
|
||||
}
|
||||
|
||||
if label[0] == '-' || label[len(label)-1] == '-' {
|
||||
return Error.New("bucket label cannot start or end with a hyphen")
|
||||
if !isLowerLetter(label[len(label)-1]) && !isDigit(label[len(label)-1]) {
|
||||
return Error.New("bucket label must end with a lowercase letter or number")
|
||||
}
|
||||
|
||||
for i := 1; i < len(label)-1; i++ {
|
||||
|
@ -2,7 +2,7 @@
|
||||
// See LICENSE for copying information.
|
||||
|
||||
// Package uploadselection implements node selection logic for uploads.
|
||||
package uploadselection
|
||||
package nodeselection
|
||||
|
||||
import (
|
||||
"github.com/spacemonkeygo/monkit/v3"
|
181
satellite/nodeselection/filter.go
Normal file
181
satellite/nodeselection/filter.go
Normal file
@ -0,0 +1,181 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package nodeselection
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/storj/location"
|
||||
)
|
||||
|
||||
// NodeFilter can decide if a Node should be part of the selection or not.
|
||||
type NodeFilter interface {
|
||||
MatchInclude(node *SelectedNode) bool
|
||||
}
|
||||
|
||||
// NodeFilters is a collection of multiple node filters (all should vote with true).
|
||||
type NodeFilters []NodeFilter
|
||||
|
||||
// NodeFilterFunc is helper to use func as NodeFilter.
|
||||
type NodeFilterFunc func(node *SelectedNode) bool
|
||||
|
||||
// MatchInclude implements NodeFilter interface.
|
||||
func (n NodeFilterFunc) MatchInclude(node *SelectedNode) bool {
|
||||
return n(node)
|
||||
}
|
||||
|
||||
// ExcludeAllFilter will never select any node.
|
||||
type ExcludeAllFilter struct{}
|
||||
|
||||
// MatchInclude implements NodeFilter interface.
|
||||
func (ExcludeAllFilter) MatchInclude(node *SelectedNode) bool { return false }
|
||||
|
||||
// MatchInclude implements NodeFilter interface.
|
||||
func (n NodeFilters) MatchInclude(node *SelectedNode) bool {
|
||||
for _, filter := range n {
|
||||
if !filter.MatchInclude(node) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// WithCountryFilter is a helper to create a new filter with additional CountryFilter.
|
||||
func (n NodeFilters) WithCountryFilter(permit location.Set) NodeFilters {
|
||||
return append(n, NewCountryFilter(permit))
|
||||
}
|
||||
|
||||
// WithAutoExcludeSubnets is a helper to create a new filter with additional AutoExcludeSubnets.
|
||||
func (n NodeFilters) WithAutoExcludeSubnets() NodeFilters {
|
||||
return append(n, NewAutoExcludeSubnets())
|
||||
}
|
||||
|
||||
// WithExcludedIDs is a helper to create a new filter with additional WithExcludedIDs.
|
||||
func (n NodeFilters) WithExcludedIDs(ds []storj.NodeID) NodeFilters {
|
||||
return append(n, ExcludedIDs(ds))
|
||||
}
|
||||
|
||||
var _ NodeFilter = NodeFilters{}
|
||||
|
||||
// CountryCodeExclude is a specific CountryFilter which excludes all nodes with the given country code.
|
||||
type CountryCodeExclude []location.CountryCode
|
||||
|
||||
// MatchInclude implements NodeFilter interface.
|
||||
func (c CountryCodeExclude) MatchInclude(node *SelectedNode) bool {
|
||||
for _, code := range c {
|
||||
if code == location.None {
|
||||
continue
|
||||
}
|
||||
if node.CountryCode == code {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var _ NodeFilter = CountryCodeExclude{}
|
||||
|
||||
// CountryFilter can select nodes based on the condition of the country code.
|
||||
type CountryFilter struct {
|
||||
permit location.Set
|
||||
}
|
||||
|
||||
// NewCountryFilter creates a new CountryFilter.
|
||||
func NewCountryFilter(permit location.Set) NodeFilter {
|
||||
return &CountryFilter{
|
||||
permit: permit,
|
||||
}
|
||||
}
|
||||
|
||||
// MatchInclude implements NodeFilter interface.
|
||||
func (p *CountryFilter) MatchInclude(node *SelectedNode) bool {
|
||||
return p.permit.Contains(node.CountryCode)
|
||||
}
|
||||
|
||||
var _ NodeFilter = &CountryFilter{}
|
||||
|
||||
// AutoExcludeSubnets pick at most one node from network.
|
||||
//
|
||||
// Stateful!!! should be re-created for each new selection request.
|
||||
// It should only be used as the last filter.
|
||||
type AutoExcludeSubnets struct {
|
||||
seenSubnets map[string]struct{}
|
||||
}
|
||||
|
||||
// NewAutoExcludeSubnets creates an initialized AutoExcludeSubnets.
|
||||
func NewAutoExcludeSubnets() *AutoExcludeSubnets {
|
||||
return &AutoExcludeSubnets{
|
||||
seenSubnets: map[string]struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
// MatchInclude implements NodeFilter interface.
|
||||
func (a *AutoExcludeSubnets) MatchInclude(node *SelectedNode) bool {
|
||||
if _, found := a.seenSubnets[node.LastNet]; found {
|
||||
return false
|
||||
}
|
||||
a.seenSubnets[node.LastNet] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
var _ NodeFilter = &AutoExcludeSubnets{}
|
||||
|
||||
// ExcludedNetworks will exclude nodes with specified networks.
|
||||
type ExcludedNetworks []string
|
||||
|
||||
// MatchInclude implements NodeFilter interface.
|
||||
func (e ExcludedNetworks) MatchInclude(node *SelectedNode) bool {
|
||||
for _, id := range e {
|
||||
if id == node.LastNet {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var _ NodeFilter = ExcludedNetworks{}
|
||||
|
||||
// ExcludedIDs can blacklist NodeIDs.
|
||||
type ExcludedIDs []storj.NodeID
|
||||
|
||||
// MatchInclude implements NodeFilter interface.
|
||||
func (e ExcludedIDs) MatchInclude(node *SelectedNode) bool {
|
||||
for _, id := range e {
|
||||
if id == node.ID {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var _ NodeFilter = ExcludedIDs{}
|
||||
|
||||
// TagFilter matches nodes with specific tags.
|
||||
type TagFilter struct {
|
||||
signer storj.NodeID
|
||||
name string
|
||||
value []byte
|
||||
}
|
||||
|
||||
// NewTagFilter creates a new tag filter.
|
||||
func NewTagFilter(id storj.NodeID, name string, value []byte) TagFilter {
|
||||
return TagFilter{
|
||||
signer: id,
|
||||
name: name,
|
||||
value: value,
|
||||
}
|
||||
}
|
||||
|
||||
// MatchInclude implements NodeFilter interface.
|
||||
func (t TagFilter) MatchInclude(node *SelectedNode) bool {
|
||||
for _, tag := range node.Tags {
|
||||
if tag.Name == t.name && bytes.Equal(tag.Value, t.value) && tag.Signer == t.signer {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var _ NodeFilter = TagFilter{}
|
175
satellite/nodeselection/filter_test.go
Normal file
175
satellite/nodeselection/filter_test.go
Normal file
@ -0,0 +1,175 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package nodeselection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/common/identity/testidentity"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/storj/location"
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
)
|
||||
|
||||
func TestNodeFilter_AutoExcludeSubnet(t *testing.T) {
|
||||
|
||||
criteria := NodeFilters{}.WithAutoExcludeSubnets()
|
||||
|
||||
assert.True(t, criteria.MatchInclude(&SelectedNode{
|
||||
LastNet: "192.168.0.1",
|
||||
}))
|
||||
|
||||
assert.False(t, criteria.MatchInclude(&SelectedNode{
|
||||
LastNet: "192.168.0.1",
|
||||
}))
|
||||
|
||||
assert.True(t, criteria.MatchInclude(&SelectedNode{
|
||||
LastNet: "192.168.1.1",
|
||||
}))
|
||||
}
|
||||
|
||||
func TestCriteria_ExcludeNodeID(t *testing.T) {
|
||||
included := testrand.NodeID()
|
||||
excluded := testrand.NodeID()
|
||||
|
||||
criteria := NodeFilters{}.WithExcludedIDs([]storj.NodeID{excluded})
|
||||
|
||||
assert.False(t, criteria.MatchInclude(&SelectedNode{
|
||||
ID: excluded,
|
||||
}))
|
||||
|
||||
assert.True(t, criteria.MatchInclude(&SelectedNode{
|
||||
ID: included,
|
||||
}))
|
||||
|
||||
}
|
||||
|
||||
func TestCriteria_NodeIDAndSubnet(t *testing.T) {
|
||||
excluded := testrand.NodeID()
|
||||
|
||||
criteria := NodeFilters{}.
|
||||
WithExcludedIDs([]storj.NodeID{excluded}).
|
||||
WithAutoExcludeSubnets()
|
||||
|
||||
// due to node id criteria
|
||||
assert.False(t, criteria.MatchInclude(&SelectedNode{
|
||||
ID: excluded,
|
||||
LastNet: "192.168.0.1",
|
||||
}))
|
||||
|
||||
// should be included as previous one excluded and
|
||||
// not stored for subnet exclusion
|
||||
assert.True(t, criteria.MatchInclude(&SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: "192.168.0.2",
|
||||
}))
|
||||
|
||||
}
|
||||
|
||||
func TestCriteria_Geofencing(t *testing.T) {
|
||||
eu := NodeFilters{}.WithCountryFilter(EuCountries)
|
||||
us := NodeFilters{}.WithCountryFilter(location.NewSet(location.UnitedStates))
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
country location.CountryCode
|
||||
criteria NodeFilters
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "US matches US selector",
|
||||
country: location.UnitedStates,
|
||||
criteria: us,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Germany is EU",
|
||||
country: location.Germany,
|
||||
criteria: eu,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "US is not eu",
|
||||
country: location.UnitedStates,
|
||||
criteria: eu,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty country doesn't match region",
|
||||
country: location.CountryCode(0),
|
||||
criteria: eu,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty country doesn't match country",
|
||||
country: location.CountryCode(0),
|
||||
criteria: us,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
assert.Equal(t, c.expected, c.criteria.MatchInclude(&SelectedNode{
|
||||
CountryCode: c.country,
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNodeFilterFullTable checks performances of rule evaluation on ALL storage nodes.
|
||||
func BenchmarkNodeFilterFullTable(b *testing.B) {
|
||||
filters := NodeFilters{}
|
||||
filters = append(filters, NodeFilterFunc(func(node *SelectedNode) bool {
|
||||
return true
|
||||
}))
|
||||
filters = append(filters, NodeFilterFunc(func(node *SelectedNode) bool {
|
||||
return true
|
||||
}))
|
||||
filters = append(filters, NodeFilterFunc(func(node *SelectedNode) bool {
|
||||
return true
|
||||
}))
|
||||
filters = filters.WithAutoExcludeSubnets()
|
||||
benchmarkFilter(b, filters)
|
||||
}
|
||||
|
||||
func benchmarkFilter(b *testing.B, filters NodeFilters) {
|
||||
nodeNo := 25000
|
||||
if testing.Short() {
|
||||
nodeNo = 20
|
||||
}
|
||||
nodes := generatedSelectedNodes(b, nodeNo)
|
||||
|
||||
b.ResetTimer()
|
||||
c := 0
|
||||
for j := 0; j < b.N; j++ {
|
||||
for n := 0; n < len(nodes); n++ {
|
||||
if filters.MatchInclude(nodes[n]) {
|
||||
c++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func generatedSelectedNodes(b *testing.B, nodeNo int) []*SelectedNode {
|
||||
nodes := make([]*SelectedNode, nodeNo)
|
||||
ctx := testcontext.New(b)
|
||||
for i := 0; i < nodeNo; i++ {
|
||||
node := SelectedNode{}
|
||||
identity, err := testidentity.NewTestIdentity(ctx)
|
||||
require.NoError(b, err)
|
||||
node.ID = identity.ID
|
||||
node.LastNet = fmt.Sprintf("192.168.%d.0", i%256)
|
||||
node.LastIPPort = fmt.Sprintf("192.168.%d.%d:%d", i%256, i%65536, i%1000+1000)
|
||||
node.CountryCode = []location.CountryCode{location.None, location.UnitedStates, location.Germany, location.Hungary, location.Austria}[i%5]
|
||||
nodes[i] = &node
|
||||
}
|
||||
return nodes
|
||||
}
|
69
satellite/nodeselection/node.go
Normal file
69
satellite/nodeselection/node.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package nodeselection
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/storj/location"
|
||||
)
|
||||
|
||||
// NodeTag is a tag associated with a node (approved by signer).
|
||||
type NodeTag struct {
|
||||
NodeID storj.NodeID
|
||||
SignedAt time.Time
|
||||
Signer storj.NodeID
|
||||
Name string
|
||||
Value []byte
|
||||
}
|
||||
|
||||
// NodeTags is a collection of multiple NodeTag.
|
||||
type NodeTags []NodeTag
|
||||
|
||||
// FindBySignerAndName selects first tag with same name / NodeID.
|
||||
func (n NodeTags) FindBySignerAndName(signer storj.NodeID, name string) (NodeTag, error) {
|
||||
for _, tag := range n {
|
||||
if tag.Name == name && signer == tag.Signer {
|
||||
return tag, nil
|
||||
}
|
||||
}
|
||||
return NodeTag{}, errs.New("tags not found")
|
||||
}
|
||||
|
||||
// SelectedNode is used as a result for creating orders limits.
|
||||
type SelectedNode struct {
|
||||
ID storj.NodeID
|
||||
Address *pb.NodeAddress
|
||||
LastNet string
|
||||
LastIPPort string
|
||||
CountryCode location.CountryCode
|
||||
Tags NodeTags
|
||||
}
|
||||
|
||||
// Clone returns a deep clone of the selected node.
|
||||
func (node *SelectedNode) Clone() *SelectedNode {
|
||||
copy := pb.CopyNode(&pb.Node{Id: node.ID, Address: node.Address})
|
||||
tags := make([]NodeTag, len(node.Tags))
|
||||
for ix, tag := range node.Tags {
|
||||
tags[ix] = NodeTag{
|
||||
NodeID: tag.NodeID,
|
||||
SignedAt: tag.SignedAt,
|
||||
Signer: tag.Signer,
|
||||
Name: tag.Name,
|
||||
Value: tag.Value,
|
||||
}
|
||||
}
|
||||
return &SelectedNode{
|
||||
ID: copy.Id,
|
||||
Address: copy.Address,
|
||||
LastNet: node.LastNet,
|
||||
LastIPPort: node.LastIPPort,
|
||||
CountryCode: node.CountryCode,
|
||||
Tags: tags,
|
||||
}
|
||||
}
|
44
satellite/nodeselection/region.go
Normal file
44
satellite/nodeselection/region.go
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package nodeselection
|
||||
|
||||
import "storj.io/common/storj/location"
|
||||
|
||||
// EuCountries defines the member countries of European Union.
|
||||
var EuCountries = location.NewSet(
|
||||
location.Austria,
|
||||
location.Belgium,
|
||||
location.Bulgaria,
|
||||
location.Croatia,
|
||||
location.Cyprus,
|
||||
location.Czechia,
|
||||
location.Denmark,
|
||||
location.Estonia,
|
||||
location.Finland,
|
||||
location.France,
|
||||
location.Germany,
|
||||
location.Greece,
|
||||
location.Hungary,
|
||||
location.Ireland,
|
||||
location.Italy,
|
||||
location.Lithuania,
|
||||
location.Latvia,
|
||||
location.Luxembourg,
|
||||
location.Malta,
|
||||
location.Netherlands,
|
||||
location.Poland,
|
||||
location.Portugal,
|
||||
location.Romania,
|
||||
location.Slovenia,
|
||||
location.Slovakia,
|
||||
location.Spain,
|
||||
location.Sweden,
|
||||
)
|
||||
|
||||
// EeaCountries defined the EEA countries.
|
||||
var EeaCountries = EuCountries.With(
|
||||
location.Iceland,
|
||||
location.Liechtenstein,
|
||||
location.Norway,
|
||||
)
|
@ -1,14 +1,14 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package uploadselection
|
||||
package nodeselection
|
||||
|
||||
import (
|
||||
mathrand "math/rand" // Using mathrand here because crypto-graphic randomness is not required and simplifies code.
|
||||
)
|
||||
|
||||
// SelectByID implements selection from nodes with every node having equal probability.
|
||||
type SelectByID []*Node
|
||||
type SelectByID []*SelectedNode
|
||||
|
||||
var _ Selector = (SelectByID)(nil)
|
||||
|
||||
@ -16,16 +16,16 @@ var _ Selector = (SelectByID)(nil)
|
||||
func (nodes SelectByID) Count() int { return len(nodes) }
|
||||
|
||||
// Select selects upto n nodes.
|
||||
func (nodes SelectByID) Select(n int, criteria Criteria) []*Node {
|
||||
func (nodes SelectByID) Select(n int, nodeFilter NodeFilter) []*SelectedNode {
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
selected := []*Node{}
|
||||
selected := []*SelectedNode{}
|
||||
for _, idx := range mathrand.Perm(len(nodes)) {
|
||||
node := nodes[idx]
|
||||
|
||||
if !criteria.MatchInclude(node) {
|
||||
if !nodeFilter.MatchInclude(node) {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -46,12 +46,12 @@ var _ Selector = (SelectBySubnet)(nil)
|
||||
// Subnet groups together nodes with the same subnet.
|
||||
type Subnet struct {
|
||||
Net string
|
||||
Nodes []*Node
|
||||
Nodes []*SelectedNode
|
||||
}
|
||||
|
||||
// SelectBySubnetFromNodes creates SelectBySubnet selector from nodes.
|
||||
func SelectBySubnetFromNodes(nodes []*Node) SelectBySubnet {
|
||||
bynet := map[string][]*Node{}
|
||||
func SelectBySubnetFromNodes(nodes []*SelectedNode) SelectBySubnet {
|
||||
bynet := map[string][]*SelectedNode{}
|
||||
for _, node := range nodes {
|
||||
bynet[node.LastNet] = append(bynet[node.LastNet], node)
|
||||
}
|
||||
@ -71,17 +71,17 @@ func SelectBySubnetFromNodes(nodes []*Node) SelectBySubnet {
|
||||
func (subnets SelectBySubnet) Count() int { return len(subnets) }
|
||||
|
||||
// Select selects upto n nodes.
|
||||
func (subnets SelectBySubnet) Select(n int, criteria Criteria) []*Node {
|
||||
func (subnets SelectBySubnet) Select(n int, filter NodeFilter) []*SelectedNode {
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
selected := []*Node{}
|
||||
selected := []*SelectedNode{}
|
||||
for _, idx := range mathrand.Perm(len(subnets)) {
|
||||
subnet := subnets[idx]
|
||||
node := subnet.Nodes[mathrand.Intn(len(subnet.Nodes))]
|
||||
|
||||
if !criteria.MatchInclude(node) {
|
||||
if !filter.MatchInclude(node) {
|
||||
continue
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package uploadselection_test
|
||||
package nodeselection_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@ -12,7 +12,7 @@ import (
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
"storj.io/storj/satellite/nodeselection/uploadselection"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
)
|
||||
|
||||
func TestSelectByID(t *testing.T) {
|
||||
@ -24,35 +24,26 @@ func TestSelectByID(t *testing.T) {
|
||||
|
||||
// create 3 nodes, 2 with same subnet
|
||||
lastNetDuplicate := "1.0.1"
|
||||
subnetA1 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: lastNetDuplicate + ".4:8080",
|
||||
},
|
||||
subnetA1 := &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: lastNetDuplicate,
|
||||
LastIPPort: lastNetDuplicate + ".4:8080",
|
||||
}
|
||||
subnetA2 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: lastNetDuplicate + ".5:8080",
|
||||
},
|
||||
subnetA2 := &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: lastNetDuplicate,
|
||||
LastIPPort: lastNetDuplicate + ".5:8080",
|
||||
}
|
||||
|
||||
lastNetSingle := "1.0.2"
|
||||
subnetB1 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: lastNetSingle + ".5:8080",
|
||||
},
|
||||
subnetB1 := &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: lastNetSingle,
|
||||
LastIPPort: lastNetSingle + ".5:8080",
|
||||
}
|
||||
|
||||
nodes := []*uploadselection.Node{subnetA1, subnetA2, subnetB1}
|
||||
selector := uploadselection.SelectByID(nodes)
|
||||
nodes := []*nodeselection.SelectedNode{subnetA1, subnetA2, subnetB1}
|
||||
selector := nodeselection.SelectByID(nodes)
|
||||
|
||||
const (
|
||||
reqCount = 2
|
||||
@ -63,7 +54,7 @@ func TestSelectByID(t *testing.T) {
|
||||
|
||||
// perform many node selections that selects 2 nodes
|
||||
for i := 0; i < executionCount; i++ {
|
||||
selectedNodes := selector.Select(reqCount, uploadselection.Criteria{})
|
||||
selectedNodes := selector.Select(reqCount, nodeselection.NodeFilters{})
|
||||
require.Len(t, selectedNodes, reqCount)
|
||||
for _, node := range selectedNodes {
|
||||
selectedNodeCount[node.ID]++
|
||||
@ -93,35 +84,26 @@ func TestSelectBySubnet(t *testing.T) {
|
||||
|
||||
// create 3 nodes, 2 with same subnet
|
||||
lastNetDuplicate := "1.0.1"
|
||||
subnetA1 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: lastNetDuplicate + ".4:8080",
|
||||
},
|
||||
subnetA1 := &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: lastNetDuplicate,
|
||||
LastIPPort: lastNetDuplicate + ".4:8080",
|
||||
}
|
||||
subnetA2 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: lastNetDuplicate + ".5:8080",
|
||||
},
|
||||
subnetA2 := &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: lastNetDuplicate,
|
||||
LastIPPort: lastNetDuplicate + ".5:8080",
|
||||
}
|
||||
|
||||
lastNetSingle := "1.0.2"
|
||||
subnetB1 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: lastNetSingle + ".5:8080",
|
||||
},
|
||||
subnetB1 := &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: lastNetSingle,
|
||||
LastIPPort: lastNetSingle + ".5:8080",
|
||||
}
|
||||
|
||||
nodes := []*uploadselection.Node{subnetA1, subnetA2, subnetB1}
|
||||
selector := uploadselection.SelectBySubnetFromNodes(nodes)
|
||||
nodes := []*nodeselection.SelectedNode{subnetA1, subnetA2, subnetB1}
|
||||
selector := nodeselection.SelectBySubnetFromNodes(nodes)
|
||||
|
||||
const (
|
||||
reqCount = 2
|
||||
@ -132,7 +114,7 @@ func TestSelectBySubnet(t *testing.T) {
|
||||
|
||||
// perform many node selections that selects 2 nodes
|
||||
for i := 0; i < executionCount; i++ {
|
||||
selectedNodes := selector.Select(reqCount, uploadselection.Criteria{})
|
||||
selectedNodes := selector.Select(reqCount, nodeselection.NodeFilters{})
|
||||
require.Len(t, selectedNodes, reqCount)
|
||||
for _, node := range selectedNodes {
|
||||
selectedNodeCount[node.ID]++
|
||||
@ -174,35 +156,26 @@ func TestSelectBySubnetOneAtATime(t *testing.T) {
|
||||
|
||||
// create 3 nodes, 2 with same subnet
|
||||
lastNetDuplicate := "1.0.1"
|
||||
subnetA1 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: lastNetDuplicate + ".4:8080",
|
||||
},
|
||||
subnetA1 := &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: lastNetDuplicate,
|
||||
LastIPPort: lastNetDuplicate + ".4:8080",
|
||||
}
|
||||
subnetA2 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: lastNetDuplicate + ".5:8080",
|
||||
},
|
||||
subnetA2 := &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: lastNetDuplicate,
|
||||
LastIPPort: lastNetDuplicate + ".5:8080",
|
||||
}
|
||||
|
||||
lastNetSingle := "1.0.2"
|
||||
subnetB1 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: lastNetSingle + ".5:8080",
|
||||
},
|
||||
subnetB1 := &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: lastNetSingle,
|
||||
LastIPPort: lastNetSingle + ".5:8080",
|
||||
}
|
||||
|
||||
nodes := []*uploadselection.Node{subnetA1, subnetA2, subnetB1}
|
||||
selector := uploadselection.SelectBySubnetFromNodes(nodes)
|
||||
nodes := []*nodeselection.SelectedNode{subnetA1, subnetA2, subnetB1}
|
||||
selector := nodeselection.SelectBySubnetFromNodes(nodes)
|
||||
|
||||
const (
|
||||
reqCount = 1
|
||||
@ -213,7 +186,7 @@ func TestSelectBySubnetOneAtATime(t *testing.T) {
|
||||
|
||||
// perform many node selections that selects 1 node
|
||||
for i := 0; i < executionCount; i++ {
|
||||
selectedNodes := selector.Select(reqCount, uploadselection.Criteria{})
|
||||
selectedNodes := selector.Select(reqCount, nodeselection.NodeFilters{})
|
||||
require.Len(t, selectedNodes, reqCount)
|
||||
for _, node := range selectedNodes {
|
||||
selectedNodeCount[node.ID]++
|
||||
@ -247,49 +220,35 @@ func TestSelectFiltered(t *testing.T) {
|
||||
// create 3 nodes, 2 with same subnet
|
||||
lastNetDuplicate := "1.0.1"
|
||||
firstID := testrand.NodeID()
|
||||
subnetA1 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: firstID,
|
||||
Address: lastNetDuplicate + ".4:8080",
|
||||
},
|
||||
subnetA1 := &nodeselection.SelectedNode{
|
||||
ID: firstID,
|
||||
LastNet: lastNetDuplicate,
|
||||
LastIPPort: lastNetDuplicate + ".4:8080",
|
||||
}
|
||||
|
||||
secondID := testrand.NodeID()
|
||||
subnetA2 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: secondID,
|
||||
Address: lastNetDuplicate + ".5:8080",
|
||||
},
|
||||
subnetA2 := &nodeselection.SelectedNode{
|
||||
ID: secondID,
|
||||
LastNet: lastNetDuplicate,
|
||||
LastIPPort: lastNetDuplicate + ".5:8080",
|
||||
}
|
||||
|
||||
thirdID := testrand.NodeID()
|
||||
lastNetSingle := "1.0.2"
|
||||
subnetB1 := &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: thirdID,
|
||||
Address: lastNetSingle + ".5:8080",
|
||||
},
|
||||
subnetB1 := &nodeselection.SelectedNode{
|
||||
ID: thirdID,
|
||||
LastNet: lastNetSingle,
|
||||
LastIPPort: lastNetSingle + ".5:8080",
|
||||
}
|
||||
|
||||
nodes := []*uploadselection.Node{subnetA1, subnetA2, subnetB1}
|
||||
selector := uploadselection.SelectByID(nodes)
|
||||
nodes := []*nodeselection.SelectedNode{subnetA1, subnetA2, subnetB1}
|
||||
selector := nodeselection.SelectByID(nodes)
|
||||
|
||||
assert.Len(t, selector.Select(3, uploadselection.Criteria{}), 3)
|
||||
assert.Len(t, selector.Select(3, uploadselection.Criteria{ExcludeNodeIDs: []storj.NodeID{firstID}}), 2)
|
||||
assert.Len(t, selector.Select(3, uploadselection.Criteria{}), 3)
|
||||
assert.Len(t, selector.Select(3, nodeselection.NodeFilters{}), 3)
|
||||
assert.Len(t, selector.Select(3, nodeselection.NodeFilters{}.WithAutoExcludeSubnets()), 2)
|
||||
assert.Len(t, selector.Select(3, nodeselection.NodeFilters{}), 3)
|
||||
|
||||
assert.Len(t, selector.Select(3, uploadselection.Criteria{ExcludeNodeIDs: []storj.NodeID{firstID, secondID}}), 1)
|
||||
assert.Len(t, selector.Select(3, uploadselection.Criteria{
|
||||
AutoExcludeSubnets: map[string]struct{}{},
|
||||
}), 2)
|
||||
assert.Len(t, selector.Select(3, uploadselection.Criteria{
|
||||
ExcludeNodeIDs: []storj.NodeID{thirdID},
|
||||
AutoExcludeSubnets: map[string]struct{}{},
|
||||
}), 1)
|
||||
assert.Len(t, selector.Select(3, nodeselection.NodeFilters{}.WithExcludedIDs([]storj.NodeID{firstID, secondID})), 1)
|
||||
assert.Len(t, selector.Select(3, nodeselection.NodeFilters{}.WithAutoExcludeSubnets()), 2)
|
||||
assert.Len(t, selector.Select(3, nodeselection.NodeFilters{}.WithExcludedIDs([]storj.NodeID{thirdID}).WithAutoExcludeSubnets()), 1)
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package uploadselection
|
||||
package nodeselection
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -10,7 +10,6 @@ import (
|
||||
"github.com/zeebo/errs"
|
||||
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/storj/location"
|
||||
)
|
||||
|
||||
// ErrNotEnoughNodes is when selecting nodes failed with the given parameters.
|
||||
@ -42,11 +41,11 @@ type Selector interface {
|
||||
Count() int
|
||||
// Select selects up-to n nodes which are included by the criteria.
|
||||
// empty criteria includes all the nodes
|
||||
Select(n int, criteria Criteria) []*Node
|
||||
Select(n int, nodeFilter NodeFilter) []*SelectedNode
|
||||
}
|
||||
|
||||
// NewState returns a state based on the input.
|
||||
func NewState(reputableNodes, newNodes []*Node) *State {
|
||||
func NewState(reputableNodes, newNodes []*SelectedNode) *State {
|
||||
state := &State{}
|
||||
|
||||
state.netByID = map[storj.NodeID]string{}
|
||||
@ -70,15 +69,13 @@ func NewState(reputableNodes, newNodes []*Node) *State {
|
||||
|
||||
// Request contains arguments for State.Request.
|
||||
type Request struct {
|
||||
Count int
|
||||
NewFraction float64
|
||||
ExcludedIDs []storj.NodeID
|
||||
Placement storj.PlacementConstraint
|
||||
ExcludedCountryCodes []string
|
||||
Count int
|
||||
NewFraction float64
|
||||
NodeFilters NodeFilters
|
||||
}
|
||||
|
||||
// Select selects requestedCount nodes where there will be newFraction nodes.
|
||||
func (state *State) Select(ctx context.Context, request Request) (_ []*Node, err error) {
|
||||
func (state *State) Select(ctx context.Context, request Request) (_ []*SelectedNode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
state.mu.RLock()
|
||||
@ -87,41 +84,23 @@ func (state *State) Select(ctx context.Context, request Request) (_ []*Node, err
|
||||
totalCount := request.Count
|
||||
newCount := int(float64(totalCount) * request.NewFraction)
|
||||
|
||||
var selected []*Node
|
||||
var selected []*SelectedNode
|
||||
|
||||
var reputableNodes Selector
|
||||
var newNodes Selector
|
||||
|
||||
var criteria Criteria
|
||||
|
||||
if request.ExcludedIDs != nil {
|
||||
criteria.ExcludeNodeIDs = request.ExcludedIDs
|
||||
}
|
||||
|
||||
for _, code := range request.ExcludedCountryCodes {
|
||||
criteria.ExcludedCountryCodes = append(criteria.ExcludedCountryCodes, location.ToCountryCode(code))
|
||||
}
|
||||
|
||||
criteria.Placement = request.Placement
|
||||
|
||||
criteria.AutoExcludeSubnets = make(map[string]struct{})
|
||||
for _, id := range request.ExcludedIDs {
|
||||
if net, ok := state.netByID[id]; ok {
|
||||
criteria.AutoExcludeSubnets[net] = struct{}{}
|
||||
}
|
||||
}
|
||||
reputableNodes = state.distinct.Reputable
|
||||
newNodes = state.distinct.New
|
||||
|
||||
// Get a random selection of new nodes out of the cache first so that if there aren't
|
||||
// enough new nodes on the network, we can fall back to using reputable nodes instead.
|
||||
selected = append(selected,
|
||||
newNodes.Select(newCount, criteria)...)
|
||||
newNodes.Select(newCount, request.NodeFilters)...)
|
||||
|
||||
// Get all the remaining reputable nodes.
|
||||
reputableCount := totalCount - len(selected)
|
||||
selected = append(selected,
|
||||
reputableNodes.Select(reputableCount, criteria)...)
|
||||
reputableNodes.Select(reputableCount, request.NodeFilters)...)
|
||||
|
||||
if len(selected) < totalCount {
|
||||
return selected, ErrNotEnoughNodes.New("requested from cache %d, found %d", totalCount, len(selected))
|
||||
@ -136,3 +115,19 @@ func (state *State) Stats() Stats {
|
||||
|
||||
return state.stats
|
||||
}
|
||||
|
||||
// ExcludeNetworksBasedOnNodes will create a NodeFilter which exclude all nodes which shares subnet with the specified ones.
|
||||
func (state *State) ExcludeNetworksBasedOnNodes(ds []storj.NodeID) NodeFilter {
|
||||
uniqueExcludedNet := make(map[string]struct{}, len(ds))
|
||||
for _, id := range ds {
|
||||
net := state.netByID[id]
|
||||
uniqueExcludedNet[net] = struct{}{}
|
||||
}
|
||||
excludedNet := make([]string, len(uniqueExcludedNet))
|
||||
i := 0
|
||||
for net := range uniqueExcludedNet {
|
||||
excludedNet[i] = net
|
||||
i++
|
||||
}
|
||||
return ExcludedNetworks(excludedNet)
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package uploadselection_test
|
||||
package nodeselection_test
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
@ -10,10 +10,9 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
"storj.io/storj/satellite/nodeselection/uploadselection"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
)
|
||||
|
||||
func TestState_SelectNonDistinct(t *testing.T) {
|
||||
@ -29,18 +28,17 @@ func TestState_SelectNonDistinct(t *testing.T) {
|
||||
createRandomNodes(3, "1.0.4", false),
|
||||
)
|
||||
|
||||
state := uploadselection.NewState(reputableNodes, newNodes)
|
||||
require.Equal(t, uploadselection.Stats{
|
||||
state := nodeselection.NewState(reputableNodes, newNodes)
|
||||
require.Equal(t, nodeselection.Stats{
|
||||
New: 5,
|
||||
Reputable: 5,
|
||||
}, state.Stats())
|
||||
|
||||
{ // select 5 non-distinct subnet reputable nodes
|
||||
const selectCount = 5
|
||||
selected, err := state.Select(ctx, uploadselection.Request{
|
||||
selected, err := state.Select(ctx, nodeselection.Request{
|
||||
Count: selectCount,
|
||||
NewFraction: 0,
|
||||
ExcludedIDs: nil,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, selected, selectCount)
|
||||
@ -49,10 +47,9 @@ func TestState_SelectNonDistinct(t *testing.T) {
|
||||
{ // select 6 non-distinct subnet reputable and new nodes (50%)
|
||||
const selectCount = 6
|
||||
const newFraction = 0.5
|
||||
selected, err := state.Select(ctx, uploadselection.Request{
|
||||
selected, err := state.Select(ctx, nodeselection.Request{
|
||||
Count: selectCount,
|
||||
NewFraction: newFraction,
|
||||
ExcludedIDs: nil,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, selected, selectCount)
|
||||
@ -63,10 +60,9 @@ func TestState_SelectNonDistinct(t *testing.T) {
|
||||
{ // select 10 distinct subnet reputable and new nodes (100%), falling back to 5 reputable
|
||||
const selectCount = 10
|
||||
const newFraction = 1.0
|
||||
selected, err := state.Select(ctx, uploadselection.Request{
|
||||
selected, err := state.Select(ctx, nodeselection.Request{
|
||||
Count: selectCount,
|
||||
NewFraction: newFraction,
|
||||
ExcludedIDs: nil,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, selected, selectCount)
|
||||
@ -88,18 +84,17 @@ func TestState_SelectDistinct(t *testing.T) {
|
||||
createRandomNodes(3, "1.0.4", true),
|
||||
)
|
||||
|
||||
state := uploadselection.NewState(reputableNodes, newNodes)
|
||||
require.Equal(t, uploadselection.Stats{
|
||||
state := nodeselection.NewState(reputableNodes, newNodes)
|
||||
require.Equal(t, nodeselection.Stats{
|
||||
New: 2,
|
||||
Reputable: 2,
|
||||
}, state.Stats())
|
||||
|
||||
{ // select 2 distinct subnet reputable nodes
|
||||
const selectCount = 2
|
||||
selected, err := state.Select(ctx, uploadselection.Request{
|
||||
selected, err := state.Select(ctx, nodeselection.Request{
|
||||
Count: selectCount,
|
||||
NewFraction: 0,
|
||||
ExcludedIDs: nil,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, selected, selectCount)
|
||||
@ -107,10 +102,9 @@ func TestState_SelectDistinct(t *testing.T) {
|
||||
|
||||
{ // try to select 5 distinct subnet reputable nodes, but there are only two 2 in the state
|
||||
const selectCount = 5
|
||||
selected, err := state.Select(ctx, uploadselection.Request{
|
||||
selected, err := state.Select(ctx, nodeselection.Request{
|
||||
Count: selectCount,
|
||||
NewFraction: 0,
|
||||
ExcludedIDs: nil,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Len(t, selected, 2)
|
||||
@ -119,10 +113,9 @@ func TestState_SelectDistinct(t *testing.T) {
|
||||
{ // select 4 distinct subnet reputable and new nodes (50%)
|
||||
const selectCount = 4
|
||||
const newFraction = 0.5
|
||||
selected, err := state.Select(ctx, uploadselection.Request{
|
||||
selected, err := state.Select(ctx, nodeselection.Request{
|
||||
Count: selectCount,
|
||||
NewFraction: newFraction,
|
||||
ExcludedIDs: nil,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, selected, selectCount)
|
||||
@ -144,15 +137,14 @@ func TestState_Select_Concurrent(t *testing.T) {
|
||||
createRandomNodes(3, "1.0.4", false),
|
||||
)
|
||||
|
||||
state := uploadselection.NewState(reputableNodes, newNodes)
|
||||
state := nodeselection.NewState(reputableNodes, newNodes)
|
||||
|
||||
var group errgroup.Group
|
||||
group.Go(func() error {
|
||||
const selectCount = 5
|
||||
nodes, err := state.Select(ctx, uploadselection.Request{
|
||||
nodes, err := state.Select(ctx, nodeselection.Request{
|
||||
Count: selectCount,
|
||||
NewFraction: 0.5,
|
||||
ExcludedIDs: nil,
|
||||
})
|
||||
require.Len(t, nodes, selectCount)
|
||||
return err
|
||||
@ -160,10 +152,9 @@ func TestState_Select_Concurrent(t *testing.T) {
|
||||
|
||||
group.Go(func() error {
|
||||
const selectCount = 4
|
||||
nodes, err := state.Select(ctx, uploadselection.Request{
|
||||
nodes, err := state.Select(ctx, nodeselection.Request{
|
||||
Count: selectCount,
|
||||
NewFraction: 0.5,
|
||||
ExcludedIDs: nil,
|
||||
})
|
||||
require.Len(t, nodes, selectCount)
|
||||
return err
|
||||
@ -172,15 +163,13 @@ func TestState_Select_Concurrent(t *testing.T) {
|
||||
}
|
||||
|
||||
// createRandomNodes creates n random nodes all in the subnet.
|
||||
func createRandomNodes(n int, subnet string, shareNets bool) []*uploadselection.Node {
|
||||
xs := make([]*uploadselection.Node, n)
|
||||
func createRandomNodes(n int, subnet string, shareNets bool) []*nodeselection.SelectedNode {
|
||||
xs := make([]*nodeselection.SelectedNode, n)
|
||||
for i := range xs {
|
||||
addr := subnet + "." + strconv.Itoa(i) + ":8080"
|
||||
xs[i] = &uploadselection.Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: addr,
|
||||
},
|
||||
xs[i] = &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
LastNet: addr,
|
||||
LastIPPort: addr,
|
||||
}
|
||||
if shareNets {
|
||||
@ -193,8 +182,8 @@ func createRandomNodes(n int, subnet string, shareNets bool) []*uploadselection.
|
||||
}
|
||||
|
||||
// joinNodes appends all slices into a single slice.
|
||||
func joinNodes(lists ...[]*uploadselection.Node) []*uploadselection.Node {
|
||||
xs := []*uploadselection.Node{}
|
||||
func joinNodes(lists ...[]*nodeselection.SelectedNode) []*nodeselection.SelectedNode {
|
||||
xs := []*nodeselection.SelectedNode{}
|
||||
for _, list := range lists {
|
||||
xs = append(xs, list...)
|
||||
}
|
||||
@ -202,8 +191,8 @@ func joinNodes(lists ...[]*uploadselection.Node) []*uploadselection.Node {
|
||||
}
|
||||
|
||||
// intersectLists returns nodes that exist in both lists compared by ID.
|
||||
func intersectLists(as, bs []*uploadselection.Node) []*uploadselection.Node {
|
||||
var xs []*uploadselection.Node
|
||||
func intersectLists(as, bs []*nodeselection.SelectedNode) []*nodeselection.SelectedNode {
|
||||
var xs []*nodeselection.SelectedNode
|
||||
|
||||
next:
|
||||
for _, a := range as {
|
@ -1,56 +0,0 @@
|
||||
// Copyright (C) 2021 Storj Labs, Inc.
|
||||
// See LICENSE for copying information
|
||||
|
||||
package uploadselection
|
||||
|
||||
import (
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/storj/location"
|
||||
)
|
||||
|
||||
// Criteria to filter nodes.
|
||||
type Criteria struct {
|
||||
ExcludeNodeIDs []storj.NodeID
|
||||
AutoExcludeSubnets map[string]struct{} // initialize it with empty map to keep only one node per subnet.
|
||||
Placement storj.PlacementConstraint
|
||||
ExcludedCountryCodes []location.CountryCode
|
||||
}
|
||||
|
||||
// MatchInclude returns with true if node is selected.
|
||||
func (c *Criteria) MatchInclude(node *Node) bool {
|
||||
if ContainsID(c.ExcludeNodeIDs, node.ID) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !c.Placement.AllowedCountry(node.CountryCode) {
|
||||
return false
|
||||
}
|
||||
|
||||
if c.AutoExcludeSubnets != nil {
|
||||
if _, excluded := c.AutoExcludeSubnets[node.LastNet]; excluded {
|
||||
return false
|
||||
}
|
||||
c.AutoExcludeSubnets[node.LastNet] = struct{}{}
|
||||
}
|
||||
|
||||
for _, code := range c.ExcludedCountryCodes {
|
||||
if code.String() == "" {
|
||||
continue
|
||||
}
|
||||
if node.CountryCode == code {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ContainsID returns whether ids contain id.
|
||||
func ContainsID(ids []storj.NodeID, id storj.NodeID) bool {
|
||||
for _, k := range ids {
|
||||
if k == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
@ -1,140 +0,0 @@
|
||||
// Copyright (C) 2021 Storj Labs, Inc.
|
||||
// See LICENSE for copying information
|
||||
|
||||
package uploadselection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/storj/location"
|
||||
"storj.io/common/testrand"
|
||||
)
|
||||
|
||||
func TestCriteria_AutoExcludeSubnet(t *testing.T) {
|
||||
|
||||
criteria := Criteria{
|
||||
AutoExcludeSubnets: map[string]struct{}{},
|
||||
}
|
||||
|
||||
assert.True(t, criteria.MatchInclude(&Node{
|
||||
LastNet: "192.168.0.1",
|
||||
}))
|
||||
|
||||
assert.False(t, criteria.MatchInclude(&Node{
|
||||
LastNet: "192.168.0.1",
|
||||
}))
|
||||
|
||||
assert.True(t, criteria.MatchInclude(&Node{
|
||||
LastNet: "192.168.1.1",
|
||||
}))
|
||||
}
|
||||
|
||||
func TestCriteria_ExcludeNodeID(t *testing.T) {
|
||||
included := testrand.NodeID()
|
||||
excluded := testrand.NodeID()
|
||||
|
||||
criteria := Criteria{
|
||||
ExcludeNodeIDs: []storj.NodeID{excluded},
|
||||
}
|
||||
|
||||
assert.False(t, criteria.MatchInclude(&Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: excluded,
|
||||
Address: "localhost",
|
||||
},
|
||||
}))
|
||||
|
||||
assert.True(t, criteria.MatchInclude(&Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: included,
|
||||
Address: "localhost",
|
||||
},
|
||||
}))
|
||||
|
||||
}
|
||||
|
||||
func TestCriteria_NodeIDAndSubnet(t *testing.T) {
|
||||
excluded := testrand.NodeID()
|
||||
|
||||
criteria := Criteria{
|
||||
ExcludeNodeIDs: []storj.NodeID{excluded},
|
||||
AutoExcludeSubnets: map[string]struct{}{},
|
||||
}
|
||||
|
||||
// due to node id criteria
|
||||
assert.False(t, criteria.MatchInclude(&Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: excluded,
|
||||
Address: "192.168.0.1",
|
||||
},
|
||||
}))
|
||||
|
||||
// should be included as previous one excluded and
|
||||
// not stored for subnet exclusion
|
||||
assert.True(t, criteria.MatchInclude(&Node{
|
||||
NodeURL: storj.NodeURL{
|
||||
ID: testrand.NodeID(),
|
||||
Address: "192.168.0.2",
|
||||
},
|
||||
}))
|
||||
|
||||
}
|
||||
|
||||
func TestCriteria_Geofencing(t *testing.T) {
|
||||
eu := Criteria{
|
||||
Placement: storj.EU,
|
||||
}
|
||||
|
||||
us := Criteria{
|
||||
Placement: storj.US,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
country location.CountryCode
|
||||
criteria Criteria
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "US matches US selector",
|
||||
country: location.UnitedStates,
|
||||
criteria: us,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Germany is EU",
|
||||
country: location.Germany,
|
||||
criteria: eu,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "US is not eu",
|
||||
country: location.UnitedStates,
|
||||
criteria: eu,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty country doesn't match region",
|
||||
country: location.CountryCode(0),
|
||||
criteria: eu,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty country doesn't match country",
|
||||
country: location.CountryCode(0),
|
||||
criteria: us,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
assert.Equal(t, c.expected, c.criteria.MatchInclude(&Node{
|
||||
CountryCode: c.country,
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
// Copyright (C) 2020 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package uploadselection
|
||||
|
||||
import (
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/storj/location"
|
||||
)
|
||||
|
||||
// Node defines necessary information for node-selection.
|
||||
type Node struct {
|
||||
storj.NodeURL
|
||||
LastNet string
|
||||
LastIPPort string
|
||||
CountryCode location.CountryCode
|
||||
}
|
||||
|
||||
// Clone returns a deep clone of the selected node.
|
||||
func (node *Node) Clone() *Node {
|
||||
return &Node{
|
||||
NodeURL: node.NodeURL,
|
||||
LastNet: node.LastNet,
|
||||
LastIPPort: node.LastIPPort,
|
||||
CountryCode: node.CountryCode,
|
||||
}
|
||||
}
|
@ -11,6 +11,7 @@ import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
||||
storj "storj.io/common/storj"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
overlay "storj.io/storj/satellite/overlay"
|
||||
)
|
||||
|
||||
@ -38,10 +39,10 @@ func (m *MockOverlayForOrders) EXPECT() *MockOverlayForOrdersMockRecorder {
|
||||
}
|
||||
|
||||
// CachedGetOnlineNodesForGet mocks base method.
|
||||
func (m *MockOverlayForOrders) CachedGetOnlineNodesForGet(arg0 context.Context, arg1 []storj.NodeID) (map[storj.NodeID]*overlay.SelectedNode, error) {
|
||||
func (m *MockOverlayForOrders) CachedGetOnlineNodesForGet(arg0 context.Context, arg1 []storj.NodeID) (map[storj.NodeID]*nodeselection.SelectedNode, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CachedGetOnlineNodesForGet", arg0, arg1)
|
||||
ret0, _ := ret[0].(map[storj.NodeID]*overlay.SelectedNode)
|
||||
ret0, _ := ret[0].(map[storj.NodeID]*nodeselection.SelectedNode)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"storj.io/common/storj"
|
||||
"storj.io/storj/satellite/internalpb"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
)
|
||||
|
||||
@ -43,7 +44,7 @@ type Config struct {
|
||||
//
|
||||
//go:generate mockgen -destination mock_test.go -package orders . OverlayForOrders
|
||||
type Overlay interface {
|
||||
CachedGetOnlineNodesForGet(context.Context, []storj.NodeID) (map[storj.NodeID]*overlay.SelectedNode, error)
|
||||
CachedGetOnlineNodesForGet(context.Context, []storj.NodeID) (map[storj.NodeID]*nodeselection.SelectedNode, error)
|
||||
GetOnlineNodesForAuditRepair(context.Context, []storj.NodeID) (map[storj.NodeID]*overlay.NodeReputation, error)
|
||||
Get(ctx context.Context, nodeID storj.NodeID) (*overlay.NodeDossier, error)
|
||||
IsOnline(node *overlay.NodeDossier) bool
|
||||
@ -53,10 +54,11 @@ type Overlay interface {
|
||||
//
|
||||
// architecture: Service
|
||||
type Service struct {
|
||||
log *zap.Logger
|
||||
satellite signing.Signer
|
||||
overlay Overlay
|
||||
orders DB
|
||||
log *zap.Logger
|
||||
satellite signing.Signer
|
||||
overlay Overlay
|
||||
orders DB
|
||||
placementRules overlay.PlacementRules
|
||||
|
||||
encryptionKeys EncryptionKeys
|
||||
|
||||
@ -69,17 +71,18 @@ type Service struct {
|
||||
// NewService creates new service for creating order limits.
|
||||
func NewService(
|
||||
log *zap.Logger, satellite signing.Signer, overlay Overlay,
|
||||
orders DB, config Config,
|
||||
orders DB, placementRules overlay.PlacementRules, config Config,
|
||||
) (*Service, error) {
|
||||
if config.EncryptionKeys.Default.IsZero() {
|
||||
return nil, Error.New("encryption keys must be specified to include encrypted metadata")
|
||||
}
|
||||
|
||||
return &Service{
|
||||
log: log,
|
||||
satellite: satellite,
|
||||
overlay: overlay,
|
||||
orders: orders,
|
||||
log: log,
|
||||
satellite: satellite,
|
||||
overlay: overlay,
|
||||
orders: orders,
|
||||
placementRules: placementRules,
|
||||
|
||||
encryptionKeys: config.EncryptionKeys,
|
||||
|
||||
@ -145,8 +148,9 @@ func (service *Service) CreateGetOrderLimits(ctx context.Context, bucket metabas
|
||||
}
|
||||
|
||||
if segment.Placement != storj.EveryCountry {
|
||||
filter := service.placementRules(segment.Placement)
|
||||
for id, node := range nodes {
|
||||
if !segment.Placement.AllowedCountry(node.CountryCode) {
|
||||
if !filter.MatchInclude(node) {
|
||||
delete(nodes, id)
|
||||
}
|
||||
}
|
||||
@ -235,7 +239,7 @@ func getLimitByStorageNodeID(limits []*pb.AddressedOrderLimit, storageNodeID sto
|
||||
}
|
||||
|
||||
// CreatePutOrderLimits creates the order limits for uploading pieces to nodes.
|
||||
func (service *Service) CreatePutOrderLimits(ctx context.Context, bucket metabase.BucketLocation, nodes []*overlay.SelectedNode, pieceExpiration time.Time, maxPieceSize int64) (_ storj.PieceID, _ []*pb.AddressedOrderLimit, privateKey storj.PiecePrivateKey, err error) {
|
||||
func (service *Service) CreatePutOrderLimits(ctx context.Context, bucket metabase.BucketLocation, nodes []*nodeselection.SelectedNode, pieceExpiration time.Time, maxPieceSize int64) (_ storj.PieceID, _ []*pb.AddressedOrderLimit, privateKey storj.PiecePrivateKey, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
signer, err := NewSignerPut(service, pieceExpiration, time.Now(), maxPieceSize, bucket)
|
||||
@ -254,7 +258,7 @@ func (service *Service) CreatePutOrderLimits(ctx context.Context, bucket metabas
|
||||
}
|
||||
|
||||
// ReplacePutOrderLimits replaces order limits for uploading pieces to nodes.
|
||||
func (service *Service) ReplacePutOrderLimits(ctx context.Context, rootPieceID storj.PieceID, addressedLimits []*pb.AddressedOrderLimit, nodes []*overlay.SelectedNode, pieceNumbers []int32) (_ []*pb.AddressedOrderLimit, err error) {
|
||||
func (service *Service) ReplacePutOrderLimits(ctx context.Context, rootPieceID storj.PieceID, addressedLimits []*pb.AddressedOrderLimit, nodes []*nodeselection.SelectedNode, pieceNumbers []int32) (_ []*pb.AddressedOrderLimit, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
pieceIDDeriver := rootPieceID.Deriver()
|
||||
@ -457,7 +461,7 @@ func (service *Service) CreateGetRepairOrderLimits(ctx context.Context, segment
|
||||
}
|
||||
|
||||
// CreatePutRepairOrderLimits creates the order limits for uploading the repaired pieces of segment to newNodes.
|
||||
func (service *Service) CreatePutRepairOrderLimits(ctx context.Context, segment metabase.Segment, getOrderLimits []*pb.AddressedOrderLimit, healthySet map[int32]struct{}, newNodes []*overlay.SelectedNode, optimalThresholdMultiplier float64, numPiecesInExcludedCountries int) (_ []*pb.AddressedOrderLimit, _ storj.PiecePrivateKey, err error) {
|
||||
func (service *Service) CreatePutRepairOrderLimits(ctx context.Context, segment metabase.Segment, getOrderLimits []*pb.AddressedOrderLimit, healthySet map[int32]struct{}, newNodes []*nodeselection.SelectedNode, optimalThresholdMultiplier float64, numPiecesInExcludedCountries int) (_ []*pb.AddressedOrderLimit, _ storj.PiecePrivateKey, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
// Create the order limits for being used to upload the repaired pieces
|
||||
@ -590,7 +594,7 @@ func (service *Service) DecryptOrderMetadata(ctx context.Context, order *pb.Orde
|
||||
return key.DecryptMetadata(order.SerialNumber, order.EncryptedMetadata)
|
||||
}
|
||||
|
||||
func resolveStorageNode_Selected(node *overlay.SelectedNode, resolveDNS bool) *pb.Node {
|
||||
func resolveStorageNode_Selected(node *nodeselection.SelectedNode, resolveDNS bool) *pb.Node {
|
||||
return resolveStorageNode(&pb.Node{
|
||||
Id: node.ID,
|
||||
Address: node.Address,
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/orders"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
)
|
||||
@ -30,10 +31,10 @@ func TestGetOrderLimits(t *testing.T) {
|
||||
bucket := metabase.BucketLocation{ProjectID: testrand.UUID(), BucketName: "bucket1"}
|
||||
|
||||
pieces := metabase.Pieces{}
|
||||
nodes := map[storj.NodeID]*overlay.SelectedNode{}
|
||||
nodes := map[storj.NodeID]*nodeselection.SelectedNode{}
|
||||
for i := 0; i < 8; i++ {
|
||||
nodeID := testrand.NodeID()
|
||||
nodes[nodeID] = &overlay.SelectedNode{
|
||||
nodes[nodeID] = &nodeselection.SelectedNode{
|
||||
ID: nodeID,
|
||||
Address: &pb.NodeAddress{
|
||||
Address: fmt.Sprintf("host%d.com", i),
|
||||
@ -55,14 +56,16 @@ func TestGetOrderLimits(t *testing.T) {
|
||||
CachedGetOnlineNodesForGet(gomock.Any(), gomock.Any()).
|
||||
Return(nodes, nil).AnyTimes()
|
||||
|
||||
service, err := orders.NewService(zaptest.NewLogger(t), k, overlayService, orders.NewNoopDB(), orders.Config{
|
||||
EncryptionKeys: orders.EncryptionKeys{
|
||||
Default: orders.EncryptionKey{
|
||||
ID: orders.EncryptionKeyID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
Key: testrand.Key(),
|
||||
service, err := orders.NewService(zaptest.NewLogger(t), k, overlayService, orders.NewNoopDB(),
|
||||
overlay.NewPlacementRules().CreateFilters,
|
||||
orders.Config{
|
||||
EncryptionKeys: orders.EncryptionKeys{
|
||||
Default: orders.EncryptionKey{
|
||||
ID: orders.EncryptionKeyID{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
Key: testrand.Key(),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
segment := metabase.Segment{
|
||||
|
@ -355,7 +355,7 @@ func BenchmarkNodeSelection(b *testing.B) {
|
||||
}
|
||||
})
|
||||
|
||||
service, err := overlay.NewService(zap.NewNop(), overlaydb, db.NodeEvents(), "", "", overlay.Config{
|
||||
service, err := overlay.NewService(zap.NewNop(), overlaydb, db.NodeEvents(), overlay.NewPlacementRules().CreateFilters, "", "", overlay.Config{
|
||||
Node: nodeSelectionConfig,
|
||||
NodeSelectionCache: overlay.UploadSelectionCacheConfig{
|
||||
Staleness: time.Hour,
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/sync2"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
)
|
||||
|
||||
// DownloadSelectionDB implements the database for download selection cache.
|
||||
@ -18,7 +19,7 @@ import (
|
||||
// architecture: Database
|
||||
type DownloadSelectionDB interface {
|
||||
// SelectAllStorageNodesDownload returns nodes that are ready for downloading
|
||||
SelectAllStorageNodesDownload(ctx context.Context, onlineWindow time.Duration, asOf AsOfSystemTimeConfig) ([]*SelectedNode, error)
|
||||
SelectAllStorageNodesDownload(ctx context.Context, onlineWindow time.Duration, asOf AsOfSystemTimeConfig) ([]*nodeselection.SelectedNode, error)
|
||||
}
|
||||
|
||||
// DownloadSelectionCacheConfig contains configuration for the selection cache.
|
||||
@ -35,15 +36,17 @@ type DownloadSelectionCache struct {
|
||||
db DownloadSelectionDB
|
||||
config DownloadSelectionCacheConfig
|
||||
|
||||
cache sync2.ReadCacheOf[*DownloadSelectionCacheState]
|
||||
cache sync2.ReadCacheOf[*DownloadSelectionCacheState]
|
||||
placementRules PlacementRules
|
||||
}
|
||||
|
||||
// NewDownloadSelectionCache creates a new cache that keeps a list of all the storage nodes that are qualified to download data from.
|
||||
func NewDownloadSelectionCache(log *zap.Logger, db DownloadSelectionDB, config DownloadSelectionCacheConfig) (*DownloadSelectionCache, error) {
|
||||
func NewDownloadSelectionCache(log *zap.Logger, db DownloadSelectionDB, placementRules PlacementRules, config DownloadSelectionCacheConfig) (*DownloadSelectionCache, error) {
|
||||
cache := &DownloadSelectionCache{
|
||||
log: log,
|
||||
db: db,
|
||||
config: config,
|
||||
log: log,
|
||||
db: db,
|
||||
placementRules: placementRules,
|
||||
config: config,
|
||||
}
|
||||
return cache, cache.cache.Init(config.Staleness/2, config.Staleness, cache.read)
|
||||
}
|
||||
@ -84,11 +87,11 @@ func (cache *DownloadSelectionCache) GetNodeIPsFromPlacement(ctx context.Context
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
return state.IPsFromPlacement(nodes, placement), nil
|
||||
return state.FilteredIPs(nodes, cache.placementRules(placement)), nil
|
||||
}
|
||||
|
||||
// GetNodes gets nodes by ID from the cache, and refreshes the cache if it is stale.
|
||||
func (cache *DownloadSelectionCache) GetNodes(ctx context.Context, nodes []storj.NodeID) (_ map[storj.NodeID]*SelectedNode, err error) {
|
||||
func (cache *DownloadSelectionCache) GetNodes(ctx context.Context, nodes []storj.NodeID) (_ map[storj.NodeID]*nodeselection.SelectedNode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
state, err := cache.cache.Get(ctx, time.Now())
|
||||
@ -110,12 +113,12 @@ func (cache *DownloadSelectionCache) Size(ctx context.Context) (int, error) {
|
||||
// DownloadSelectionCacheState contains state of download selection cache.
|
||||
type DownloadSelectionCacheState struct {
|
||||
// byID returns IP based on storj.NodeID
|
||||
byID map[storj.NodeID]*SelectedNode // TODO: optimize, avoid pointery structures for performance
|
||||
byID map[storj.NodeID]*nodeselection.SelectedNode // TODO: optimize, avoid pointery structures for performance
|
||||
}
|
||||
|
||||
// NewDownloadSelectionCacheState creates a new state from the nodes.
|
||||
func NewDownloadSelectionCacheState(nodes []*SelectedNode) *DownloadSelectionCacheState {
|
||||
byID := map[storj.NodeID]*SelectedNode{}
|
||||
func NewDownloadSelectionCacheState(nodes []*nodeselection.SelectedNode) *DownloadSelectionCacheState {
|
||||
byID := map[storj.NodeID]*nodeselection.SelectedNode{}
|
||||
for _, n := range nodes {
|
||||
byID[n.ID] = n
|
||||
}
|
||||
@ -140,11 +143,11 @@ func (state *DownloadSelectionCacheState) IPs(nodes []storj.NodeID) map[storj.No
|
||||
return xs
|
||||
}
|
||||
|
||||
// IPsFromPlacement returns node ip:port for nodes that are in state. Results are filtered out by placement.
|
||||
func (state *DownloadSelectionCacheState) IPsFromPlacement(nodes []storj.NodeID, placement storj.PlacementConstraint) map[storj.NodeID]string {
|
||||
// FilteredIPs returns node ip:port for nodes that are in state. Results are filtered out..
|
||||
func (state *DownloadSelectionCacheState) FilteredIPs(nodes []storj.NodeID, filter nodeselection.NodeFilters) map[storj.NodeID]string {
|
||||
xs := make(map[storj.NodeID]string, len(nodes))
|
||||
for _, nodeID := range nodes {
|
||||
if n, exists := state.byID[nodeID]; exists && placement.AllowedCountry(n.CountryCode) {
|
||||
if n, exists := state.byID[nodeID]; exists && filter.MatchInclude(n) {
|
||||
xs[nodeID] = n.LastIPPort
|
||||
}
|
||||
}
|
||||
@ -152,8 +155,8 @@ func (state *DownloadSelectionCacheState) IPsFromPlacement(nodes []storj.NodeID,
|
||||
}
|
||||
|
||||
// Nodes returns node ip:port for nodes that are in state.
|
||||
func (state *DownloadSelectionCacheState) Nodes(nodes []storj.NodeID) map[storj.NodeID]*SelectedNode {
|
||||
xs := make(map[storj.NodeID]*SelectedNode, len(nodes))
|
||||
func (state *DownloadSelectionCacheState) Nodes(nodes []storj.NodeID) map[storj.NodeID]*nodeselection.SelectedNode {
|
||||
xs := make(map[storj.NodeID]*nodeselection.SelectedNode, len(nodes))
|
||||
for _, nodeID := range nodes {
|
||||
if n, exists := state.byID[nodeID]; exists {
|
||||
xs[nodeID] = n.Clone() // TODO: optimize the clones
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
"storj.io/storj/satellite"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
"storj.io/storj/satellite/satellitedb/satellitedbtest"
|
||||
)
|
||||
@ -30,6 +31,7 @@ func TestDownloadSelectionCacheState_Refresh(t *testing.T) {
|
||||
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
|
||||
cache, err := overlay.NewDownloadSelectionCache(zap.NewNop(),
|
||||
db.OverlayCache(),
|
||||
overlay.NewPlacementRules().CreateFilters,
|
||||
downloadSelectionCacheConfig,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@ -62,6 +64,7 @@ func TestDownloadSelectionCacheState_GetNodeIPs(t *testing.T) {
|
||||
satellitedbtest.Run(t, func(ctx *testcontext.Context, t *testing.T, db satellite.DB) {
|
||||
cache, err := overlay.NewDownloadSelectionCache(zap.NewNop(),
|
||||
db.OverlayCache(),
|
||||
overlay.NewPlacementRules().CreateFilters,
|
||||
downloadSelectionCacheConfig,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@ -87,7 +90,7 @@ func TestDownloadSelectionCacheState_IPs(t *testing.T) {
|
||||
ctx := testcontext.New(t)
|
||||
defer ctx.Cleanup()
|
||||
|
||||
node := &overlay.SelectedNode{
|
||||
node := &nodeselection.SelectedNode{
|
||||
ID: testrand.NodeID(),
|
||||
Address: &pb.NodeAddress{
|
||||
Address: "1.0.1.1:8080",
|
||||
@ -96,7 +99,7 @@ func TestDownloadSelectionCacheState_IPs(t *testing.T) {
|
||||
LastIPPort: "1.0.1.1:8080",
|
||||
}
|
||||
|
||||
state := overlay.NewDownloadSelectionCacheState([]*overlay.SelectedNode{node})
|
||||
state := overlay.NewDownloadSelectionCacheState([]*nodeselection.SelectedNode{node})
|
||||
require.Equal(t, state.Size(), 1)
|
||||
|
||||
ips := state.IPs([]storj.NodeID{testrand.NodeID(), node.ID})
|
||||
@ -113,6 +116,7 @@ func TestDownloadSelectionCache_GetNodes(t *testing.T) {
|
||||
// create new cache and select nodes
|
||||
cache, err := overlay.NewDownloadSelectionCache(zap.NewNop(),
|
||||
db.OverlayCache(),
|
||||
overlay.NewPlacementRules().CreateFilters,
|
||||
overlay.DownloadSelectionCacheConfig{
|
||||
Staleness: time.Hour,
|
||||
OnlineWindow: time.Hour,
|
||||
|
151
satellite/overlay/placement.go
Normal file
151
satellite/overlay/placement.go
Normal file
@ -0,0 +1,151 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jtolio/mito"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/zeebo/errs"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/storj/location"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
)
|
||||
|
||||
// PlacementRules can crate filter based on the placement identifier.
|
||||
type PlacementRules func(constraint storj.PlacementConstraint) (filter nodeselection.NodeFilters)
|
||||
|
||||
// ConfigurablePlacementRule can include the placement definitions for each known identifier.
|
||||
type ConfigurablePlacementRule struct {
|
||||
placements map[storj.PlacementConstraint]nodeselection.NodeFilters
|
||||
}
|
||||
|
||||
// String implements pflag.Value.
|
||||
func (d *ConfigurablePlacementRule) String() string {
|
||||
parts := []string{}
|
||||
for id, filter := range d.placements {
|
||||
// we can hide the internal rules...
|
||||
if id > 9 {
|
||||
// TODO: we need proper String implementation for all the used filters
|
||||
parts = append(parts, fmt.Sprintf("%d:%s", id, filter))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, ";")
|
||||
}
|
||||
|
||||
// Set implements pflag.Value.
|
||||
func (d *ConfigurablePlacementRule) Set(s string) error {
|
||||
if d.placements == nil {
|
||||
d.placements = make(map[storj.PlacementConstraint]nodeselection.NodeFilters)
|
||||
}
|
||||
d.AddLegacyStaticRules()
|
||||
return d.AddPlacementFromString(s)
|
||||
}
|
||||
|
||||
// Type implements pflag.Value.
|
||||
func (d *ConfigurablePlacementRule) Type() string {
|
||||
return "placement-rule"
|
||||
}
|
||||
|
||||
var _ pflag.Value = &ConfigurablePlacementRule{}
|
||||
|
||||
// NewPlacementRules creates a fully initialized NewPlacementRules.
|
||||
func NewPlacementRules() *ConfigurablePlacementRule {
|
||||
return &ConfigurablePlacementRule{
|
||||
placements: map[storj.PlacementConstraint]nodeselection.NodeFilters{},
|
||||
}
|
||||
}
|
||||
|
||||
// AddLegacyStaticRules initializes all the placement rules defined earlier in static golang code.
|
||||
func (d *ConfigurablePlacementRule) AddLegacyStaticRules() {
|
||||
d.placements[storj.EEA] = nodeselection.NodeFilters{nodeselection.NewCountryFilter(nodeselection.EeaCountries)}
|
||||
d.placements[storj.EU] = nodeselection.NodeFilters{nodeselection.NewCountryFilter(nodeselection.EuCountries)}
|
||||
d.placements[storj.US] = nodeselection.NodeFilters{nodeselection.NewCountryFilter(location.NewSet(location.UnitedStates))}
|
||||
d.placements[storj.DE] = nodeselection.NodeFilters{nodeselection.NewCountryFilter(location.NewSet(location.Germany))}
|
||||
d.placements[storj.NR] = nodeselection.NodeFilters{nodeselection.NewCountryFilter(location.NewFullSet().Without(location.Russia, location.Belarus, location.None))}
|
||||
}
|
||||
|
||||
// AddPlacementRule registers a new placement.
|
||||
func (d *ConfigurablePlacementRule) AddPlacementRule(id storj.PlacementConstraint, filters nodeselection.NodeFilters) {
|
||||
d.placements[id] = filters
|
||||
}
|
||||
|
||||
// AddPlacementFromString parses placement definition form string representations from id:definition;id:definition;...
|
||||
func (d *ConfigurablePlacementRule) AddPlacementFromString(definitions string) error {
|
||||
env := map[any]any{
|
||||
"country": func(countries ...string) (nodeselection.NodeFilters, error) {
|
||||
var set location.Set
|
||||
for _, country := range countries {
|
||||
code := location.ToCountryCode(country)
|
||||
if code == location.None {
|
||||
return nil, errs.New("invalid country code %q", code)
|
||||
}
|
||||
set.Include(code)
|
||||
}
|
||||
return nodeselection.NodeFilters{nodeselection.NewCountryFilter(set)}, nil
|
||||
},
|
||||
"all": func(filters ...nodeselection.NodeFilters) (nodeselection.NodeFilters, error) {
|
||||
res := nodeselection.NodeFilters{}
|
||||
for _, filter := range filters {
|
||||
res = append(res, filter...)
|
||||
}
|
||||
return res, nil
|
||||
},
|
||||
"tag": func(nodeIDstr string, key string, value any) (nodeselection.NodeFilters, error) {
|
||||
nodeID, err := storj.NodeIDFromString(nodeIDstr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var rawValue []byte
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
rawValue = []byte(v)
|
||||
case []byte:
|
||||
rawValue = v
|
||||
default:
|
||||
return nil, errs.New("3rd argument of tag() should be string or []byte")
|
||||
}
|
||||
res := nodeselection.NodeFilters{
|
||||
nodeselection.NewTagFilter(nodeID, key, rawValue),
|
||||
}
|
||||
return res, nil
|
||||
},
|
||||
}
|
||||
for _, definition := range strings.Split(definitions, ";") {
|
||||
definition = strings.TrimSpace(definition)
|
||||
if definition == "" {
|
||||
continue
|
||||
}
|
||||
idDef := strings.SplitN(definition, ":", 2)
|
||||
|
||||
val, err := mito.Eval(idDef[1], env)
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
id, err := strconv.Atoi(idDef[0])
|
||||
if err != nil {
|
||||
return errs.Wrap(err)
|
||||
}
|
||||
d.placements[storj.PlacementConstraint(id)] = val.(nodeselection.NodeFilters)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFilters implements PlacementCondition.
|
||||
func (d *ConfigurablePlacementRule) CreateFilters(constraint storj.PlacementConstraint) (filter nodeselection.NodeFilters) {
|
||||
if constraint == storj.EveryCountry {
|
||||
return nodeselection.NodeFilters{}
|
||||
}
|
||||
if filters, found := d.placements[constraint]; found {
|
||||
return slices.Clone(filters)
|
||||
}
|
||||
return nodeselection.NodeFilters{
|
||||
nodeselection.ExcludeAllFilter{},
|
||||
}
|
||||
}
|
147
satellite/overlay/placement_test.go
Normal file
147
satellite/overlay/placement_test.go
Normal file
@ -0,0 +1,147 @@
|
||||
// Copyright (C) 2023 Storj Labs, Inc.
|
||||
// See LICENSE for copying information.
|
||||
|
||||
package overlay
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/storj/location"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
)
|
||||
|
||||
func TestPlacementFromString(t *testing.T) {
|
||||
signer, err := storj.NodeIDFromString("12whfK1EDvHJtajBiAUeajQLYcWqxcQmdYQU5zX5cCf6bAxfgu4")
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("invalid country-code", func(t *testing.T) {
|
||||
p := NewPlacementRules()
|
||||
err := p.AddPlacementFromString(`1:country("ZZZZ")`)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("single country", func(t *testing.T) {
|
||||
p := NewPlacementRules()
|
||||
err := p.AddPlacementFromString(`11:country("GB")`)
|
||||
require.NoError(t, err)
|
||||
filters := p.placements[storj.PlacementConstraint(11)]
|
||||
require.NotNil(t, filters)
|
||||
require.True(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.UnitedKingdom,
|
||||
}))
|
||||
require.False(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.Germany,
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("tag rule", func(t *testing.T) {
|
||||
p := NewPlacementRules()
|
||||
err := p.AddPlacementFromString(`11:tag("12whfK1EDvHJtajBiAUeajQLYcWqxcQmdYQU5zX5cCf6bAxfgu4","foo","bar")`)
|
||||
require.NoError(t, err)
|
||||
filters := p.placements[storj.PlacementConstraint(11)]
|
||||
require.NotNil(t, filters)
|
||||
require.True(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
Tags: nodeselection.NodeTags{
|
||||
{
|
||||
Signer: signer,
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}))
|
||||
require.False(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.Germany,
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("all rules", func(t *testing.T) {
|
||||
p := NewPlacementRules()
|
||||
err := p.AddPlacementFromString(`11:all(country("GB"),tag("12whfK1EDvHJtajBiAUeajQLYcWqxcQmdYQU5zX5cCf6bAxfgu4","foo","bar"))`)
|
||||
require.NoError(t, err)
|
||||
filters := p.placements[storj.PlacementConstraint(11)]
|
||||
require.NotNil(t, filters)
|
||||
require.True(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.UnitedKingdom,
|
||||
Tags: nodeselection.NodeTags{
|
||||
{
|
||||
Signer: signer,
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}))
|
||||
require.False(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.UnitedKingdom,
|
||||
}))
|
||||
require.False(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.Germany,
|
||||
Tags: nodeselection.NodeTags{
|
||||
{
|
||||
Signer: signer,
|
||||
Name: "foo",
|
||||
Value: []byte("bar"),
|
||||
},
|
||||
},
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("multi rule", func(t *testing.T) {
|
||||
p := NewPlacementRules()
|
||||
err := p.AddPlacementFromString(`11:country("GB");12:country("DE")`)
|
||||
require.NoError(t, err)
|
||||
|
||||
filters := p.placements[storj.PlacementConstraint(11)]
|
||||
require.NotNil(t, filters)
|
||||
require.True(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.UnitedKingdom,
|
||||
}))
|
||||
require.False(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.Germany,
|
||||
}))
|
||||
|
||||
filters = p.placements[storj.PlacementConstraint(12)]
|
||||
require.NotNil(t, filters)
|
||||
require.False(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.UnitedKingdom,
|
||||
}))
|
||||
require.True(t, filters.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.Germany,
|
||||
}))
|
||||
|
||||
})
|
||||
|
||||
t.Run("legacy geofencing rules", func(t *testing.T) {
|
||||
p := NewPlacementRules()
|
||||
p.AddLegacyStaticRules()
|
||||
|
||||
t.Run("nr", func(t *testing.T) {
|
||||
nr := p.placements[storj.NR]
|
||||
require.True(t, nr.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.UnitedKingdom,
|
||||
}))
|
||||
require.False(t, nr.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.Russia,
|
||||
}))
|
||||
require.False(t, nr.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: 0,
|
||||
}))
|
||||
})
|
||||
t.Run("us", func(t *testing.T) {
|
||||
us := p.placements[storj.US]
|
||||
require.True(t, us.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.UnitedStates,
|
||||
}))
|
||||
require.False(t, us.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: location.Germany,
|
||||
}))
|
||||
require.False(t, us.MatchInclude(&nodeselection.SelectedNode{
|
||||
CountryCode: 0,
|
||||
}))
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
}
|
@ -26,6 +26,7 @@ import (
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/storj/private/testplanet"
|
||||
"storj.io/storj/satellite"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
"storj.io/storj/satellite/reputation"
|
||||
)
|
||||
@ -147,10 +148,10 @@ func TestOnlineOffline(t *testing.T) {
|
||||
require.Empty(t, offline)
|
||||
require.Len(t, online, 2)
|
||||
|
||||
require.False(t, slices.ContainsFunc(online, func(node overlay.SelectedNode) bool {
|
||||
require.False(t, slices.ContainsFunc(online, func(node nodeselection.SelectedNode) bool {
|
||||
return node.ID == unreliableNodeID
|
||||
}))
|
||||
require.False(t, slices.ContainsFunc(offline, func(node overlay.SelectedNode) bool {
|
||||
require.False(t, slices.ContainsFunc(offline, func(node nodeselection.SelectedNode) bool {
|
||||
return node.ID == unreliableNodeID
|
||||
}))
|
||||
})
|
||||
@ -192,7 +193,7 @@ func TestEnsureMinimumRequested(t *testing.T) {
|
||||
|
||||
reputable := map[storj.NodeID]bool{}
|
||||
|
||||
countReputable := func(selected []*overlay.SelectedNode) (count int) {
|
||||
countReputable := func(selected []*nodeselection.SelectedNode) (count int) {
|
||||
for _, n := range selected {
|
||||
if reputable[n.ID] {
|
||||
count++
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"storj.io/storj/satellite/geoip"
|
||||
"storj.io/storj/satellite/metabase"
|
||||
"storj.io/storj/satellite/nodeevents"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
)
|
||||
|
||||
// ErrEmptyNode is returned when the nodeID is empty.
|
||||
@ -53,20 +54,18 @@ type DB interface {
|
||||
// current reputation status.
|
||||
GetOnlineNodesForAuditRepair(ctx context.Context, nodeIDs []storj.NodeID, onlineWindow time.Duration) (map[storj.NodeID]*NodeReputation, error)
|
||||
// SelectStorageNodes looks up nodes based on criteria
|
||||
SelectStorageNodes(ctx context.Context, totalNeededNodes, newNodeCount int, criteria *NodeCriteria) ([]*SelectedNode, error)
|
||||
SelectStorageNodes(ctx context.Context, totalNeededNodes, newNodeCount int, criteria *NodeCriteria) ([]*nodeselection.SelectedNode, error)
|
||||
// SelectAllStorageNodesUpload returns all nodes that qualify to store data, organized as reputable nodes and new nodes
|
||||
SelectAllStorageNodesUpload(ctx context.Context, selectionCfg NodeSelectionConfig) (reputable, new []*SelectedNode, err error)
|
||||
SelectAllStorageNodesUpload(ctx context.Context, selectionCfg NodeSelectionConfig) (reputable, new []*nodeselection.SelectedNode, err error)
|
||||
// SelectAllStorageNodesDownload returns a nodes that are ready for downloading
|
||||
SelectAllStorageNodesDownload(ctx context.Context, onlineWindow time.Duration, asOf AsOfSystemTimeConfig) ([]*SelectedNode, error)
|
||||
SelectAllStorageNodesDownload(ctx context.Context, onlineWindow time.Duration, asOf AsOfSystemTimeConfig) ([]*nodeselection.SelectedNode, error)
|
||||
|
||||
// Get looks up the node by nodeID
|
||||
Get(ctx context.Context, nodeID storj.NodeID) (*NodeDossier, error)
|
||||
// KnownReliableInExcludedCountries filters healthy nodes that are in excluded countries.
|
||||
KnownReliableInExcludedCountries(context.Context, *NodeCriteria, storj.NodeIDList) (storj.NodeIDList, error)
|
||||
// KnownReliable filters a set of nodes to reliable (online and qualified) nodes.
|
||||
KnownReliable(ctx context.Context, nodeIDs storj.NodeIDList, onlineWindow, asOfSystemInterval time.Duration) (online []SelectedNode, offline []SelectedNode, err error)
|
||||
// Reliable returns all nodes that are reliable
|
||||
Reliable(context.Context, *NodeCriteria) (storj.NodeIDList, error)
|
||||
KnownReliable(ctx context.Context, nodeIDs storj.NodeIDList, onlineWindow, asOfSystemInterval time.Duration) (online []nodeselection.SelectedNode, offline []nodeselection.SelectedNode, err error)
|
||||
// Reliable returns all nodes that are reliable (separated by whether they are currently online or offline).
|
||||
Reliable(ctx context.Context, onlineWindow, asOfSystemInterval time.Duration) (online []nodeselection.SelectedNode, offline []nodeselection.SelectedNode, err error)
|
||||
// UpdateReputation updates the DB columns for all reputation fields in ReputationStatus.
|
||||
UpdateReputation(ctx context.Context, id storj.NodeID, request ReputationUpdate) error
|
||||
// UpdateNodeInfo updates node dossier with info requested from the node itself like node type, email, wallet, capacity, and version.
|
||||
@ -131,9 +130,15 @@ type DB interface {
|
||||
OneTimeFixLastNets(ctx context.Context) error
|
||||
|
||||
// IterateAllContactedNodes will call cb on all known nodes (used in restore trash contexts).
|
||||
IterateAllContactedNodes(context.Context, func(context.Context, *SelectedNode) error) error
|
||||
IterateAllContactedNodes(context.Context, func(context.Context, *nodeselection.SelectedNode) error) error
|
||||
// IterateAllNodeDossiers will call cb on all known nodes (used for invoice generation).
|
||||
IterateAllNodeDossiers(context.Context, func(context.Context, *NodeDossier) error) error
|
||||
|
||||
// UpdateNodeTags insert (or refresh) node tags.
|
||||
UpdateNodeTags(ctx context.Context, tags nodeselection.NodeTags) error
|
||||
|
||||
// GetNodeTags returns all nodes for a specific node.
|
||||
GetNodeTags(ctx context.Context, id storj.NodeID) (nodeselection.NodeTags, error)
|
||||
}
|
||||
|
||||
// DisqualificationReason is disqualification reason enum type.
|
||||
@ -192,7 +197,6 @@ type NodeCriteria struct {
|
||||
MinimumVersion string // semver or empty
|
||||
OnlineWindow time.Duration
|
||||
AsOfSystemInterval time.Duration // only used for CRDB queries
|
||||
ExcludedCountries []string
|
||||
}
|
||||
|
||||
// ReputationStatus indicates current reputation status for a node.
|
||||
@ -273,15 +277,6 @@ type NodeLastContact struct {
|
||||
LastContactFailure time.Time
|
||||
}
|
||||
|
||||
// SelectedNode is used as a result for creating orders limits.
|
||||
type SelectedNode struct {
|
||||
ID storj.NodeID
|
||||
Address *pb.NodeAddress
|
||||
LastNet string
|
||||
LastIPPort string
|
||||
CountryCode location.CountryCode
|
||||
}
|
||||
|
||||
// NodeReputation is used as a result for creating orders limits for audits.
|
||||
type NodeReputation struct {
|
||||
ID storj.NodeID
|
||||
@ -291,18 +286,6 @@ type NodeReputation struct {
|
||||
Reputation ReputationStatus
|
||||
}
|
||||
|
||||
// Clone returns a deep clone of the selected node.
|
||||
func (node *SelectedNode) Clone() *SelectedNode {
|
||||
copy := pb.CopyNode(&pb.Node{Id: node.ID, Address: node.Address})
|
||||
return &SelectedNode{
|
||||
ID: copy.Id,
|
||||
Address: copy.Address,
|
||||
LastNet: node.LastNet,
|
||||
LastIPPort: node.LastIPPort,
|
||||
CountryCode: node.CountryCode,
|
||||
}
|
||||
}
|
||||
|
||||
// Service is used to store and handle node information.
|
||||
//
|
||||
// architecture: Service
|
||||
@ -318,13 +301,14 @@ type Service struct {
|
||||
UploadSelectionCache *UploadSelectionCache
|
||||
DownloadSelectionCache *DownloadSelectionCache
|
||||
LastNetFunc LastNetFunc
|
||||
placementRules PlacementRules
|
||||
}
|
||||
|
||||
// LastNetFunc is the type of a function that will be used to derive a network from an ip and port.
|
||||
type LastNetFunc func(config NodeSelectionConfig, ip net.IP, port string) (string, error)
|
||||
|
||||
// NewService returns a new Service.
|
||||
func NewService(log *zap.Logger, db DB, nodeEvents nodeevents.DB, satelliteAddr, satelliteName string, config Config) (*Service, error) {
|
||||
func NewService(log *zap.Logger, db DB, nodeEvents nodeevents.DB, placementRules PlacementRules, satelliteAddr, satelliteName string, config Config) (*Service, error) {
|
||||
err := config.Node.AsOfSystemTime.isValid()
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
@ -338,17 +322,34 @@ func NewService(log *zap.Logger, db DB, nodeEvents nodeevents.DB, satelliteAddr,
|
||||
}
|
||||
}
|
||||
|
||||
defaultSelection := nodeselection.NodeFilters{}
|
||||
|
||||
if len(config.Node.UploadExcludedCountryCodes) > 0 {
|
||||
set := location.NewFullSet()
|
||||
for _, country := range config.Node.UploadExcludedCountryCodes {
|
||||
countryCode := location.ToCountryCode(country)
|
||||
if countryCode == location.None {
|
||||
return nil, Error.New("invalid country %q", country)
|
||||
}
|
||||
set.Remove(countryCode)
|
||||
}
|
||||
defaultSelection = defaultSelection.WithCountryFilter(set)
|
||||
}
|
||||
|
||||
uploadSelectionCache, err := NewUploadSelectionCache(log, db,
|
||||
config.NodeSelectionCache.Staleness, config.Node,
|
||||
defaultSelection, placementRules,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
downloadSelectionCache, err := NewDownloadSelectionCache(log, db, DownloadSelectionCacheConfig{
|
||||
Staleness: config.NodeSelectionCache.Staleness,
|
||||
OnlineWindow: config.Node.OnlineWindow,
|
||||
AsOfSystemTime: config.Node.AsOfSystemTime,
|
||||
})
|
||||
downloadSelectionCache, err := NewDownloadSelectionCache(log, db,
|
||||
placementRules,
|
||||
DownloadSelectionCacheConfig{
|
||||
Staleness: config.NodeSelectionCache.Staleness,
|
||||
OnlineWindow: config.Node.OnlineWindow,
|
||||
AsOfSystemTime: config.Node.AsOfSystemTime,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(err)
|
||||
}
|
||||
@ -366,6 +367,8 @@ func NewService(log *zap.Logger, db DB, nodeEvents nodeevents.DB, satelliteAddr,
|
||||
UploadSelectionCache: uploadSelectionCache,
|
||||
DownloadSelectionCache: downloadSelectionCache,
|
||||
LastNetFunc: MaskOffLastNet,
|
||||
|
||||
placementRules: placementRules,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -392,7 +395,7 @@ func (service *Service) Get(ctx context.Context, nodeID storj.NodeID) (_ *NodeDo
|
||||
}
|
||||
|
||||
// CachedGetOnlineNodesForGet returns a map of nodes from the download selection cache from the suppliedIDs.
|
||||
func (service *Service) CachedGetOnlineNodesForGet(ctx context.Context, nodeIDs []storj.NodeID) (_ map[storj.NodeID]*SelectedNode, err error) {
|
||||
func (service *Service) CachedGetOnlineNodesForGet(ctx context.Context, nodeIDs []storj.NodeID) (_ map[storj.NodeID]*nodeselection.SelectedNode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
return service.DownloadSelectionCache.GetNodes(ctx, nodeIDs)
|
||||
}
|
||||
@ -415,45 +418,8 @@ func (service *Service) IsOnline(node *NodeDossier) bool {
|
||||
return time.Since(node.Reputation.LastContactSuccess) < service.config.Node.OnlineWindow
|
||||
}
|
||||
|
||||
// GetNodesNetworkInOrder returns the /24 subnet for each storage node, in order. If a
|
||||
// requested node is not in the database, an empty string will be returned corresponding
|
||||
// to that node's last_net.
|
||||
func (service *Service) GetNodesNetworkInOrder(ctx context.Context, nodeIDs []storj.NodeID) (lastNets []string, err error) {
|
||||
defer mon.Task()(&ctx)(nil)
|
||||
|
||||
nodes, err := service.DownloadSelectionCache.GetNodes(ctx, nodeIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lastNets = make([]string, len(nodeIDs))
|
||||
for i, nodeID := range nodeIDs {
|
||||
if selectedNode, ok := nodes[nodeID]; ok {
|
||||
lastNets[i] = selectedNode.LastNet
|
||||
}
|
||||
}
|
||||
return lastNets, nil
|
||||
}
|
||||
|
||||
// GetNodesOutOfPlacement checks if nodes from nodeIDs list are in allowed country according to specified geo placement
|
||||
// and returns list of node ids which are not.
|
||||
func (service *Service) GetNodesOutOfPlacement(ctx context.Context, nodeIDs []storj.NodeID, placement storj.PlacementConstraint) (offNodes []storj.NodeID, err error) {
|
||||
defer mon.Task()(&ctx)(nil)
|
||||
|
||||
nodes, err := service.DownloadSelectionCache.GetNodes(ctx, nodeIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
offNodes = make([]storj.NodeID, 0, len(nodeIDs))
|
||||
for _, nodeID := range nodeIDs {
|
||||
if selectedNode, ok := nodes[nodeID]; ok && !placement.AllowedCountry(selectedNode.CountryCode) {
|
||||
offNodes = append(offNodes, selectedNode.ID)
|
||||
}
|
||||
}
|
||||
return offNodes, nil
|
||||
}
|
||||
|
||||
// FindStorageNodesForGracefulExit searches the overlay network for nodes that meet the provided requirements for graceful-exit requests.
|
||||
func (service *Service) FindStorageNodesForGracefulExit(ctx context.Context, req FindStorageNodesRequest) (_ []*SelectedNode, err error) {
|
||||
func (service *Service) FindStorageNodesForGracefulExit(ctx context.Context, req FindStorageNodesRequest) (_ []*nodeselection.SelectedNode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
return service.UploadSelectionCache.GetNodes(ctx, req)
|
||||
}
|
||||
@ -462,7 +428,7 @@ func (service *Service) FindStorageNodesForGracefulExit(ctx context.Context, req
|
||||
//
|
||||
// When enabled it uses the cache to select nodes.
|
||||
// When the node selection from the cache fails, it falls back to the old implementation.
|
||||
func (service *Service) FindStorageNodesForUpload(ctx context.Context, req FindStorageNodesRequest) (_ []*SelectedNode, err error) {
|
||||
func (service *Service) FindStorageNodesForUpload(ctx context.Context, req FindStorageNodesRequest) (_ []*nodeselection.SelectedNode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
if service.config.Node.AsOfSystemTime.Enabled && service.config.Node.AsOfSystemTime.DefaultInterval < 0 {
|
||||
req.AsOfSystemInterval = service.config.Node.AsOfSystemTime.DefaultInterval
|
||||
@ -498,7 +464,7 @@ func (service *Service) FindStorageNodesForUpload(ctx context.Context, req FindS
|
||||
// FindStorageNodesWithPreferences searches the overlay network for nodes that meet the provided criteria.
|
||||
//
|
||||
// This does not use a cache.
|
||||
func (service *Service) FindStorageNodesWithPreferences(ctx context.Context, req FindStorageNodesRequest, preferences *NodeSelectionConfig) (nodes []*SelectedNode, err error) {
|
||||
func (service *Service) FindStorageNodesWithPreferences(ctx context.Context, req FindStorageNodesRequest, preferences *NodeSelectionConfig) (nodes []*nodeselection.SelectedNode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
// TODO: add sanity limits to requested node count
|
||||
// TODO: add sanity limits to excluded nodes
|
||||
@ -572,34 +538,20 @@ func (service *Service) InsertOfflineNodeEvents(ctx context.Context, cooldown ti
|
||||
return count, err
|
||||
}
|
||||
|
||||
// KnownReliableInExcludedCountries filters healthy nodes that are in excluded countries.
|
||||
func (service *Service) KnownReliableInExcludedCountries(ctx context.Context, nodeIds storj.NodeIDList) (reliableInExcluded storj.NodeIDList, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
criteria := &NodeCriteria{
|
||||
OnlineWindow: service.config.Node.OnlineWindow,
|
||||
ExcludedCountries: service.config.RepairExcludedCountryCodes,
|
||||
}
|
||||
return service.db.KnownReliableInExcludedCountries(ctx, criteria, nodeIds)
|
||||
}
|
||||
|
||||
// KnownReliable filters a set of nodes to reliable (online and qualified) nodes.
|
||||
func (service *Service) KnownReliable(ctx context.Context, nodeIDs storj.NodeIDList) (onlineNodes []SelectedNode, offlineNodes []SelectedNode, err error) {
|
||||
func (service *Service) KnownReliable(ctx context.Context, nodeIDs storj.NodeIDList) (onlineNodes []nodeselection.SelectedNode, offlineNodes []nodeselection.SelectedNode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
// TODO add as of system time
|
||||
return service.db.KnownReliable(ctx, nodeIDs, service.config.Node.OnlineWindow, 0)
|
||||
}
|
||||
|
||||
// Reliable filters a set of nodes that are reliable, independent of new.
|
||||
func (service *Service) Reliable(ctx context.Context) (nodes storj.NodeIDList, err error) {
|
||||
// Reliable returns all nodes that are reliable (separated by whether they are currently online or offline).
|
||||
func (service *Service) Reliable(ctx context.Context) (online []nodeselection.SelectedNode, offline []nodeselection.SelectedNode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
criteria := &NodeCriteria{
|
||||
OnlineWindow: service.config.Node.OnlineWindow,
|
||||
}
|
||||
criteria.ExcludedCountries = service.config.RepairExcludedCountryCodes
|
||||
return service.db.Reliable(ctx, criteria)
|
||||
// TODO add as of system tim.
|
||||
return service.db.Reliable(ctx, service.config.Node.OnlineWindow, 0)
|
||||
}
|
||||
|
||||
// UpdateReputation updates the DB columns for any of the reputation fields.
|
||||
@ -782,28 +734,6 @@ func (service *Service) GetMissingPieces(ctx context.Context, pieces metabase.Pi
|
||||
return maps.Values(missingPiecesMap), nil
|
||||
}
|
||||
|
||||
// GetReliablePiecesInExcludedCountries returns the list of pieces held by nodes located in excluded countries.
|
||||
func (service *Service) GetReliablePiecesInExcludedCountries(ctx context.Context, pieces metabase.Pieces) (piecesInExcluded []uint16, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
var nodeIDs storj.NodeIDList
|
||||
for _, p := range pieces {
|
||||
nodeIDs = append(nodeIDs, p.StorageNode)
|
||||
}
|
||||
inExcluded, err := service.KnownReliableInExcludedCountries(ctx, nodeIDs)
|
||||
if err != nil {
|
||||
return nil, Error.New("error getting nodes %s", err)
|
||||
}
|
||||
|
||||
for _, p := range pieces {
|
||||
for _, nodeID := range inExcluded {
|
||||
if nodeID == p.StorageNode {
|
||||
piecesInExcluded = append(piecesInExcluded, p.Number)
|
||||
}
|
||||
}
|
||||
}
|
||||
return piecesInExcluded, nil
|
||||
}
|
||||
|
||||
// DQNodesLastSeenBefore disqualifies nodes who have not been contacted since the cutoff time.
|
||||
func (service *Service) DQNodesLastSeenBefore(ctx context.Context, cutoff time.Time, limit int) (count int, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
@ -840,7 +770,7 @@ func (service *Service) DisqualifyNode(ctx context.Context, nodeID storj.NodeID,
|
||||
}
|
||||
|
||||
// SelectAllStorageNodesDownload returns a nodes that are ready for downloading.
|
||||
func (service *Service) SelectAllStorageNodesDownload(ctx context.Context, onlineWindow time.Duration, asOf AsOfSystemTimeConfig) (_ []*SelectedNode, err error) {
|
||||
func (service *Service) SelectAllStorageNodesDownload(ctx context.Context, onlineWindow time.Duration, asOf AsOfSystemTimeConfig) (_ []*nodeselection.SelectedNode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
return service.db.SelectAllStorageNodesDownload(ctx, onlineWindow, asOf)
|
||||
}
|
||||
@ -851,6 +781,16 @@ func (service *Service) ResolveIPAndNetwork(ctx context.Context, target string)
|
||||
return ResolveIPAndNetwork(ctx, target, service.config.Node, service.LastNetFunc)
|
||||
}
|
||||
|
||||
// UpdateNodeTags persists all new and old node tags.
|
||||
func (service *Service) UpdateNodeTags(ctx context.Context, tags []nodeselection.NodeTag) error {
|
||||
return service.db.UpdateNodeTags(ctx, tags)
|
||||
}
|
||||
|
||||
// GetNodeTags returns the node tags of a node.
|
||||
func (service *Service) GetNodeTags(ctx context.Context, id storj.NodeID) (nodeselection.NodeTags, error) {
|
||||
return service.db.GetNodeTags(ctx, id)
|
||||
}
|
||||
|
||||
// ResolveIPAndNetwork resolves the target address and determines its IP and appropriate last_net, as indicated.
|
||||
func ResolveIPAndNetwork(ctx context.Context, target string, config NodeSelectionConfig, lastNetFunc LastNetFunc) (ip net.IP, port, network string, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
@ -18,12 +18,12 @@ import (
|
||||
"storj.io/common/memory"
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/storj"
|
||||
"storj.io/common/storj/location"
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/common/testrand"
|
||||
"storj.io/storj/private/testplanet"
|
||||
"storj.io/storj/satellite"
|
||||
"storj.io/storj/satellite/nodeevents"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
"storj.io/storj/satellite/reputation"
|
||||
"storj.io/storj/satellite/satellitedb/satellitedbtest"
|
||||
@ -74,7 +74,7 @@ func testCache(ctx *testcontext.Context, t *testing.T, store overlay.DB, nodeEve
|
||||
|
||||
serviceCtx, serviceCancel := context.WithCancel(ctx)
|
||||
defer serviceCancel()
|
||||
service, err := overlay.NewService(zaptest.NewLogger(t), store, nodeEvents, "", "", serviceConfig)
|
||||
service, err := overlay.NewService(zaptest.NewLogger(t), store, nodeEvents, overlay.NewPlacementRules().CreateFilters, "", "", serviceConfig)
|
||||
require.NoError(t, err)
|
||||
ctx.Go(func() error { return service.Run(serviceCtx) })
|
||||
defer ctx.Check(service.Close)
|
||||
@ -205,7 +205,7 @@ func TestRandomizedSelection(t *testing.T) {
|
||||
|
||||
// select numNodesToSelect nodes selectIterations times
|
||||
for i := 0; i < selectIterations; i++ {
|
||||
var nodes []*overlay.SelectedNode
|
||||
var nodes []*nodeselection.SelectedNode
|
||||
var err error
|
||||
|
||||
if i%2 == 0 {
|
||||
@ -326,7 +326,7 @@ func TestRandomizedSelectionCache(t *testing.T) {
|
||||
|
||||
// select numNodesToSelect nodes selectIterations times
|
||||
for i := 0; i < selectIterations; i++ {
|
||||
var nodes []*overlay.SelectedNode
|
||||
var nodes []*nodeselection.SelectedNode
|
||||
var err error
|
||||
req := overlay.FindStorageNodesRequest{
|
||||
RequestedCount: numNodesToSelect,
|
||||
@ -670,7 +670,7 @@ func TestSuspendedSelection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var nodes []*overlay.SelectedNode
|
||||
var nodes []*nodeselection.SelectedNode
|
||||
var err error
|
||||
|
||||
numNodesToSelect := 10
|
||||
@ -816,50 +816,6 @@ func TestVetAndUnvetNode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestReliable(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 2, UplinkCount: 0,
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
service := planet.Satellites[0].Overlay.Service
|
||||
node := planet.StorageNodes[0]
|
||||
|
||||
nodes, err := service.Reliable(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 2)
|
||||
|
||||
err = planet.Satellites[0].Overlay.Service.TestNodeCountryCode(ctx, node.ID(), "FR")
|
||||
require.NoError(t, err)
|
||||
|
||||
// first node should be excluded from Reliable result because of country code
|
||||
nodes, err = service.Reliable(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 1)
|
||||
require.NotEqual(t, node.ID(), nodes[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestKnownReliableInExcludedCountries(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 2, UplinkCount: 0,
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
service := planet.Satellites[0].Overlay.Service
|
||||
node := planet.StorageNodes[0]
|
||||
|
||||
nodes, err := service.Reliable(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 2)
|
||||
|
||||
err = planet.Satellites[0].Overlay.Service.TestNodeCountryCode(ctx, node.ID(), "FR")
|
||||
require.NoError(t, err)
|
||||
|
||||
// first node should be excluded from Reliable result because of country code
|
||||
nodes, err = service.KnownReliableInExcludedCountries(ctx, nodes)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, nodes, 1)
|
||||
require.Equal(t, node.ID(), nodes[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateReputationNodeEvents(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 2, UplinkCount: 0,
|
||||
@ -1049,47 +1005,3 @@ func TestUpdateCheckInBelowMinVersionEvent(t *testing.T) {
|
||||
require.True(t, ne2.CreatedAt.After(ne1.CreatedAt))
|
||||
})
|
||||
}
|
||||
|
||||
func TestService_GetNodesOutOfPlacement(t *testing.T) {
|
||||
testplanet.Run(t, testplanet.Config{
|
||||
SatelliteCount: 1, StorageNodeCount: 4, UplinkCount: 1,
|
||||
Reconfigure: testplanet.Reconfigure{
|
||||
Satellite: func(log *zap.Logger, index int, config *satellite.Config) {
|
||||
config.Overlay.Node.AsOfSystemTime.Enabled = false
|
||||
config.Overlay.Node.AsOfSystemTime.DefaultInterval = 0
|
||||
},
|
||||
},
|
||||
}, func(t *testing.T, ctx *testcontext.Context, planet *testplanet.Planet) {
|
||||
service := planet.Satellites[0].Overlay.Service
|
||||
|
||||
placement := storj.EU
|
||||
|
||||
nodeIDs := []storj.NodeID{}
|
||||
for _, node := range planet.StorageNodes {
|
||||
nodeIDs = append(nodeIDs, node.ID())
|
||||
|
||||
err := service.TestNodeCountryCode(ctx, node.ID(), location.Poland.String())
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.NoError(t, service.DownloadSelectionCache.Refresh(ctx))
|
||||
|
||||
offNodes, err := service.GetNodesOutOfPlacement(ctx, nodeIDs, placement)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, offNodes)
|
||||
|
||||
expectedNodeIDs := []storj.NodeID{}
|
||||
for _, node := range planet.StorageNodes {
|
||||
expectedNodeIDs = append(expectedNodeIDs, node.ID())
|
||||
err := service.TestNodeCountryCode(ctx, node.ID(), location.Brazil.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
// we need to refresh cache because node country code was changed
|
||||
require.NoError(t, service.DownloadSelectionCache.Refresh(ctx))
|
||||
|
||||
offNodes, err := service.GetNodesOutOfPlacement(ctx, nodeIDs, placement)
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, expectedNodeIDs, offNodes)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"storj.io/common/storj/location"
|
||||
"storj.io/common/testcontext"
|
||||
"storj.io/storj/satellite"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
"storj.io/storj/satellite/overlay"
|
||||
"storj.io/storj/satellite/satellitedb/satellitedbtest"
|
||||
)
|
||||
@ -101,18 +102,13 @@ func testDatabase(ctx context.Context, t *testing.T, cache overlay.DB) {
|
||||
storj.NodeID{7}, storj.NodeID{8},
|
||||
storj.NodeID{9},
|
||||
}
|
||||
criteria := &overlay.NodeCriteria{
|
||||
OnlineWindow: time.Hour,
|
||||
ExcludedCountries: []string{"FR", "BE"},
|
||||
}
|
||||
|
||||
contains := func(nodeID storj.NodeID) func(node overlay.SelectedNode) bool {
|
||||
return func(node overlay.SelectedNode) bool {
|
||||
contains := func(nodeID storj.NodeID) func(node nodeselection.SelectedNode) bool {
|
||||
return func(node nodeselection.SelectedNode) bool {
|
||||
return node.ID == nodeID
|
||||
}
|
||||
}
|
||||
|
||||
online, offline, err := cache.KnownReliable(ctx, nodeIds, criteria.OnlineWindow, criteria.AsOfSystemInterval)
|
||||
online, offline, err := cache.KnownReliable(ctx, nodeIds, time.Hour, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// unrealiable nodes shouldn't be in results
|
||||
@ -123,19 +119,26 @@ func testDatabase(ctx context.Context, t *testing.T, cache overlay.DB) {
|
||||
require.False(t, slices.ContainsFunc(append(online, offline...), contains(storj.NodeID{9}))) // not in db
|
||||
|
||||
require.True(t, slices.ContainsFunc(offline, contains(storj.NodeID{4}))) // offline
|
||||
// KnownReliable is not excluding by country anymore
|
||||
require.True(t, slices.ContainsFunc(online, contains(storj.NodeID{7}))) // excluded country
|
||||
|
||||
require.Len(t, append(online, offline...), 4)
|
||||
|
||||
valid, err := cache.Reliable(ctx, criteria)
|
||||
online, offline, err = cache.Reliable(ctx, time.Hour, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotContains(t, valid, storj.NodeID{2}) // disqualified
|
||||
require.NotContains(t, valid, storj.NodeID{3}) // unknown audit suspended
|
||||
require.NotContains(t, valid, storj.NodeID{4}) // offline
|
||||
require.NotContains(t, valid, storj.NodeID{5}) // gracefully exited
|
||||
require.NotContains(t, valid, storj.NodeID{6}) // offline suspended
|
||||
require.NotContains(t, valid, storj.NodeID{7}) // excluded country
|
||||
require.NotContains(t, valid, storj.NodeID{9}) // not in db
|
||||
require.Len(t, valid, 2)
|
||||
require.False(t, slices.ContainsFunc(append(online, offline...), contains(storj.NodeID{2}))) // disqualified
|
||||
require.False(t, slices.ContainsFunc(append(online, offline...), contains(storj.NodeID{3}))) // unknown audit suspended
|
||||
|
||||
require.False(t, slices.ContainsFunc(append(online, offline...), contains(storj.NodeID{5}))) // gracefully exited
|
||||
require.False(t, slices.ContainsFunc(append(online, offline...), contains(storj.NodeID{6}))) // offline suspended
|
||||
require.False(t, slices.ContainsFunc(append(online, offline...), contains(storj.NodeID{9}))) // not in db
|
||||
|
||||
require.True(t, slices.ContainsFunc(offline, contains(storj.NodeID{4}))) // offline
|
||||
// Reliable is not excluding by country anymore
|
||||
require.True(t, slices.ContainsFunc(online, contains(storj.NodeID{7}))) // excluded country
|
||||
|
||||
require.Len(t, append(online, offline...), 4)
|
||||
}
|
||||
|
||||
{ // TestUpdateOperator
|
||||
|
@ -9,9 +9,8 @@ import (
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"storj.io/common/pb"
|
||||
"storj.io/common/sync2"
|
||||
"storj.io/storj/satellite/nodeselection/uploadselection"
|
||||
"storj.io/storj/satellite/nodeselection"
|
||||
)
|
||||
|
||||
// UploadSelectionDB implements the database for upload selection cache.
|
||||
@ -19,7 +18,7 @@ import (
|
||||
// architecture: Database
|
||||
type UploadSelectionDB interface {
|
||||
// SelectAllStorageNodesUpload returns all nodes that qualify to store data, organized as reputable nodes and new nodes
|
||||
SelectAllStorageNodesUpload(ctx context.Context, selectionCfg NodeSelectionConfig) (reputable, new []*SelectedNode, err error)
|
||||
SelectAllStorageNodesUpload(ctx context.Context, selectionCfg NodeSelectionConfig) (reputable, new []*nodeselection.SelectedNode, err error)
|
||||
}
|
||||
|
||||
// UploadSelectionCacheConfig is a configuration for upload selection cache.
|
||||
@ -36,15 +35,20 @@ type UploadSelectionCache struct {
|
||||
db UploadSelectionDB
|
||||
selectionConfig NodeSelectionConfig
|
||||
|
||||
cache sync2.ReadCacheOf[*uploadselection.State]
|
||||
cache sync2.ReadCacheOf[*nodeselection.State]
|
||||
|
||||
defaultFilters nodeselection.NodeFilters
|
||||
placementRules PlacementRules
|
||||
}
|
||||
|
||||
// NewUploadSelectionCache creates a new cache that keeps a list of all the storage nodes that are qualified to store data.
|
||||
func NewUploadSelectionCache(log *zap.Logger, db UploadSelectionDB, staleness time.Duration, config NodeSelectionConfig) (*UploadSelectionCache, error) {
|
||||
func NewUploadSelectionCache(log *zap.Logger, db UploadSelectionDB, staleness time.Duration, config NodeSelectionConfig, defaultFilter nodeselection.NodeFilters, placementRules PlacementRules) (*UploadSelectionCache, error) {
|
||||
cache := &UploadSelectionCache{
|
||||
log: log,
|
||||
db: db,
|
||||
selectionConfig: config,
|
||||
defaultFilters: defaultFilter,
|
||||
placementRules: placementRules,
|
||||
}
|
||||
return cache, cache.cache.Init(staleness/2, staleness, cache.read)
|
||||
}
|
||||
@ -65,7 +69,7 @@ func (cache *UploadSelectionCache) Refresh(ctx context.Context) (err error) {
|
||||
// refresh calls out to the database and refreshes the cache with the most up-to-date
|
||||
// data from the nodes table, then sets time that the last refresh occurred so we know when
|
||||
// to refresh again in the future.
|
||||
func (cache *UploadSelectionCache) read(ctx context.Context) (_ *uploadselection.State, err error) {
|
||||
func (cache *UploadSelectionCache) read(ctx context.Context) (_ *nodeselection.State, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
reputableNodes, newNodes, err := cache.db.SelectAllStorageNodesUpload(ctx, cache.selectionConfig)
|
||||
@ -73,7 +77,7 @@ func (cache *UploadSelectionCache) read(ctx context.Context) (_ *uploadselection
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
state := uploadselection.NewState(convSelectedNodesToNodes(reputableNodes), convSelectedNodesToNodes(newNodes))
|
||||
state := nodeselection.NewState(reputableNodes, newNodes)
|
||||
|
||||
mon.IntVal("refresh_cache_size_reputable").Observe(int64(len(reputableNodes)))
|
||||
mon.IntVal("refresh_cache_size_new").Observe(int64(len(newNodes)))
|
||||
@ -84,7 +88,7 @@ func (cache *UploadSelectionCache) read(ctx context.Context) (_ *uploadselection
|
||||
// GetNodes selects nodes from the cache that will be used to upload a file.
|
||||
// Every node selected will be from a distinct network.
|
||||
// If the cache hasn't been refreshed recently it will do so first.
|
||||
func (cache *UploadSelectionCache) GetNodes(ctx context.Context, req FindStorageNodesRequest) (_ []*SelectedNode, err error) {
|
||||
func (cache *UploadSelectionCache) GetNodes(ctx context.Context, req FindStorageNodesRequest) (_ []*nodeselection.SelectedNode, err error) {
|
||||
defer mon.Task()(&ctx)(&err)
|
||||
|
||||
state, err := cache.cache.Get(ctx, time.Now())
|
||||
@ -92,18 +96,23 @@ func (cache *UploadSelectionCache) GetNodes(ctx context.Context, req FindStorage
|
||||
return nil, Error.Wrap(err)
|
||||
}
|
||||
|
||||
selected, err := state.Select(ctx, uploadselection.Request{
|
||||
Count: req.RequestedCount,
|
||||
NewFraction: cache.selectionConfig.NewNodeFraction,
|
||||
ExcludedIDs: req.ExcludedIDs,
|
||||
Placement: req.Placement,
|
||||
ExcludedCountryCodes: cache.selectionConfig.UploadExcludedCountryCodes,
|
||||
})
|
||||
if uploadselection.ErrNotEnoughNodes.Has(err) {
|
||||
err = ErrNotEnoughNodes.Wrap(err)
|
||||
filters := cache.placementRules(req.Placement)
|
||||
if len(req.ExcludedIDs) > 0 {
|
||||
filters = append(filters, state.ExcludeNetworksBasedOnNodes(req.ExcludedIDs))
|
||||
}
|
||||
|
||||
return convNodesToSelectedNodes(selected), err
|
||||
filters = append(filters, cache.defaultFilters)
|
||||
filters = filters.WithAutoExcludeSubnets()
|
||||
|
||||
selected, err := state.Select(ctx, nodeselection.Request{
|
||||
Count: req.RequestedCount,
|
||||
NewFraction: cache.selectionConfig.NewNodeFraction,
|
||||
NodeFilters: filters,
|
||||
})
|
||||
if nodeselection.ErrNotEnoughNodes.Has(err) {
|
||||
err = ErrNotEnoughNodes.Wrap(err)
|
||||
}
|
||||
return selected, err
|
||||
}
|
||||
|
||||
// Size returns how many reputable nodes and new nodes are in the cache.
|
||||
@ -115,31 +124,3 @@ func (cache *UploadSelectionCache) Size(ctx context.Context) (reputableNodeCount
|
||||
stats := state.Stats()
|
||||
return stats.Reputable, stats.New, nil
|
||||
}
|
||||
|
||||
func convNodesToSelectedNodes(nodes []*uploadselection.Node) (xs []*SelectedNode) {
|
||||
for _, n := range nodes {
|
||||
xs = append(xs, &SelectedNode{
|
||||
ID: n.ID,
|
||||
Address: pb.NodeFromNodeURL(n.NodeURL).Address,
|
||||
LastNet: n.LastNet,
|
||||
LastIPPort: n.LastIPPort,
|
||||
CountryCode: n.CountryCode,
|
||||
})
|
||||
}
|
||||
return xs
|
||||
}
|
||||
|
||||
func convSelectedNodesToNodes(nodes []*SelectedNode) (xs []*uploadselection.Node) {
|
||||
for _, n := range nodes {
|
||||
xs = append(xs, &uploadselection.Node{
|
||||
NodeURL: (&pb.Node{
|
||||
Id: n.ID,
|
||||
Address: n.Address,
|
||||
}).NodeURL(),
|
||||
LastNet: n.LastNet,
|
||||
LastIPPort: n.LastIPPort,
|
||||
CountryCode: n.CountryCode,
|
||||
})
|
||||
}
|
||||
return xs
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user