storj/scripts/check-imports.go

242 lines
4.8 KiB
Go
Raw Normal View History

2019-01-24 20:15:10 +00:00
// Copyright (C) 2019 Storj Labs, Inc.
2018-11-30 13:40:13 +00:00
// See LICENSE for copying information.
// +build ignore
2018-11-30 15:02:01 +00:00
2018-11-30 13:40:13 +00:00
package main
import (
2019-01-08 14:05:14 +00:00
"bytes"
2018-11-30 13:40:13 +00:00
"flag"
"fmt"
"go/ast"
"go/token"
2019-01-08 14:05:14 +00:00
"io"
2018-11-30 13:40:13 +00:00
"os"
"runtime"
"sort"
"strconv"
"strings"
"golang.org/x/tools/go/packages"
)
/*
2019-01-08 14:05:14 +00:00
check-imports verifies whether imports are divided into three blocks:
2018-11-30 13:40:13 +00:00
std packages
external packages
storj.io packages
*/
var race = flag.Bool("race", false, "load with race tag")
2018-11-30 13:40:13 +00:00
func main() {
flag.Parse()
pkgNames := flag.Args()
if len(pkgNames) == 0 {
pkgNames = []string{"."}
}
var buildFlags []string
if *race {
buildFlags = append(buildFlags, "-race")
}
2018-11-30 13:40:13 +00:00
roots, err := packages.Load(&packages.Config{
Mode: packages.LoadAllSyntax,
Env: os.Environ(),
BuildFlags: buildFlags,
2018-11-30 13:40:13 +00:00
}, pkgNames...)
if err != nil {
panic(err)
}
2019-01-08 14:05:14 +00:00
fmt.Println("checking import order:")
2018-11-30 13:40:13 +00:00
seen := map[*packages.Package]bool{}
pkgs := []*packages.Package{}
var visit func(*packages.Package)
visit = func(p *packages.Package) {
if seen[p] {
return
}
includeStd(p)
if strings.HasPrefix(p.ID, "storj.io") {
pkgs = append(pkgs, p)
}
seen[p] = true
for _, pkg := range p.Imports {
visit(pkg)
}
}
for _, pkg := range roots {
visit(pkg)
}
sort.Slice(pkgs, func(i, k int) bool { return pkgs[i].ID < pkgs[k].ID })
2019-01-08 14:05:14 +00:00
correct := true
incorrectPkgs := []string{}
2018-11-30 13:40:13 +00:00
for _, pkg := range pkgs {
2019-01-08 14:05:14 +00:00
if !correctPackage(pkg) {
incorrectPkgs = append(incorrectPkgs, pkg.String())
2019-01-08 14:05:14 +00:00
correct = false
}
}
if !correct {
fmt.Fprintln(os.Stderr, "Error: imports are not in the correct order for package/s: ", incorrectPkgs)
fmt.Fprintln(os.Stderr, "Correct order should be: std packages -> external packages -> storj.io packages.")
2019-01-08 14:05:14 +00:00
os.Exit(1)
2018-11-30 13:40:13 +00:00
}
}
2019-01-08 14:05:14 +00:00
func correctPackage(pkg *packages.Package) bool {
correct := true
2018-11-30 13:40:13 +00:00
for i, file := range pkg.Syntax {
2019-01-08 14:05:14 +00:00
path := pkg.CompiledGoFiles[i]
if !correctImports(pkg.Fset, path, file) {
if !isGenerated(path) { // ignore generated files
fmt.Fprintln(os.Stderr, path)
correct = false
} else {
fmt.Fprintln(os.Stderr, "(ignoring generated)", path)
}
}
2018-11-30 13:40:13 +00:00
}
2019-01-08 14:05:14 +00:00
return correct
2018-11-30 13:40:13 +00:00
}
2019-01-08 14:05:14 +00:00
func correctImports(fset *token.FileSet, name string, f *ast.File) bool {
2018-11-30 13:40:13 +00:00
for _, d := range f.Decls {
d, ok := d.(*ast.GenDecl)
if !ok || d.Tok != token.IMPORT {
// Not an import declaration, so we're done.
// Imports are always first.
break
}
if !d.Lparen.IsValid() {
// Not a block: sorted by default.
continue
}
// Identify and sort runs of specs on successive lines.
lastGroup := 0
specgroups := [][]ast.Spec{}
for i, s := range d.Specs {
if i > lastGroup && fset.Position(s.Pos()).Line > 1+fset.Position(d.Specs[i-1].End()).Line {
// i begins a new run. End this one.
specgroups = append(specgroups, d.Specs[lastGroup:i])
lastGroup = i
}
}
specgroups = append(specgroups, d.Specs[lastGroup:])
if !correctOrder(specgroups) {
2019-01-08 14:05:14 +00:00
return false
2018-11-30 13:40:13 +00:00
}
}
2019-01-08 14:05:14 +00:00
return true
2018-11-30 13:40:13 +00:00
}
func correctOrder(specgroups [][]ast.Spec) bool {
if len(specgroups) == 0 {
return true
}
// remove std group from beginning
std, other, storj := countGroup(specgroups[0])
if std > 0 {
if other+storj != 0 {
return false
}
specgroups = specgroups[1:]
}
if len(specgroups) == 0 {
return true
}
// remove storj.io group from the end
std, other, storj = countGroup(specgroups[len(specgroups)-1])
if storj > 0 {
if std+other > 0 {
return false
}
specgroups = specgroups[:len(specgroups)-1]
}
if len(specgroups) == 0 {
return true
}
// check that we have a center group for misc stuff
if len(specgroups) != 1 {
return false
}
std, other, storj = countGroup(specgroups[0])
return other > 0 && std+storj == 0
}
func countGroup(p []ast.Spec) (std, other, storj int) {
for _, imp := range p {
imp := imp.(*ast.ImportSpec)
path, err := strconv.Unquote(imp.Path.Value)
if err != nil {
panic(err)
}
if strings.HasPrefix(path, "storj.io/") {
storj++
} else if stdlib[path] {
std++
} else {
other++
}
}
return std, other, storj
}
var root = runtime.GOROOT()
var stdlib = map[string]bool{}
func includeStd(p *packages.Package) {
if len(p.GoFiles) == 0 {
stdlib[p.ID] = true
return
}
if strings.HasPrefix(p.GoFiles[0], root) {
stdlib[p.ID] = true
return
}
}
2019-01-08 14:05:14 +00:00
func isGenerated(path string) bool {
file, err := os.Open(path)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to read %v: %v\n", path, err)
return false
}
defer func() {
if err := file.Close(); err != nil {
fmt.Fprintln(os.Stderr, err)
}
}()
var header [256]byte
n, err := file.Read(header[:])
if err != nil && err != io.EOF {
fmt.Fprintf(os.Stderr, "failed to read %v: %v\n", path, err)
return false
}
return bytes.Contains(header[:n], []byte(`AUTOGENERATED`)) ||
bytes.Contains(header[:n], []byte(`Code generated`))
}