diff --git a/node/bindnode/example_test.go b/node/bindnode/example_test.go index 69f01b52..8c8f3fd6 100644 --- a/node/bindnode/example_test.go +++ b/node/bindnode/example_test.go @@ -17,7 +17,7 @@ func ExampleWrap_withSchema() { type Person struct { Name String Age optional Int - Friends [String] + Friends optional [String] } `)) if err != nil { @@ -27,8 +27,8 @@ func ExampleWrap_withSchema() { type Person struct { Name string - Age *int64 // optional - Friends []string + Age *int64 // optional + Friends []string // optional; no need for a pointer as slices are nilable } person := &Person{ Name: "Michael", diff --git a/node/bindnode/infer.go b/node/bindnode/infer.go index 22a31152..92f38b68 100644 --- a/node/bindnode/infer.go +++ b/node/bindnode/infer.go @@ -108,10 +108,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp } goType = goType.Elem() if schemaType.ValueIsNullable() { - if goType.Kind() != reflect.Ptr { - doPanic("nullable types must be pointers") + if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable { + doPanic("nullable types must be nilable") + } else if ptr { + goType = goType.Elem() } - goType = goType.Elem() } verifyCompatibility(seen, goType, schemaType.ValueType()) case *schema.TypeMap: @@ -141,10 +142,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp elemType := fieldValues.Type.Elem() if schemaType.ValueIsNullable() { - if elemType.Kind() != reflect.Ptr { - doPanic("nullable types must be pointers") + if ptr, nilable := ptrOrNilable(elemType.Kind()); !nilable { + doPanic("nullable types must be nilable") + } else if ptr { + elemType = elemType.Elem() } - elemType = elemType.Elem() } verifyCompatibility(seen, elemType, schemaType.ValueType()) case *schema.TypeStruct: @@ -159,18 +161,31 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp for i, schemaField := range schemaFields { schemaType := schemaField.Type() goType := goType.Field(i).Type - // TODO: allow "is nilable" to some degree? - if schemaField.IsNullable() { + switch { + case schemaField.IsOptional() && schemaField.IsNullable(): + // TODO: https://github.com/ipld/go-ipld-prime/issues/340 will + // help here, to avoid the double pointer. We can't use nilable + // but non-pointer types because that's just one "nil" state. if goType.Kind() != reflect.Ptr { - doPanic("nullable types must be pointers") + doPanic("optional and nullable fields must use double pointers (**)") } goType = goType.Elem() - } - if schemaField.IsOptional() { if goType.Kind() != reflect.Ptr { - doPanic("optional types must be pointers") + doPanic("optional and nullable fields must use double pointers (**)") } goType = goType.Elem() + case schemaField.IsOptional(): + if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable { + doPanic("optional fields must be nilable") + } else if ptr { + goType = goType.Elem() + } + case schemaField.IsNullable(): + if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable { + doPanic("nullable fields must be nilable") + } else if ptr { + goType = goType.Elem() + } } verifyCompatibility(seen, goType, schemaType) } @@ -186,10 +201,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp for i, schemaType := range schemaMembers { goType := goType.Field(i).Type - if goType.Kind() != reflect.Ptr { - doPanic("union members must be pointers") + if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable { + doPanic("union members must be nilable") + } else if ptr { + goType = goType.Elem() } - goType = goType.Elem() verifyCompatibility(seen, goType, schemaType) } case *schema.TypeLink: @@ -206,6 +222,17 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp } } +func ptrOrNilable(kind reflect.Kind) (ptr, nilable bool) { + switch kind { + case reflect.Ptr: + return true, true + case reflect.Interface, reflect.Map, reflect.Slice: + return false, true + default: + return false, false + } +} + // If we recurse past a large number of levels, we're mostly stuck in a loop. // Prevent burning CPU or causing OOM crashes. // If a user really wrote an IPLD schema or Go type with such deep nesting, diff --git a/node/bindnode/infer_test.go b/node/bindnode/infer_test.go index f2bbd6c3..057daf59 100644 --- a/node/bindnode/infer_test.go +++ b/node/bindnode/infer_test.go @@ -334,37 +334,64 @@ func TestPrototypePointerCombinations(t *testing.T) { })(nil), `{"x":3,"y":4}`}, } + // For each IPLD kind, we test a matrix of combinations for IPLD's optional + // and nullable fields alongside pointer usage on the Go field side. + modifiers := []struct { + schemaField string // "", "optional", "nullable", "optional nullable" + goPointers int // 0 (T), 1 (*T), 2 (**T) + }{ + {"", 0}, // regular IPLD field with Go's T + {"", 1}, // regular IPLD field with Go's *T + {"optional", 0}, // optional IPLD field with Go's T (skipped unless T is nilable) + {"optional", 1}, // optional IPLD field with Go's *T + {"nullable", 0}, // nullable IPLD field with Go's T (skipped unless T is nilable) + {"nullable", 1}, // nullable IPLD field with Go's *T + {"optional nullable", 2}, // optional and nullable IPLD field with Go's **T + } for _, kindTest := range kindTests { - for _, modifier := range []string{"", "optional", "nullable"} { + for _, modifier := range modifiers { // don't reuse range vars kindTest := kindTest modifier := modifier - t.Run(fmt.Sprintf("%s/%s", kindTest.name, modifier), func(t *testing.T) { + goFieldType := reflect.TypeOf(kindTest.fieldPtrType) + switch modifier.goPointers { + case 0: + goFieldType = goFieldType.Elem() // dereference fieldPtrType + case 1: + // fieldPtrType already uses one pointer + case 2: + goFieldType = reflect.PtrTo(goFieldType) // dereference fieldPtrType + } + if modifier.schemaField != "" && !nilable(goFieldType.Kind()) { + continue + } + t.Run(fmt.Sprintf("%s/%s-%dptr", kindTest.name, modifier.schemaField, modifier.goPointers), func(t *testing.T) { t.Parallel() var buf bytes.Buffer err := template.Must(template.New("").Parse(` - type Root struct { - field {{.Modifier}} {{.Type}} - }`)).Execute(&buf, struct { - Type, Modifier string - }{kindTest.schemaType, modifier}) + type Root struct { + field {{.Modifier}} {{.Type}} + }`)).Execute(&buf, + struct { + Type, Modifier string + }{kindTest.schemaType, modifier.schemaField}) qt.Assert(t, err, qt.IsNil) schemaSrc := buf.String() - t.Logf("IPLD schema: %T", schemaSrc) + t.Logf("IPLD schema: %s", schemaSrc) - // *struct { Field {{.fieldPtrType}} } - ptrType := reflect.Zero(reflect.PtrTo(reflect.StructOf([]reflect.StructField{ - {Name: "Field", Type: reflect.TypeOf(kindTest.fieldPtrType)}, + // *struct { Field {{.goFieldType}} } + goType := reflect.Zero(reflect.PtrTo(reflect.StructOf([]reflect.StructField{ + {Name: "Field", Type: goFieldType}, }))).Interface() - t.Logf("Go type: %T", ptrType) + t.Logf("Go type: %T", goType) ts, err := ipld.LoadSchemaBytes([]byte(schemaSrc)) qt.Assert(t, err, qt.IsNil) schemaType := ts.TypeByName("Root") qt.Assert(t, schemaType, qt.Not(qt.IsNil)) - proto := bindnode.Prototype(ptrType, schemaType) + proto := bindnode.Prototype(goType, schemaType) wantEncodedBytes, err := json.Marshal(map[string]interface{}{"field": json.RawMessage(kindTest.fieldDagJSON)}) qt.Assert(t, err, qt.IsNil) wantEncoded := string(wantEncodedBytes) @@ -377,26 +404,32 @@ func TestPrototypePointerCombinations(t *testing.T) { // Assigning with the missing field should only work with optional. nb := proto.NewBuilder() err = dagjson.Decode(nb, strings.NewReader(`{}`)) - if modifier == "optional" { + switch modifier.schemaField { + case "optional", "optional nullable": qt.Assert(t, err, qt.IsNil) node := nb.Build() // The resulting node should be non-nil with a nil field. nodeVal := reflect.ValueOf(bindnode.Unwrap(node)) qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue) - } else { + default: qt.Assert(t, err, qt.Not(qt.IsNil)) } // Assigning with a null field should only work with nullable. nb = proto.NewBuilder() err = dagjson.Decode(nb, strings.NewReader(`{"field":null}`)) - if modifier == "nullable" { + switch modifier.schemaField { + case "nullable", "optional nullable": qt.Assert(t, err, qt.IsNil) node := nb.Build() // The resulting node should be non-nil with a nil field. nodeVal := reflect.ValueOf(bindnode.Unwrap(node)) - qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue) - } else { + if modifier.schemaField == "nullable" { + qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue) + } else { + qt.Assert(t, nodeVal.Elem().FieldByName("Field").Elem().IsNil(), qt.IsTrue) + } + default: qt.Assert(t, err, qt.Not(qt.IsNil)) } }) @@ -404,6 +437,15 @@ func TestPrototypePointerCombinations(t *testing.T) { } } +func nilable(kind reflect.Kind) bool { + switch kind { + case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return true + default: + return false + } +} + func assembleAsKind(proto datamodel.NodePrototype, schemaType schema.Type, asKind datamodel.Kind) (ipld.Node, error) { nb := proto.NewBuilder() switch asKind { @@ -895,13 +937,13 @@ var verifyTests = []struct { Keys []string Values map[string]*datamodel.Node })(nil), + (*struct { + Keys []string + Values map[string]datamodel.Node + })(nil), }, badTypes: []verifyBadType{ {(*string)(nil), `.*type Root .* type string: kind mismatch;.*`}, - {(*struct { - Keys []string - Values map[string]datamodel.Node - })(nil), `.*type Root .*: nullable types must be pointers`}, }, }, { @@ -918,6 +960,10 @@ var verifyTests = []struct { List *[]string String *string })(nil), + (*struct { + List []string + String *string + })(nil), (*struct { List *[]namedString String *namedString @@ -927,9 +973,9 @@ var verifyTests = []struct { {(*string)(nil), `.*type Root .* type string: kind mismatch;.*`}, {(*struct{ List *[]string })(nil), `.*type Root .*: 1 vs 2 members`}, {(*struct { - List *[]string + List []string String string - })(nil), `.*type Root .*: union members must be pointers`}, + })(nil), `.*type Root .*: union members must be nilable`}, {(*struct { List *[]string String *int diff --git a/node/bindnode/node.go b/node/bindnode/node.go index 9bda3f63..83fe021c 100644 --- a/node/bindnode/node.go +++ b/node/bindnode/node.go @@ -149,13 +149,17 @@ func (w *_node) LookupByString(key string) (datamodel.Node, error) { if fval.IsNil() { return datamodel.Absent, nil } - fval = fval.Elem() + if fval.Kind() == reflect.Ptr { + fval = fval.Elem() + } } if field.IsNullable() { if fval.IsNil() { return datamodel.Null, nil } - fval = fval.Elem() + if fval.Kind() == reflect.Ptr { + fval = fval.Elem() + } } if _, ok := field.Type().(*schema.TypeAny); ok { return nonPtrVal(fval).Interface().(datamodel.Node), nil @@ -822,8 +826,14 @@ func (w *_structAssembler) AssembleValue() datamodel.NodeAssembler { w.doneFields[ftyp.Index[0]] = true fval := w.val.FieldByIndex(ftyp.Index) if field.IsOptional() { - fval.Set(reflect.New(fval.Type().Elem())) - fval = fval.Elem() + if fval.Kind() == reflect.Ptr { + // ptrVal = new(T); val = *ptrVal + fval.Set(reflect.New(fval.Type().Elem())) + fval = fval.Elem() + } else { + // val = *new(T) + fval.Set(reflect.New(fval.Type()).Elem()) + } } // TODO: reuse same assembler for perf? return &_assembler{ @@ -1087,13 +1097,17 @@ func (w *_structIterator) Next() (key, value datamodel.Node, _ error) { if val.IsNil() { return key, datamodel.Absent, nil } - val = val.Elem() + if val.Kind() == reflect.Ptr { + val = val.Elem() + } } if field.IsNullable() { if val.IsNil() { return key, datamodel.Null, nil } - val = val.Elem() + if val.Kind() == reflect.Ptr { + val = val.Elem() + } } if _, ok := field.Type().(*schema.TypeAny); ok { return key, nonPtrVal(val).Interface().(datamodel.Node), nil