Skip to content

Commit 4026f8f

Browse files
feat: check whether structs within arrays/slices/maps implement a Marshaler interface (#87)
1 parent a6b60d7 commit 4026f8f

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

musttag.go

+12-4
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,6 @@ func (c *checker) checkType(typ types.Type, tag string) bool {
149149
}
150150
c.seenTypes[typ.String()] = struct{}{}
151151

152-
if implementsInterface(typ, c.ifaceWhitelist, c.imports) {
153-
return true // the type implements a Marshaler interface; see issue #64.
154-
}
155-
156152
styp, ok := c.parseStruct(typ)
157153
if !ok {
158154
return true // not a struct.
@@ -161,7 +157,19 @@ func (c *checker) checkType(typ types.Type, tag string) bool {
161157
return c.checkStruct(styp, tag)
162158
}
163159

160+
// recursively unwrap a type until we get to an underlying
161+
// raw struct type that should have its fields checked
162+
//
163+
// SomeStruct -> struct{SomeStructField: ... }
164+
// []*SomeStruct -> struct{SomeStructField: ... }
165+
// ...
166+
//
167+
// exits early if it hits a type that implements a whitelisted interface
164168
func (c *checker) parseStruct(typ types.Type) (*types.Struct, bool) {
169+
if implementsInterface(typ, c.ifaceWhitelist, c.imports) {
170+
return nil, false // the type implements a Marshaler interface; see issue #64.
171+
}
172+
165173
switch typ := typ.(type) {
166174
case *types.Pointer:
167175
return c.parseStruct(typ.Elem())

testdata/src/tests/tests.go

+11
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,14 @@ func ignoredNestedType() {
164164
json.Marshal(Foo{}) // no error
165165
json.Marshal(&Foo{}) // no error
166166
}
167+
168+
func interfaceSliceType() {
169+
type WithMarshallableSlice struct {
170+
List []Marshaler `json:"marshallable"`
171+
}
172+
var withMarshallableSlice WithMarshallableSlice
173+
174+
json.Marshal(withMarshallableSlice)
175+
json.MarshalIndent(withMarshallableSlice, "", "")
176+
json.NewEncoder(nil).Encode(withMarshallableSlice)
177+
}

0 commit comments

Comments
 (0)