Add arguments to lockedgen for using it in other places (#1030)

This commit is contained in:
Egon Elbre 2019-01-11 18:07:26 +02:00 committed by GitHub
parent bde0f09c15
commit 12eec57abf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 67 additions and 14 deletions

View File

@ -23,7 +23,7 @@ var (
Error = errs.Class("satellitedb")
)
//go:generate go run lockedgen/main.go -o locked.go
//go:generate go run ../../scripts/lockedgen.go -o locked.go -p satellitedb -i storj.io/storj/satellite.DB
// DB contains access to different database tables
type DB struct {

View File

@ -1,6 +1,6 @@
// Code generated by lockedgen using 'go generate'. DO NOT EDIT.
// Copyright (C) 2018 Storj Labs, Inc.
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package satellitedb

View File

@ -0,0 +1,17 @@
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package scripts_test
// this ensures that we download the necessary packages for the tools in scripts folder
// without actually being a binary
import (
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
)
var _ = imports.Process
var _ = packages.LoadImports
var _ = astutil.PathEnclosingInterval

View File

@ -1,6 +1,8 @@
// Copyright (C) 2018 Storj Labs, Inc.
// See LICENSE for copying information.
// +build ignore
package main
import (
@ -11,7 +13,10 @@ import (
"go/token"
"go/types"
"io/ioutil"
"os"
"path"
"sort"
"strings"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
@ -20,20 +25,38 @@ import (
func main() {
var outputPath string
var packageName string
var typeFullyQualifedName string
flag.StringVar(&outputPath, "o", "", "output file name")
flag.StringVar(&packageName, "p", "", "output package name")
flag.StringVar(&typeFullyQualifedName, "i", "", "interface to generate code for")
flag.Parse()
if outputPath == "" || packageName == "" || typeFullyQualifedName == "" {
fmt.Println("missing argument")
os.Exit(1)
}
var code Code
code.Imports = map[string]bool{}
code.Ignore = map[string]bool{
"error": true,
}
code.IgnoreMethods = map[string]bool{
"BeginTx": true,
}
code.OutputPackage = packageName
code.Config = &packages.Config{
Mode: packages.LoadAllSyntax,
}
code.Package = "storj.io/storj/satellite"
// e.g. storj.io/storj/satellite.DB
p := strings.LastIndexByte(typeFullyQualifedName, '.')
code.Package = typeFullyQualifedName[:p] // storj.io/storj/satellite
code.Type = typeFullyQualifedName[p+1:] // DB
code.QualifiedType = path.Base(code.Package) + "." + code.Type
var err error
code.Roots, err = packages.Load(code.Config, code.Package)
@ -72,12 +95,17 @@ type Methods interface {
// Code is the information for generating the code.
type Code struct {
Config *packages.Config
Package string
Roots []*packages.Package
Config *packages.Config
Package string
Type string
QualifiedType string
Roots []*packages.Package
Imports map[string]bool
Ignore map[string]bool
OutputPackage string
Imports map[string]bool
Ignore map[string]bool
IgnoreMethods map[string]bool
Preamble bytes.Buffer
Source bytes.Buffer
@ -95,9 +123,9 @@ func (code *Code) Bytes() []byte {
func (code *Code) PrintPreamble() {
w := &code.Preamble
fmt.Fprintf(w, "// Code generated by lockedgen using 'go generate'. DO NOT EDIT.\n\n")
fmt.Fprintf(w, "// Copyright (C) 2018 Storj Labs, Inc.\n")
fmt.Fprintf(w, "// Copyright (C) 2019 Storj Labs, Inc.\n")
fmt.Fprintf(w, "// See LICENSE for copying information.\n\n")
fmt.Fprintf(w, "package satellitedb\n\n")
fmt.Fprintf(w, "package %v\n\n", code.OutputPackage)
fmt.Fprintf(w, "import (\n")
var imports []string
@ -119,16 +147,16 @@ func (code *Code) PrintLocked() {
code.Printf("// locked implements a locking wrapper around satellite.DB.\n")
code.Printf("type locked struct {\n")
code.Printf(" sync.Locker\n")
code.Printf(" db satellite.DB\n")
code.Printf(" db %v\n", code.QualifiedType)
code.Printf("}\n\n")
code.Printf("// newLocked returns database wrapped with locker.\n")
code.Printf("func newLocked(db satellite.DB) satellite.DB {\n")
code.Printf("func newLocked(db %v) %v {\n", code.QualifiedType, code.QualifiedType)
code.Printf(" return &locked{&sync.Mutex{}, db}\n")
code.Printf("}\n\n")
// find the satellite.DB type info
dbObject := code.Roots[0].Types.Scope().Lookup("DB")
dbObject := code.Roots[0].Types.Scope().Lookup(code.Type)
methods := dbObject.Type().Underlying().(Methods)
for i := 0; i < methods.NumMethods(); i++ {
@ -226,6 +254,10 @@ func (code *Code) IncludeImports(sig *types.Signature) {
// NeedsWrapper checks whether method result needs a wrapper type.
func (code *Code) NeedsWrapper(method *types.Func) bool {
if code.IgnoreMethods[method.Name()] {
return false
}
sig := method.Type().Underlying().(*types.Signature)
return sig.Results().Len() == 1 && !code.Ignore[sig.Results().At(0).Type().String()]
}
@ -237,6 +269,10 @@ func (code *Code) WrapperTypeName(method *types.Func) string {
// PrintLockedFunc prints a method with locking and defers the actual logic to method.
func (code *Code) PrintLockedFunc(receiverType string, method *types.Func, allowNesting bool) {
if code.IgnoreMethods[method.Name()] {
return
}
sig := method.Type().Underlying().(*types.Signature)
code.IncludeImports(sig)