Skip to content

Commit f16d6ea

Browse files
authored
feat: check whether nested types implement a Marshaler interface (#84)
1 parent 1964f4d commit f16d6ea

File tree

2 files changed

+47
-25
lines changed

2 files changed

+47
-25
lines changed

musttag.go

+28-25
Original file line numberDiff line numberDiff line change
@@ -119,22 +119,14 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any,
119119
return // no type info found.
120120
}
121121

122-
// TODO: check nested structs too.
123-
if implementsInterface(typ, fn.ifaceWhitelist, pass.Pkg.Imports()) {
124-
return // the type implements a Marshaler interface; see issue #64.
125-
}
126-
127122
checker := checker{
128-
mainModule: mainModule,
129-
seenTypes: make(map[string]struct{}),
130-
}
131-
132-
styp, ok := checker.parseStruct(typ)
133-
if !ok {
134-
return // not a struct.
123+
mainModule: mainModule,
124+
seenTypes: make(map[string]struct{}),
125+
ifaceWhitelist: fn.ifaceWhitelist,
126+
imports: pass.Pkg.Imports(),
135127
}
136128

137-
if valid := checker.checkStruct(styp, fn.Tag); valid {
129+
if valid := checker.checkType(typ, fn.Tag); valid {
138130
return // nothing to report.
139131
}
140132

@@ -145,8 +137,28 @@ func run(pass *analysis.Pass, mainModule string, funcs map[string]Func) (_ any,
145137
}
146138

147139
type checker struct {
148-
mainModule string
149-
seenTypes map[string]struct{}
140+
mainModule string
141+
seenTypes map[string]struct{}
142+
ifaceWhitelist []string
143+
imports []*types.Package
144+
}
145+
146+
func (c *checker) checkType(typ types.Type, tag string) bool {
147+
if _, ok := c.seenTypes[typ.String()]; ok {
148+
return true // already checked.
149+
}
150+
c.seenTypes[typ.String()] = struct{}{}
151+
152+
if implementsInterface(typ, c.ifaceWhitelist, c.imports) {
153+
return true // the type implements a Marshaler interface; see issue #64.
154+
}
155+
156+
styp, ok := c.parseStruct(typ)
157+
if !ok {
158+
return true // not a struct.
159+
}
160+
161+
return c.checkStruct(styp, tag)
150162
}
151163

152164
func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) {
@@ -186,8 +198,6 @@ func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) {
186198
}
187199

188200
func (c *checker) checkStruct(styp *types.Struct, tag string) (valid bool) {
189-
c.seenTypes[styp.String()] = struct{}{}
190-
191201
for i := 0; i < styp.NumFields(); i++ {
192202
field := styp.Field(i)
193203
if !field.Exported() {
@@ -201,14 +211,7 @@ func (c *checker) checkStruct(styp *types.Struct, tag string) (valid bool) {
201211
}
202212
}
203213

204-
nested, ok := c.parseStruct(field.Type())
205-
if !ok {
206-
continue
207-
}
208-
if _, ok := c.seenTypes[nested.String()]; ok {
209-
continue
210-
}
211-
if valid := c.checkStruct(nested, tag); !valid {
214+
if valid := c.checkType(field.Type(), tag); !valid {
212215
return false
213216
}
214217
}

testdata/src/tests/tests.go

+19
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,22 @@ func shouldBeIgnored() {
130130
json.Marshal(0) // a non-struct argument.
131131
json.Marshal(nil) // nil argument, see issue #20.
132132
}
133+
134+
type WithInterface struct {
135+
NoTag string
136+
}
137+
138+
func (w WithInterface) MarshalJSON() ([]byte, error) {
139+
return json.Marshal(w.NoTag)
140+
}
141+
142+
func nestedTypeWithInterface() {
143+
type Foo struct {
144+
Nested WithInterface `json:"nested"`
145+
}
146+
var foo Foo
147+
json.Marshal(foo) // no error
148+
json.Marshal(&foo) // no error
149+
json.Marshal(Foo{}) // no error
150+
json.Marshal(&Foo{}) // no error
151+
}

0 commit comments

Comments
 (0)