Skip to content

Commit

Permalink
node/bindnode: allow nilable types for IPLD optional/nullable
Browse files Browse the repository at this point in the history
For simplicity, bindnode used to always require a pointer to represent
optional or nullable IPLD types. This is because we need an extra bit to
store whether the value is absent or null.

However, some Go types can already store a "nil" value to represent that
bit without needing that extra pointer. Most notably:

	[]T            versus *[]T
	map[K]V        versus *map[K]V
	datamodel.Node versus *datamodel.Node

Avoiding the extra pointer makes the types easier for humans to deal
with, and it also avoids a potential footgun due to the extra "nil"
state that bindnode doesn't actually use.

Note that we still require pointers for "optional nullable" struct
fields, as those need two extra bits. A TODO is left for that edge case.

Fixes #378.
  • Loading branch information
mvdan authored and rvagg committed Apr 26, 2022
1 parent 9d19e3c commit e69ddc7
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 48 deletions.
6 changes: 3 additions & 3 deletions node/bindnode/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func ExampleWrap_withSchema() {
type Person struct {
Name String
Age optional Int
Friends [String]
Friends optional [String]
}
`))
if err != nil {
Expand All @@ -27,8 +27,8 @@ func ExampleWrap_withSchema() {

type Person struct {
Name string
Age *int64 // optional
Friends []string
Age *int64 // optional
Friends []string // optional; no need for a pointer as slices are nilable
}
person := &Person{
Name: "Michael",
Expand Down
57 changes: 42 additions & 15 deletions node/bindnode/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp
}
goType = goType.Elem()
if schemaType.ValueIsNullable() {
if goType.Kind() != reflect.Ptr {
doPanic("nullable types must be pointers")
if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
doPanic("nullable types must be nilable")
} else if ptr {
goType = goType.Elem()
}
goType = goType.Elem()
}
verifyCompatibility(seen, goType, schemaType.ValueType())
case *schema.TypeMap:
Expand Down Expand Up @@ -141,10 +142,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp

elemType := fieldValues.Type.Elem()
if schemaType.ValueIsNullable() {
if elemType.Kind() != reflect.Ptr {
doPanic("nullable types must be pointers")
if ptr, nilable := ptrOrNilable(elemType.Kind()); !nilable {
doPanic("nullable types must be nilable")
} else if ptr {
elemType = elemType.Elem()
}
elemType = elemType.Elem()
}
verifyCompatibility(seen, elemType, schemaType.ValueType())
case *schema.TypeStruct:
Expand All @@ -159,18 +161,31 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp
for i, schemaField := range schemaFields {
schemaType := schemaField.Type()
goType := goType.Field(i).Type
// TODO: allow "is nilable" to some degree?
if schemaField.IsNullable() {
switch {
case schemaField.IsOptional() && schemaField.IsNullable():
// TODO: https://github.com/ipld/go-ipld-prime/issues/340 will
// help here, to avoid the double pointer. We can't use nilable
// but non-pointer types because that's just one "nil" state.
if goType.Kind() != reflect.Ptr {
doPanic("nullable types must be pointers")
doPanic("optional and nullable fields must use double pointers (**)")
}
goType = goType.Elem()
}
if schemaField.IsOptional() {
if goType.Kind() != reflect.Ptr {
doPanic("optional types must be pointers")
doPanic("optional and nullable fields must use double pointers (**)")
}
goType = goType.Elem()
case schemaField.IsOptional():
if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
doPanic("optional fields must be nilable")
} else if ptr {
goType = goType.Elem()
}
case schemaField.IsNullable():
if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
doPanic("nullable fields must be nilable")
} else if ptr {
goType = goType.Elem()
}
}
verifyCompatibility(seen, goType, schemaType)
}
Expand All @@ -186,10 +201,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp

for i, schemaType := range schemaMembers {
goType := goType.Field(i).Type
if goType.Kind() != reflect.Ptr {
doPanic("union members must be pointers")
if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
doPanic("union members must be nilable")
} else if ptr {
goType = goType.Elem()
}
goType = goType.Elem()
verifyCompatibility(seen, goType, schemaType)
}
case *schema.TypeLink:
Expand All @@ -206,6 +222,17 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp
}
}

func ptrOrNilable(kind reflect.Kind) (ptr, nilable bool) {
switch kind {
case reflect.Ptr:
return true, true
case reflect.Interface, reflect.Map, reflect.Slice:
return false, true
default:
return false, false
}
}

// If we recurse past a large number of levels, we're mostly stuck in a loop.
// Prevent burning CPU or causing OOM crashes.
// If a user really wrote an IPLD schema or Go type with such deep nesting,
Expand Down
94 changes: 70 additions & 24 deletions node/bindnode/infer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,37 +334,64 @@ func TestPrototypePointerCombinations(t *testing.T) {
})(nil), `{"x":3,"y":4}`},
}

// For each IPLD kind, we test a matrix of combinations for IPLD's optional
// and nullable fields alongside pointer usage on the Go field side.
modifiers := []struct {
schemaField string // "", "optional", "nullable", "optional nullable"
goPointers int // 0 (T), 1 (*T), 2 (**T)
}{
{"", 0}, // regular IPLD field with Go's T
{"", 1}, // regular IPLD field with Go's *T
{"optional", 0}, // optional IPLD field with Go's T (skipped unless T is nilable)
{"optional", 1}, // optional IPLD field with Go's *T
{"nullable", 0}, // nullable IPLD field with Go's T (skipped unless T is nilable)
{"nullable", 1}, // nullable IPLD field with Go's *T
{"optional nullable", 2}, // optional and nullable IPLD field with Go's **T
}
for _, kindTest := range kindTests {
for _, modifier := range []string{"", "optional", "nullable"} {
for _, modifier := range modifiers {
// don't reuse range vars
kindTest := kindTest
modifier := modifier
t.Run(fmt.Sprintf("%s/%s", kindTest.name, modifier), func(t *testing.T) {
goFieldType := reflect.TypeOf(kindTest.fieldPtrType)
switch modifier.goPointers {
case 0:
goFieldType = goFieldType.Elem() // dereference fieldPtrType
case 1:
// fieldPtrType already uses one pointer
case 2:
goFieldType = reflect.PtrTo(goFieldType) // dereference fieldPtrType
}
if modifier.schemaField != "" && !nilable(goFieldType.Kind()) {
continue
}
t.Run(fmt.Sprintf("%s/%s-%dptr", kindTest.name, modifier.schemaField, modifier.goPointers), func(t *testing.T) {
t.Parallel()

var buf bytes.Buffer
err := template.Must(template.New("").Parse(`
type Root struct {
field {{.Modifier}} {{.Type}}
}`)).Execute(&buf, struct {
Type, Modifier string
}{kindTest.schemaType, modifier})
type Root struct {
field {{.Modifier}} {{.Type}}
}`)).Execute(&buf,
struct {
Type, Modifier string
}{kindTest.schemaType, modifier.schemaField})
qt.Assert(t, err, qt.IsNil)
schemaSrc := buf.String()
t.Logf("IPLD schema: %T", schemaSrc)
t.Logf("IPLD schema: %s", schemaSrc)

