Skip to content

Commit f0ccf71

Browse files
authored
Merge pull request #358 from kkHAIKE/encode_MarshalText_fix
change eindirect behave match with indirect from decode
2 parents 0a9f2b0 + c03a31c commit f0ccf71

File tree

2 files changed

+144
-76
lines changed

2 files changed

+144
-76
lines changed

encode.go

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ var dblQuotedReplacer = strings.NewReplacer(
6464
"\x7f", `\u007f`,
6565
)
6666

67+
var (
68+
marshalToml = reflect.TypeOf((*Marshaler)(nil)).Elem()
69+
marshalText = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
70+
timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
71+
)
72+
6773
// Marshaler is the interface implemented by types that can marshal themselves
6874
// into valid TOML.
6975
type Marshaler interface {
@@ -154,12 +160,12 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
154160
// If we can marshal the type to text, then we use that. This prevents the
155161
// encoder for handling these types as generic structs (or whatever the
156162
// underlying type of a TextMarshaler is).
157-
switch t := rv.Interface().(type) {
158-
case encoding.TextMarshaler, Marshaler:
163+
switch {
164+
case isMarshaler(rv):
159165
enc.writeKeyValue(key, rv, false)
160166
return
161-
case Primitive: // TODO: #76 would make this superfluous after implemented.
162-
enc.encode(key, reflect.ValueOf(t.undecoded))
167+
case rv.Type() == primitiveType: // TODO: #76 would make this superfluous after implemented.
168+
enc.encode(key, reflect.ValueOf(rv.Interface().(Primitive).undecoded))
163169
return
164170
}
165171

@@ -318,7 +324,7 @@ func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
318324
length := rv.Len()
319325
enc.wf("[")
320326
for i := 0; i < length; i++ {
321-
elem := rv.Index(i)
327+
elem := eindirect(rv.Index(i))
322328
enc.eElement(elem)
323329
if i != length-1 {
324330
enc.wf(", ")
@@ -332,7 +338,7 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
332338
encPanic(errNoKey)
333339
}
334340
for i := 0; i < rv.Len(); i++ {
335-
trv := rv.Index(i)
341+
trv := eindirect(rv.Index(i))
336342
if isNil(trv) {
337343
continue
338344
}
@@ -357,7 +363,7 @@ func (enc *Encoder) eTable(key Key, rv reflect.Value) {
357363
}
358364

359365
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) {
360-
switch rv := eindirect(rv); rv.Kind() {
366+
switch rv.Kind() {
361367
case reflect.Map:
362368
enc.eMap(key, rv, inline)
363369
case reflect.Struct:
@@ -379,7 +385,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
379385
var mapKeysDirect, mapKeysSub []string
380386
for _, mapKey := range rv.MapKeys() {
381387
k := mapKey.String()
382-
if typeIsTable(tomlTypeOfGo(rv.MapIndex(mapKey))) {
388+
if typeIsTable(tomlTypeOfGo(eindirect(rv.MapIndex(mapKey)))) {
383389
mapKeysSub = append(mapKeysSub, k)
384390
} else {
385391
mapKeysDirect = append(mapKeysDirect, k)
@@ -389,7 +395,7 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
389395
var writeMapKeys = func(mapKeys []string, trailC bool) {
390396
sort.Strings(mapKeys)
391397
for i, mapKey := range mapKeys {
392-
val := rv.MapIndex(reflect.ValueOf(mapKey))
398+
val := eindirect(rv.MapIndex(reflect.ValueOf(mapKey)))
393399
if isNil(val) {
394400
continue
395401
}
@@ -417,6 +423,13 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
417423

418424
const is32Bit = (32 << (^uint(0) >> 63)) == 32
419425

426+
func pointerTo(t reflect.Type) reflect.Type {
427+
if t.Kind() == reflect.Ptr {
428+
return pointerTo(t.Elem())
429+
}
430+
return t
431+
}
432+
420433
func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
421434
// Write keys for fields directly under this key first, because if we write
422435
// a field that creates a new table then all keys under it will be in that
@@ -433,35 +446,25 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
433446
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
434447
for i := 0; i < rt.NumField(); i++ {
435448
f := rt.Field(i)
436-
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields.
449+
isEmbed := f.Anonymous && pointerTo(f.Type).Kind() == reflect.Struct
450+
if f.PkgPath != "" && !isEmbed { /// Skip unexported fields.
437451
continue
438452
}
439453
opts := getOptions(f.Tag)
440454
if opts.skip {
441455
continue
442456
}
443457

444-
frv := rv.Field(i)
458+
frv := eindirect(rv.Field(i))
445459

446460
// Treat anonymous struct fields with tag names as though they are
447461
// not anonymous, like encoding/json does.
448462
//
449463
// Non-struct anonymous fields use the normal encoding logic.
450-
if f.Anonymous {
451-
t := f.Type
452-
switch t.Kind() {
453-
case reflect.Struct:
454-
if getOptions(f.Tag).name == "" {
455-
addFields(t, frv, append(start, f.Index...))
456-
continue
457-
}
458-
case reflect.Ptr:
459-
if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" {
460-
if !frv.IsNil() {
461-
addFields(t.Elem(), frv.Elem(), append(start, f.Index...))
462-
}
463-
continue
464-
}
464+
if isEmbed {
465+
if getOptions(f.Tag).name == "" && frv.Kind() == reflect.Struct {
466+
addFields(frv.Type(), frv, append(start, f.Index...))
467+
continue
465468
}
466469
}
467470

@@ -487,7 +490,7 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
487490
writeFields := func(fields [][]int) {
488491
for _, fieldIndex := range fields {
489492
fieldType := rt.FieldByIndex(fieldIndex)
490-
fieldVal := rv.FieldByIndex(fieldIndex)
493+
fieldVal := eindirect(rv.FieldByIndex(fieldIndex))
491494

492495
if isNil(fieldVal) { /// Don't write anything for nil fields.
493496
continue
@@ -540,6 +543,21 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
540543
if isNil(rv) || !rv.IsValid() {
541544
return nil
542545
}
546+
547+
if rv.Kind() == reflect.Struct {
548+
if rv.Type() == timeType {
549+
return tomlDatetime
550+
}
551+
if isMarshaler(rv) {
552+
return tomlString
553+
}
554+
return tomlHash
555+
}
556+
557+
if isMarshaler(rv) {
558+
return tomlString
559+
}
560+
543561
switch rv.Kind() {
544562
case reflect.Bool:
545563
return tomlBool
@@ -561,42 +579,14 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
561579
return tomlString
562580
case reflect.Map:
563581
return tomlHash
564-
case reflect.Struct:
565-
if _, ok := rv.Interface().(time.Time); ok {
566-
return tomlDatetime
567-
}
568-
if isMarshaler(rv) {
569-
return tomlString
570-
}
571-
return tomlHash
572582
default:
573-
if isMarshaler(rv) {
574-
return tomlString
575-
}
576-
577583
encPanic(errors.New("unsupported type: " + rv.Kind().String()))
578584
panic("unreachable")
579585
}
580586
}
581587

582588
func isMarshaler(rv reflect.Value) bool {
583-
switch rv.Interface().(type) {
584-
case encoding.TextMarshaler:
585-
return true
586-
case Marshaler:
587-
return true
588-
}
589-
590-
// Someone used a pointer receiver: we can make it work for pointer values.
591-
if rv.CanAddr() {
592-
if _, ok := rv.Addr().Interface().(encoding.TextMarshaler); ok {
593-
return true
594-
}
595-
if _, ok := rv.Addr().Interface().(Marshaler); ok {
596-
return true
597-
}
598-
}
599-
return false
589+
return rv.Type().Implements(marshalText) || rv.Type().Implements(marshalToml)
600590
}
601591

602592
// isTableArray reports if all entries in the array or slice are a table.
@@ -605,19 +595,19 @@ func isTableArray(arr reflect.Value) bool {
605595
return false
606596
}
607597

608-
/// Don't allow nil.
598+
ret := true
609599
for i := 0; i < arr.Len(); i++ {
610-
if tomlTypeOfGo(arr.Index(i)) == nil {
600+
tt := tomlTypeOfGo(eindirect(arr.Index(i)))
601+
// Don't allow nil.
602+
if tt == nil {
611603
encPanic(errArrayNilElement)
612604
}
613-
}
614605

615-
for i := 0; i < arr.Len(); i++ {
616-
if !typeEqual(tomlHash, tomlTypeOfGo(arr.Index(i))) {
617-
return false
606+
if ret && !typeEqual(tomlHash, tt) {
607+
ret = false
618608
}
619609
}
620-
return true
610+
return ret
621611
}
622612

623613
type tagOptions struct {
@@ -715,13 +705,25 @@ func encPanic(err error) {
715705
panic(tomlEncodeError{err})
716706
}
717707

708+
// Resolve any level of pointers to the actual value (e.g. **string → string).
718709
func eindirect(v reflect.Value) reflect.Value {
719-
switch v.Kind() {
720-
case reflect.Ptr, reflect.Interface:
721-
return eindirect(v.Elem())
722-
default:
710+
if v.Kind() != reflect.Ptr && v.Kind() != reflect.Interface {
711+
if isMarshaler(v) {
712+
return v
713+
}
714+
if v.CanAddr() { /// Special case for marshalers; see #358.
715+
if pv := v.Addr(); isMarshaler(pv) {
716+
return pv
717+
}
718+
}
719+
return v
720+
}
721+
722+
if v.IsNil() {
723723
return v
724724
}
725+
726+
return eindirect(v.Elem())
725727
}
726728

727729
func isNil(rv reflect.Value) bool {

0 commit comments

Comments
 (0)