diff --git a/mapstructure.go b/mapstructure.go index 05bc140..e77e63b 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -442,21 +442,26 @@ func (d *Decoder) Decode(input interface{}) error { return err } +// isNil returns true if the input is nil or a typed nil pointer. +func isNil(input interface{}) bool { + if input == nil { + return true + } + val := reflect.ValueOf(input) + return val.Kind() == reflect.Ptr && val.IsNil() +} + // Decodes an unknown data type into a specific reflection value. func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) error { - var inputVal reflect.Value - if input != nil { - inputVal = reflect.ValueOf(input) - - // We need to check here if input is a typed nil. Typed nils won't - // match the "input == nil" below so we check that here. - if inputVal.Kind() == reflect.Ptr && inputVal.IsNil() { - input = nil - } + var ( + inputVal = reflect.ValueOf(input) + outputKind = getKind(outVal) + decodeNil = d.config.DecodeNil && d.cachedDecodeHook != nil + ) + if isNil(input) { + // Typed nils won't match the "input == nil" below, so reset input. + input = nil } - - decodeNil := d.config.DecodeNil && d.config.DecodeHook != nil - if input == nil { // If the data is nil, then we don't set anything, unless ZeroFields is set // to true. @@ -467,12 +472,10 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e d.config.Metadata.Keys = append(d.config.Metadata.Keys, name) } } - if !decodeNil { return nil } } - if !inputVal.IsValid() { if !decodeNil { // If the input value is invalid, then we just set the value @@ -483,11 +486,17 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e } return nil } - - // If we get here, we have an untyped nil so the type of the input is assumed. - // We do this because all subsequent code requires a valid value for inputVal. - var mapVal map[string]interface{} - inputVal = reflect.MakeMap(reflect.TypeOf(mapVal)) + // Hooks need a valid inputVal, so reset it to zero value of outVal type. + switch outputKind { + case reflect.Struct, reflect.Map: + var mapVal map[string]interface{} + inputVal = reflect.ValueOf(mapVal) // create nil map pointer + case reflect.Slice, reflect.Array: + var sliceVal []interface{} + inputVal = reflect.ValueOf(sliceVal) // create nil slice pointer + default: + inputVal = reflect.Zero(outVal.Type()) + } } if d.cachedDecodeHook != nil { @@ -498,9 +507,11 @@ func (d *Decoder) decode(name string, input interface{}, outVal reflect.Value) e return fmt.Errorf("error decoding '%s': %w", name, err) } } + if isNil(input) { + return nil + } var err error - outputKind := getKind(outVal) addMetaKey := true switch outputKind { case reflect.Bool: @@ -781,8 +792,8 @@ func (d *Decoder) decodeBool(name string, data interface{}, val reflect.Value) e } default: return fmt.Errorf( - "'%s' expected type '%s', got unconvertible type '%s', value: '%v'", - name, val.Type(), dataVal.Type(), data) + "'%s' expected type '%s', got unconvertible type '%#v', value: '%#v'", + name, val, dataVal, data) } return nil diff --git a/mapstructure_test.go b/mapstructure_test.go index e30ff47..519e722 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -3083,7 +3083,7 @@ func TestDecoder_IgnoreUntaggedFieldsWithStruct(t *testing.T) { } } -func TestDecoder_CanPerformDecodingForNilInputs(t *testing.T) { +func TestDecoder_DecodeNilOption(t *testing.T) { t.Parallel() type Transformed struct { @@ -3100,6 +3100,9 @@ func TestDecoder_CanPerformDecodingForNilInputs(t *testing.T) { appendHook := func(from reflect.Value, to reflect.Value) (interface{}, error) { if from.Kind() == reflect.Map { stringMap := from.Interface().(map[string]interface{}) + if stringMap == nil { + stringMap = make(map[string]interface{}) + } stringMap["when"] = "see you later" return stringMap, nil } @@ -3248,6 +3251,67 @@ func TestDecoder_CanPerformDecodingForNilInputs(t *testing.T) { } } +func TestDecoder_ExpandNilStructPointersHookFunc(t *testing.T) { + // a decoder hook that expands nil pointers in a struct to their zero value + // if the input map contains the corresponding key. + decodeHook := func(from reflect.Value, to reflect.Value) (any, error) { + if from.Kind() == reflect.Map && to.Kind() == reflect.Map { + toElem := to.Type().Elem() + if toElem.Kind() == reflect.Ptr && toElem.Elem().Kind() == reflect.Struct { + fromRange := from.MapRange() + for fromRange.Next() { + fromKey := fromRange.Key() + fromValue := fromRange.Value() + if fromValue.IsNil() { + newFromValue := reflect.New(toElem.Elem()) + from.SetMapIndex(fromKey, newFromValue) + } + } + } + } + return from.Interface(), nil + } + type Struct struct { + Name string + } + type TestConfig struct { + Boolean *bool `mapstructure:"boolean"` + Struct *Struct `mapstructure:"struct"` + MapStruct map[string]*Struct `mapstructure:"map_struct"` + } + stringMap := map[string]any{ + "boolean": nil, + "struct": nil, + "map_struct": map[string]any{ + "struct": nil, + }, + } + var result TestConfig + decoder, err := NewDecoder(&DecoderConfig{ + Result: &result, + DecodeNil: true, + DecodeHook: decodeHook, + }) + if err != nil { + t.Fatalf("err: %s", err) + } + if err := decoder.Decode(stringMap); err != nil { + t.Fatalf("got an err: %s", err) + } + if result.Boolean != nil { + t.Errorf("nil Boolean expected, got '%#v'", result.Boolean) + } + if result.Struct != nil { + t.Errorf("nil Struct expected, got '%#v'", result.Struct) + } + if len(result.MapStruct) == 0 { + t.Fatalf("not-empty MapStruct expected, got '%#v'", result.MapStruct) + } + if _, ok := result.MapStruct["struct"]; !ok { + t.Errorf("MapStruct['struct'] expected") + } +} + func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) { var result Slice err := Decode(input, &result)