Skip to content

Commit 27a84af

Browse files
authored
fix: incorrect report position and panic on invalid Func.ArgPos (#18)
1 parent 9229084 commit 27a84af

File tree

3 files changed

+157
-147
lines changed

3 files changed

+157
-147
lines changed

musttag.go

Lines changed: 82 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ var builtin = []Func{
5757
{Name: "github.com/mitchellh/mapstructure.WeakDecodeMetadata", Tag: "mapstructure", ArgPos: 1},
5858
}
5959

60-
// flags creates a flag set for the analyzer. The funcs slice will be filled
61-
// with custom functions passed via CLI flags.
60+
// flags creates a flag set for the analyzer.
61+
// The funcs slice will be filled with custom functions passed via CLI flags.
6262
func flags(funcs *[]Func) flag.FlagSet {
6363
fs := flag.NewFlagSet("musttag", flag.ContinueOnError)
6464
fs.Func("fn", "report custom function (name:tag:argpos)", func(s string) error {
@@ -80,8 +80,9 @@ func flags(funcs *[]Func) flag.FlagSet {
8080
return *fs
8181
}
8282

83-
// New creates a new musttag analyzer. To report a custom function provide its
84-
// description via Func, it will be added to the builtin ones.
83+
// New creates a new musttag analyzer.
84+
// To report a custom function provide its description via Func,
85+
// it will be added to the builtin ones.
8586
func New(funcs ...Func) *analysis.Analyzer {
8687
var flagFuncs []Func
8788
return &analysis.Analyzer{
@@ -112,125 +113,135 @@ var (
112113

113114
// reportf is a wrapper for pass.Reportf (as a variable, so it could be mocked in tests).
114115
reportf = func(pass *analysis.Pass, pos token.Pos, fn Func) {
116+
// TODO(junk1tm): print the name of the struct type as well?
115117
pass.Reportf(pos, "exported fields should be annotated with the %q tag", fn.Tag)
116118
}
117119
)
118120

119121
// run starts the analysis.
120122
func run(pass *analysis.Pass, funcs map[string]Func) (any, error) {
121-
insp := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
122-
123-
filter := []ast.Node{
124-
(*ast.CallExpr)(nil),
125-
}
126-
127123
type report struct {
128-
pos token.Pos
129-
tag string
124+
pos token.Pos // the position for report.
125+
tag string // the missing struct tag.
130126
}
131-
reported := make(map[report]struct{})
132127

133-
insp.Preorder(filter, func(n ast.Node) {
134-
call := n.(*ast.CallExpr)
128+
// store previous reports to prevent reporting
129+
// the same struct more than once (if reportOnce is true).
130+
reports := make(map[report]struct{})
131+
132+
walk := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
133+
filter := []ast.Node{(*ast.CallExpr)(nil)}
134+
135+
walk.Preorder(filter, func(n ast.Node) {
136+
call, ok := n.(*ast.CallExpr)
137+
if !ok {
138+
return // not a function call.
139+
}
135140

136141
callee := typeutil.StaticCallee(pass.TypesInfo, call)
137142
if callee == nil {
138-
return
143+
return // not a static call.
139144
}
140145

141146
fn, ok := funcs[callee.FullName()]
142147
if !ok {
143-
return
148+
return // the function is not supported.
149+
}
150+
151+
if len(call.Args) <= fn.ArgPos {
152+
return // TODO(junk1tm): return a proper error.
153+
}
154+
155+
arg := call.Args[fn.ArgPos]
156+
if unary, ok := arg.(*ast.UnaryExpr); ok {
157+
arg = unary.X // e.g. json.Marshal(&foo)
158+
}
159+
160+
initialPos := token.NoPos
161+
switch arg := arg.(type) {
162+
case *ast.Ident: // e.g. json.Marshal(foo)
163+
initialPos = arg.Obj.Pos()
164+
case *ast.CompositeLit: // e.g. json.Marshal(struct{}{})
165+
initialPos = arg.Pos()
144166
}
145167

146-
s, pos, ok := structAndPos(pass, call.Args[fn.ArgPos])
168+
t := pass.TypesInfo.TypeOf(arg)
169+
s, ok := parseStruct(t, initialPos)
147170
if !ok {
148-
return
171+
return // not a struct argument.
149172
}
150173

151-
if ok := checkStruct(s, fn.Tag, &pos); ok {
152-
return
174+
reportPos, ok := checkStruct(s, fn.Tag)
175+
if ok {
176+
return // nothing to report.
153177
}
154178

155-
r := report{pos, fn.Tag}
156-
if _, ok := reported[r]; ok && reportOnce {
157-
return
179+
r := report{reportPos, fn.Tag}
180+
if _, ok := reports[r]; ok && reportOnce {
181+
return // already reported.
158182
}
159183

160-
reportf(pass, pos, fn)
161-
reported[r] = struct{}{}
184+
reportf(pass, reportPos, fn)
185+
reports[r] = struct{}{}
162186
})
163187

164188
return nil, nil
165189
}
166190

167-
// structAndPos analyses the given expression and returns the struct to check
168-
// and the position to report if needed.
169-
func structAndPos(pass *analysis.Pass, expr ast.Expr) (*types.Struct, token.Pos, bool) {
170-
t := pass.TypesInfo.TypeOf(expr)
171-
if ptr, ok := t.(*types.Pointer); ok {
191+
// structInfo expands types.Struct with its position in the source code.
192+
// If the struct is anonymous, Pos points to the corresponding identifier.
193+
type structInfo struct {
194+
*types.Struct
195+
Pos token.Pos
196+
}
197+
198+
// parseStruct parses the given types.Type, returning the underlying struct type.
199+
// If it's a named type, the result will contain the position of its declaration,
200+
// or the given token.Pos otherwise.
201+
func parseStruct(t types.Type, pos token.Pos) (*structInfo, bool) {
202+
for {
203+
// unwrap pointers (if any) first.
204+
ptr, ok := t.(*types.Pointer)
205+
if !ok {
206+
break
207+
}
172208
t = ptr.Elem()
173209
}
174210

175211
switch t := t.(type) {
176-
case *types.Named: // named type
177-
s, ok := t.Underlying().(*types.Struct)
178-
if ok {
179-
return s, t.Obj().Pos(), true
180-
}
181-
182-
case *types.Struct: // anonymous struct
183-
if unary, ok := expr.(*ast.UnaryExpr); ok {
184-
expr = unary.X // &x
185-
}
186-
//nolint:gocritic // commentedOutCode: these are examples
187-
switch arg := expr.(type) {
188-
case *ast.Ident: // var x struct{}; json.Marshal(x)
189-
return t, arg.Obj.Pos(), true
190-
case *ast.CompositeLit: // json.Marshal(struct{}{})
191-
return t, arg.Pos(), true
212+
case *types.Named: // a struct of the named type.
213+
if s, ok := t.Underlying().(*types.Struct); ok {
214+
return &structInfo{Struct: s, Pos: t.Obj().Pos()}, true
192215
}
216+
case *types.Struct: // an anonymous struct.
217+
return &structInfo{Struct: t, Pos: pos}, true
193218
}
194219

195-
return nil, 0, false
220+
return nil, false
196221
}
197222

198-
// checkStruct checks that exported fields of the given struct are annotated
199-
// with the tag and updates the position to report in case a nested struct of a
200-
// named type is found.
201-
func checkStruct(s *types.Struct, tag string, pos *token.Pos) (ok bool) {
223+
// checkStruct recursively checks the given struct and returns the position for report,
224+
// in case one of its fields is missing the tag.
225+
func checkStruct(s *structInfo, tag string) (token.Pos, bool) {
202226
for i := 0; i < s.NumFields(); i++ {
203227
if !s.Field(i).Exported() {
204228
continue
205229
}
206230

207231
st := reflect.StructTag(s.Tag(i))
208-
if _, ok := st.Lookup(tag); !ok {
209-
// it's ok for embedded types not to be tagged,
210-
// see https://github.com/junk1tm/musttag/issues/12
211-
if !s.Field(i).Embedded() {
212-
return false
213-
}
232+
if _, ok := st.Lookup(tag); !ok && !s.Field(i).Embedded() {
233+
return s.Pos, false
214234
}
215235

216-
// check if the field is a nested struct.
217236
t := s.Field(i).Type()
218-
if ptr, ok := t.(*types.Pointer); ok {
219-
t = ptr.Elem()
220-
}
221-
nested, ok := t.Underlying().(*types.Struct)
237+
nested, ok := parseStruct(t, s.Pos) // TODO(junk1tm): or s.Field(i).Pos()?
222238
if !ok {
223239
continue
224240
}
225-
if ok := checkStruct(nested, tag, pos); ok {
226-
continue
227-
}
228-
// update the position to point to the named type.
229-
if named, ok := t.(*types.Named); ok {
230-
*pos = named.Obj().Pos()
241+
if pos, ok := checkStruct(nested, tag); !ok {
242+
return pos, false
231243
}
232-
return false
233244
}
234245

235-
return true
246+
return token.NoPos, true
236247
}

musttag_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func TestAnalyzer(t *testing.T) {
5454

5555
func TestFlags(t *testing.T) {
5656
analyzer := New()
57-
analyzer.Flags.SetOutput(io.Discard)
57+
analyzer.Flags.SetOutput(io.Discard) // TODO(junk1tm): does not work, the usage is still printed.
5858

5959
t.Run("ok", func(t *testing.T) {
6060
err := analyzer.Flags.Parse([]string{"-fn=test.Test:test:0"})

testdata/src/tests/tests.go

Lines changed: 74 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -287,81 +287,80 @@ func nestedAnonymousType() {
287287

288288
// embedded types should not be reported.
289289
func embeddedType() {
290-
type Y struct { /* want
291-
`\Qjson.Marshal`
292-
`\Qjson.MarshalIndent`
293-
`\Qjson.Unmarshal`
294-
`\Qjson.Encoder.Encode`
295-
`\Qjson.Decoder.Decode`
296-
297-
`\Qxml.Marshal`
298-
`\Qxml.MarshalIndent`
299-
`\Qxml.Unmarshal`
300-
`\Qxml.Encoder.Encode`
301-
`\Qxml.Decoder.Decode`
302-
`\Qxml.Encoder.EncodeElement`
303-
`\Qxml.Decoder.DecodeElement`
304-
305-
`\Qyaml.v3.Marshal`
306-
`\Qyaml.v3.Unmarshal`
307-
`\Qyaml.v3.Encoder.Encode`
308-
`\Qyaml.v3.Decoder.Decode`
309-
310-
`\Qtoml.Unmarshal`
311-
`\Qtoml.Decode`
312-
`\Qtoml.DecodeFS`
313-
`\Qtoml.DecodeFile`
314-
`\Qtoml.Encoder.Encode`
315-
`\Qtoml.Decoder.Decode`
316-
317-
`\Qmapstructure.Decode`
318-
`\Qmapstructure.DecodeMetadata`
319-
`\Qmapstructure.WeakDecode`
320-
`\Qmapstructure.WeakDecodeMetadata`
321-
322-
`\Qcustom.Marshal`
323-
`\Qcustom.Unmarshal` */
324-
NoTag int
325-
}
326-
327-
var x struct {
328-
Y
329-
Z int `json:"z" xml:"z" yaml:"z" toml:"z" mapstructure:"z" custom:"z"`
330-
}
331-
332-
json.Marshal(x)
333-
json.MarshalIndent(x, "", "")
334-
json.Unmarshal(nil, &x)
335-
json.NewEncoder(nil).Encode(x)
336-
json.NewDecoder(nil).Decode(&x)
337-
338-
xml.Marshal(x)
339-
xml.MarshalIndent(x, "", "")
340-
xml.Unmarshal(nil, &x)
341-
xml.NewEncoder(nil).Encode(x)
342-
xml.NewDecoder(nil).Decode(&x)
343-
xml.NewEncoder(nil).EncodeElement(x, xmlSE)
344-
xml.NewDecoder(nil).DecodeElement(&x, &xmlSE)
345-
346-
yaml.Marshal(x)
347-
yaml.Unmarshal(nil, &x)
348-
yaml.NewEncoder(nil).Encode(x)
349-
yaml.NewDecoder(nil).Decode(&x)
350-
351-
toml.Unmarshal(nil, &x)
352-
toml.Decode("", &x)
353-
toml.DecodeFS(nil, "", &x)
354-
toml.DecodeFile("", &x)
355-
toml.NewEncoder(nil).Encode(x)
356-
toml.NewDecoder(nil).Decode(&x)
357-
358-
mapstructure.Decode(nil, &x)
359-
mapstructure.DecodeMetadata(nil, &x, nil)
360-
mapstructure.WeakDecode(nil, &x)
361-
mapstructure.WeakDecodeMetadata(nil, &x, nil)
362-
363-
custom.Marshal(x)
364-
custom.Unmarshal(nil, &x)
290+
type Y struct { /* want
291+
`\Qjson.Marshal`
292+
`\Qjson.MarshalIndent`
293+
`\Qjson.Unmarshal`
294+
`\Qjson.Encoder.Encode`
295+
`\Qjson.Decoder.Decode`
296+
297+
`\Qxml.Marshal`
298+
`\Qxml.MarshalIndent`
299+
`\Qxml.Unmarshal`
300+
`\Qxml.Encoder.Encode`
301+
`\Qxml.Decoder.Decode`
302+
`\Qxml.Encoder.EncodeElement`
303+
`\Qxml.Decoder.DecodeElement`
304+
305+
`\Qyaml.v3.Marshal`
306+
`\Qyaml.v3.Unmarshal`
307+
`\Qyaml.v3.Encoder.Encode`
308+
`\Qyaml.v3.Decoder.Decode`
309+
310+
`\Qtoml.Unmarshal`
311+
`\Qtoml.Decode`
312+
`\Qtoml.DecodeFS`
313+
`\Qtoml.DecodeFile`
314+
`\Qtoml.Encoder.Encode`
315+
`\Qtoml.Decoder.Decode`
316+
317+
`\Qmapstructure.Decode`
318+
`\Qmapstructure.DecodeMetadata`
319+
`\Qmapstructure.WeakDecode`
320+
`\Qmapstructure.WeakDecodeMetadata`
321+
322+
`\Qcustom.Marshal`
323+
`\Qcustom.Unmarshal` */
324+
NoTag int
325+
}
326+
var x struct {
327+
Y
328+
Z int `json:"z" xml:"z" yaml:"z" toml:"z" mapstructure:"z" custom:"z"`
329+
}
330+
331+
json.Marshal(x)
332+
json.MarshalIndent(x, "", "")
333+
json.Unmarshal(nil, &x)
334+
json.NewEncoder(nil).Encode(x)
335+
json.NewDecoder(nil).Decode(&x)
336+
337+
xml.Marshal(x)
338+
xml.MarshalIndent(x, "", "")
339+
xml.Unmarshal(nil, &x)
340+
xml.NewEncoder(nil).Encode(x)
341+
xml.NewDecoder(nil).Decode(&x)
342+
xml.NewEncoder(nil).EncodeElement(x, xmlSE)
343+
xml.NewDecoder(nil).DecodeElement(&x, &xmlSE)
344+
345+
yaml.Marshal(x)
346+
yaml.Unmarshal(nil, &x)
347+
yaml.NewEncoder(nil).Encode(x)
348+
yaml.NewDecoder(nil).Decode(&x)
349+
350+
toml.Unmarshal(nil, &x)
351+
toml.Decode("", &x)
352+
toml.DecodeFS(nil, "", &x)
353+
toml.DecodeFile("", &x)
354+
toml.NewEncoder(nil).Encode(x)
355+
toml.NewDecoder(nil).Decode(&x)
356+
357+
mapstructure.Decode(nil, &x)
358+
mapstructure.DecodeMetadata(nil, &x, nil)
359+
mapstructure.WeakDecode(nil, &x)
360+
mapstructure.WeakDecodeMetadata(nil, &x, nil)
361+
362+
custom.Marshal(x)
363+
custom.Unmarshal(nil, &x)
365364
}
366365

367366
// all good, nothing to report.

0 commit comments

Comments
 (0)