// *struct { Field {{.fieldPtrType}} }
ptrType := reflect.Zero(reflect.PtrTo(reflect.StructOf([]reflect.StructField{
{Name: "Field", Type: reflect.TypeOf(kindTest.fieldPtrType)},
// *struct { Field {{.goFieldType}} }
goType := reflect.Zero(reflect.PtrTo(reflect.StructOf([]reflect.StructField{
{Name: "Field", Type: goFieldType},
}))).Interface()
t.Logf("Go type: %T", ptrType)
t.Logf("Go type: %T", goType)

ts, err := ipld.LoadSchemaBytes([]byte(schemaSrc))
qt.Assert(t, err, qt.IsNil)
schemaType := ts.TypeByName("Root")
qt.Assert(t, schemaType, qt.Not(qt.IsNil))

proto := bindnode.Prototype(ptrType, schemaType)
proto := bindnode.Prototype(goType, schemaType)
wantEncodedBytes, err := json.Marshal(map[string]interface{}{"field": json.RawMessage(kindTest.fieldDagJSON)})
qt.Assert(t, err, qt.IsNil)
wantEncoded := string(wantEncodedBytes)
Expand All @@ -377,33 +404,48 @@ func TestPrototypePointerCombinations(t *testing.T) {
// Assigning with the missing field should only work with optional.
nb := proto.NewBuilder()
err = dagjson.Decode(nb, strings.NewReader(`{}`))
if modifier == "optional" {
switch modifier.schemaField {
case "optional", "optional nullable":
qt.Assert(t, err, qt.IsNil)
node := nb.Build()
// The resulting node should be non-nil with a nil field.
nodeVal := reflect.ValueOf(bindnode.Unwrap(node))
qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue)
} else {
default:
qt.Assert(t, err, qt.Not(qt.IsNil))
}

// Assigning with a null field should only work with nullable.
nb = proto.NewBuilder()
err = dagjson.Decode(nb, strings.NewReader(`{"field":null}`))
if modifier == "nullable" {
switch modifier.schemaField {
case "nullable", "optional nullable":
qt.Assert(t, err, qt.IsNil)
node := nb.Build()
// The resulting node should be non-nil with a nil field.
nodeVal := reflect.ValueOf(bindnode.Unwrap(node))
qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue)
} else {
if modifier.schemaField == "nullable" {
qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue)
} else {
qt.Assert(t, nodeVal.Elem().FieldByName("Field").Elem().IsNil(), qt.IsTrue)
}
default:
qt.Assert(t, err, qt.Not(qt.IsNil))
}
})
}
}
}

