storj/scripts/lockedgen.go

391 lines
9.5 KiB
Go
Raw Normal View History

2019-01-24 20:15:10 +00:00
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
// +build ignore
package main
import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/token"
"go/types"
"io/ioutil"
"os"
"path"
"sort"
"strings"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
)
func main() {
var outputPath string
var packageName string
var typeFullyQualifedName string
flag.StringVar(&outputPath, "o", "", "output file name")
flag.StringVar(&packageName, "p", "", "output package name")
flag.StringVar(&typeFullyQualifedName, "i", "", "interface to generate code for")
flag.Parse()
if outputPath == "" || packageName == "" || typeFullyQualifedName == "" {
fmt.Println("missing argument")
os.Exit(1)
}
var code Code
code.Imports = map[string]bool{}
code.Ignore = map[string]bool{
"error": true,
}
code.IgnoreMethods = map[string]bool{
"BeginTx": true,
}
code.OutputPackage = packageName
code.Config = &packages.Config{
Mode: packages.LoadAllSyntax,
}
code.Wrapped = map[string]bool{}
code.AdditionalNesting = map[string]int{"Console": 1, "Marketing": 1}
// e.g. storj.io/storj/satellite.DB
p := strings.LastIndexByte(typeFullyQualifedName, '.')
code.Package = typeFullyQualifedName[:p] // storj.io/storj/satellite
code.Type = typeFullyQualifedName[p+1:] // DB
code.QualifiedType = path.Base(code.Package) + "." + code.Type
var err error
code.Roots, err = packages.Load(code.Config, code.Package)
if err != nil {
panic(err)
}
code.PrintLocked()
code.PrintPreamble()
unformatted := code.Bytes()
imports.LocalPrefix = "storj.io"
formatted, err := imports.Process(outputPath, unformatted, nil)
if err != nil {
fmt.Println(string(unformatted))
panic(err)
}
if outputPath == "" {
fmt.Println(string(formatted))
return
}
err = ioutil.WriteFile(outputPath, formatted, 0644)
if err != nil {
panic(err)
}
}
// Methods is the common interface for types having methods.
type Methods interface {
Method(i int) *types.Func
NumMethods() int
}
// Code is the information for generating the code.
type Code struct {
Config *packages.Config
Package string
Type string
QualifiedType string
Roots []*packages.Package
OutputPackage string
Imports map[string]bool
Ignore map[string]bool
IgnoreMethods map[string]bool
Wrapped map[string]bool
AdditionalNesting map[string]int
Preamble bytes.Buffer
Source bytes.Buffer
}
// Bytes returns all code merged together
func (code *Code) Bytes() []byte {
var all bytes.Buffer
all.Write(code.Preamble.Bytes())
all.Write(code.Source.Bytes())
return all.Bytes()
}
// PrintPreamble creates package header and imports.
func (code *Code) PrintPreamble() {
w := &code.Preamble
fmt.Fprintf(w, "// Code generated by lockedgen using 'go generate'. DO NOT EDIT.\n\n")
fmt.Fprintf(w, "// Copyright (C) 2019 Storj Labs, Inc.\n")
fmt.Fprintf(w, "// See LICENSE for copying information.\n\n")
fmt.Fprintf(w, "package %v\n\n", code.OutputPackage)
fmt.Fprintf(w, "import (\n")
var imports []string
for imp := range code.Imports {
imports = append(imports, imp)
}
sort.Strings(imports)
for _, imp := range imports {
fmt.Fprintf(w, " %q\n", imp)
}
fmt.Fprintf(w, ")\n\n")
}
// PrintLocked writes locked wrapper and methods.
func (code *Code) PrintLocked() {
code.Imports["sync"] = true
code.Imports["storj.io/statellite"] = true
code.Printf("// locked implements a locking wrapper around satellite.DB.\n")
code.Printf("type locked struct {\n")
code.Printf(" sync.Locker\n")
code.Printf(" db %v\n", code.QualifiedType)
code.Printf("}\n\n")
code.Printf("// newLocked returns database wrapped with locker.\n")
code.Printf("func newLocked(db %v) %v {\n", code.QualifiedType, code.QualifiedType)
code.Printf(" return &locked{&sync.Mutex{}, db}\n")
code.Printf("}\n\n")
// find the satellite.DB type info
dbObject := code.Roots[0].Types.Scope().Lookup(code.Type)
methods := dbObject.Type().Underlying().(Methods)
for i := 0; i < methods.NumMethods(); i++ {
code.PrintLockedFunc("locked", methods.Method(i), code.AdditionalNesting[methods.Method(i).Name()]+1)
}
}
// Printf writes formatted text to source.
func (code *Code) Printf(format string, a ...interface{}) {
fmt.Fprintf(&code.Source, format, a...)
}
// PrintSignature prints method signature.
func (code *Code) PrintSignature(sig *types.Signature) {
code.PrintSignatureTuple(sig.Params(), true)
if sig.Results().Len() > 0 {
code.Printf(" ")
code.PrintSignatureTuple(sig.Results(), false)
}
}
// PrintSignatureTuple prints method tuple, params or results.
func (code *Code) PrintSignatureTuple(tuple *types.Tuple, needsNames bool) {
code.Printf("(")
defer code.Printf(")")
for i := 0; i < tuple.Len(); i++ {
if i > 0 {
code.Printf(", ")
}
param := tuple.At(i)
if code.PrintName(tuple.At(i), i, needsNames) {
code.Printf(" ")
}
code.PrintType(param.Type())
}
}
// PrintCall prints a call using the specified signature.
func (code *Code) PrintCall(sig *types.Signature) {
code.Printf("(")
defer code.Printf(")")
params := sig.Params()
for i := 0; i < params.Len(); i++ {
if i != 0 {
code.Printf(", ")
}
code.PrintName(params.At(i), i, true)
}
}
// PrintName prints an appropriate name from signature tuple.
func (code *Code) PrintName(v *types.Var, index int, needsNames bool) bool {
name := v.Name()
if needsNames && name == "" {
if v.Type().String() == "context.Context" {
code.Printf("ctx")
return true
}
code.Printf("a%d", index)
return true
}
code.Printf("%s", name)
return name != ""
}
// PrintType prints short form of type t.
func (code *Code) PrintType(t types.Type) {
types.WriteType(&code.Source, t, (*types.Package).Name)
}
func typeName(typ types.Type) string {
var body bytes.Buffer
types.WriteType(&body, typ, (*types.Package).Name)
return body.String()
}
// IncludeImports imports all types referenced in the signature.
func (code *Code) IncludeImports(sig *types.Signature) {
var tmp bytes.Buffer
types.WriteSignature(&tmp, sig, func(p *types.Package) string {
code.Imports[p.Path()] = true
return p.Name()
})
}
// NeedsWrapper checks whether method result needs a wrapper type.
func (code *Code) NeedsWrapper(method *types.Func) bool {
if code.IgnoreMethods[method.Name()] {
return false
}
sig := method.Type().Underlying().(*types.Signature)
return sig.Results().Len() == 1 && !code.Ignore[sig.Results().At(0).Type().String()]
}
// WrapperTypeName returns an appropriate name for the wrapper type.
func (code *Code) WrapperTypeName(method *types.Func) string {
return "locked" + method.Name()
}
// PrintLockedFunc prints a method with locking and defers the actual logic to method.
func (code *Code) PrintLockedFunc(receiverType string, method *types.Func, nestingDepth int) {
if code.IgnoreMethods[method.Name()] {
return
}
sig := method.Type().Underlying().(*types.Signature)
code.IncludeImports(sig)
doc := strings.TrimSpace(code.MethodDoc(method))
if doc != "" {
for _, line := range strings.Split(doc, "\n") {
code.Printf("// %s\n", line)
}
}
code.Printf("func (m *%s) %s", receiverType, method.Name())
code.PrintSignature(sig)
code.Printf(" {\n")
code.Printf(" m.Lock(); defer m.Unlock()\n")
if !code.NeedsWrapper(method) {
code.Printf(" return m.db.%s", method.Name())
code.PrintCall(sig)
code.Printf("\n")
code.Printf("}\n\n")
return
}
code.Printf(" return &%s{m.Locker, ", code.WrapperTypeName(method))
code.Printf("m.db.%s", method.Name())
code.PrintCall(sig)
code.Printf("}\n")
code.Printf("}\n\n")
if nestingDepth > 0 {
code.PrintWrapper(method, nestingDepth-1)
}
}
// PrintWrapper prints wrapper for the result type of method.
func (code *Code) PrintWrapper(method *types.Func, nestingDepth int) {
sig := method.Type().Underlying().(*types.Signature)
results := sig.Results()
result := results.At(0).Type()
receiverType := code.WrapperTypeName(method)
if code.Wrapped[receiverType] {
return
}
code.Wrapped[receiverType] = true
code.Printf("// %s implements locking wrapper for %s\n", receiverType, typeName(result))
code.Printf("type %s struct {\n", receiverType)
code.Printf(" sync.Locker\n")
code.Printf(" db %s\n", typeName(result))
code.Printf("}\n\n")
methods := result.Underlying().(Methods)
for i := 0; i < methods.NumMethods(); i++ {
code.PrintLockedFunc(receiverType, methods.Method(i), nestingDepth)
}
}
// MethodDoc finds documentation for the specified method.
func (code *Code) MethodDoc(method *types.Func) string {
file := code.FindASTFile(method.Pos())
if file == nil {
return ""
}
path, exact := astutil.PathEnclosingInterval(file, method.Pos(), method.Pos())
if !exact {
return ""
}
for _, p := range path {
switch decl := p.(type) {
case *ast.Field:
return decl.Doc.Text()
case *ast.GenDecl:
return decl.Doc.Text()
case *ast.FuncDecl:
return decl.Doc.Text()
}
}
return ""
}
// FindASTFile finds the *ast.File at the specified position.
func (code *Code) FindASTFile(pos token.Pos) *ast.File {
seen := map[*packages.Package]bool{}
// find searches pos recursively from p and its dependencies.
var find func(p *packages.Package) *ast.File
find = func(p *packages.Package) *ast.File {
if seen[p] {
return nil
}
seen[p] = true
for _, file := range p.Syntax {
if file.Pos() <= pos && pos <= file.End() {
return file
}
}
for _, dep := range p.Imports {
if file := find(dep); file != nil {
return file
}
}
return nil
}
for _, root := range code.Roots {
if file := find(root); file != nil {
return file
}
}
return nil
}