Skip to content

Commit d56d9f6

Browse files
committed
Allow MarshalTOML and MarshalText to be used on the document type itself
Fixes #383
1 parent 2967a1e commit d56d9f6

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

encode.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ func NewEncoder(w io.Writer) *Encoder {
136136
// document.
137137
func (enc *Encoder) Encode(v interface{}) error {
138138
rv := eindirect(reflect.ValueOf(v))
139+
140+
// XXX
141+
139142
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
140143
return err
141144
}
@@ -693,8 +696,11 @@ func (enc *Encoder) newline() {
693696
// v v v v vv
694697
// key = {k = 1, k2 = 2}
695698
func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) {
699+
/// Marshaler used on top-level document; call eElement() to just call
700+
/// Marshal{TOML,Text}.
696701
if len(key) == 0 {
697-
encPanic(errNoKey)
702+
enc.eElement(val)
703+
return
698704
}
699705
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
700706
enc.eElement(val)

encode_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,6 +1261,43 @@ c = 3
12611261
}
12621262
}
12631263

1264+
type (
1265+
Doc1 struct{ N string }
1266+
Doc2 struct{ N string }
1267+
)
1268+
1269+
func (d Doc1) MarshalTOML() ([]byte, error) { return []byte(`marshal_toml = "` + d.N + `"`), nil }
1270+
func (d Doc2) MarshalText() ([]byte, error) { return []byte(`marshal_text = "` + d.N + `"`), nil }
1271+
1272+
// MarshalTOML and MarshalText on the top level type, rather than a field.
1273+
func TestMarshalDoc(t *testing.T) {
1274+
t.Run("toml", func(t *testing.T) {
1275+
var buf bytes.Buffer
1276+
err := NewEncoder(&buf).Encode(Doc1{"asd"})
1277+
if err != nil {
1278+
t.Fatal(err)
1279+
}
1280+
1281+
want := `marshal_toml = "asd"`
1282+
if want != buf.String() {
1283+
t.Errorf("\nhave: %s\nwant: %s\n", buf.String(), want)
1284+
}
1285+
})
1286+
1287+
t.Run("text", func(t *testing.T) {
1288+
var buf bytes.Buffer
1289+
err := NewEncoder(&buf).Encode(Doc2{"asd"})
1290+
if err != nil {
1291+
t.Fatal(err)
1292+
}
1293+
1294+
want := `"marshal_text = \"asd\""`
1295+
if want != buf.String() {
1296+
t.Errorf("\nhave: %s\nwant: %s\n", buf.String(), want)
1297+
}
1298+
})
1299+
}
1300+
12641301
func encodeExpected(t *testing.T, label string, val interface{}, want string, wantErr error) {
12651302
t.Helper()
12661303
t.Run(label, func(t *testing.T) {

0 commit comments

Comments
 (0)