diff --git a/decode.go b/decode.go index d29509ce..09523315 100644 --- a/decode.go +++ b/decode.go @@ -305,10 +305,10 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { } func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { - if k := rv.Type().Key().Kind(); k != reflect.String { - return fmt.Errorf( - "toml: cannot decode to a map with non-string key type (%s in %q)", - k, rv.Type()) + keyType := rv.Type().Key().Kind() + if keyType != reflect.String && keyType != reflect.Interface { + return fmt.Errorf("toml: cannot decode to a map with non-string key type (%s in %q)", + keyType, rv.Type()) } tmap, ok := mapping.(map[string]interface{}) @@ -334,7 +334,14 @@ func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { md.context = md.context[0 : len(md.context)-1] rvkey := indirect(reflect.New(rv.Type().Key())) - rvkey.SetString(k) + + switch keyType { + case reflect.Interface: + rvkey.Set(reflect.ValueOf(k)) + case reflect.String: + rvkey.SetString(k) + } + rv.SetMapIndex(rvkey, rvval) } return nil diff --git a/decode_test.go b/decode_test.go index fd9bfee3..e27c4351 100644 --- a/decode_test.go +++ b/decode_test.go @@ -430,38 +430,56 @@ func TestDecodeSizedInts(t *testing.T) { type NopUnmarshalTOML int -func (NopUnmarshalTOML) UnmarshalTOML(p interface{}) error { return nil } +func (n *NopUnmarshalTOML) UnmarshalTOML(p interface{}) error { + *n = 42 + return nil +} func TestDecodeTypes(t *testing.T) { - type mystr string + type ( + mystr string + myiface interface{} + ) for _, tt := range []struct { - v interface{} - want string + v interface{} + want string + wantErr string }{ - {new(map[string]bool), ""}, - {new(map[mystr]bool), ""}, - {new(NopUnmarshalTOML), ""}, + {new(map[string]bool), "&map[F:true]", ""}, + {new(map[mystr]bool), "&map[F:true]", ""}, + {new(NopUnmarshalTOML), "42", ""}, + {new(map[interface{}]bool), "&map[F:true]", ""}, + {new(map[myiface]bool), "&map[F:true]", ""}, - {3, `toml: cannot decode to non-pointer "int"`}, - {map[string]interface{}{}, `toml: cannot decode to non-pointer "map[string]interface {}"`}, + {3, "", `toml: cannot decode to non-pointer "int"`}, + {map[string]interface{}{}, "", `toml: cannot decode to non-pointer "map[string]interface {}"`}, - {(*int)(nil), `toml: cannot decode to nil value of "*int"`}, - {(*Unmarshaler)(nil), `toml: cannot decode to nil value of "*toml.Unmarshaler"`}, - {nil, `toml: cannot decode to non-pointer `}, + {(*int)(nil), "", `toml: cannot decode to nil value of "*int"`}, + {(*Unmarshaler)(nil), "", `toml: cannot decode to nil value of "*toml.Unmarshaler"`}, + {nil, "", `toml: cannot decode to non-pointer `}, - {new(map[int]string), "toml: cannot decode to a map with non-string key type"}, - {new(map[interface{}]string), "toml: cannot decode to a map with non-string key type"}, + {new(map[int]string), "", "toml: cannot decode to a map with non-string key type"}, - {new(struct{ F int }), `toml: line 1 (last key "F"): incompatible types: TOML value has type bool; destination has type integer`}, - {new(map[string]int), `toml: line 1 (last key "F"): incompatible types: TOML value has type bool; destination has type integer`}, - {new(int), `toml: cannot decode to type int`}, - {new([]int), "toml: cannot decode to type []int"}, + {new(struct{ F int }), "", `toml: line 1 (last key "F"): incompatible types: TOML value has type bool; destination has type integer`}, + {new(map[string]int), "", `toml: line 1 (last key "F"): incompatible types: TOML value has type bool; destination has type integer`}, + {new(int), "", `toml: cannot decode to type int`}, + {new([]int), "", "toml: cannot decode to type []int"}, } { t.Run(fmt.Sprintf("%T", tt.v), func(t *testing.T) { _, err := Decode(`F = true`, tt.v) - if !errorContains(err, tt.want) { - t.Errorf("wrong error\nhave: %q\nwant: %q", err, tt.want) + if !errorContains(err, tt.wantErr) { + t.Fatalf("wrong error\nhave: %q\nwant: %q", err, tt.wantErr) + } + + if err == nil { + have := fmt.Sprintf("%v", tt.v) + if n, ok := tt.v.(*NopUnmarshalTOML); ok { + have = fmt.Sprintf("%v", *n) + } + if have != tt.want { + t.Errorf("\nhave: %s\nwant: %s", have, tt.want) + } } }) }