Skip to content

Commit 06fb30b

Browse files
authored
Decode: fix reuse of slice for array tables (#934)
When decoding into a non-empty slice, it needs to be emptied so that only the tables contained in the document are present in the resulting value. Arrays are not impacted because their unmarshal offset is tracked separately. Fixes #931
1 parent 2e087bd commit 06fb30b

File tree

3 files changed

+119
-31
lines changed

3 files changed

+119
-31
lines changed

internal/tracker/seen.go

+33-29
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,9 @@ func (s *SeenTracker) setExplicitFlag(parentIdx int) {
149149

150150
// CheckExpression takes a top-level node and checks that it does not contain
151151
// keys that have been seen in previous calls, and validates that types are
152-
// consistent.
153-
func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
152+
// consistent. It returns true if it is the first time this node's key is seen.
153+
// Useful to clear array tables on first use.
154+
func (s *SeenTracker) CheckExpression(node *unstable.Node) (bool, error) {
154155
if s.entries == nil {
155156
s.reset()
156157
}
@@ -166,7 +167,7 @@ func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
166167
}
167168
}
168169

169-
func (s *SeenTracker) checkTable(node *unstable.Node) error {
170+
func (s *SeenTracker) checkTable(node *unstable.Node) (bool, error) {
170171
if s.currentIdx >= 0 {
171172
s.setExplicitFlag(s.currentIdx)
172173
}
@@ -192,7 +193,7 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error {
192193
} else {
193194
entry := s.entries[idx]
194195
if entry.kind == valueKind {
195-
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
196+
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
196197
}
197198
}
198199
parentIdx = idx
@@ -201,25 +202,27 @@ func (s *SeenTracker) checkTable(node *unstable.Node) error {
201202
k := it.Node().Data
202203
idx := s.find(parentIdx, k)
203204

205+
first := false
204206
if idx >= 0 {
205207
kind := s.entries[idx].kind
206208
if kind != tableKind {
207-
return fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind)
209+
return false, fmt.Errorf("toml: key %s should be a table, not a %s", string(k), kind)
208210
}
209211
if s.entries[idx].explicit {
210-
return fmt.Errorf("toml: table %s already exists", string(k))
212+
return false, fmt.Errorf("toml: table %s already exists", string(k))
211213
}
212214
s.entries[idx].explicit = true
213215
} else {
214216
idx = s.create(parentIdx, k, tableKind, true, false)
217+
first = true
215218
}
216219

217220
s.currentIdx = idx
218221

219-
return nil
222+
return first, nil
220223
}
221224

222-
func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
225+
func (s *SeenTracker) checkArrayTable(node *unstable.Node) (bool, error) {
223226
if s.currentIdx >= 0 {
224227
s.setExplicitFlag(s.currentIdx)
225228
}
@@ -242,7 +245,7 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
242245
} else {
243246
entry := s.entries[idx]
244247
if entry.kind == valueKind {
245-
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
248+
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
246249
}
247250
}
248251

@@ -252,22 +255,23 @@ func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
252255
k := it.Node().Data
253256
idx := s.find(parentIdx, k)
254257

255-
if idx >= 0 {
258+
firstTime := idx < 0
259+
if firstTime {
260+
idx = s.create(parentIdx, k, arrayTableKind, true, false)
261+
} else {
256262
kind := s.entries[idx].kind
257263
if kind != arrayTableKind {
258-
return fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k))
264+
return false, fmt.Errorf("toml: key %s already exists as a %s, but should be an array table", kind, string(k))
259265
}
260266
s.clear(idx)
261-
} else {
262-
idx = s.create(parentIdx, k, arrayTableKind, true, false)
263267
}
264268

265269
s.currentIdx = idx
266270

267-
return nil
271+
return firstTime, nil
268272
}
269273

270-
func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
274+
func (s *SeenTracker) checkKeyValue(node *unstable.Node) (bool, error) {
271275
parentIdx := s.currentIdx
272276
it := node.Key()
273277

@@ -281,11 +285,11 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
281285
} else {
282286
entry := s.entries[idx]
283287
if it.IsLast() {
284-
return fmt.Errorf("toml: key %s is already defined", string(k))
288+
return false, fmt.Errorf("toml: key %s is already defined", string(k))
285289
} else if entry.kind != tableKind {
286-
return fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
290+
return false, fmt.Errorf("toml: expected %s to be a table, not a %s", string(k), entry.kind)
287291
} else if entry.explicit {
288-
return fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
292+
return false, fmt.Errorf("toml: cannot redefine table %s that has already been explicitly defined", string(k))
289293
}
290294
}
291295

@@ -303,30 +307,30 @@ func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
303307
return s.checkArray(value)
304308
}
305309

306-
return nil
310+
return false, nil
307311
}
308312

