Skip to content

Commit ea68c39

Browse files
authored
fix: increase performance (#56)
1 parent 31c8451 commit ea68c39

File tree

11 files changed

+165
-226
lines changed

11 files changed

+165
-226
lines changed

musttag.go

+28-29
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import (
88
"go/token"
99
"go/types"
1010
"path"
11-
"path/filepath"
1211
"reflect"
12+
"regexp"
1313
"strconv"
1414
"strings"
1515

@@ -43,16 +43,23 @@ func New(funcs ...Func) *analysis.Analyzer {
4343
Requires: []*analysis.Analyzer{inspect.Analyzer},
4444
Run: func(pass *analysis.Pass) (any, error) {
4545
l := len(builtins) + len(funcs) + len(flagFuncs)
46-
m := make(map[string]Func, l)
46+
f := make(map[string]Func, l)
47+
4748
toMap := func(slice []Func) {
4849
for _, fn := range slice {
49-
m[fn.Name] = fn
50+
f[fn.Name] = fn
5051
}
5152
}
5253
toMap(builtins)
5354
toMap(funcs)
5455
toMap(flagFuncs)
55-
return run(pass, m)
56+
57+
mainModule, err := getMainModule()
58+
if err != nil {
59+
return nil, err
60+
}
61+
62+
return run(pass, mainModule, f)
5663
},
5764
}
5865
}
@@ -81,27 +88,16 @@ func flags(funcs *[]Func) flag.FlagSet {
8188
}
8289

8390
// for tests only.
84-
var (
85-
report = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) {
86-
const format = "`%s` should be annotated with the `%s` tag as it is passed to `%s` at %s"
87-
pass.Reportf(st.Pos, format, st.Name, fn.Tag, fn.shortName(), fnPos)
88-
}
91+
var report = func(pass *analysis.Pass, st *structType, fn Func, fnPos token.Position) {
92+
const format = "`%s` should be annotated with the `%s` tag as it is passed to `%s` at %s"
93+
pass.Reportf(st.Pos, format, st.Name, fn.Tag, fn.shortName(), fnPos)
94+
}
8995

90-
// HACK: mainModulePackages() does not return packages from `testdata`,
91-
// because it is ignored by the go tool, and thus, by the `go list` command.
92-
// For tests to pass we need to add the packages with tests to the main module manually.
93-
testPackages []string
94-
)
96+
var cleanFullName = regexp.MustCompile(`([^*/(]+/vendor/)`)
9597

9698
// run starts the analysis.
97-
func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
98-
moduleDir, modulePackages, err := mainModule()
99-
if err != nil {
100-
return nil, err
101-
}
102-
for _, pkg := range testPackages {
103-
modulePackages[pkg] = struct{}{}
104-
}
99+
func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (any, error) {
100+
var err error
105101

106102
walk := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
107103
filter := []ast.Node{(*ast.CallExpr)(nil)}
@@ -116,12 +112,13 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
116112
return // not a function call.
117113
}
118114

119-
caller := typeutil.StaticCallee(pass.TypesInfo, call)
120-
if caller == nil {
115+
callee := typeutil.StaticCallee(pass.TypesInfo, call)
116+
if callee == nil {
121117
return // not a static call.
122118
}
123119

124-
fn, ok := funcs[caller.FullName()]
120+
name := cleanFullName.ReplaceAllString(callee.FullName(), "")
121+
fn, ok := funcs[name]
125122
if !ok {
126123
return // the function is not supported.
127124
}
@@ -148,7 +145,7 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
148145
}
149146

150147
checker := checker{
151-
mainModule: modulePackages,
148+
mainModule: mainModule,
152149
seenTypes: make(map[string]struct{}),
153150
}
154151

@@ -164,7 +161,6 @@ func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
164161
}
165162

166163
p := pass.Fset.Position(call.Pos())
167-
p.Filename, _ = filepath.Rel(moduleDir, p.Filename)
168164
report(pass, result, fn, p)
169165
})
170166

