storj/cmd/satellite/billing.go

104 lines
2.7 KiB
Go
Raw Normal View History

// Copyright (C) 2020 Storj Labs, Inc.
// See LICENSE for copying information.
package main
import (
"strconv"
"strings"
"time"
"github.com/zeebo/errs"
"go.uber.org/zap"
"storj.io/storj/private/dbutil"
"storj.io/storj/satellite"
"storj.io/storj/satellite/payments/stripecoinpayments"
"storj.io/storj/satellite/satellitedb"
"storj.io/storj/satellite/satellitedb/dbx"
)
func runBillingCmd(cmdFunc func(*stripecoinpayments.Service, *dbx.DB) error) error {
// Open SatelliteDB for the Payment Service
logger := zap.L()
db, err := satellitedb.New(logger.Named("db"), runCfg.Database, satellitedb.Options{})
if err != nil {
return errs.New("error connecting to master database on satellite: %+v", err)
}
defer func() {
err = errs.Combine(err, db.Close())
}()
// Open direct DB connection to execute custom queries
driver, source, implementation, err := dbutil.SplitConnStr(runCfg.Database)
if err != nil {
return err
}
if implementation != dbutil.Postgres && implementation != dbutil.Cockroach {
return errs.New("unsupported driver %q", driver)
}
dbxDB, err := dbx.Open(driver, source)
if err != nil {
return err
}
defer func() {
err = errs.Combine(err, dbxDB.Close())
}()
logger.Debug("Connected to:", zap.String("db source", source))
payments, err := setupPayments(logger, db)
if err != nil {
return err
}
return cmdFunc(payments, dbxDB)
}
func setupPayments(log *zap.Logger, db satellite.DB) (*stripecoinpayments.Service, error) {
pc := runCfg.Payments
stripeClient := stripecoinpayments.NewStripeClient(log, pc.StripeCoinPayments)
return stripecoinpayments.NewService(
log.Named("payments.stripe:service"),
stripeClient,
pc.StripeCoinPayments,
db.StripeCoinPayments(),
db.Console().Projects(),
db.ProjectAccounting(),
pc.StorageTBPrice,
pc.EgressTBPrice,
pc.ObjectPrice,
pc.BonusRate,
pc.CouponValue,
pc.CouponDuration,
pc.CouponProjectLimit,
pc.MinCoinPayment)
}
// parseBillingPeriodFromString parses provided date string and returns corresponding time.Time.
func parseBillingPeriod(s string) (time.Time, error) {
values := strings.Split(s, "/")
if len(values) != 2 {
return time.Time{}, errs.New("invalid date format %s, use mm/yyyy", s)
}
month, err := strconv.ParseInt(values[0], 10, 64)
if err != nil {
return time.Time{}, errs.New("can not parse month: %v", err)
}
year, err := strconv.ParseInt(values[1], 10, 64)
if err != nil {
return time.Time{}, errs.New("can not parse year: %v", err)
}
date := time.Date(int(year), time.Month(month), 1, 0, 0, 0, 0, time.UTC)
if date.Year() != int(year) || date.Month() != time.Month(month) || date.Day() != 1 {
return date, errs.New("dates mismatch have %s result %s", s, date)
}
return date, nil
}