func nilable(kind reflect.Kind) bool {
switch kind {
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return true
default:
return false
}
}

func assembleAsKind(proto datamodel.NodePrototype, schemaType schema.Type, asKind datamodel.Kind) (ipld.Node, error) {
nb := proto.NewBuilder()
switch asKind {
Expand Down Expand Up @@ -895,13 +937,13 @@ var verifyTests = []struct {
Keys []string
Values map[string]*datamodel.Node
})(nil),
(*struct {
Keys []string
Values map[string]datamodel.Node
})(nil),
},
badTypes: []verifyBadType{
{(*string)(nil), `.*type Root .* type string: kind mismatch;.*`},
{(*struct {
Keys []string
Values map[string]datamodel.Node
})(nil), `.*type Root .*: nullable types must be pointers`},
},
},
{
Expand All @@ -918,6 +960,10 @@ var verifyTests = []struct {
List *[]string
String *string
})(nil),
(*struct {
List []string
String *string
})(nil),
(*struct {
List *[]namedString
String *namedString
Expand All @@ -927,9 +973,9 @@ var verifyTests = []struct {
{(*string)(nil), `.*type Root .* type string: kind mismatch;.*`},
{(*struct{ List *[]string })(nil), `.*type Root .*: 1 vs 2 members`},
{(*struct {
List *[]string
List []string
String string
})(nil), `.*type Root .*: union members must be pointers`},
})(nil), `.*type Root .*: union members must be nilable`},
{(*struct {
List *[]string
String *int
Expand Down
26 changes: 20 additions & 6 deletions node/bindnode/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,17 @@ func (w *_node) LookupByString(key string) (datamodel.Node, error) {
if fval.IsNil() {
return datamodel.Absent, nil
}
fval = fval.Elem()
if fval.Kind() == reflect.Ptr {
fval = fval.Elem()
}
}
if field.IsNullable() {
if fval.IsNil() {
return datamodel.Null, nil
}
fval = fval.Elem()
if fval.Kind() == reflect.Ptr {
fval = fval.Elem()
}
}
if _, ok := field.Type().(*schema.TypeAny); ok {
return nonPtrVal(fval).Interface().(datamodel.Node), nil
Expand Down Expand Up @@ -822,8 +826,14 @@ func (w *_structAssembler) AssembleValue() datamodel.NodeAssembler {
w.doneFields[ftyp.Index[0]] = true
fval := w.val.FieldByIndex(ftyp.Index)
if field.IsOptional() {
fval.Set(reflect.New(fval.Type().Elem()))
fval = fval.Elem()
if fval.Kind() == reflect.Ptr {
// ptrVal = new(T); val = *ptrVal
fval.Set(reflect.New(fval.Type().Elem()))
fval = fval.Elem()
} else {
// val = *new(T)
fval.Set(reflect.New(fval.Type()).Elem())
}
}
// TODO: reuse same assembler for perf?
return &_assembler{
Expand Down Expand Up @@ -1087,13 +1097,17 @@ func (w *_structIterator) Next() (key, value datamodel.Node, _ error) {
if val.IsNil() {
return key, datamodel.Absent, nil
}
val = val.Elem()
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
}
if field.IsNullable() {
if val.IsNil() {
return key, datamodel.Null, nil
}
val = val.Elem()
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
}
if _, ok := field.Type().(*schema.TypeAny); ok {
return key, nonPtrVal(val).Interface().(datamodel.Node), nil
Expand Down

0 comments on commit e69ddc7

Please sign in to comment.