From e69ddc72bd47603b2d2e16d5b4bf7c8094c00ce6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mart=C3=AD?= Date: Sun, 10 Apr 2022 12:16:31 +0100 Subject: [PATCH] node/bindnode: allow nilable types for IPLD optional/nullable For simplicity, bindnode used to always require a pointer to represent optional or nullable IPLD types. This is because we need an extra bit to store whether the value is absent or null. However, some Go types can already store a "nil" value to represent that bit without needing that extra pointer. Most notably: []T versus *[]T map[K]V versus *map[K]V datamodel.Node versus *datamodel.Node Avoiding the extra pointer makes the types easier for humans to deal with, and it also avoids a potential footgun due to the extra "nil" state that bindnode doesn't actually use. Note that we still require pointers for "optional nullable" struct fields, as those need two extra bits. A TODO is left for that edge case. Fixes #378. --- node/bindnode/example_test.go | 6 +-- node/bindnode/infer.go | 57 +++++++++++++++------ node/bindnode/infer_test.go | 94 ++++++++++++++++++++++++++--------- node/bindnode/node.go | 26 +++++++--- 4 files changed, 135 insertions(+), 48 deletions(-) diff --git a/node/bindnode/example_test.go b/node/bindnode/example_test.go index 69f01b52..8c8f3fd6 100644 --- a/node/bindnode/example_test.go +++ b/node/bindnode/example_test.go @@ -17,7 +17,7 @@ func ExampleWrap_withSchema() { type Person struct { Name String Age optional Int - Friends [String] + Friends optional [String] } `)) if err != nil { @@ -27,8 +27,8 @@ func ExampleWrap_withSchema() { type Person struct { Name string - Age *int64 // optional - Friends []string + Age *int64 // optional + Friends []string // optional; no need for a pointer as slices are nilable } person := &Person{ Name: "Michael", diff --git a/node/bindnode/infer.go b/node/bindnode/infer.go index 22a31152..92f38b68 100644 --- a/node/bindnode/infer.go +++ b/node/bindnode/infer.go @@ -108,10 +108,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp } goType = goType.Elem() if schemaType.ValueIsNullable() { - if goType.Kind() != reflect.Ptr { - doPanic("nullable types must be pointers") + if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable { + doPanic("nullable types must be nilable") + } else if ptr { + goType = goType.Elem() } - goType = goType.Elem() } verifyCompatibility(seen, goType, schemaType.ValueType()) case *schema.TypeMap: @@ -141,10 +142,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp elemType := fieldValues.Type.Elem() if schemaType.ValueIsNullable() { - if elemType.Kind() != reflect.Ptr { - doPanic("nullable types must be pointers") + if ptr, nilable := ptrOrNilable(elemType.Kind()); !nilable { + doPanic("nullable types must be nilable") + } else if ptr { + elemType = elemType.Elem() } - elemType = elemType.Elem() } verifyCompatibility(seen, elemType, schemaType.ValueType()) case *schema.TypeStruct: @@ -159,18 +161,31 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp for i, schemaField := range schemaFields { schemaType := schemaField.Type() goType := goType.Field(i).Type - // TODO: allow "is nilable" to some degree? - if schemaField.IsNullable() { + switch { + case schemaField.IsOptional() && schemaField.IsNullable(): + // TODO: https://github.com/ipld/go-ipld-prime/issues/340 will + // help here, to avoid the double pointer. We can't use nilable + // but non-pointer types because that's just one "nil" state. if goType.Kind() != reflect.Ptr { - doPanic("nullable types must be pointers") + doPanic("optional and nullable fields must use double pointers (**)") } goType = goType.Elem() - } - if schemaField.IsOptional() { if goType.Kind() != reflect.Ptr { - doPanic("optional types must be pointers") + doPanic("optional and nullable fields must use double pointers (**)") } goType = goType.Elem() + case schemaField.IsOptional(): + if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable { + doPanic("optional fields must be nilable") + } else if ptr { + goType = goType.Elem() + } + case schemaField.IsNullable(): + if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable { + doPanic("nullable fields must be nilable") + } else if ptr { + goType = goType.Elem() + } } verifyCompatibility(seen, goType, schemaType) } @@ -186,10 +201,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp for i, schemaType := range schemaMembers { goType := goType.Field(i).Type - if goType.Kind() != reflect.Ptr { - doPanic("union members must be pointers") + if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable { + doPanic("union members must be nilable") + } else if ptr { + goType = goType.Elem() } - goType = goType.Elem() verifyCompatibility(seen, goType, schemaType) } case *schema.TypeLink: @@ -206,6 +222,17 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp } } +func ptrOrNilable(kind reflect.Kind) (ptr, nilable bool) { + switch kind { + case reflect.Ptr: + return true, true + case reflect.Interface, reflect.Map, reflect.Slice: + return false, true + default: + return false, false + } +} + // If we recurse past a large number of levels, we're mostly stuck in a loop. // Prevent burning CPU or causing OOM crashes. // If a user really wrote an IPLD schema or Go type with such deep nesting, diff --git a/node/bindnode/infer_test.go b/node/bindnode/infer_test.go index f2bbd6c3..057daf59 100644 --- a/node/bindnode/infer_test.go +++ b/node/bindnode/infer_test.go @@ -334,37 +334,64 @@ func TestPrototypePointerCombinations(t *testing.T) { })(nil), `{"x":3,"y":4}`}, } + // For each IPLD kind, we test a matrix of combinations for IPLD's optional + // and nullable fields alongside pointer usage on the Go field side. + modifiers := []struct { + schemaField string // "", "optional", "nullable", "optional nullable" + goPointers int // 0 (T), 1 (*T), 2 (**T) + }{ + {"", 0}, // regular IPLD field with Go's T + {"", 1}, // regular IPLD field with Go's *T + {"optional", 0}, // optional IPLD field with Go's T (skipped unless T is nilable) + {"optional", 1}, // optional IPLD field with Go's *T + {"nullable", 0}, // nullable IPLD field with Go's T (skipped unless T is nilable) + {"nullable", 1}, // nullable IPLD field with Go's *T + {"optional nullable", 2}, // optional and nullable IPLD field with Go's **T + } for _, kindTest := range kindTests { - for _, modifier := range []string{"", "optional", "nullable"} { + for _, modifier := range modifiers { // don't reuse range vars kindTest := kindTest modifier := modifier - t.Run(fmt.Sprintf("%s/%s", kindTest.name, modifier), func(t *testing.T) { + goFieldType := reflect.TypeOf(kindTest.fieldPtrType) + switch modifier.goPointers { + case 0: + goFieldType = goFieldType.Elem() // dereference fieldPtrType + case 1: + // fieldPtrType already uses one pointer + case 2: + goFieldType = reflect.PtrTo(goFieldType) // dereference fieldPtrType + } + if modifier.schemaField != "" && !nilable(goFieldType.Kind()) { + continue + } + t.Run(fmt.Sprintf("%s/%s-%dptr", kindTest.name, modifier.schemaField, modifier.goPointers), func(t *testing.T) { t.Parallel() var buf bytes.Buffer err := template.Must(template.New("").Parse(` - type Root struct { - field {{.Modifier}} {{.Type}} - }`)).Execute(&buf, struct { - Type, Modifier string - }{kindTest.schemaType, modifier}) + type Root struct { + field {{.Modifier}} {{.Type}} + }`)).Execute(&buf, + struct { + Type, Modifier string + }{kindTest.schemaType, modifier.schemaField}) qt.Assert(t, err, qt.IsNil) schemaSrc := buf.String() - t.Logf("IPLD schema: %T", schemaSrc) + t.Logf("IPLD schema: %s", schemaSrc) - // *struct { Field {{.fieldPtrType}} } - ptrType := reflect.Zero(reflect.PtrTo(reflect.StructOf([]reflect.StructField{ - {Name: "Field", Type: reflect.TypeOf(kindTest.fieldPtrType)}, + // *struct { Field {{.goFieldType}} } + goType := reflect.Zero(reflect.PtrTo(reflect.StructOf([]reflect.StructField{ + {Name: "Field", Type: goFieldType}, }))).Interface() - t.Logf("Go type: %T", ptrType) + t.Logf("Go type: %T", goType) ts, err := ipld.LoadSchemaBytes([]byte(schemaSrc)) qt.Assert(t, err, qt.IsNil) schemaType := ts.TypeByName("Root") qt.Assert(t, schemaType, qt.Not(qt.IsNil)) - proto := bindnode.Prototype(ptrType, schemaType) + proto := bindnode.Prototype(goType, schemaType) wantEncodedBytes, err := json.Marshal(map[string]interface{}{"field": json.RawMessage(kindTest.fieldDagJSON)}) qt.Assert(t, err, qt.IsNil) wantEncoded := string(wantEncodedBytes) @@ -377,26 +404,32 @@ func TestPrototypePointerCombinations(t *testing.T) { // Assigning with the missing field should only work with optional. nb := proto.NewBuilder() err = dagjson.Decode(nb, strings.NewReader(`{}`)) - if modifier == "optional" { + switch modifier.schemaField { + case "optional", "optional nullable": qt.Assert(t, err, qt.IsNil) node := nb.Build() // The resulting node should be non-nil with a nil field. nodeVal := reflect.ValueOf(bindnode.Unwrap(node)) qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue) - } else { + default: qt.Assert(t, err, qt.Not(qt.IsNil)) } // Assigning with a null field should only work with nullable. nb = proto.NewBuilder() err = dagjson.Decode(nb, strings.NewReader(`{"field":null}`)) - if modifier == "nullable" { + switch modifier.schemaField { + case "nullable", "optional nullable": qt.Assert(t, err, qt.IsNil) node := nb.Build() // The resulting node should be non-nil with a nil field. nodeVal := reflect.ValueOf(bindnode.Unwrap(node)) - qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue) - } else { + if modifier.schemaField == "nullable" { + qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue) + } else { + qt.Assert(t, nodeVal.Elem().FieldByName("Field").Elem().IsNil(), qt.IsTrue) + } + default: qt.Assert(t, err, qt.Not(qt.IsNil)) } }) @@ -404,6 +437,15 @@ func TestPrototypePointerCombinations(t *testing.T) { } } +func nilable(kind reflect.Kind) bool { + switch kind { + case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return true + default: + return false + } +} + func assembleAsKind(proto datamodel.NodePrototype, schemaType schema.Type, asKind datamodel.Kind) (ipld.Node, error) { nb := proto.NewBuilder() switch asKind { @@ -895,13 +937,13 @@ var verifyTests = []struct { Keys []string Values map[string]*datamodel.Node })(nil), + (*struct { + Keys []string + Values map[string]datamodel.Node + })(nil), }, badTypes: []verifyBadType{ {(*string)(nil), `.*type Root .* type string: kind mismatch;.*`}, - {(*struct { - Keys []string - Values map[string]datamodel.Node - })(nil), `.*type Root .*: nullable types must be pointers`}, }, }, { @@ -918,6 +960,10 @@ var verifyTests = []struct { List *[]string String *string })(nil), + (*struct { + List []string + String *string + })(nil), (*struct { List *[]namedString String *namedString @@ -927,9 +973,9 @@ var verifyTests = []struct { {(*string)(nil), `.*type Root .* type string: kind mismatch;.*`}, {(*struct{ List *[]string })(nil), `.*type Root .*: 1 vs 2 members`}, {(*struct { - List *[]string + List []string String string - })(nil), `.*type Root .*: union members must be pointers`}, + })(nil), `.*type Root .*: union members must be nilable`}, {(*struct { List *[]string String *int diff --git a/node/bindnode/node.go b/node/bindnode/node.go index 9bda3f63..83fe021c 100644 --- a/node/bindnode/node.go +++ b/node/bindnode/node.go @@ -149,13 +149,17 @@ func (w *_node) LookupByString(key string) (datamodel.Node, error) { if fval.IsNil() { return datamodel.Absent, nil } - fval = fval.Elem() + if fval.Kind() == reflect.Ptr { + fval = fval.Elem() + } } if field.IsNullable() { if fval.IsNil() { return datamodel.Null, nil } - fval = fval.Elem() + if fval.Kind() == reflect.Ptr { + fval = fval.Elem() + } } if _, ok := field.Type().(*schema.TypeAny); ok { return nonPtrVal(fval).Interface().(datamodel.Node), nil @@ -822,8 +826,14 @@ func (w *_structAssembler) AssembleValue() datamodel.NodeAssembler { w.doneFields[ftyp.Index[0]] = true fval := w.val.FieldByIndex(ftyp.Index) if field.IsOptional() { - fval.Set(reflect.New(fval.Type().Elem())) - fval = fval.Elem() + if fval.Kind() == reflect.Ptr { + // ptrVal = new(T); val = *ptrVal + fval.Set(reflect.New(fval.Type().Elem())) + fval = fval.Elem() + } else { + // val = *new(T) + fval.Set(reflect.New(fval.Type()).Elem()) + } } // TODO: reuse same assembler for perf? return &_assembler{ @@ -1087,13 +1097,17 @@ func (w *_structIterator) Next() (key, value datamodel.Node, _ error) { if val.IsNil() { return key, datamodel.Absent, nil } - val = val.Elem() + if val.Kind() == reflect.Ptr { + val = val.Elem() + } } if field.IsNullable() { if val.IsNil() { return key, datamodel.Null, nil } - val = val.Elem() + if val.Kind() == reflect.Ptr { + val = val.Elem() + } } if _, ok := field.Type().(*schema.TypeAny); ok { return key, nonPtrVal(val).Interface().(datamodel.Node), nil