Skip to content

Commit

Permalink
Avoid infinite recursion when normalizing a recursive type (#1213)
Browse files Browse the repository at this point in the history
## Changes

This is a follow-up to #1211 prompted by the addition of a recursive
type in the Go SDK v0.31.0 (`jobs.ForEachTask`).

When populating missing fields with their zero values we must not
inadvertently recurse into a recursive type.

## Tests

New unit test fails with a stack overflow if the fix if the check is
disabled.
  • Loading branch information
pietern authored Feb 16, 2024
1 parent 788ec81 commit ea8daf1
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 17 deletions.
41 changes: 24 additions & 17 deletions libs/dyn/convert/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package convert
import (
"fmt"
"reflect"
"slices"
"strconv"

"github.com/databricks/cli/libs/diag"
Expand Down Expand Up @@ -31,21 +32,21 @@ func Normalize(dst any, src dyn.Value, opts ...NormalizeOption) (dyn.Value, diag
}
}

return n.normalizeType(reflect.TypeOf(dst), src)
return n.normalizeType(reflect.TypeOf(dst), src, []reflect.Type{})
}

func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
for typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}

switch typ.Kind() {
case reflect.Struct:
return n.normalizeStruct(typ, src)
return n.normalizeStruct(typ, src, append(seen, typ))
case reflect.Map:
return n.normalizeMap(typ, src)
return n.normalizeMap(typ, src, append(seen, typ))
case reflect.Slice:
return n.normalizeSlice(typ, src)
return n.normalizeSlice(typ, src, append(seen, typ))
case reflect.String:
return n.normalizeString(typ, src)
case reflect.Bool:
Expand All @@ -67,7 +68,7 @@ func typeMismatch(expected dyn.Kind, src dyn.Value) diag.Diagnostic {
}
}

func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics

switch src.Kind() {
Expand All @@ -86,7 +87,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
}

// Normalize the value according to the field type.
v, err := n.normalizeType(typ.FieldByIndex(index).Type, v)
v, err := n.normalizeType(typ.FieldByIndex(index).Type, v, seen)
if err != nil {
diags = diags.Extend(err)
// Skip the element if it cannot be normalized.
Expand Down Expand Up @@ -115,20 +116,26 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
ftyp = ftyp.Elem()
}

// Skip field if we have already seen its type to avoid infinite recursion
// when filling in the zero value of a recursive type.
if slices.Contains(seen, ftyp) {
continue
}

var v dyn.Value
switch ftyp.Kind() {
case reflect.Struct, reflect.Map:
v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{}))
v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{}), seen)
case reflect.Slice:
v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{}))
v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{}), seen)
case reflect.String:
v, _ = n.normalizeType(ftyp, dyn.V(""))
v, _ = n.normalizeType(ftyp, dyn.V(""), seen)
case reflect.Bool:
v, _ = n.normalizeType(ftyp, dyn.V(false))
v, _ = n.normalizeType(ftyp, dyn.V(false), seen)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v, _ = n.normalizeType(ftyp, dyn.V(int64(0)))
v, _ = n.normalizeType(ftyp, dyn.V(int64(0)), seen)
case reflect.Float32, reflect.Float64:
v, _ = n.normalizeType(ftyp, dyn.V(float64(0)))
v, _ = n.normalizeType(ftyp, dyn.V(float64(0)), seen)
default:
// Skip fields for which we do not have a natural [dyn.Value] equivalent.
// For example, we don't handle reflect.Complex* and reflect.Uint* types.
Expand All @@ -147,15 +154,15 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
}

func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics

switch src.Kind() {
case dyn.KindMap:
out := make(map[string]dyn.Value)
for k, v := range src.MustMap() {
// Normalize the value according to the map element type.
v, err := n.normalizeType(typ.Elem(), v)
v, err := n.normalizeType(typ.Elem(), v, seen)
if err != nil {
diags = diags.Extend(err)
// Skip the element if it cannot be normalized.
Expand All @@ -175,15 +182,15 @@ func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Val
return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src))
}

func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) {
func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) {
var diags diag.Diagnostics

switch src.Kind() {
case dyn.KindSequence:
out := make([]dyn.Value, 0, len(src.MustSequence()))
for _, v := range src.MustSequence() {
// Normalize the value according to the slice element type.
v, err := n.normalizeType(typ.Elem(), v)
v, err := n.normalizeType(typ.Elem(), v, seen)
if err != nil {
diags = diags.Extend(err)
// Skip the element if it cannot be normalized.
Expand Down
31 changes: 31 additions & 0 deletions libs/dyn/convert/normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,37 @@ func TestNormalizeStructIncludeMissingFields(t *testing.T) {
}), vout)
}

func TestNormalizeStructIncludeMissingFieldsOnRecursiveType(t *testing.T) {
type Tmp struct {
// Verify that structs are recursively normalized if not set.
Ptr *Tmp `json:"ptr"`

// Verify that primitive types are zero-initialized if not set.
String string `json:"string"`
}

var typ Tmp
vin := dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
"string": dyn.V("already set"),
}),
}),
})
vout, err := Normalize(typ, vin, IncludeMissingFields)
assert.Empty(t, err)
assert.Equal(t, dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
"ptr": dyn.V(map[string]dyn.Value{
// Note: the ptr field is not zero-initialized because that would recurse.
"string": dyn.V("already set"),
}),
"string": dyn.V(""),
}),
"string": dyn.V(""),
}), vout)
}

func TestNormalizeMap(t *testing.T) {
var typ map[string]string
vin := dyn.V(map[string]dyn.Value{
Expand Down

0 comments on commit ea8daf1

Please sign in to comment.