Skip to content

Commit

Permalink
Allow using interface{} as map keys when decoding
Browse files Browse the repository at this point in the history
Ref: #181
  • Loading branch information
arp242 committed Jun 25, 2022
1 parent f0ccf71 commit e2f6fa2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 25 deletions.
17 changes: 12 additions & 5 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,10 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
}

func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
if k := rv.Type().Key().Kind(); k != reflect.String {
return fmt.Errorf(
"toml: cannot decode to a map with non-string key type (%s in %q)",
k, rv.Type())
keyType := rv.Type().Key().Kind()
if keyType != reflect.String && keyType != reflect.Interface {
return fmt.Errorf("toml: cannot decode to a map with non-string key type (%s in %q)",
keyType, rv.Type())
}

tmap, ok := mapping.(map[string]interface{})
Expand All @@ -334,7 +334,14 @@ func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
md.context = md.context[0 : len(md.context)-1]

rvkey := indirect(reflect.New(rv.Type().Key()))
rvkey.SetString(k)

switch keyType {
case reflect.Interface:
rvkey.Set(reflect.ValueOf(k))
case reflect.String:
rvkey.SetString(k)
}

rv.SetMapIndex(rvkey, rvval)
}
return nil
Expand Down
58 changes: 38 additions & 20 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,38 +430,56 @@ func TestDecodeSizedInts(t *testing.T) {

type NopUnmarshalTOML int

func (NopUnmarshalTOML) UnmarshalTOML(p interface{}) error { return nil }
func (n *NopUnmarshalTOML) UnmarshalTOML(p interface{}) error {
*n = 42
return nil
}

func TestDecodeTypes(t *testing.T) {
type mystr string
type (
mystr string
myiface interface{}
)

for _, tt := range []struct {
v interface{}
want string
v interface{}
want string
wantErr string
}{
{new(map[string]bool), ""},
{new(map[mystr]bool), ""},
{new(NopUnmarshalTOML), ""},
{new(map[string]bool), "&map[F:true]", ""},
{new(map[mystr]bool), "&map[F:true]", ""},
{new(NopUnmarshalTOML), "42", ""},
{new(map[interface{}]bool), "&map[F:true]", ""},
{new(map[myiface]bool), "&map[F:true]", ""},

{3, `toml: cannot decode to non-pointer "int"`},
{map[string]interface{}{}, `toml: cannot decode to non-pointer "map[string]interface {}"`},
{3, "", `toml: cannot decode to non-pointer "int"`},
{map[string]interface{}{}, "", `toml: cannot decode to non-pointer "map[string]interface {}"`},

{(*int)(nil), `toml: cannot decode to nil value of "*int"`},
{(*Unmarshaler)(nil), `toml: cannot decode to nil value of "*toml.Unmarshaler"`},
{nil, `toml: cannot decode to non-pointer <nil>`},
{(*int)(nil), "", `toml: cannot decode to nil value of "*int"`},
{(*Unmarshaler)(nil), "", `toml: cannot decode to nil value of "*toml.Unmarshaler"`},
{nil, "", `toml: cannot decode to non-pointer <nil>`},

{new(map[int]string), "toml: cannot decode to a map with non-string key type"},
{new(map[interface{}]string), "toml: cannot decode to a map with non-string key type"},
{new(map[int]string), "", "toml: cannot decode to a map with non-string key type"},

{new(struct{ F int }), `toml: line 1 (last key "F"): incompatible types: TOML value has type bool; destination has type integer`},
{new(map[string]int), `toml: line 1 (last key "F"): incompatible types: TOML value has type bool; destination has type integer`},
{new(int), `toml: cannot decode to type int`},
{new([]int), "toml: cannot decode to type []int"},
{new(struct{ F int }), "", `toml: line 1 (last key "F"): incompatible types: TOML value has type bool; destination has type integer`},
{new(map[string]int), "", `toml: line 1 (last key "F"): incompatible types: TOML value has type bool; destination has type integer`},
{new(int), "", `toml: cannot decode to type int`},
{new([]int), "", "toml: cannot decode to type []int"},
} {
t.Run(fmt.Sprintf("%T", tt.v), func(t *testing.T) {
_, err := Decode(`F = true`, tt.v)
if !errorContains(err, tt.want) {
t.Errorf("wrong error\nhave: %q\nwant: %q", err, tt.want)
if !errorContains(err, tt.wantErr) {
t.Fatalf("wrong error\nhave: %q\nwant: %q", err, tt.wantErr)
}

if err == nil {
have := fmt.Sprintf("%v", tt.v)
if n, ok := tt.v.(*NopUnmarshalTOML); ok {
have = fmt.Sprintf("%v", *n)
}
if have != tt.want {
t.Errorf("\nhave: %s\nwant: %s", have, tt.want)
}
}
})
}
Expand Down

0 comments on commit e2f6fa2

Please sign in to comment.