@@ -181,7 +177,7 @@ type structType struct {
181177

182178
// checker parses and checks struct types.
183179
type checker struct {
184-
mainModule map[string]struct{} // do not check types outside of the main module; see issue #17.
180+
mainModule string
185181
seenTypes map[string]struct{} // prevent panic on recursive types; see issue #16.
186182
}
187183

@@ -202,13 +198,16 @@ func (c *checker) parseStructType(t types.Type, pos token.Pos) (*structType, boo
202198
if pkg == nil {
203199
return nil, false
204200
}
205-
if _, ok := c.mainModule[pkg.Path()]; !ok {
201+
202+
if !strings.HasPrefix(pkg.Path(), c.mainModule) {
206203
return nil, false
207204
}
205+
208206
s, ok := t.Underlying().(*types.Struct)
209207
if !ok {
210208
return nil, false
211209
}
210+
212211
return &structType{
213212
Struct: s,
214213
Pos: t.Obj().Pos(),

musttag_test.go

+18-110
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"go/token"
55
"io"
66
"os"
7+
"os/exec"
78
"path/filepath"
89
"testing"
910

@@ -14,13 +15,6 @@ import (
1415
)
1516

1617
func TestAnalyzer(t *testing.T) {
17-
// NOTE: analysistest does not yet support modules;
18-
// see https://github.com/golang/go/issues/37054 for details.
19-
// To be able to run tests with external dependencies,
20-
// we first need to write a GOPATH-like tree of stubs.
21-
prepareTestFiles(t)
22-
testPackages = []string{"tests", "builtins"}
23-
2418
testdata := analysistest.TestData()
2519

2620
t.Run("tests", func(t *testing.T) {
@@ -31,11 +25,15 @@ func TestAnalyzer(t *testing.T) {
3125
pass.Reportf(st.Pos, fn.shortName())
3226
}
3327

28+
setupTestData(t, testdata, "tests")
29+
3430
analyzer := New()
3531
analysistest.Run(t, testdata, analyzer, "tests")
3632
})
3733

3834
t.Run("builtins", func(t *testing.T) {
35+
setupTestData(t, testdata, "builtins")
36+
3937
analyzer := New(
4038
Func{Name: "example.com/custom.Marshal", Tag: "custom", ArgPos: 0},
4139
Func{Name: "example.com/custom.Unmarshal", Tag: "custom", ArgPos: 1},
@@ -44,6 +42,8 @@ func TestAnalyzer(t *testing.T) {
4442
})
4543

4644
t.Run("bad Func.ArgPos", func(t *testing.T) {
45+
setupTestData(t, testdata, "tests")
46+
4747
analyzer := New(
4848
Func{Name: "encoding/json.Marshal", Tag: "json", ArgPos: 10},
4949
)
@@ -77,111 +77,19 @@ type nopT struct{}
7777

7878
func (nopT) Errorf(string, ...any) {}
7979

80-
func prepareTestFiles(t *testing.T) {
81-
testdata := analysistest.TestData()
82-
83-
t.Cleanup(func() {
84-
err := os.RemoveAll(filepath.Join(testdata, "src"))
85-
assert.NoErr[F](t, err)
86-
})
87-
88-
hardlink := func(dir, file string) {
89-
target := filepath.Join(testdata, "src", dir, file)
90-
91-
err := os.MkdirAll(filepath.Dir(target), 0o777)
92-
assert.NoErr[F](t, err)
80+
// NOTE: analysistest does not yet support modules;
81+
// see https://github.com/golang/go/issues/37054 for details.
82+
func setupTestData(t *testing.T, testDataDir, dir string) {
83+
t.Helper()
9384

94-
err = os.Link(filepath.Join(testdata, file), target)
95-
assert.NoErr[F](t, err)
85+
err := os.Chdir(filepath.Join(testDataDir, "src", dir))
86+
if err != nil {
87+
t.Fatal(err)
9688
}
9789

98-
hardlink("tests", "tests.go")
99-
hardlink("builtins", "builtins.go")
100-
101-
for file, data := range stubs {
102-
target := filepath.Join(testdata, "src", file)
103-
104-
err := os.MkdirAll(filepath.Dir(target), 0o777)
105-
assert.NoErr[F](t, err)
106-
107-
err = os.WriteFile(target, []byte(data), 0o666)
108-
assert.NoErr[F](t, err)
90+
output, err := exec.Command("go", "mod", "vendor").CombinedOutput()
91+
if err != nil {
92+
t.Log(string(output))
93+
t.Fatal(err)
10994
}
11095
}
111-
112-
var stubs = map[string]string{
113-
"gopkg.in/yaml.v3/yaml.go": `package yaml
114-
import "io"
115-
func Marshal(_ any) ([]byte, error) { return nil, nil }
116-
func Unmarshal(_ []byte, _ any) error { return nil }
117-
type Encoder struct{}
118-
func NewEncoder(_ io.Writer) *Encoder { return nil }
119-
func (*Encoder) Encode(_ any) error { return nil }
120-
type Decoder struct{}
121-
func NewDecoder(_ io.Reader) *Decoder { return nil }
122-
func (*Decoder) Decode(_ any) error { return nil }`,
123-
124-
"github.com/BurntSushi/toml/toml.go": `package toml
125-
import "io"
126-
import "io/fs"
127-
func Unmarshal(_ []byte, _ any) error { return nil }
128-
type MetaData struct{}
129-
func Decode(_ string, _ any) (MetaData, error) { return MetaData{}, nil }
130-
func DecodeFS(_ fs.FS, _ string, _ any) (MetaData, error) { return MetaData{}, nil }
131-
func DecodeFile(_ string, _ any) (MetaData, error) { return MetaData{}, nil }
132-
type Encoder struct{}
133-
func NewEncoder(_ io.Writer) *Encoder { return nil }
134-
func (*Encoder) Encode(_ any) error { return nil }
135-
type Decoder struct{}
136-
func NewDecoder(_ io.Reader) *Decoder { return nil }
137-
func (*Decoder) Decode(_ any) error { return nil }`,
138-
139-
"github.com/mitchellh/mapstructure/mapstructure.go": `package mapstructure
140-
type Metadata struct{}
141-
func Decode(_, _ any) error { return nil }
142-
func DecodeMetadata(_, _ any, _ *Metadata) error { return nil }
143-
func WeakDecode(_, _ any) error { return nil }
144-
func WeakDecodeMetadata(_, _ any, _ *Metadata) error { return nil }`,
145-
146-
"github.com/jmoiron/sqlx/sqlx.go": `package sqlx
147-
import "context"
148-
type Queryer interface{}
149-
type QueryerContext interface{}
150-
type rowsi interface{}
151-
func Get(Queryer, any, string, ...any) error { return nil }
152-
func GetContext(context.Context, QueryerContext, any, string, ...any) error { return nil }
153-
func Select(Queryer, any, string, ...any) error { return nil }
154-
func SelectContext(context.Context, QueryerContext, any, string, ...any) error { return nil }
155-
func StructScan(rowsi, any) error { return nil }
156-
type Conn struct{}
157-
func (*Conn) GetContext(context.Context, any, string, ...any) error { return nil }
158-
func (*Conn) SelectContext(context.Context, any, string, ...any) error { return nil }
159-
type DB struct{}
160-
func (*DB) Get(any, string, ...any) error { return nil }
161-
func (*DB) GetContext(context.Context, any, string, ...any) error { return nil }
162-
func (*DB) Select(any, string, ...any) error { return nil }
163-
func (*DB) SelectContext(context.Context, any, string, ...any) error { return nil }
164-
type NamedStmt struct{}
165-
func (n *NamedStmt) Get(any, any) error { return nil }
166-
func (n *NamedStmt) GetContext(context.Context, any, any) error { return nil }
167-
func (n *NamedStmt) Select(any, any) error { return nil }
168-
func (n *NamedStmt) SelectContext(context.Context, any, any) error { return nil }
169-
type Row struct{}
170-
func (*Row) StructScan(any) error { return nil }
171-
type Rows struct{}
172-
func (*Rows) StructScan(any) error { return nil }
173-
type Stmt struct{}
174-
func (*Stmt) Get(any, ...any) error { return nil }
175-
func (*Stmt) GetContext(context.Context, any, ...any) error { return nil }
176-
func (*Stmt) Select(any, ...any) error { return nil }
177-
func (*Stmt) SelectContext(context.Context, any, ...any) error { return nil }
178-
type Tx struct{}
179-
func (*Tx) Get(any, string, ...any) error { return nil }
180-
func (*Tx) GetContext(context.Context, any, string, ...any) error { return nil }
181-
func (*Tx) Select(any, string, ...any) error { return nil }
182-
func (*Tx) SelectContext(context.Context, any, string, ...any) error { return nil }`,
183-
184-
"example.com/custom/custom.go": `package custom
185-
func Marshal(_ any) ([]byte, error) { return nil, nil }
186-
func Unmarshal(_ []byte, _ any) error { return nil }`,
187-
}

testdata/src/builtins/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
vendor

0 commit comments

Comments
 (0)