Skip to content

Commit c29fc28

Browse files
Merge pull request #45 from yurishkuro/decode-nil-fix-42
Pass appropriate empty Value to hooks
2 parents abbd7b4 + b5334ce commit c29fc28

File tree

2 files changed

+98
-23
lines changed

2 files changed

+98
-23
lines changed

mapstructure.go

+33-22
Original file line numberDiff line numberDiff line change
@@ -442,21 +442,26 @@ func (d *Decoder) Decode(input interface{}) error {
442442
return err
443443
}
444444

445+
// isNil returns true if the input is nil or a typed nil pointer.
446+
func isNil(input interface{}) bool {
447+
if input == nil {
448+
return true
449+
}
450+
val := reflect.ValueOf(input)
451+
return val.Kind() == reflect.Ptr && val.IsNil()
452+
}
453+
445454
// Decodes an unknown data type into a specific reflection value.
446455
func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) error {
447-
var inputVal reflect.Value
448-
if input != nil {
449-
inputVal = reflect.ValueOf(input)
450-
451-
// We need to check here if input is a typed nil. Typed nils won't
452-
// match the "input == nil" below so we check that here.
453-
if inputVal.Kind() == reflect.Ptr && inputVal.IsNil() {
454-
input = nil
455-
}
456+
var (
457+
inputVal = reflect.ValueOf(input)
458+
outputKind = getKind(outVal)
459+
decodeNil = d.config.DecodeNil && d.cachedDecodeHook != nil
460+
)
461+
if isNil(input) {
462+
// Typed nils won't match the "input == nil" below, so reset input.
463+
input = nil
456464
}
457-
458-
decodeNil := d.config.DecodeNil && d.config.DecodeHook != nil
459-
460465
if input == nil {
461466
// If the data is nil, then we don't set anything, unless ZeroFields is set
462467
// to true.
@@ -467,12 +472,10 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e
467472
d.config.Metadata.Keys = append(d.config.Metadata.Keys, name)
468473
}
469474
}
470-
471475
if !decodeNil {
472476
return nil
473477
}
474478
}
475-
476479
if !inputVal.IsValid() {
477480
if !decodeNil {
478481
// If the input value is invalid, then we just set the value
@@ -483,11 +486,17 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e
483486
}
484487
return nil
485488
}
486-
487-
// If we get here, we have an untyped nil so the type of the input is assumed.
488-
// We do this because all subsequent code requires a valid value for inputVal.
489-
var mapVal map[string]interface{}
490-
inputVal = reflect.MakeMap(reflect.TypeOf(mapVal))
489+
// Hooks need a valid inputVal, so reset it to zero value of outVal type.
490+
switch outputKind {
491+
case reflect.Struct, reflect.Map:
492+
var mapVal map[string]interface{}
493+
inputVal = reflect.ValueOf(mapVal) // create nil map pointer
494+
case reflect.Slice, reflect.Array:
495+
var sliceVal []interface{}
496+
inputVal = reflect.ValueOf(sliceVal) // create nil slice pointer
497+
default:
498+
inputVal = reflect.Zero(outVal.Type())
499+
}
491500
}
492501

493502
if d.cachedDecodeHook != nil {
@@ -498,9 +507,11 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e
498507
return fmt.Errorf("error decoding '%s': %w", name, err)
499508
}
500509
}
510+
if isNil(input) {
511+
return nil
512+
}
501513

502514
var err error
503-
outputKind := getKind(outVal)
504515
addMetaKey := true
505516
switch outputKind {
506517
case reflect.Bool:
@@ -781,8 +792,8 @@ func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) e
781792
}
782793
default:
783794
return fmt.Errorf(
784-
"'%s' expected type '%s', got unconvertible type '%s', value: '%v'",
785-
name, val.Type(), dataVal.Type(), data)
795+
"'%s' expected type '%s', got unconvertible type '%#v', value: '%#v'",
796+
name, val, dataVal, data)
786797
}
787798