309-
func (s *SeenTracker) checkArray(node *unstable.Node) error {
313+
func (s *SeenTracker) checkArray(node *unstable.Node) (first bool, err error) {
310314
it := node.Children()
311315
for it.Next() {
312316
n := it.Node()
313317
switch n.Kind {
314318
case unstable.InlineTable:
315-
err := s.checkInlineTable(n)
319+
first, err = s.checkInlineTable(n)
316320
if err != nil {
317-
return err
321+
return false, err
318322
}
319323
case unstable.Array:
320-
err := s.checkArray(n)
324+
first, err = s.checkArray(n)
321325
if err != nil {
322-
return err
326+
return false, err
323327
}
324328
}
325329
}
326-
return nil
330+
return first, nil
327331
}
328332

329-
func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
333+
func (s *SeenTracker) checkInlineTable(node *unstable.Node) (first bool, err error) {
330334
if pool.New == nil {
331335
pool.New = func() interface{} {
332336
return &SeenTracker{}
@@ -339,9 +343,9 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
339343
it := node.Children()
340344
for it.Next() {
341345
n := it.Node()
342-
err := s.checkKeyValue(n)
346+
first, err = s.checkKeyValue(n)
343347
if err != nil {
344-
return err
348+
return false, err
345349
}
346350
}
347351

@@ -352,5 +356,5 @@ func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
352356
// redefinition of its keys: check* functions cannot walk into
353357
// a value.
354358
pool.Put(s)
355-
return nil
359+
return first, nil
356360
}

unmarshaler.go

+16-2
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ type decoder struct {
127127
// need to be skipped.
128128
skipUntilTable bool
129129

130+
// Flag indicating that the current array/slice table should be cleared because
131+
// it is the first encounter of an array table.
132+
clearArrayTable bool
133+
130134
// Tracks position in Go arrays.
131135
// This is used when decoding [[array tables]] into Go arrays. Given array
132136
// tables are separate TOML expression, we need to keep track of where we
@@ -246,9 +250,10 @@ Rules for the unmarshal code:
246250
func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error {
247251
var x reflect.Value
248252
var err error
253+
var first bool // used for to clear array tables on first use
249254

250255
if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) {
251-
err = d.seen.CheckExpression(expr)
256+
first, err = d.seen.CheckExpression(expr)
252257
if err != nil {
253258
return err
254259
}
@@ -267,6 +272,7 @@ func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) err
267272
case unstable.ArrayTable:
268273
d.skipUntilTable = false
269274
d.strict.EnterArrayTable(expr)
275+
d.clearArrayTable = first
270276
x, err = d.handleArrayTable(expr.Key(), v)
271277
default:
272278
panic(fmt.Errorf("parser should not permit expression of kind %s at document root", expr.Kind))
@@ -307,6 +313,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec
307313
reflect.Copy(nelem, elem)
308314
elem = nelem
309315
}
316+
if d.clearArrayTable && elem.Len() > 0 {
317+
elem.SetLen(0)
318+
d.clearArrayTable = false
319+
}
310320
}
311321
return d.handleArrayTableCollectionLast(key, elem)
312322
case reflect.Ptr:
@@ -325,6 +335,10 @@ func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflec
325335

326336
return v, nil
327337
case reflect.Slice:
338+
if d.clearArrayTable && v.Len() > 0 {
339+
v.SetLen(0)
340+
d.clearArrayTable = false
341+
}
328342
elemType := v.Type().Elem()
329343
var elem reflect.Value
330344
if elemType.Kind() == reflect.Interface {
@@ -576,7 +590,7 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
576590
break
577591
}
578592

579-
err := d.seen.CheckExpression(expr)
593+
_, err := d.seen.CheckExpression(expr)
580594
if err != nil {
581595
return reflect.Value{}, err
582596
}

unmarshaler_test.go

+70
Original file line numberDiff line numberDiff line change
@@ -2823,6 +2823,76 @@ blah.a = "def"`)
28232823
require.Equal(t, "def", cfg.A)
28242824
}
28252825

2826+
func TestIssue931(t *testing.T) {
2827+
type item struct {
2828+
Name string
2829+
}
2830+
2831+
type items struct {
2832+
Slice []item
2833+
}
2834+
2835+
its := items{[]item{{"a"}, {"b"}}}
2836+
2837+
b := []byte(`
2838+
[[Slice]]
2839+
Name = 'c'
2840+
2841+
[[Slice]]
2842+
Name = 'd'
2843+
`)
2844+
2845+
toml.Unmarshal(b, &its)
2846+
require.Equal(t, items{[]item{{"c"}, {"d"}}}, its)
2847+
}
2848+
2849+
func TestIssue931Interface(t *testing.T) {
2850+
type items struct {
2851+
Slice interface{}
2852+
}
2853+
2854+
type item = map[string]interface{}
2855+
2856+
its := items{[]interface{}{item{"Name": "a"}, item{"Name": "b"}}}
2857+
2858+
b := []byte(`
2859+
[[Slice]]
2860+
Name = 'c'
2861+
2862+
[[Slice]]
2863+
Name = 'd'
2864+
`)
2865+
2866+
toml.Unmarshal(b, &its)
2867+
require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its)
2868+
}
2869+
2870+
func TestIssue931SliceInterface(t *testing.T) {
2871+
type items struct {
2872+
Slice []interface{}
2873+
}
2874+
2875+
type item = map[string]interface{}
2876+
2877+
its := items{
2878+
[]interface{}{
2879+
item{"Name": "a"},
2880+
item{"Name": "b"},
2881+
},
2882+
}
2883+
2884+
b := []byte(`
2885+
[[Slice]]
2886+
Name = 'c'
2887+
2888+
[[Slice]]
2889+
Name = 'd'
2890+
`)
2891+
2892+
toml.Unmarshal(b, &its)
2893+
require.Equal(t, items{[]interface{}{item{"Name": "c"}, item{"Name": "d"}}}, its)
2894+
}
2895+
28262896
func TestUnmarshalDecodeErrors(t *testing.T) {
28272897
examples := []struct {
28282898
desc string

0 commit comments

Comments
 (0)