// Copyright (C) 2018 Storj Labs, Inc. // See LICENSE for copying information. // +build ignore package main import ( "bytes" "flag" "fmt" "go/ast" "go/token" "go/types" "io/ioutil" "os" "path" "sort" "strings" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" "golang.org/x/tools/imports" ) func main() { var outputPath string var packageName string var typeFullyQualifedName string flag.StringVar(&outputPath, "o", "", "output file name") flag.StringVar(&packageName, "p", "", "output package name") flag.StringVar(&typeFullyQualifedName, "i", "", "interface to generate code for") flag.Parse() if outputPath == "" || packageName == "" || typeFullyQualifedName == "" { fmt.Println("missing argument") os.Exit(1) } var code Code code.Imports = map[string]bool{} code.Ignore = map[string]bool{ "error": true, } code.IgnoreMethods = map[string]bool{ "BeginTx": true, } code.OutputPackage = packageName code.Config = &packages.Config{ Mode: packages.LoadAllSyntax, } code.Wrapped = map[string]bool{} code.AdditionalNesting = map[string]int{"Console": 1} // e.g. storj.io/storj/satellite.DB p := strings.LastIndexByte(typeFullyQualifedName, '.') code.Package = typeFullyQualifedName[:p] // storj.io/storj/satellite code.Type = typeFullyQualifedName[p+1:] // DB code.QualifiedType = path.Base(code.Package) + "." + code.Type var err error code.Roots, err = packages.Load(code.Config, code.Package) if err != nil { panic(err) } code.PrintLocked() code.PrintPreamble() unformatted := code.Bytes() imports.LocalPrefix = "storj.io" formatted, err := imports.Process(outputPath, unformatted, nil) if err != nil { fmt.Println(string(unformatted)) panic(err) } if outputPath == "" { fmt.Println(string(formatted)) return } err = ioutil.WriteFile(outputPath, formatted, 0644) if err != nil { panic(err) } } // Methods is the common interface for types having methods. type Methods interface { Method(i int) *types.Func NumMethods() int } // Code is the information for generating the code. type Code struct { Config *packages.Config Package string Type string QualifiedType string Roots []*packages.Package OutputPackage string Imports map[string]bool Ignore map[string]bool IgnoreMethods map[string]bool Wrapped map[string]bool AdditionalNesting map[string]int Preamble bytes.Buffer Source bytes.Buffer } // Bytes returns all code merged together func (code *Code) Bytes() []byte { var all bytes.Buffer all.Write(code.Preamble.Bytes()) all.Write(code.Source.Bytes()) return all.Bytes() } // PrintPreamble creates package header and imports. func (code *Code) PrintPreamble() { w := &code.Preamble fmt.Fprintf(w, "// Code generated by lockedgen using 'go generate'. DO NOT EDIT.\n\n") fmt.Fprintf(w, "// Copyright (C) 2019 Storj Labs, Inc.\n") fmt.Fprintf(w, "// See LICENSE for copying information.\n\n") fmt.Fprintf(w, "package %v\n\n", code.OutputPackage) fmt.Fprintf(w, "import (\n") var imports []string for imp := range code.Imports { imports = append(imports, imp) } sort.Strings(imports) for _, imp := range imports { fmt.Fprintf(w, " %q\n", imp) } fmt.Fprintf(w, ")\n\n") } // PrintLocked writes locked wrapper and methods. func (code *Code) PrintLocked() { code.Imports["sync"] = true code.Imports["storj.io/statellite"] = true code.Printf("// locked implements a locking wrapper around satellite.DB.\n") code.Printf("type locked struct {\n") code.Printf(" sync.Locker\n") code.Printf(" db %v\n", code.QualifiedType) code.Printf("}\n\n") code.Printf("// newLocked returns database wrapped with locker.\n") code.Printf("func newLocked(db %v) %v {\n", code.QualifiedType, code.QualifiedType) code.Printf(" return &locked{&sync.Mutex{}, db}\n") code.Printf("}\n\n") // find the satellite.DB type info dbObject := code.Roots[0].Types.Scope().Lookup(code.Type) methods := dbObject.Type().Underlying().(Methods) for i := 0; i < methods.NumMethods(); i++ { code.PrintLockedFunc("locked", methods.Method(i), code.AdditionalNesting[methods.Method(i).Name()]+1) } } // Printf writes formatted text to source. func (code *Code) Printf(format string, a ...interface{}) { fmt.Fprintf(&code.Source, format, a...) } // PrintSignature prints method signature. func (code *Code) PrintSignature(sig *types.Signature) { code.PrintSignatureTuple(sig.Params(), true) if sig.Results().Len() > 0 { code.Printf(" ") code.PrintSignatureTuple(sig.Results(), false) } } // PrintSignatureTuple prints method tuple, params or results. func (code *Code) PrintSignatureTuple(tuple *types.Tuple, needsNames bool) { code.Printf("(") defer code.Printf(")") for i := 0; i < tuple.Len(); i++ { if i > 0 { code.Printf(", ") } param := tuple.At(i) if code.PrintName(tuple.At(i), i, needsNames) { code.Printf(" ") } code.PrintType(param.Type()) } } // PrintCall prints a call using the specified signature. func (code *Code) PrintCall(sig *types.Signature) { code.Printf("(") defer code.Printf(")") params := sig.Params() for i := 0; i < params.Len(); i++ { if i != 0 { code.Printf(", ") } code.PrintName(params.At(i), i, true) } } // PrintName prints an appropriate name from signature tuple. func (code *Code) PrintName(v *types.Var, index int, needsNames bool) bool { name := v.Name() if needsNames && name == "" { if v.Type().String() == "context.Context" { code.Printf("ctx") return true } code.Printf("a%d", index) return true } code.Printf("%s", name) return name != "" } // PrintType prints short form of type t. func (code *Code) PrintType(t types.Type) { types.WriteType(&code.Source, t, (*types.Package).Name) } func typeName(typ types.Type) string { var body bytes.Buffer types.WriteType(&body, typ, (*types.Package).Name) return body.String() } // IncludeImports imports all types referenced in the signature. func (code *Code) IncludeImports(sig *types.Signature) { var tmp bytes.Buffer types.WriteSignature(&tmp, sig, func(p *types.Package) string { code.Imports[p.Path()] = true return p.Name() }) } // NeedsWrapper checks whether method result needs a wrapper type. func (code *Code) NeedsWrapper(method *types.Func) bool { if code.IgnoreMethods[method.Name()] { return false } sig := method.Type().Underlying().(*types.Signature) return sig.Results().Len() == 1 && !code.Ignore[sig.Results().At(0).Type().String()] } // WrapperTypeName returns an appropriate name for the wrapper type. func (code *Code) WrapperTypeName(method *types.Func) string { return "locked" + method.Name() } // PrintLockedFunc prints a method with locking and defers the actual logic to method. func (code *Code) PrintLockedFunc(receiverType string, method *types.Func, nestingDepth int) { if code.IgnoreMethods[method.Name()] { return } sig := method.Type().Underlying().(*types.Signature) code.IncludeImports(sig) doc := code.MethodDoc(method) if doc != "" { code.Printf("// %s", code.MethodDoc(method)) } code.Printf("func (m *%s) %s", receiverType, method.Name()) code.PrintSignature(sig) code.Printf(" {\n") code.Printf(" m.Lock(); defer m.Unlock()\n") if !code.NeedsWrapper(method) { code.Printf(" return m.db.%s", method.Name()) code.PrintCall(sig) code.Printf("\n") code.Printf("}\n\n") return } code.Printf(" return &%s{m.Locker, ", code.WrapperTypeName(method)) code.Printf("m.db.%s", method.Name()) code.PrintCall(sig) code.Printf("}\n") code.Printf("}\n\n") if nestingDepth > 0 { code.PrintWrapper(method, nestingDepth-1) } } // PrintWrapper prints wrapper for the result type of method. func (code *Code) PrintWrapper(method *types.Func, nestingDepth int) { sig := method.Type().Underlying().(*types.Signature) results := sig.Results() result := results.At(0).Type() receiverType := code.WrapperTypeName(method) if code.Wrapped[receiverType] { return } code.Wrapped[receiverType] = true code.Printf("// %s implements locking wrapper for %s\n", receiverType, typeName(result)) code.Printf("type %s struct {\n", receiverType) code.Printf(" sync.Locker\n") code.Printf(" db %s\n", typeName(result)) code.Printf("}\n\n") methods := result.Underlying().(Methods) for i := 0; i < methods.NumMethods(); i++ { code.PrintLockedFunc(receiverType, methods.Method(i), nestingDepth) } } // MethodDoc finds documentation for the specified method. func (code *Code) MethodDoc(method *types.Func) string { file := code.FindASTFile(method.Pos()) if file == nil { return "" } path, exact := astutil.PathEnclosingInterval(file, method.Pos(), method.Pos()) if !exact { return "" } for _, p := range path { switch decl := p.(type) { case *ast.Field: return decl.Doc.Text() case *ast.GenDecl: return decl.Doc.Text() case *ast.FuncDecl: return decl.Doc.Text() } } return "" } // FindASTFile finds the *ast.File at the specified position. func (code *Code) FindASTFile(pos token.Pos) *ast.File { seen := map[*packages.Package]bool{} // find searches pos recursively from p and its dependencies. var find func(p *packages.Package) *ast.File find = func(p *packages.Package) *ast.File { if seen[p] { return nil } seen[p] = true for _, file := range p.Syntax { if file.Pos() <= pos && pos <= file.End() { return file } } for _, dep := range p.Imports { if file := find(dep); file != nil { return file } } return nil } for _, root := range code.Roots { if file := find(root); file != nil { return file } } return nil }