09b0c2a630
* init marketing service Fix linting error Create offerdb implementation Create offers service Add update method Create offer table and migration Fix linting error fix conflicts Insert new data Change duration to have clear indication to be based on days add error wrapper Change from using uuid to int for id field * Create Marketing service * make error virable name more readable * add condition in update service method to check offer status * generate lock file Change get to listAllOffers * Add method for getting current offer wip * add check for expires_at in update method * Fix conflicts * add copyright header * Fix linting error * only allow update to active offers * add isDefault argument to GetCurrent * Update lock file * add migration file * finish migrate for adding credit_in_cents for both award and invitee * save 100 years as expiration date for default offers * create crud test for offers * add GetCurrent test * modify doc * Fix GetCurrent to work with default offer * fix linting issue * add more tests and address feedbacks * fix migration file * add type column back to match with mockup design * add type column back to match with mockup design * move doc changes to new pr * add comments * change GetCurrent to GetCurrentByType * fix typo
391 lines
9.5 KiB
Go
391 lines
9.5 KiB
Go
// Copyright (C) 2019 Storj Labs, Inc.
|
|
// See LICENSE for copying information.
|
|
|
|
// +build ignore
|
|
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/token"
|
|
"go/types"
|
|
"io/ioutil"
|
|
"os"
|
|
"path"
|
|
"sort"
|
|
"strings"
|
|
|
|
"golang.org/x/tools/go/ast/astutil"
|
|
"golang.org/x/tools/go/packages"
|
|
"golang.org/x/tools/imports"
|
|
)
|
|
|
|
func main() {
|
|
var outputPath string
|
|
var packageName string
|
|
var typeFullyQualifedName string
|
|
|
|
flag.StringVar(&outputPath, "o", "", "output file name")
|
|
flag.StringVar(&packageName, "p", "", "output package name")
|
|
flag.StringVar(&typeFullyQualifedName, "i", "", "interface to generate code for")
|
|
flag.Parse()
|
|
|
|
if outputPath == "" || packageName == "" || typeFullyQualifedName == "" {
|
|
fmt.Println("missing argument")
|
|
os.Exit(1)
|
|
}
|
|
|
|
var code Code
|
|
|
|
code.Imports = map[string]bool{}
|
|
code.Ignore = map[string]bool{
|
|
"error": true,
|
|
}
|
|
code.IgnoreMethods = map[string]bool{
|
|
"BeginTx": true,
|
|
}
|
|
code.OutputPackage = packageName
|
|
code.Config = &packages.Config{
|
|
Mode: packages.LoadAllSyntax,
|
|
}
|
|
code.Wrapped = map[string]bool{}
|
|
code.AdditionalNesting = map[string]int{"Console": 1, "Marketing": 1}
|
|
|
|
// e.g. storj.io/storj/satellite.DB
|
|
p := strings.LastIndexByte(typeFullyQualifedName, '.')
|
|
code.Package = typeFullyQualifedName[:p] // storj.io/storj/satellite
|
|
code.Type = typeFullyQualifedName[p+1:] // DB
|
|
code.QualifiedType = path.Base(code.Package) + "." + code.Type
|
|
|
|
var err error
|
|
code.Roots, err = packages.Load(code.Config, code.Package)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
code.PrintLocked()
|
|
code.PrintPreamble()
|
|
|
|
unformatted := code.Bytes()
|
|
|
|
imports.LocalPrefix = "storj.io"
|
|
formatted, err := imports.Process(outputPath, unformatted, nil)
|
|
if err != nil {
|
|
fmt.Println(string(unformatted))
|
|
panic(err)
|
|
}
|
|
|
|
if outputPath == "" {
|
|
fmt.Println(string(formatted))
|
|
return
|
|
}
|
|
|
|
err = ioutil.WriteFile(outputPath, formatted, 0644)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
// Methods is the common interface for types having methods.
|
|
type Methods interface {
|
|
Method(i int) *types.Func
|
|
NumMethods() int
|
|
}
|
|
|
|
// Code is the information for generating the code.
|
|
type Code struct {
|
|
Config *packages.Config
|
|
Package string
|
|
Type string
|
|
QualifiedType string
|
|
Roots []*packages.Package
|
|
|
|
OutputPackage string
|
|
|
|
Imports map[string]bool
|
|
Ignore map[string]bool
|
|
IgnoreMethods map[string]bool
|
|
Wrapped map[string]bool
|
|
AdditionalNesting map[string]int
|
|
|
|
Preamble bytes.Buffer
|
|
Source bytes.Buffer
|
|
}
|
|
|
|
// Bytes returns all code merged together
|
|
func (code *Code) Bytes() []byte {
|
|
var all bytes.Buffer
|
|
all.Write(code.Preamble.Bytes())
|
|
all.Write(code.Source.Bytes())
|
|
return all.Bytes()
|
|
}
|
|
|
|
// PrintPreamble creates package header and imports.
|
|
func (code *Code) PrintPreamble() {
|
|
w := &code.Preamble
|
|
fmt.Fprintf(w, "// Code generated by lockedgen using 'go generate'. DO NOT EDIT.\n\n")
|
|
fmt.Fprintf(w, "// Copyright (C) 2019 Storj Labs, Inc.\n")
|
|
fmt.Fprintf(w, "// See LICENSE for copying information.\n\n")
|
|
fmt.Fprintf(w, "package %v\n\n", code.OutputPackage)
|
|
fmt.Fprintf(w, "import (\n")
|
|
|
|
var imports []string
|
|
for imp := range code.Imports {
|
|
imports = append(imports, imp)
|
|
}
|
|
sort.Strings(imports)
|
|
for _, imp := range imports {
|
|
fmt.Fprintf(w, " %q\n", imp)
|
|
}
|
|
fmt.Fprintf(w, ")\n\n")
|
|
}
|
|
|
|
// PrintLocked writes locked wrapper and methods.
|
|
func (code *Code) PrintLocked() {
|
|
code.Imports["sync"] = true
|
|
code.Imports["storj.io/statellite"] = true
|
|
|
|
code.Printf("// locked implements a locking wrapper around satellite.DB.\n")
|
|
code.Printf("type locked struct {\n")
|
|
code.Printf(" sync.Locker\n")
|
|
code.Printf(" db %v\n", code.QualifiedType)
|
|
code.Printf("}\n\n")
|
|
|
|
code.Printf("// newLocked returns database wrapped with locker.\n")
|
|
code.Printf("func newLocked(db %v) %v {\n", code.QualifiedType, code.QualifiedType)
|
|
code.Printf(" return &locked{&sync.Mutex{}, db}\n")
|
|
code.Printf("}\n\n")
|
|
|
|
// find the satellite.DB type info
|
|
dbObject := code.Roots[0].Types.Scope().Lookup(code.Type)
|
|
methods := dbObject.Type().Underlying().(Methods)
|
|
|
|
for i := 0; i < methods.NumMethods(); i++ {
|
|
code.PrintLockedFunc("locked", methods.Method(i), code.AdditionalNesting[methods.Method(i).Name()]+1)
|
|
}
|
|
}
|
|
|
|
// Printf writes formatted text to source.
|
|
func (code *Code) Printf(format string, a ...interface{}) {
|
|
fmt.Fprintf(&code.Source, format, a...)
|
|
}
|
|
|
|
// PrintSignature prints method signature.
|
|
func (code *Code) PrintSignature(sig *types.Signature) {
|
|
code.PrintSignatureTuple(sig.Params(), true)
|
|
if sig.Results().Len() > 0 {
|
|
code.Printf(" ")
|
|
code.PrintSignatureTuple(sig.Results(), false)
|
|
}
|
|
}
|
|
|
|
// PrintSignatureTuple prints method tuple, params or results.
|
|
func (code *Code) PrintSignatureTuple(tuple *types.Tuple, needsNames bool) {
|
|
code.Printf("(")
|
|
defer code.Printf(")")
|
|
|
|
for i := 0; i < tuple.Len(); i++ {
|
|
if i > 0 {
|
|
code.Printf(", ")
|
|
}
|
|
|
|
param := tuple.At(i)
|
|
if code.PrintName(tuple.At(i), i, needsNames) {
|
|
code.Printf(" ")
|
|
}
|
|
code.PrintType(param.Type())
|
|
}
|
|
}
|
|
|
|
// PrintCall prints a call using the specified signature.
|
|
func (code *Code) PrintCall(sig *types.Signature) {
|
|
code.Printf("(")
|
|
defer code.Printf(")")
|
|
|
|
params := sig.Params()
|
|
for i := 0; i < params.Len(); i++ {
|
|
if i != 0 {
|
|
code.Printf(", ")
|
|
}
|
|
code.PrintName(params.At(i), i, true)
|
|
}
|
|
}
|
|
|
|
// PrintName prints an appropriate name from signature tuple.
|
|
func (code *Code) PrintName(v *types.Var, index int, needsNames bool) bool {
|
|
name := v.Name()
|
|
if needsNames && name == "" {
|
|
if v.Type().String() == "context.Context" {
|
|
code.Printf("ctx")
|
|
return true
|
|
}
|
|
code.Printf("a%d", index)
|
|
return true
|
|
}
|
|
code.Printf("%s", name)
|
|
return name != ""
|
|
}
|
|
|
|
// PrintType prints short form of type t.
|
|
func (code *Code) PrintType(t types.Type) {
|
|
types.WriteType(&code.Source, t, (*types.Package).Name)
|
|
}
|
|
|
|
func typeName(typ types.Type) string {
|
|
var body bytes.Buffer
|
|
types.WriteType(&body, typ, (*types.Package).Name)
|
|
return body.String()
|
|
}
|
|
|
|
// IncludeImports imports all types referenced in the signature.
|
|
func (code *Code) IncludeImports(sig *types.Signature) {
|
|
var tmp bytes.Buffer
|
|
types.WriteSignature(&tmp, sig, func(p *types.Package) string {
|
|
code.Imports[p.Path()] = true
|
|
return p.Name()
|
|
})
|
|
}
|
|
|
|
// NeedsWrapper checks whether method result needs a wrapper type.
|
|
func (code *Code) NeedsWrapper(method *types.Func) bool {
|
|
if code.IgnoreMethods[method.Name()] {
|
|
return false
|
|
}
|
|
|
|
sig := method.Type().Underlying().(*types.Signature)
|
|
return sig.Results().Len() == 1 && !code.Ignore[sig.Results().At(0).Type().String()]
|
|
}
|
|
|
|
// WrapperTypeName returns an appropriate name for the wrapper type.
|
|
func (code *Code) WrapperTypeName(method *types.Func) string {
|
|
return "locked" + method.Name()
|
|
}
|
|
|
|
// PrintLockedFunc prints a method with locking and defers the actual logic to method.
|
|
func (code *Code) PrintLockedFunc(receiverType string, method *types.Func, nestingDepth int) {
|
|
if code.IgnoreMethods[method.Name()] {
|
|
return
|
|
}
|
|
|
|
sig := method.Type().Underlying().(*types.Signature)
|
|
code.IncludeImports(sig)
|
|
|
|
doc := strings.TrimSpace(code.MethodDoc(method))
|
|
if doc != "" {
|
|
for _, line := range strings.Split(doc, "\n") {
|
|
code.Printf("// %s\n", line)
|
|
}
|
|
}
|
|
code.Printf("func (m *%s) %s", receiverType, method.Name())
|
|
code.PrintSignature(sig)
|
|
code.Printf(" {\n")
|
|
|
|
code.Printf(" m.Lock(); defer m.Unlock()\n")
|
|
if !code.NeedsWrapper(method) {
|
|
code.Printf(" return m.db.%s", method.Name())
|
|
code.PrintCall(sig)
|
|
code.Printf("\n")
|
|
code.Printf("}\n\n")
|
|
return
|
|
}
|
|
|
|
code.Printf(" return &%s{m.Locker, ", code.WrapperTypeName(method))
|
|
code.Printf("m.db.%s", method.Name())
|
|
code.PrintCall(sig)
|
|
code.Printf("}\n")
|
|
code.Printf("}\n\n")
|
|
|
|
if nestingDepth > 0 {
|
|
code.PrintWrapper(method, nestingDepth-1)
|
|
}
|
|
}
|
|
|
|
// PrintWrapper prints wrapper for the result type of method.
|
|
func (code *Code) PrintWrapper(method *types.Func, nestingDepth int) {
|
|
sig := method.Type().Underlying().(*types.Signature)
|
|
results := sig.Results()
|
|
result := results.At(0).Type()
|
|
|
|
receiverType := code.WrapperTypeName(method)
|
|
|
|
if code.Wrapped[receiverType] {
|
|
return
|
|
}
|
|
code.Wrapped[receiverType] = true
|
|
|
|
code.Printf("// %s implements locking wrapper for %s\n", receiverType, typeName(result))
|
|
code.Printf("type %s struct {\n", receiverType)
|
|
code.Printf(" sync.Locker\n")
|
|
code.Printf(" db %s\n", typeName(result))
|
|
code.Printf("}\n\n")
|
|
|
|
methods := result.Underlying().(Methods)
|
|
for i := 0; i < methods.NumMethods(); i++ {
|
|
code.PrintLockedFunc(receiverType, methods.Method(i), nestingDepth)
|
|
}
|
|
}
|
|
|
|
// MethodDoc finds documentation for the specified method.
|
|
func (code *Code) MethodDoc(method *types.Func) string {
|
|
file := code.FindASTFile(method.Pos())
|
|
if file == nil {
|
|
return ""
|
|
}
|
|
|
|
path, exact := astutil.PathEnclosingInterval(file, method.Pos(), method.Pos())
|
|
if !exact {
|
|
return ""
|
|
}
|
|
|
|
for _, p := range path {
|
|
switch decl := p.(type) {
|
|
case *ast.Field:
|
|
return decl.Doc.Text()
|
|
case *ast.GenDecl:
|
|
return decl.Doc.Text()
|
|
case *ast.FuncDecl:
|
|
return decl.Doc.Text()
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// FindASTFile finds the *ast.File at the specified position.
|
|
func (code *Code) FindASTFile(pos token.Pos) *ast.File {
|
|
seen := map[*packages.Package]bool{}
|
|
|
|
// find searches pos recursively from p and its dependencies.
|
|
var find func(p *packages.Package) *ast.File
|
|
find = func(p *packages.Package) *ast.File {
|
|
if seen[p] {
|
|
return nil
|
|
}
|
|
seen[p] = true
|
|
|
|
for _, file := range p.Syntax {
|
|
if file.Pos() <= pos && pos <= file.End() {
|
|
return file
|
|
}
|
|
}
|
|
|
|
for _, dep := range p.Imports {
|
|
if file := find(dep); file != nil {
|
|
return file
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
for _, root := range code.Roots {
|
|
if file := find(root); file != nil {
|
|
return file
|
|
}
|
|
}
|
|
return nil
|
|
}
|