From b7afc715b566b948526ad840fc20d53b642f6b6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mart=C3=AD?= Date: Thu, 25 Apr 2019 19:47:36 +0700 Subject: [PATCH] join all standard library imports This code is a bit tricky since we have to be careful to not break other imports, and to not move comments around too much. Fixes #8. --- internal/gofumpt.go | 46 ++++++++++++++++++++++++++++++++ testdata/scripts/std-imports.txt | 39 +++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 testdata/scripts/std-imports.txt diff --git a/internal/gofumpt.go b/internal/gofumpt.go index f63dc5f..959d017 100644 --- a/internal/gofumpt.go +++ b/internal/gofumpt.go @@ -187,6 +187,7 @@ func (f *fumpter) visit(node ast.Node) { pos = comments[0].Pos() } + // multiline top-level declarations should be separated multi := f.posLine(decl.Pos()) < f.posLine(decl.End()) if (multi && lastMulti) && f.posLine(lastEnd)+1 == f.posLine(pos) { @@ -222,6 +223,9 @@ func (f *fumpter) visit(node ast.Node) { } case *ast.GenDecl: + if node.Tok == token.IMPORT && node.Lparen.IsValid() { + f.joinStdImports(node) + } if len(node.Specs) == 1 && node.Lparen.IsValid() { // If the single spec has any comment, it must go before // the entire declaration now. @@ -329,3 +333,45 @@ func (f *fumpter) visit(node ast.Node) { f.removeLines(openLine, closeLine) } } + +// joinStdImports ensures that all standard library imports are together and at +// the top of the imports list. +func (f *fumpter) joinStdImports(d *ast.GenDecl) { + var std, other []ast.Spec + for _, spec := range d.Specs { + spec := spec.(*ast.ImportSpec) + // First, separate the non-std imports. + if strings.Contains(spec.Path.Value, ".") { + other = append(other, spec) + continue + } + if len(other) > 0 { + // If we're moving this std import further up, reset its + // position, to avoid breaking comments. + setPos(reflect.ValueOf(spec), d.Pos()) + } + std = append(std, spec) + } + // Finally, join the imports, keeping std at the top. + d.Specs = append(std, other...) +} + +var posType = reflect.TypeOf(token.NoPos) + +// setPos recursively sets all position fields in the node v to pos. +func setPos(v reflect.Value, pos token.Pos) { + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if !v.IsValid() { + return + } + if v.Type() == posType { + v.Set(reflect.ValueOf(pos)) + } + if v.Kind() == reflect.Struct { + for i := 0; i < v.NumField(); i++ { + setPos(v.Field(i), pos) + } + } +} diff --git a/testdata/scripts/std-imports.txt b/testdata/scripts/std-imports.txt new file mode 100644 index 0000000..ed5893a --- /dev/null +++ b/testdata/scripts/std-imports.txt @@ -0,0 +1,39 @@ +[gofumports] skip 'don''t add or remove imports' + +gofumpt -w foo.go . +cmp foo.go foo.go.golden + +-- foo.go -- +package p + +import ( + "io" + + _ "bufio" // for a side effect +) + +import ( + "os" + + "foo.localhost/other" + + bytes_ "bytes" + + "io" +) +-- foo.go.golden -- +package p + +import ( + "io" + + _ "bufio" // for a side effect +) + +import ( + bytes_ "bytes" + "io" + "os" + + "foo.localhost/other" +)