Skip to content

Commit b0e65b3

Browse files
committed
fix
Signed-off-by: Yuri Shkuro <[email protected]>
1 parent cd1c879 commit b0e65b3

File tree

2 files changed

+107
-31
lines changed

2 files changed

+107
-31
lines changed

mapstructure.go

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -442,22 +442,15 @@ func (d *Decoder) Decode(input interface{}) error {
442442
return err
443443
}
444444

445-
// A comparison input == nil will fail if input is a typed nil.
446-
// This function converts a typed nil to an actual, untyped nil.
447-
func toRealNil(input interface{}) interface{} {
445+
// isNil returns true if the input is nil or a typed nil pointer.
446+
func isNil(input interface{}) bool {
448447
if input == nil {
449-
return nil
448+
return true
450449
}
451450
val := reflect.ValueOf(input)
452451
k := val.Kind()
453-
if (k == reflect.Ptr ||
454-
k == reflect.Interface ||
455-
k == reflect.Map ||
456-
k == reflect.Slice ||
457-
k == reflect.Array) && val.IsNil() {
458-
return nil
459-
}
460-
return input
452+
return (k == reflect.Ptr ||
453+
/*k == reflect.Interface || k == reflect.Map || k == reflect.Slice*/ false) && val.IsNil()
461454
}
462455

463456
// Decodes an unknown data type into a specific reflection value.
@@ -467,8 +460,14 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e
467460
outputKind = getKind(outVal)
468461
decodeNil = d.config.DecodeNil && d.cachedDecodeHook != nil
469462
)
470-
input = toRealNil(input)
471-
if input == nil || !inputVal.IsValid() {
463+
if input != nil {
464+
// We need to check here if input is a typed nil. Typed nils won't
465+
// match the "input == nil" below so we check that here.
466+
if inputVal.Kind() == reflect.Ptr && inputVal.IsNil() {
467+
input = nil
468+
}
469+
}
470+
if input == nil {
472471
// If the data is nil, then we don't set anything, unless ZeroFields is set
473472
// to true.
474473
if d.config.ZeroFields {
@@ -482,29 +481,41 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e
482481
return nil
483482
}
484483
}
484+
if !inputVal.IsValid() {
485+
if !decodeNil {
486+
// If the input value is invalid, then we just set the value
487+
// to be the zero value.
488+
outVal.Set(reflect.Zero(outVal.Type()))
489+
if d.config.Metadata != nil && name != "" {
490+
d.config.Metadata.Keys = append(d.config.Metadata.Keys, name)
491+
}
492+
return nil
493+
}
494+
// Hooks need a valid inputVal, so reset it to zero value of outVal type.
495+
switch outputKind {
496+
case reflect.Struct, reflect.Map:
497+
// create empty map
498+
var mapVal map[string]interface{}
499+
inputVal = reflect.ValueOf(mapVal)
500+
// inputVal = reflect.MakeMap(reflect.TypeOf(mapVal))
501+
case reflect.Slice, reflect.Array:
502+
// create nil slice
503+
var sliceVal []interface{}
504+
inputVal = reflect.ValueOf(sliceVal)
505+
default:
506+
inputVal = reflect.Zero(outVal.Type())
507+
}
508+
}
485509

486510
if d.cachedDecodeHook != nil {
487511
// We have a DecodeHook, so let's pre-process the input.
488-
if !inputVal.IsValid() {
489-
// Hooks need a valid inputVal, so reset it to zero value of outVal type.
490-
switch outputKind {
491-
case reflect.Struct, reflect.Map:
492-
var mapVal map[string]interface{}
493-
inputVal = reflect.ValueOf(mapVal)
494-
case reflect.Slice, reflect.Array:
495-
var sliceVal []interface{}
496-
inputVal = reflect.ValueOf(sliceVal)
497-
default:
498-
inputVal = reflect.Zero(outVal.Type())
499-
}
500-
}
501512
var err error
502513
input, err = d.cachedDecodeHook(inputVal, outVal)
503514
if err != nil {
504515
return fmt.Errorf("error decoding '%s': %w", name, err)
505516
}
506517
}
507-
if toRealNil(input) == nil {
518+
if isNil(input) {
508519
return nil
509520
}
510521

@@ -789,8 +800,8 @@ func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) e
789800
}
790801
default:
791802
return fmt.Errorf(
792-
"'%s' expected type '%s', got unconvertible type '%s', value: '%v'",
793-
name, val.Type(), dataVal.Type(), data)
803+
"'%s' expected type '%s', got unconvertible type '%#v', value: '%#v'",
804+
name, val, dataVal, data)
794805
}
795806