788799
return nil

mapstructure_test.go

+65-1
Original file line numberDiff line numberDiff line change
@@ -3083,7 +3083,7 @@ func TestDecoder_IgnoreUntaggedFieldsWithStruct(t *testing.T) {
30833083
}
30843084
}
30853085

3086-
func TestDecoder_CanPerformDecodingForNilInputs(t *testing.T) {
3086+
func TestDecoder_DecodeNilOption(t *testing.T) {
30873087
t.Parallel()
30883088

30893089
type Transformed struct {
@@ -3100,6 +3100,9 @@ func TestDecoder_CanPerformDecodingForNilInputs(t *testing.T) {
31003100
appendHook := func(from reflect.Value, to reflect.Value) (interface{}, error) {
31013101
if from.Kind() == reflect.Map {
31023102
stringMap := from.Interface().(map[string]interface{})
3103+
if stringMap == nil {
3104+
stringMap = make(map[string]interface{})
3105+
}
31033106
stringMap["when"] = "see you later"
31043107
return stringMap, nil
31053108
}
@@ -3248,6 +3251,67 @@ func TestDecoder_CanPerformDecodingForNilInputs(t *testing.T) {
32483251
}
32493252
}
32503253

3254+
func TestDecoder_ExpandNilStructPointersHookFunc(t *testing.T) {
3255+
// a decoder hook that expands nil pointers in a struct to their zero value
3256+
// if the input map contains the corresponding key.
3257+
decodeHook := func(from reflect.Value, to reflect.Value) (any, error) {
3258+
if from.Kind() == reflect.Map && to.Kind() == reflect.Map {
3259+
toElem := to.Type().Elem()
3260+
if toElem.Kind() == reflect.Ptr && toElem.Elem().Kind() == reflect.Struct {
3261+
fromRange := from.MapRange()
3262+
for fromRange.Next() {
3263+
fromKey := fromRange.Key()
3264+
fromValue := fromRange.Value()
3265+
if fromValue.IsNil() {
3266+
newFromValue := reflect.New(toElem.Elem())
3267+
from.SetMapIndex(fromKey, newFromValue)
3268+
}
3269+
}
3270+
}
3271+
}
3272+
return from.Interface(), nil
3273+
}
3274+
type Struct struct {
3275+
Name string
3276+
}
3277+
type TestConfig struct {
3278+
Boolean *bool `mapstructure:"boolean"`
3279+
Struct *Struct `mapstructure:"struct"`
3280+
MapStruct map[string]*Struct `mapstructure:"map_struct"`
3281+
}
3282+
stringMap := map[string]any{
3283+
"boolean": nil,
3284+
"struct": nil,
3285+
"map_struct": map[string]any{
3286+
"struct": nil,
3287+
},
3288+
}
3289+
var result TestConfig
3290+
decoder, err := NewDecoder(&DecoderConfig{
3291+
Result: &result,
3292+
DecodeNil: true,
3293+
DecodeHook: decodeHook,
3294+
})
3295+
if err != nil {
3296+
t.Fatalf("err: %s", err)
3297+
}
3298+
if err := decoder.Decode(stringMap); err != nil {
3299+
t.Fatalf("got an err: %s", err)
3300+
}
3301+
if result.Boolean != nil {
3302+
t.Errorf("nil Boolean expected, got '%#v'", result.Boolean)
3303+
}
3304+
if result.Struct != nil {
3305+
t.Errorf("nil Struct expected, got '%#v'", result.Struct)
3306+
}
3307+
if len(result.MapStruct) == 0 {
3308+
t.Fatalf("not-empty MapStruct expected, got '%#v'", result.MapStruct)
3309+
}
3310+
if _, ok := result.MapStruct["struct"]; !ok {
3311+
t.Errorf("MapStruct['struct'] expected")
3312+
}
3313+
}
3314+
32513315
func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) {
32523316
var result Slice
32533317
err := Decode(input, &result)

0 commit comments

Comments
 (0)