// Copyright (C) 2019 Storj Labs, Inc. // See LICENSE for copying information. package stripepayments import ( "context" "fmt" "time" "github.com/stripe/stripe-go" "github.com/stripe/stripe-go/client" "github.com/zeebo/errs" "go.uber.org/zap" "storj.io/storj/satellite/payments" ) // stripeErr is a wrapper for stripe err var stripeErr = errs.Class("stripe error") // service is payments.Service implementation which // works with stripe network through stripe-go client type service struct { log *zap.Logger client *client.API } // NewService creates new instance of StripeService initialized with API key func NewService(log *zap.Logger, apiKey string) payments.Service { stripe.DefaultLeveledLogger = log.Sugar() sc := new(client.API) sc.Init(apiKey, nil) return &service{ log: log, client: sc, } } // CreateCustomer creates new customer from CustomerParams struct // sets default payment to one of the predefined testing VISA credit cards func (s *service) CreateCustomer(ctx context.Context, params payments.CreateCustomerParams) (*payments.Customer, error) { cparams := &stripe.CustomerParams{ Email: stripe.String(params.Email), Name: stripe.String(params.Name), } // TODO: delete after migrating from test environment err := cparams.SetSource("tok_visa") if err != nil { return nil, stripeErr.Wrap(err) } cus, err := s.client.Customers.New(cparams) if err != nil { return nil, stripeErr.Wrap(err) } return &payments.Customer{ ID: []byte(cus.ID), Name: cus.Name, Email: cus.Email, CreatedAt: time.Unix(cus.Created, 0), }, nil } // GetCustomer retrieves customer object from stripe network func (s *service) GetCustomer(ctx context.Context, id []byte) (*payments.Customer, error) { cus, err := s.client.Customers.Get(string(id), nil) if err != nil { return nil, stripeErr.Wrap(err) } return &payments.Customer{ ID: []byte(cus.ID), Name: cus.Name, Email: cus.Email, CreatedAt: time.Unix(cus.Created, 0), }, nil } // GetCustomerDefaultPaymentMethod retrieves customer default payment method from stripe network func (s *service) GetCustomerDefaultPaymentMethod(ctx context.Context, customerID []byte) (*payments.PaymentMethod, error) { cus, err := s.client.Customers.Get(string(customerID), nil) if err != nil { return nil, stripeErr.Wrap(err) } if cus.DefaultSource == nil { return nil, stripeErr.New("no default payment method attached to customer") } pm, err := s.client.PaymentMethods.Get(cus.DefaultSource.ID, nil) if err != nil { return nil, stripeErr.Wrap(err) } if pm.Type != stripe.PaymentMethodTypeCard { return nil, stripeErr.New("payment method other than cards are not allowed") } return &payments.PaymentMethod{ ID: []byte(pm.ID), CustomerID: []byte(cus.ID), Card: payments.Card{ Country: pm.Card.Country, Brand: string(pm.Card.Brand), Name: pm.BillingDetails.Name, ExpMonth: int64(pm.Card.ExpMonth), ExpYear: int64(pm.Card.ExpYear), LastFour: pm.Card.Last4, }, CreatedAt: time.Unix(pm.Created, 0), }, nil } // GetCustomerPaymentsMethods retrieves all payments method attached to particular customer func (s *service) GetCustomerPaymentsMethods(ctx context.Context, customerID []byte) ([]payments.PaymentMethod, error) { var err error pmparams := &stripe.PaymentMethodListParams{} pmparams.Filters.AddFilter("customer", "", string(customerID)) pmparams.Filters.AddFilter("type", "", "card") iterator := s.client.PaymentMethods.List(pmparams) if err = iterator.Err(); err != nil { return nil, stripeErr.Wrap(err) } var paymentMethods []payments.PaymentMethod for iterator.Next() { pm := iterator.PaymentMethod() if pm.Type != stripe.PaymentMethodTypeCard { continue } paymentMethods = append(paymentMethods, payments.PaymentMethod{ ID: []byte(pm.ID), CustomerID: customerID, Card: payments.Card{ Country: pm.Card.Country, Brand: string(pm.Card.Brand), Name: pm.BillingDetails.Name, ExpMonth: int64(pm.Card.ExpMonth), ExpYear: int64(pm.Card.ExpYear), LastFour: pm.Card.Last4, }, CreatedAt: time.Unix(pm.Created, 0), }) } return paymentMethods, nil } // GetPaymentMethod retrieve payment method object from stripe network func (s *service) GetPaymentMethod(ctx context.Context, id []byte) (*payments.PaymentMethod, error) { pm, err := s.client.PaymentMethods.Get(string(id), nil) if err != nil { return nil, stripeErr.Wrap(err) } if pm.Type != stripe.PaymentMethodTypeCard { return nil, stripeErr.New("payment method other than cards are not allowed") } // TODO: check if name is always returned var customerID []byte if pm.Customer != nil { customerID = []byte(pm.Customer.ID) } return &payments.PaymentMethod{ ID: []byte(pm.ID), CustomerID: customerID, Card: payments.Card{ Country: pm.Card.Country, Brand: string(pm.Card.Brand), Name: pm.BillingDetails.Name, ExpMonth: int64(pm.Card.ExpMonth), ExpYear: int64(pm.Card.ExpYear), LastFour: pm.Card.Last4, }, CreatedAt: time.Unix(pm.Created, 0), }, nil } // CreateProjectInvoice creates new project invoice on stripe network from input params. // Included line items: // - Storage // - Egress // - ObjectsCount // Created invoice has AutoAdvance property set to true, so it will be finalized // (no further editing) and attempted to be paid in 1 hour after creation func (s *service) CreateProjectInvoice(ctx context.Context, params payments.CreateProjectInvoiceParams) (*payments.Invoice, error) { // create line items _, err := s.client.InvoiceItems.New(&stripe.InvoiceItemParams{ Customer: stripe.String(params.CustomerID), Description: stripe.String("Storage"), Quantity: stripe.Int64(int64(params.Storage)), UnitAmount: stripe.Int64(100), Currency: stripe.String(string(stripe.CurrencyUSD)), }) if err != nil { return nil, stripeErr.Wrap(err) } _, err = s.client.InvoiceItems.New(&stripe.InvoiceItemParams{ Customer: stripe.String(params.CustomerID), Description: stripe.String("Egress"), Quantity: stripe.Int64(int64(params.Egress)), UnitAmount: stripe.Int64(100), Currency: stripe.String(string(stripe.CurrencyUSD)), }) if err != nil { return nil, stripeErr.Wrap(err) } _, err = s.client.InvoiceItems.New(&stripe.InvoiceItemParams{ Customer: stripe.String(params.CustomerID), Description: stripe.String("ObjectsCount"), Quantity: stripe.Int64(int64(params.ObjectCount)), UnitAmount: stripe.Int64(100), Currency: stripe.String(string(stripe.CurrencyUSD)), }) if err != nil { return nil, stripeErr.Wrap(err) } // TODO: fetch card info manually? // create invoice invoiceParams := &stripe.InvoiceParams{ Customer: stripe.String(params.CustomerID), DefaultPaymentMethod: stripe.String(params.PaymentMethodID), Description: stripe.String(fmt.Sprintf("Invoice for usage of %s", params.ProjectName)), CustomFields: []*stripe.InvoiceCustomFieldParams{ { Name: stripe.String("Billing period"), Value: stripe.String(timeRangeString(params.StartDate, params.EndDate)), }, { Name: stripe.String("Project Name"), Value: stripe.String(params.ProjectName), }, }, AutoAdvance: stripe.Bool(true), } inv, err := s.client.Invoices.New(invoiceParams) if err != nil { return nil, stripeErr.Wrap(err) } // TODO: check for more items var lineItems []payments.LineItem for _, item := range inv.Lines.Data { lineItems = append(lineItems, payments.LineItem{ Key: item.Description, Quantity: item.Quantity, Amount: item.Amount, }) } var customFields []payments.CustomField for _, field := range inv.CustomFields { customFields = append(customFields, payments.CustomField{ Name: field.Name, Value: field.Value, }) } return &payments.Invoice{ ID: []byte(inv.ID), Amount: inv.AmountDue, Currency: payments.Currency(inv.Currency), LineItems: lineItems, CustomFields: customFields, CreatedAt: time.Unix(inv.Created, 0), }, nil } // GetInvoice retrieves an invoice from stripe network by invoiceID func (s *service) GetInvoice(ctx context.Context, id []byte) (*payments.Invoice, error) { inv, err := s.client.Invoices.Get(string(id), nil) if err != nil { return nil, stripeErr.Wrap(err) } // TODO: check for more items var lineItems []payments.LineItem for _, item := range inv.Lines.Data { lineItems = append(lineItems, payments.LineItem{ Key: item.Description, Quantity: item.Quantity, Amount: item.Amount, }) } var customFields []payments.CustomField for _, field := range inv.CustomFields { customFields = append(customFields, payments.CustomField{ Name: field.Name, Value: field.Value, }) } return &payments.Invoice{ ID: []byte(inv.ID), Amount: inv.AmountDue, Currency: payments.Currency(inv.Currency), LineItems: lineItems, CustomFields: customFields, CreatedAt: time.Unix(inv.Created, 0), }, nil } // timeRangeString helper function to create string representation of time range func timeRangeString(start, end time.Time) string { return fmt.Sprintf("%d/%d/%d - %d/%d/%d", start.UTC().Month(), start.UTC().Day(), start.UTC().Year(), end.UTC().Month(), end.UTC().Day(), end.UTC().Year()) }