Skip to content

Commit 0ae83fe

Browse files
committed
replace some Interface() check when encode private embed struct
1 parent eaf0d98 commit 0ae83fe

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

encode.go

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ var dblQuotedReplacer = strings.NewReplacer(
6464
"\x7f", `\u007f`,
6565
)
6666

67+
var (
68+
marshalToml = reflect.TypeOf((*Marshaler)(nil)).Elem()
69+
marshalText = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
70+
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
71+
)
72+
6773
// Marshaler is the interface implemented by types that can marshal themselves
6874
// into valid TOML.
6975
type Marshaler interface {
@@ -154,12 +160,12 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
154160
// If we can marshal the type to text, then we use that. This prevents the
155161
// encoder for handling these types as generic structs (or whatever the
156162
// underlying type of a TextMarshaler is).
157-
switch t := rv.Interface().(type) {
158-
case encoding.TextMarshaler, Marshaler:
163+
switch {
164+
case isMarshaler(rv):
159165
enc.writeKeyValue(key, rv, false)
160166
return
161-
case Primitive: // TODO: #76 would make this superfluous after implemented.
162-
enc.encode(key, reflect.ValueOf(t.undecoded))
167+
case rv.Type() == primitiveType: // TODO: #76 would make this superfluous after implemented.
168+
enc.encode(key, reflect.ValueOf(rv.Interface().(Primitive).undecoded))
163169
return
164170
}
165171

@@ -429,11 +435,19 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
429435
rt = rv.Type()
430436
fieldsDirect, fieldsSub [][]int
431437
addFields func(rt reflect.Type, rv reflect.Value, start []int)
438+
ptrto func(t reflect.Type) reflect.Type
432439
)
440+
ptrto = func(t reflect.Type) reflect.Type {
441+
if t.Kind() == reflect.Ptr {
442+
return ptrto(t.Elem())
443+
}
444+
return t
445+
}
433446
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
434447
for i := 0; i < rt.NumField(); i++ {
435448
f := rt.Field(i)
436-
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields.
449+
isEmbed := f.Anonymous && ptrto(f.Type).Kind() == reflect.Struct
450+
if f.PkgPath != "" && !isEmbed { /// Skip unexported fields.
437451
continue
438452
}
439453
opts := getOptions(f.Tag)
@@ -447,7 +461,7 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
447461
// not anonymous, like encoding/json does.
448462
//
449463
// Non-struct anonymous fields use the normal encoding logic.
450-
if f.Anonymous {
464+
if isEmbed {
451465
if getOptions(f.Tag).name == "" && frv.Kind() == reflect.Struct {
452466
addFields(frv.Type(), frv, append(start, f.Index...))
453467
continue
@@ -531,7 +545,7 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
531545
}
532546

533547
if rv.Kind() == reflect.Struct {
534-
if _, ok := rv.Interface().(time.Time); ok {
548+
if rv.Type() == timeType {
535549
return tomlDatetime
536550
}
537551
if isMarshaler(rv) {
@@ -572,13 +586,8 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
572586
}
573587

574588
func isMarshaler(rv reflect.Value) bool {
575-
switch rv.Interface().(type) {
576-
case encoding.TextMarshaler:
577-
return true
578-
case Marshaler:
579-
return true
580-
}
581-
return false
589+
return rv.Type().Implements(marshalText) ||
590+
rv.Type().Implements(marshalToml)
582591
}
583592

584593
// isTableArray reports if all entries in the array or slice are a table.

0 commit comments

Comments
 (0)