796807
return nil

mapstructure_test.go

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3083,12 +3083,13 @@ func TestDecoder_IgnoreUntaggedFieldsWithStruct(t *testing.T) {
30833083
}
30843084
}
30853085

3086-
func TestDecoder_CanPerformDecodingForNilInputs(t *testing.T) {
3086+
func TestDecoder_DecodeNilOption(t *testing.T) {
30873087
t.Parallel()
30883088

30893089
type Transformed struct {
30903090
Message string
30913091
When string
3092+
Boolean *bool //
30923093
}
30933094

30943095
helloHook := func(reflect.Type, reflect.Type, interface{}) (interface{}, error) {
@@ -3100,6 +3101,9 @@ func TestDecoder_CanPerformDecodingForNilInputs(t *testing.T) {
31003101
appendHook := func(from reflect.Value, to reflect.Value) (interface{}, error) {
31013102
if from.Kind() == reflect.Map {
31023103
stringMap := from.Interface().(map[string]interface{})
3104+
if stringMap == nil {
3105+
stringMap = make(map[string]interface{})
3106+
}
31033107
stringMap["when"] = "see you later"
31043108
return stringMap, nil
31053109
}
@@ -3248,6 +3252,67 @@ func TestDecoder_CanPerformDecodingForNilInputs(t *testing.T) {
32483252
}
32493253
}
32503254

3255+
func TestDecoder_ExpandNilStructPointersHookFunc(t *testing.T) {
3256+
// a decoder hook that expands nil pointers in a struct to their zero value
3257+
// if the input map contains the corresponding key.
3258+
decodeHook := func(from reflect.Value, to reflect.Value) (any, error) {
3259+
if from.Kind() == reflect.Map && to.Kind() == reflect.Map {
3260+
toElem := to.Type().Elem()
3261+
if toElem.Kind() == reflect.Ptr && toElem.Elem().Kind() == reflect.Struct {
3262+
fromRange := from.MapRange()
3263+
for fromRange.Next() {
3264+
fromKey := fromRange.Key()
3265+
fromValue := fromRange.Value()
3266+
if fromValue.IsNil() {
3267+
newFromValue := reflect.New(toElem.Elem())
3268+
from.SetMapIndex(fromKey, newFromValue)
3269+
}
3270+
}
3271+
}
3272+
}
3273+
return from.Interface(), nil
3274+
}
3275+
type Struct struct {
3276+
Name string
3277+
}
3278+
type TestConfig struct {
3279+
Boolean *bool `mapstructure:"boolean"`
3280+
Struct *Struct `mapstructure:"struct"`
3281+
MapStruct map[string]*Struct `mapstructure:"map_struct"`
3282+
}
3283+
stringMap := map[string]any{
3284+
"boolean": nil,
3285+
"struct": nil,
3286+
"map_struct": map[string]any{
3287+
"struct": nil,
3288+
},
3289+
}
3290+
var result TestConfig
3291+
decoder, err := NewDecoder(&DecoderConfig{
3292+
Result: &result,
3293+
DecodeNil: true,
3294+
DecodeHook: decodeHook,
3295+
})
3296+
if err != nil {
3297+
t.Fatalf("err: %s", err)
3298+
}
3299+
if err := decoder.Decode(stringMap); err != nil {
3300+
t.Fatalf("got an err: %s", err)
3301+
}
3302+
if result.Boolean != nil {
3303+
t.Errorf("nil Boolean expected, got '%#v'", result.Boolean)
3304+
}
3305+
if result.Struct != nil {
3306+
t.Errorf("nil Struct expected, got '%#v'", result.Struct)
3307+
}
3308+
if len(result.MapStruct) == 0 {
3309+
t.Fatalf("not-empty MapStruct expected, got '%#v'", result.MapStruct)
3310+
}
3311+
if _, ok := result.MapStruct["struct"]; !ok {
3312+
t.Errorf("MapStruct['struct'] expected")
3313+
}
3314+
}
3315+
32513316
func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) {
32523317
var result Slice
32533318
err := Decode(input, &result)

0 commit comments

Comments
 (0)