Skip to content

Commit ffbd31e

Browse files
augustoromanbradfitz
authored andcommitted
encoding/json: allow non-string type keys for (un-)marshal
This CL allows JSON-encoding & -decoding maps whose keys are types that implement encoding.TextMarshaler / TextUnmarshaler. During encode, the map keys are marshaled upfront so that they can be sorted. Fixes #12146 Change-Id: I43809750a7ad82a3603662f095c7baf75fd172da Reviewed-on: https://go-review.googlesource.com/20356 Run-TryBot: Caleb Spare <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent acefcb7 commit ffbd31e

File tree

4 files changed

+123
-40
lines changed

4 files changed

+123
-40
lines changed

Diff for: src/encoding/json/decode.go

+24-8
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,11 @@ import (
6161
// If the JSON array is smaller than the Go array,
6262
// the additional Go array elements are set to zero values.
6363
//
64-
// To unmarshal a JSON object into a string-keyed map, Unmarshal first
65-
// establishes a map to use, If the map is nil, Unmarshal allocates a new map.
66-
// Otherwise Unmarshal reuses the existing map, keeping existing entries.
67-
// Unmarshal then stores key-value pairs from the JSON object into the map.
64+
// To unmarshal a JSON object into a map, Unmarshal first establishes a map to
65+
// use, If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal
66+
// reuses the existing map, keeping existing entries. Unmarshal then stores key-
67+
// value pairs from the JSON object into the map. The map's key type must
68+
// either be a string or implement encoding.TextUnmarshaler.
6869
//
6970
// If a JSON value is not appropriate for a given target type,
7071
// or if a JSON number overflows the target type, Unmarshal
@@ -549,6 +550,7 @@ func (d *decodeState) array(v reflect.Value) {
549550
}
550551

551552
var nullLiteral = []byte("null")
553+
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
552554

553555
// object consumes an object from d.data[d.off-1:], decoding into the value v.
554556
// the first byte ('{') of the object has been read already.
@@ -577,12 +579,15 @@ func (d *decodeState) object(v reflect.Value) {
577579
return
578580
}
579581

580-
// Check type of target: struct or map[string]T
582+
// Check type of target:
583+
// struct or
584+
// map[string]T or map[encoding.TextUnmarshaler]T
581585
switch v.Kind() {
582586
case reflect.Map:
583-
// map must have string kind
587+
// Map key must either have string kind or be an encoding.TextUnmarshaler.
584588
t := v.Type()
585-
if t.Key().Kind() != reflect.String {
589+
if t.Key().Kind() != reflect.String &&
590+
!reflect.PtrTo(t.Key()).Implements(textUnmarshalerType) {
586591
d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)})
587592
d.off--
588593
d.next() // skip over { } in input
@@ -687,7 +692,18 @@ func (d *decodeState) object(v reflect.Value) {
687692
// Write value back to map;
688693
// if using struct, subv points into struct already.
689694
if v.Kind() == reflect.Map {
690-
kv := reflect.ValueOf(key).Convert(v.Type().Key())
695+
kt := v.Type().Key()
696+
var kv reflect.Value
697+
switch {
698+
case kt.Kind() == reflect.String:
699+
kv = reflect.ValueOf(key).Convert(v.Type().Key())
700+
case reflect.PtrTo(kt).Implements(textUnmarshalerType):
701+
kv = reflect.New(v.Type().Key())
702+
d.literalStore(item, kv, true)
703+
kv = kv.Elem()
704+
default:
705+
panic("json: Unexpected key type") // should never occur
706+
}
691707
v.SetMapIndex(kv, subv)
692708
}
693709

Diff for: src/encoding/json/decode_test.go

+44-18
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package json
77
import (
88
"bytes"
99
"encoding"
10+
"errors"
1011
"fmt"
1112
"image"
1213
"net"
@@ -68,16 +69,20 @@ type ustruct struct {
6869
}
6970

7071
type unmarshalerText struct {
71-
T bool
72+
A, B string
7273
}
7374

7475
// needed for re-marshaling tests
75-
func (u *unmarshalerText) MarshalText() ([]byte, error) {
76-
return []byte(""), nil
76+
func (u unmarshalerText) MarshalText() ([]byte, error) {
77+
return []byte(u.A + ":" + u.B), nil
7778
}
7879

7980
func (u *unmarshalerText) UnmarshalText(b []byte) error {
80-
*u = unmarshalerText{true} // All we need to see that UnmarshalText is called.
81+
pos := bytes.Index(b, []byte(":"))
82+
if pos == -1 {
83+
return errors.New("missing separator")
84+
}
85+
u.A, u.B = string(b[:pos]), string(b[pos+1:])
8186
return nil
8287
}
8388

@@ -95,12 +100,16 @@ var (
95100
umslicep = new([]unmarshaler)
96101
umstruct = ustruct{unmarshaler{true}}
97102

98-
um0T, um1T unmarshalerText // target2 of unmarshaling
99-
umpT = &um1T
100-
umtrueT = unmarshalerText{true}
101-
umsliceT = []unmarshalerText{{true}}
102-
umslicepT = new([]unmarshalerText)
103-
umstructT = ustructText{unmarshalerText{true}}
103+
um0T, um1T unmarshalerText // target2 of unmarshaling
104+
umpType = &um1T
105+
umtrueXY = unmarshalerText{"x", "y"}
106+
umsliceXY = []unmarshalerText{{"x", "y"}}
107+
umslicepType = new([]unmarshalerText)
108+
umstructType = new(ustructText)
109+
umstructXY = ustructText{unmarshalerText{"x", "y"}}
110+
111+
ummapType = map[unmarshalerText]bool{}
112+
ummapXY = map[unmarshalerText]bool{unmarshalerText{"x", "y"}: true}
104113
)
105114

106115
// Test data structures for anonymous fields.
@@ -302,14 +311,19 @@ var unmarshalTests = []unmarshalTest{
302311
{in: `{"T":false}`, ptr: &ump, out: &umtrue},
303312
{in: `[{"T":false}]`, ptr: &umslice, out: umslice},
304313
{in: `[{"T":false}]`, ptr: &umslicep, out: &umslice},
305-
{in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct},
314+
{in: `{"M":{"T":"x:y"}}`, ptr: &umstruct, out: umstruct},
306315

307316
// UnmarshalText interface test
308-
{in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called
309-
{in: `"X"`, ptr: &umpT, out: &umtrueT},
310-
{in: `["X"]`, ptr: &umsliceT, out: umsliceT},
311-
{in: `["X"]`, ptr: &umslicepT, out: &umsliceT},
312-
{in: `{"M":"X"}`, ptr: &umstructT, out: umstructT},
317+
{in: `"x:y"`, ptr: &um0T, out: umtrueXY},
318+
{in: `"x:y"`, ptr: &umpType, out: &umtrueXY},
319+
{in: `["x:y"]`, ptr: &umsliceXY, out: umsliceXY},
320+
{in: `["x:y"]`, ptr: &umslicepType, out: &umsliceXY},
321+
{in: `{"M":"x:y"}`, ptr: umstructType, out: umstructXY},
322+
323+
// Map keys can be encoding.TextUnmarshalers
324+
{in: `{"x:y":true}`, ptr: &ummapType, out: ummapXY},
325+
// If multiple values for the same key exists, only the most recent value is used.
326+
{in: `{"x:y":false,"x:y":true}`, ptr: &ummapType, out: ummapXY},
313327

314328
// Overwriting of data.
315329
// This is different from package xml, but it's what we've always done.
@@ -426,11 +440,23 @@ var unmarshalTests = []unmarshalTest{
426440
out: "hello\ufffd\ufffd\ufffd\ufffd\ufffd\ufffdworld",
427441
},
428442

429-
// issue 8305
443+
// Used to be issue 8305, but time.Time implements encoding.TextUnmarshaler so this works now.
430444
{
431445
in: `{"2009-11-10T23:00:00Z": "hello world"}`,
432446
ptr: &map[time.Time]string{},
433-
err: &UnmarshalTypeError{"object", reflect.TypeOf(map[time.Time]string{}), 1},
447+
out: map[time.Time]string{time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC): "hello world"},
448+
},
449+
450+
// issue 8305
451+
{
452+
in: `{"2009-11-10T23:00:00Z": "hello world"}`,
453+
ptr: &map[Point]string{},
454+
err: &UnmarshalTypeError{"object", reflect.TypeOf(map[Point]string{}), 1},
455+
},
456+
{
457+
in: `{"asdf": "hello world"}`,
458+
ptr: &map[unmarshaler]string{},
459+
err: &UnmarshalTypeError{"object", reflect.TypeOf(map[unmarshaler]string{}), 1},
434460
},
435461
}
436462

Diff for: src/encoding/json/encode.go

+39-14
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ import (
116116
// an anonymous struct field in both current and earlier versions, give the field
117117
// a JSON tag of "-".
118118
//
119-
// Map values encode as JSON objects.
120-
// The map's key type must be string; the map keys are used as JSON object
119+
// Map values encode as JSON objects. The map's key type must either be a string
120+
// or implement encoding.TextMarshaler. The map keys are used as JSON object
121121
// keys, subject to the UTF-8 coercion described for string values above.
122122
//
123123
// Pointer values encode as the value pointed to.
@@ -611,21 +611,31 @@ func (me *mapEncoder) encode(e *encodeState, v reflect.Value, _ bool) {
611611
return
612612
}
613613
e.WriteByte('{')
614-
var sv stringValues = v.MapKeys()
615-
sort.Sort(sv)
616-
for i, k := range sv {
614+
615+
// Extract and sort the keys.
616+
keys := v.MapKeys()
617+
sv := make([]reflectWithString, len(keys))
618+
for i, v := range keys {
619+
sv[i].v = v
620+
if err := sv[i].resolve(); err != nil {
621+
e.error(&MarshalerError{v.Type(), err})
622+
}
623+
}
624+
sort.Sort(byString(sv))
625+
626+
for i, kv := range sv {
617627
if i > 0 {
618628
e.WriteByte(',')
619629
}
620-
e.string(k.String())
630+
e.string(kv.s)
621631
e.WriteByte(':')
622-
me.elemEnc(e, v.MapIndex(k), false)
632+
me.elemEnc(e, v.MapIndex(kv.v), false)
623633
}
624634
e.WriteByte('}')
625635
}
626636

627637
func newMapEncoder(t reflect.Type) encoderFunc {
628-
if t.Key().Kind() != reflect.String {
638+
if t.Key().Kind() != reflect.String && !t.Key().Implements(textMarshalerType) {
629639
return unsupportedTypeEncoder
630640
}
631641
me := &mapEncoder{typeEncoder(t.Elem())}
@@ -775,14 +785,29 @@ func typeByIndex(t reflect.Type, index []int) reflect.Type {
775785
return t
776786
}
777787

778-
// stringValues is a slice of reflect.Value holding *reflect.StringValue.
788+
type reflectWithString struct {
789+
v reflect.Value
790+
s string
791+
}
792+
793+
func (w *reflectWithString) resolve() error {
794+
if w.v.Kind() == reflect.String {
795+
w.s = w.v.String()
796+
return nil
797+
}
798+
buf, err := w.v.Interface().(encoding.TextMarshaler).MarshalText()
799+
w.s = string(buf)
800+
return err
801+
}
802+
803+
// byString is a slice of reflectWithString where the reflect.Value is either
804+
// a string or an encoding.TextMarshaler.
779805
// It implements the methods to sort by string.
780-
type stringValues []reflect.Value
806+
type byString []reflectWithString
781807

782-
func (sv stringValues) Len() int { return len(sv) }
783-
func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
784-
func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
785-
func (sv stringValues) get(i int) string { return sv[i].String() }
808+
func (sv byString) Len() int { return len(sv) }
809+
func (sv byString) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
810+
func (sv byString) Less(i, j int) bool { return sv[i].s < sv[j].s }
786811

787812
// NOTE: keep in sync with stringBytes below.
788813
func (e *encodeState) string(s string) int {

Diff for: src/encoding/json/encode_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -536,3 +536,19 @@ func TestEncodeString(t *testing.T) {
536536
}
537537
}
538538
}
539+
540+
func TestTextMarshalerMapKeysAreSorted(t *testing.T) {
541+
b, err := Marshal(map[unmarshalerText]int{
542+
{"x", "y"}: 1,
543+
{"y", "x"}: 2,
544+
{"a", "z"}: 3,
545+
{"z", "a"}: 4,
546+
})
547+
if err != nil {
548+
t.Fatalf("Failed to Marshal text.Marshaler: %v", err)
549+
}
550+
const want = `{"a:z":3,"x:y":1,"y:x":2,"z:a":4}`
551+
if string(b) != want {
552+
t.Errorf("Marshal map with text.Marshaler keys: got %#q, want %#q", b, want)
553+
}
554+
}

0 commit comments

Comments
 (0)