Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marshal missing UDT fields as null instead of failing #269

Merged
merged 3 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion iterx.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/scylladb/go-reflectx"
)

// DefaultUnsafe enables the behavior of forcing the iterator to ignore
// DefaultUnsafe enables the behavior of forcing queries and iterators to ignore
// missing fields for all queries. See Unsafe below for more information.
var DefaultUnsafe bool

Expand Down
331 changes: 324 additions & 7 deletions iterx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package gocqlx_test

import (
"math/big"
"reflect"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -48,6 +49,328 @@ type FullNamePtrUDT struct {
*FullName
}

func diff(t *testing.T, expected, got interface{}) {
t.Helper()

if d := cmp.Diff(expected, got, diffOpts); d != "" {
t.Errorf("got %+v expected %+v, diff: %s", got, expected, d)
}
}

var diffOpts = cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{})

func TestIterxUDT(t *testing.T) {
session := gocqlxtest.CreateSession(t)
t.Cleanup(func() {
session.Close()
})

if err := session.ExecStmt(`CREATE TYPE gocqlx_test.UDTTest_Full (first text, second text)`); err != nil {
t.Fatal("create type:", err)
}

if err := session.ExecStmt(`CREATE TABLE gocqlx_test.udt_table (
testuuid timeuuid PRIMARY KEY,
testudt gocqlx_test.UDTTest_Full
)`); err != nil {
t.Fatal("create table:", err)
}

type Full struct {
First string
Second string
}

type Part struct {
First string
}

type Extra struct {
First string
Second string
Third string
}

type FullUDT struct {
gocqlx.UDT
Full
}

type PartUDT struct {
gocqlx.UDT
Part
}

type ExtraUDT struct {
gocqlx.UDT
Extra
}

type FullUDTPtr struct {
gocqlx.UDT
*Full
}

type PartUDTPtr struct {
gocqlx.UDT
*Part
}

type ExtraUDTPtr struct {
gocqlx.UDT
*Extra
}

full := FullUDT{
Full: Full{
First: "John",
Second: "Doe",
},
}

makeStruct := func(testuuid gocql.UUID, insert interface{}) interface{} {
b := reflect.New(reflect.StructOf([]reflect.StructField{
{
Name: "TestUUID",
Type: reflect.TypeOf(gocql.UUID{}),
},
{
Name: "TestUDT",
Type: reflect.TypeOf(insert),
},
})).Interface()
reflect.ValueOf(b).Elem().FieldByName("TestUUID").Set(reflect.ValueOf(testuuid))
reflect.ValueOf(b).Elem().FieldByName("TestUDT").Set(reflect.ValueOf(insert))
return b
}

tcases := []struct {
name string
insert interface{}
expected interface{}
expectedOnDB FullUDT
}{
{
name: "exact-match",
insert: full,
expectedOnDB: full,
expected: full,
},
{
name: "exact-match-ptr",
insert: FullUDTPtr{
Full: &Full{
First: "John",
Second: "Doe",
},
},
expectedOnDB: full,
expected: FullUDTPtr{
Full: &Full{
First: "John",
Second: "Doe",
},
},
},
{
name: "extra-field",
insert: ExtraUDT{
Extra: Extra{
First: "John",
Second: "Doe",
Third: "Smith",
},
},
expectedOnDB: full,
expected: ExtraUDT{
Extra: Extra{
First: "John",
Second: "Doe",
Third: "", // Since the UDT has only 2 fields, the third field should be empty
},
},
},
{
name: "extra-field-ptr",
insert: ExtraUDTPtr{
Extra: &Extra{
First: "John",
Second: "Doe",
Third: "Smith",
},
},
expectedOnDB: full,
expected: ExtraUDTPtr{
Extra: &Extra{
First: "John",
Second: "Doe",
Third: "", // Since the UDT has only 2 fields, the third field should be empty
},
},
},
{
name: "absent-field",
insert: PartUDT{
Part: Part{
First: "John",
},
},
expectedOnDB: FullUDT{
Full: Full{
First: "John",
Second: "",
},
},
expected: PartUDT{
Part: Part{
First: "John",
},
},
},
{
name: "absent-field-ptr",
insert: PartUDTPtr{
Part: &Part{
First: "John",
},
},
expectedOnDB: FullUDT{
Full: Full{
First: "John",
Second: "",
},
},
expected: PartUDTPtr{
Part: &Part{
First: "John",
},
},
},
}

const insertStmt = `INSERT INTO udt_table (testuuid, testudt) VALUES (?, ?)`
const deleteStmt = `DELETE FROM udt_table WHERE testuuid = ?`

for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
testuuid := gocql.TimeUUID()

if reflect.TypeOf(tc.insert) != reflect.TypeOf(tc.expected) {
t.Fatalf("insert and expectedOnDB must have the same type")
}

t.Cleanup(func() {
session.Query(deleteStmt, nil).Bind(testuuid).ExecRelease() // nolint:errcheck
})

t.Run("insert-bind", func(t *testing.T) {
if err := session.Query(insertStmt, nil).Unsafe().Bind(
testuuid,
tc.insert,
).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := FullUDT{}
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expectedOnDB, v)
})

t.Run("scan", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
})

t.Run("get", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
})

t.Run("delete", func(t *testing.T) {
if err := session.Query(deleteStmt, nil).Bind(
testuuid,
).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
})

t.Run("insert-bind-struct", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})

t.Run("insert-bind-struct-map", func(t *testing.T) {
t.Run("empty-map", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindStructMap(b, nil).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})

t.Run("empty-struct", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindStructMap(struct{}{}, map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
}).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})
})

t.Run("insert-bind-map", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindMap(map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
}).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})
})
}
}

func TestIterxStruct(t *testing.T) {
session := gocqlxtest.CreateSession(t)
defer session.Close()
Expand Down Expand Up @@ -153,8 +476,6 @@ func TestIterxStruct(t *testing.T) {
t.Fatal("insert:", err)
}

diffOpts := cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{})

const stmt = `SELECT * FROM struct_table`

t.Run("get", func(t *testing.T) {
Expand Down Expand Up @@ -484,12 +805,8 @@ func TestIterxUnsafe(t *testing.T) {
})

t.Run("select default unsafe", func(t *testing.T) {
gocqlx.DefaultUnsafe = true
defer func() {
gocqlx.DefaultUnsafe = false
}()
var v []UnsafeTable
err := session.Query(stmt, nil).Iter().Select(&v)
err := session.Query(stmt, nil).Unsafe().Iter().Select(&v)
if err != nil {
t.Fatal("Select() failed:", err)
}
wprzytula marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading
Loading