Skip to content

Commit befed6c

Browse files
Merge pull request #33 from cedric-cordenier/main
Port Succo's fix to cache decodeHooks
2 parents fd1be46 + 4807a3a commit befed6c

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

decode_hooks.go

+36-4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,30 @@ func typedDecodeHook(h DecodeHookFunc) DecodeHookFunc {
3636
return nil
3737
}
3838

39+
// cachedDecodeHook takes a raw DecodeHookFunc (an interface{}) and turns
40+
// it into a closure to be used directly
41+
// if the type fails to convert we return a closure always erroring to keep the previous behaviour
42+
func cachedDecodeHook(raw DecodeHookFunc) func(from reflect.Value, to reflect.Value) (interface{}, error) {
43+
switch f := typedDecodeHook(raw).(type) {
44+
case DecodeHookFuncType:
45+
return func(from reflect.Value, to reflect.Value) (interface{}, error) {
46+
return f(from.Type(), to.Type(), from.Interface())
47+
}
48+
case DecodeHookFuncKind:
49+
return func(from reflect.Value, to reflect.Value) (interface{}, error) {
50+
return f(from.Kind(), to.Kind(), from.Interface())
51+
}
52+
case DecodeHookFuncValue:
53+
return func(from reflect.Value, to reflect.Value) (interface{}, error) {
54+
return f(from, to)
55+
}
56+
default:
57+
return func(from reflect.Value, to reflect.Value) (interface{}, error) {
58+
return nil, errors.New("invalid decode hook signature")
59+
}
60+
}
61+
}
62+
3963
// DecodeHookExec executes the given decode hook. This should be used
4064
// since it'll naturally degrade to the older backwards compatible DecodeHookFunc
4165
// that took reflect.Kind instead of reflect.Type.
@@ -61,13 +85,17 @@ func DecodeHookExec(
6185
// The composed funcs are called in order, with the result of the
6286
// previous transformation.
6387
func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc {
88+
cached := make([]func(from reflect.Value, to reflect.Value) (interface{}, error), 0, len(fs))
89+
for _, f := range fs {
90+
cached = append(cached, cachedDecodeHook(f))
91+
}
6492
return func(f reflect.Value, t reflect.Value) (interface{}, error) {
6593
var err error
6694
data := f.Interface()
6795

6896
newFrom := f
69-
for _, f1 := range fs {
70-
data, err = DecodeHookExec(f1, newFrom, t)
97+
for _, c := range cached {
98+
data, err = c(newFrom, t)
7199
if err != nil {
72100
return nil, err
73101
}
@@ -81,13 +109,17 @@ func ComposeDecodeHookFunc(fs ...DecodeHookFunc) DecodeHookFunc {
81109
// OrComposeDecodeHookFunc executes all input hook functions until one of them returns no error. In that case its value is returned.
82110
// If all hooks return an error, OrComposeDecodeHookFunc returns an error concatenating all error messages.
83111
func OrComposeDecodeHookFunc(ff ...DecodeHookFunc) DecodeHookFunc {
112+
cached := make([]func(from reflect.Value, to reflect.Value) (interface{}, error), 0, len(ff))
113+
for _, f := range ff {
114+
cached = append(cached, cachedDecodeHook(f))
115+
}
84116
return func(a, b reflect.Value) (interface{}, error) {
85117
var allErrs string
86118
var out interface{}
87119
var err error
88120

89-
for _, f := range ff {
90-
out, err = DecodeHookExec(f, a, b)
121+
for _, c := range cached {
122+
out, err = c(a, b)
91123
if err != nil {
92124
allErrs += err.Error() + "\n"
93125
continue

mapstructure.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ type DecoderConfig struct {
283283
// structure. The top-level Decode method is just a convenience that sets
284284
// up the most basic Decoder.
285285
type Decoder struct {
286-
config *DecoderConfig
286+
config *DecoderConfig
287+
cachedDecodeHook func(from reflect.Value, to reflect.Value) (interface{}, error)
287288
}
288289

289290
// Metadata contains information about decoding a structure that
@@ -408,6 +409,9 @@ func NewDecoder(config *DecoderConfig) (*Decoder, error) {
408409
result := &Decoder{
409410
config: config,
410411
}
412+
if config.DecodeHook != nil {
413+
result.cachedDecodeHook = cachedDecodeHook(config.DecodeHook)
414+
}
411415

412416
return result, nil
413417
}
@@ -462,10 +466,10 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e
462466
return nil
463467
}
464468

465-
if d.config.DecodeHook != nil {
469+
if d.cachedDecodeHook != nil {
466470
// We have a DecodeHook, so let's pre-process the input.
467471
var err error
468-
input, err = DecodeHookExec(d.config.DecodeHook, inputVal, outVal)
472+
input, err = d.cachedDecodeHook(inputVal, outVal)
469473
if err != nil {
470474
return fmt.Errorf("error decoding '%s': %w", name, err)
471475
}

0 commit comments

Comments
 (0)