2019-01-24 20:15:10 +00:00
|
|
|
// Copyright (C) 2019 Storj Labs, Inc.
|
2019-01-02 17:53:27 +00:00
|
|
|
// See LICENSE for copying information.
|
|
|
|
|
2019-01-11 16:07:26 +00:00
|
|
|
// +build ignore
|
|
|
|
|
2019-01-02 17:53:27 +00:00
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"flag"
|
|
|
|
"fmt"
|
|
|
|
"go/ast"
|
|
|
|
"go/token"
|
|
|
|
"go/types"
|
|
|
|
"io/ioutil"
|
2019-01-11 16:07:26 +00:00
|
|
|
"os"
|
|
|
|
"path"
|
2019-01-02 17:53:27 +00:00
|
|
|
"sort"
|
2019-01-11 16:07:26 +00:00
|
|
|
"strings"
|
2019-01-02 17:53:27 +00:00
|
|
|
|
|
|
|
"golang.org/x/tools/go/ast/astutil"
|
|
|
|
"golang.org/x/tools/go/packages"
|
|
|
|
"golang.org/x/tools/imports"
|
|
|
|
)
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
var outputPath string
|
2019-01-11 16:07:26 +00:00
|
|
|
var packageName string
|
|
|
|
var typeFullyQualifedName string
|
|
|
|
|
2019-01-02 17:53:27 +00:00
|
|
|
flag.StringVar(&outputPath, "o", "", "output file name")
|
2019-01-11 16:07:26 +00:00
|
|
|
flag.StringVar(&packageName, "p", "", "output package name")
|
|
|
|
flag.StringVar(&typeFullyQualifedName, "i", "", "interface to generate code for")
|
2019-01-02 17:53:27 +00:00
|
|
|
flag.Parse()
|
|
|
|
|
2019-01-11 16:07:26 +00:00
|
|
|
if outputPath == "" || packageName == "" || typeFullyQualifedName == "" {
|
|
|
|
fmt.Println("missing argument")
|
|
|
|
os.Exit(1)
|
|
|
|
}
|
|
|
|
|
2019-01-02 17:53:27 +00:00
|
|
|
var code Code
|
|
|
|
|
|
|
|
code.Imports = map[string]bool{}
|
|
|
|
code.Ignore = map[string]bool{
|
|
|
|
"error": true,
|
|
|
|
}
|
2019-01-11 16:07:26 +00:00
|
|
|
code.IgnoreMethods = map[string]bool{
|
|
|
|
"BeginTx": true,
|
|
|
|
}
|
|
|
|
code.OutputPackage = packageName
|
2019-01-02 17:53:27 +00:00
|
|
|
code.Config = &packages.Config{
|
|
|
|
Mode: packages.LoadAllSyntax,
|
|
|
|
}
|
2019-01-15 18:29:52 +00:00
|
|
|
code.Wrapped = map[string]bool{}
|
|
|
|
code.AdditionalNesting = map[string]int{"Console": 1}
|
2019-01-11 16:07:26 +00:00
|
|
|
|
|
|
|
// e.g. storj.io/storj/satellite.DB
|
|
|
|
p := strings.LastIndexByte(typeFullyQualifedName, '.')
|
|
|
|
code.Package = typeFullyQualifedName[:p] // storj.io/storj/satellite
|
|
|
|
code.Type = typeFullyQualifedName[p+1:] // DB
|
|
|
|
code.QualifiedType = path.Base(code.Package) + "." + code.Type
|
2019-01-02 17:53:27 +00:00
|
|
|
|
|
|
|
var err error
|
|
|
|
code.Roots, err = packages.Load(code.Config, code.Package)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
code.PrintLocked()
|
|
|
|
code.PrintPreamble()
|
|
|
|
|
|
|
|
unformatted := code.Bytes()
|
|
|
|
|
|
|
|
imports.LocalPrefix = "storj.io"
|
|
|
|
formatted, err := imports.Process(outputPath, unformatted, nil)
|
|
|
|
if err != nil {
|
|
|
|
fmt.Println(string(unformatted))
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if outputPath == "" {
|
|
|
|
fmt.Println(string(formatted))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
err = ioutil.WriteFile(outputPath, formatted, 0644)
|
|
|
|
if err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Methods is the common interface for types having methods.
|
|
|
|
type Methods interface {
|
|
|
|
Method(i int) *types.Func
|
|
|
|
NumMethods() int
|
|
|
|
}
|
|
|
|
|
|
|
|
// Code is the information for generating the code.
|
|
|
|
type Code struct {
|
2019-01-11 16:07:26 +00:00
|
|
|
Config *packages.Config
|
|
|
|
Package string
|
|
|
|
Type string
|
|
|
|
QualifiedType string
|
|
|
|
Roots []*packages.Package
|
2019-01-02 17:53:27 +00:00
|
|
|
|
2019-01-11 16:07:26 +00:00
|
|
|
OutputPackage string
|
|
|
|
|
2019-01-15 18:29:52 +00:00
|
|
|
Imports map[string]bool
|
|
|
|
Ignore map[string]bool
|
|
|
|
IgnoreMethods map[string]bool
|
|
|
|
Wrapped map[string]bool
|
|
|
|
AdditionalNesting map[string]int
|
2019-01-02 17:53:27 +00:00
|
|
|
|
|
|
|
Preamble bytes.Buffer
|
|
|
|
Source bytes.Buffer
|
|
|
|
}
|
|
|
|
|
|
|
|
// Bytes returns all code merged together
|
|
|
|
func (code *Code) Bytes() []byte {
|
|
|
|
var all bytes.Buffer
|
|
|
|
all.Write(code.Preamble.Bytes())
|
|
|
|
all.Write(code.Source.Bytes())
|
|
|
|
return all.Bytes()
|
|
|
|
}
|
|
|
|
|
|
|
|
// PrintPreamble creates package header and imports.
|
|
|
|
func (code *Code) PrintPreamble() {
|
|
|
|
w := &code.Preamble
|
|
|
|
fmt.Fprintf(w, "// Code generated by lockedgen using 'go generate'. DO NOT EDIT.\n\n")
|
2019-01-11 16:07:26 +00:00
|
|
|
fmt.Fprintf(w, "// Copyright (C) 2019 Storj Labs, Inc.\n")
|
2019-01-02 17:53:27 +00:00
|
|
|
fmt.Fprintf(w, "// See LICENSE for copying information.\n\n")
|
2019-01-11 16:07:26 +00:00
|
|
|
fmt.Fprintf(w, "package %v\n\n", code.OutputPackage)
|
2019-01-02 17:53:27 +00:00
|
|
|
fmt.Fprintf(w, "import (\n")
|
|
|
|
|
|
|
|
var imports []string
|
|
|
|
for imp := range code.Imports {
|
|
|
|
imports = append(imports, imp)
|
|
|
|
}
|
|
|
|
sort.Strings(imports)
|
|
|
|
for _, imp := range imports {
|
|
|
|
fmt.Fprintf(w, " %q\n", imp)
|
|
|
|
}
|
|
|
|
fmt.Fprintf(w, ")\n\n")
|
|
|
|
}
|
|
|
|
|
|
|
|
// PrintLocked writes locked wrapper and methods.
|
|
|
|
func (code *Code) PrintLocked() {
|
|
|
|
code.Imports["sync"] = true
|
|
|
|
code.Imports["storj.io/statellite"] = true
|
|
|
|
|
|
|
|
code.Printf("// locked implements a locking wrapper around satellite.DB.\n")
|
|
|
|
code.Printf("type locked struct {\n")
|
|
|
|
code.Printf(" sync.Locker\n")
|
2019-01-11 16:07:26 +00:00
|
|
|
code.Printf(" db %v\n", code.QualifiedType)
|
2019-01-02 17:53:27 +00:00
|
|
|
code.Printf("}\n\n")
|
|
|
|
|
|
|
|
code.Printf("// newLocked returns database wrapped with locker.\n")
|
2019-01-11 16:07:26 +00:00
|
|
|
code.Printf("func newLocked(db %v) %v {\n", code.QualifiedType, code.QualifiedType)
|
2019-01-02 17:53:27 +00:00
|
|
|
code.Printf(" return &locked{&sync.Mutex{}, db}\n")
|
|
|
|
code.Printf("}\n\n")
|
|
|
|
|
|
|
|
// find the satellite.DB type info
|
2019-01-11 16:07:26 +00:00
|
|
|
dbObject := code.Roots[0].Types.Scope().Lookup(code.Type)
|
2019-01-02 17:53:27 +00:00
|
|
|
methods := dbObject.Type().Underlying().(Methods)
|
|
|
|
|
|
|
|
for i := 0; i < methods.NumMethods(); i++ {
|
2019-01-15 18:29:52 +00:00
|
|
|
code.PrintLockedFunc("locked", methods.Method(i), code.AdditionalNesting[methods.Method(i).Name()]+1)
|
2019-01-02 17:53:27 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Printf writes formatted text to source.
|
|
|
|
func (code *Code) Printf(format string, a ...interface{}) {
|
|
|
|
fmt.Fprintf(&code.Source, format, a...)
|
|
|
|
}
|
|
|
|
|
|
|
|
// PrintSignature prints method signature.
|
|
|
|
func (code *Code) PrintSignature(sig *types.Signature) {
|
|
|
|
code.PrintSignatureTuple(sig.Params(), true)
|
|
|
|
if sig.Results().Len() > 0 {
|
|
|
|
code.Printf(" ")
|
|
|
|
code.PrintSignatureTuple(sig.Results(), false)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// PrintSignatureTuple prints method tuple, params or results.
|
|
|
|
func (code *Code) PrintSignatureTuple(tuple *types.Tuple, needsNames bool) {
|
|
|
|
code.Printf("(")
|
|
|
|
defer code.Printf(")")
|
|
|
|
|
|
|
|
for i := 0; i < tuple.Len(); i++ {
|
|
|
|
if i > 0 {
|
|
|
|
code.Printf(", ")
|
|
|
|
}
|
|
|
|
|
|
|
|
param := tuple.At(i)
|
|
|
|
if code.PrintName(tuple.At(i), i, needsNames) {
|
|
|
|
code.Printf(" ")
|
|
|
|
}
|
|
|
|
code.PrintType(param.Type())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// PrintCall prints a call using the specified signature.
|
|
|
|
func (code *Code) PrintCall(sig *types.Signature) {
|
|
|
|
code.Printf("(")
|
|
|
|
defer code.Printf(")")
|
|
|
|
|
|
|
|
params := sig.Params()
|
|
|
|
for i := 0; i < params.Len(); i++ {
|
|
|
|
if i != 0 {
|
|
|
|
code.Printf(", ")
|
|
|
|
}
|
|
|
|
code.PrintName(params.At(i), i, true)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// PrintName prints an appropriate name from signature tuple.
|
|
|
|
func (code *Code) PrintName(v *types.Var, index int, needsNames bool) bool {
|
|
|
|
name := v.Name()
|
|
|
|
if needsNames && name == "" {
|
|
|
|
if v.Type().String() == "context.Context" {
|
|
|
|
code.Printf("ctx")
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
code.Printf("a%d", index)
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
code.Printf("%s", name)
|
|
|
|
return name != ""
|
|
|
|
}
|
|
|
|
|
|
|
|
// PrintType prints short form of type t.
|
|
|
|
func (code *Code) PrintType(t types.Type) {
|
|
|
|
types.WriteType(&code.Source, t, (*types.Package).Name)
|
|
|
|
}
|
|
|
|
|
|
|
|
func typeName(typ types.Type) string {
|
|
|
|
var body bytes.Buffer
|
|
|
|
types.WriteType(&body, typ, (*types.Package).Name)
|
|
|
|
return body.String()
|
|
|
|
}
|
|
|
|
|
|
|
|
// IncludeImports imports all types referenced in the signature.
|
|
|
|
func (code *Code) IncludeImports(sig *types.Signature) {
|
|
|
|
var tmp bytes.Buffer
|
|
|
|
types.WriteSignature(&tmp, sig, func(p *types.Package) string {
|
|
|
|
code.Imports[p.Path()] = true
|
|
|
|
return p.Name()
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
// NeedsWrapper checks whether method result needs a wrapper type.
|
|
|
|
func (code *Code) NeedsWrapper(method *types.Func) bool {
|
2019-01-11 16:07:26 +00:00
|
|
|
if code.IgnoreMethods[method.Name()] {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
2019-01-02 17:53:27 +00:00
|
|
|
sig := method.Type().Underlying().(*types.Signature)
|
|
|
|
return sig.Results().Len() == 1 && !code.Ignore[sig.Results().At(0).Type().String()]
|
|
|
|
}
|
|
|
|
|
2019-01-15 18:29:52 +00:00
|
|
|
// WrapperTypeName returns an appropriate name for the wrapper type.
|
2019-01-02 17:53:27 +00:00
|
|
|
func (code *Code) WrapperTypeName(method *types.Func) string {
|
|
|
|
return "locked" + method.Name()
|
|
|
|
}
|
|
|
|
|
|
|
|
// PrintLockedFunc prints a method with locking and defers the actual logic to method.
|
2019-01-15 18:29:52 +00:00
|
|
|
func (code *Code) PrintLockedFunc(receiverType string, method *types.Func, nestingDepth int) {
|
2019-01-11 16:07:26 +00:00
|
|
|
if code.IgnoreMethods[method.Name()] {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2019-01-02 17:53:27 +00:00
|
|
|
sig := method.Type().Underlying().(*types.Signature)
|
|
|
|
code.IncludeImports(sig)
|
|
|
|
|
|
|
|
doc := code.MethodDoc(method)
|
|
|
|
if doc != "" {
|
|
|
|
code.Printf("// %s", code.MethodDoc(method))
|
|
|
|
}
|
|
|
|
code.Printf("func (m *%s) %s", receiverType, method.Name())
|
|
|
|
code.PrintSignature(sig)
|
|
|
|
code.Printf(" {\n")
|
|
|
|
|
|
|
|
code.Printf(" m.Lock(); defer m.Unlock()\n")
|
2019-01-15 18:29:52 +00:00
|
|
|
if !code.NeedsWrapper(method) {
|
2019-01-02 17:53:27 +00:00
|
|
|
code.Printf(" return m.db.%s", method.Name())
|
|
|
|
code.PrintCall(sig)
|
|
|
|
code.Printf("\n")
|
2019-01-15 18:29:52 +00:00
|
|
|
code.Printf("}\n\n")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
code.Printf(" return &%s{m.Locker, ", code.WrapperTypeName(method))
|
|
|
|
code.Printf("m.db.%s", method.Name())
|
|
|
|
code.PrintCall(sig)
|
|
|
|
code.Printf("}\n")
|
|
|
|
code.Printf("}\n\n")
|
|
|
|
|
|
|
|
if nestingDepth > 0 {
|
|
|
|
code.PrintWrapper(method, nestingDepth-1)
|
2019-01-02 17:53:27 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// PrintWrapper prints wrapper for the result type of method.
|
2019-01-15 18:29:52 +00:00
|
|
|
func (code *Code) PrintWrapper(method *types.Func, nestingDepth int) {
|
2019-01-02 17:53:27 +00:00
|
|
|
sig := method.Type().Underlying().(*types.Signature)
|
|
|
|
results := sig.Results()
|
|
|
|
result := results.At(0).Type()
|
|
|
|
|
|
|
|
receiverType := code.WrapperTypeName(method)
|
2019-01-15 18:29:52 +00:00
|
|
|
|
|
|
|
if code.Wrapped[receiverType] {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
code.Wrapped[receiverType] = true
|
|
|
|
|
2019-01-02 17:53:27 +00:00
|
|
|
code.Printf("// %s implements locking wrapper for %s\n", receiverType, typeName(result))
|
|
|
|
code.Printf("type %s struct {\n", receiverType)
|
|
|
|
code.Printf(" sync.Locker\n")
|
|
|
|
code.Printf(" db %s\n", typeName(result))
|
|
|
|
code.Printf("}\n\n")
|
|
|
|
|
|
|
|
methods := result.Underlying().(Methods)
|
|
|
|
for i := 0; i < methods.NumMethods(); i++ {
|
2019-01-15 18:29:52 +00:00
|
|
|
code.PrintLockedFunc(receiverType, methods.Method(i), nestingDepth)
|
2019-01-02 17:53:27 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// MethodDoc finds documentation for the specified method.
|
|
|
|
func (code *Code) MethodDoc(method *types.Func) string {
|
|
|
|
file := code.FindASTFile(method.Pos())
|
|
|
|
if file == nil {
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
|
|
|
path, exact := astutil.PathEnclosingInterval(file, method.Pos(), method.Pos())
|
|
|
|
if !exact {
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, p := range path {
|
|
|
|
switch decl := p.(type) {
|
|
|
|
case *ast.Field:
|
|
|
|
return decl.Doc.Text()
|
|
|
|
case *ast.GenDecl:
|
|
|
|
return decl.Doc.Text()
|
|
|
|
case *ast.FuncDecl:
|
|
|
|
return decl.Doc.Text()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
|
|
|
// FindASTFile finds the *ast.File at the specified position.
|
|
|
|
func (code *Code) FindASTFile(pos token.Pos) *ast.File {
|
|
|
|
seen := map[*packages.Package]bool{}
|
|
|
|
|
|
|
|
// find searches pos recursively from p and its dependencies.
|
|
|
|
var find func(p *packages.Package) *ast.File
|
|
|
|
find = func(p *packages.Package) *ast.File {
|
|
|
|
if seen[p] {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
seen[p] = true
|
|
|
|
|
|
|
|
for _, file := range p.Syntax {
|
|
|
|
if file.Pos() <= pos && pos <= file.End() {
|
|
|
|
return file
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, dep := range p.Imports {
|
|
|
|
if file := find(dep); file != nil {
|
|
|
|
return file
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, root := range code.Roots {
|
|
|
|
if file := find(root); file != nil {
|
|
|
|
return file
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|