From ea8daf1f97d66ec88d72fd7b9a7d286a80d68f90 Mon Sep 17 00:00:00 2001 From: Pieter Noordhuis Date: Fri, 16 Feb 2024 13:56:02 +0100 Subject: [PATCH] Avoid infinite recursion when normalizing a recursive type (#1213) ## Changes This is a follow-up to #1211 prompted by the addition of a recursive type in the Go SDK v0.31.0 (`jobs.ForEachTask`). When populating missing fields with their zero values we must not inadvertently recurse into a recursive type. ## Tests New unit test fails with a stack overflow if the fix if the check is disabled. --- libs/dyn/convert/normalize.go | 41 +++++++++++++++++------------- libs/dyn/convert/normalize_test.go | 31 ++++++++++++++++++++++ 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/libs/dyn/convert/normalize.go b/libs/dyn/convert/normalize.go index 26df09578d..e0dfbda23d 100644 --- a/libs/dyn/convert/normalize.go +++ b/libs/dyn/convert/normalize.go @@ -3,6 +3,7 @@ package convert import ( "fmt" "reflect" + "slices" "strconv" "github.com/databricks/cli/libs/diag" @@ -31,21 +32,21 @@ func Normalize(dst any, src dyn.Value, opts ...NormalizeOption) (dyn.Value, diag } } - return n.normalizeType(reflect.TypeOf(dst), src) + return n.normalizeType(reflect.TypeOf(dst), src, []reflect.Type{}) } -func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) { +func (n normalizeOptions) normalizeType(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) { for typ.Kind() == reflect.Pointer { typ = typ.Elem() } switch typ.Kind() { case reflect.Struct: - return n.normalizeStruct(typ, src) + return n.normalizeStruct(typ, src, append(seen, typ)) case reflect.Map: - return n.normalizeMap(typ, src) + return n.normalizeMap(typ, src, append(seen, typ)) case reflect.Slice: - return n.normalizeSlice(typ, src) + return n.normalizeSlice(typ, src, append(seen, typ)) case reflect.String: return n.normalizeString(typ, src) case reflect.Bool: @@ -67,7 +68,7 @@ func typeMismatch(expected dyn.Kind, src dyn.Value) diag.Diagnostic { } } -func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) { +func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) { var diags diag.Diagnostics switch src.Kind() { @@ -86,7 +87,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn. } // Normalize the value according to the field type. - v, err := n.normalizeType(typ.FieldByIndex(index).Type, v) + v, err := n.normalizeType(typ.FieldByIndex(index).Type, v, seen) if err != nil { diags = diags.Extend(err) // Skip the element if it cannot be normalized. @@ -115,20 +116,26 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn. ftyp = ftyp.Elem() } + // Skip field if we have already seen its type to avoid infinite recursion + // when filling in the zero value of a recursive type. + if slices.Contains(seen, ftyp) { + continue + } + var v dyn.Value switch ftyp.Kind() { case reflect.Struct, reflect.Map: - v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{})) + v, _ = n.normalizeType(ftyp, dyn.V(map[string]dyn.Value{}), seen) case reflect.Slice: - v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{})) + v, _ = n.normalizeType(ftyp, dyn.V([]dyn.Value{}), seen) case reflect.String: - v, _ = n.normalizeType(ftyp, dyn.V("")) + v, _ = n.normalizeType(ftyp, dyn.V(""), seen) case reflect.Bool: - v, _ = n.normalizeType(ftyp, dyn.V(false)) + v, _ = n.normalizeType(ftyp, dyn.V(false), seen) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v, _ = n.normalizeType(ftyp, dyn.V(int64(0))) + v, _ = n.normalizeType(ftyp, dyn.V(int64(0)), seen) case reflect.Float32, reflect.Float64: - v, _ = n.normalizeType(ftyp, dyn.V(float64(0))) + v, _ = n.normalizeType(ftyp, dyn.V(float64(0)), seen) default: // Skip fields for which we do not have a natural [dyn.Value] equivalent. // For example, we don't handle reflect.Complex* and reflect.Uint* types. @@ -147,7 +154,7 @@ func (n normalizeOptions) normalizeStruct(typ reflect.Type, src dyn.Value) (dyn. return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src)) } -func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) { +func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) { var diags diag.Diagnostics switch src.Kind() { @@ -155,7 +162,7 @@ func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Val out := make(map[string]dyn.Value) for k, v := range src.MustMap() { // Normalize the value according to the map element type. - v, err := n.normalizeType(typ.Elem(), v) + v, err := n.normalizeType(typ.Elem(), v, seen) if err != nil { diags = diags.Extend(err) // Skip the element if it cannot be normalized. @@ -175,7 +182,7 @@ func (n normalizeOptions) normalizeMap(typ reflect.Type, src dyn.Value) (dyn.Val return dyn.InvalidValue, diags.Append(typeMismatch(dyn.KindMap, src)) } -func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.Value, diag.Diagnostics) { +func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value, seen []reflect.Type) (dyn.Value, diag.Diagnostics) { var diags diag.Diagnostics switch src.Kind() { @@ -183,7 +190,7 @@ func (n normalizeOptions) normalizeSlice(typ reflect.Type, src dyn.Value) (dyn.V out := make([]dyn.Value, 0, len(src.MustSequence())) for _, v := range src.MustSequence() { // Normalize the value according to the slice element type. - v, err := n.normalizeType(typ.Elem(), v) + v, err := n.normalizeType(typ.Elem(), v, seen) if err != nil { diags = diags.Extend(err) // Skip the element if it cannot be normalized. diff --git a/libs/dyn/convert/normalize_test.go b/libs/dyn/convert/normalize_test.go index d59cc3b351..82abc82600 100644 --- a/libs/dyn/convert/normalize_test.go +++ b/libs/dyn/convert/normalize_test.go @@ -189,6 +189,37 @@ func TestNormalizeStructIncludeMissingFields(t *testing.T) { }), vout) } +func TestNormalizeStructIncludeMissingFieldsOnRecursiveType(t *testing.T) { + type Tmp struct { + // Verify that structs are recursively normalized if not set. + Ptr *Tmp `json:"ptr"` + + // Verify that primitive types are zero-initialized if not set. + String string `json:"string"` + } + + var typ Tmp + vin := dyn.V(map[string]dyn.Value{ + "ptr": dyn.V(map[string]dyn.Value{ + "ptr": dyn.V(map[string]dyn.Value{ + "string": dyn.V("already set"), + }), + }), + }) + vout, err := Normalize(typ, vin, IncludeMissingFields) + assert.Empty(t, err) + assert.Equal(t, dyn.V(map[string]dyn.Value{ + "ptr": dyn.V(map[string]dyn.Value{ + "ptr": dyn.V(map[string]dyn.Value{ + // Note: the ptr field is not zero-initialized because that would recurse. + "string": dyn.V("already set"), + }), + "string": dyn.V(""), + }), + "string": dyn.V(""), + }), vout) +} + func TestNormalizeMap(t *testing.T) { var typ map[string]string vin := dyn.V(map[string]dyn.Value{