Skip to content

Commit

Permalink
Marshal/Unmarshal missing UDT fields as null instead of failing in un…
Browse files Browse the repository at this point in the history
…safe mode

We can't return an error in case a field is added to the UDT,
otherwise existing code would break by simply altering the UDT in the
database. For extra fields at the end of the UDT put nulls to be in
line with gocql, but also python-driver and java-driver.

In gocql it was fixed in scylladb/gocql@d2ed1bb
  • Loading branch information
dkropachev committed Jun 14, 2024
1 parent 50c2977 commit 4d8c8ba
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 20 deletions.
5 changes: 0 additions & 5 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCS
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gocql/gocql v0.0.0-20200131111108-92af2e088537 h1:NaMut1fdw76YYX/TPinSAbai4DShF5tPort3bHpET6g=
github.com/gocql/gocql v0.0.0-20200131111108-92af2e088537/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY=
github.com/gocql/gocql v0.0.0-20211015133455-b225f9b53fa1 h1:px9qUCy/RNJNsfCam4m2IxWGxNuimkrioEF0vrrbPsg=
github.com/gocql/gocql v0.0.0-20211015133455-b225f9b53fa1/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8=
github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
Expand Down
325 changes: 323 additions & 2 deletions iterx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package gocqlx_test

import (
"math/big"
"reflect"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -47,6 +48,328 @@ type FullNamePtrUDT struct {
*FullName
}

func diff(t *testing.T, expected, got interface{}) {
t.Helper()

if d := cmp.Diff(expected, got, diffOpts); d != "" {
t.Errorf("got %+v expected %+v, diff: %s", got, expected, d)
}
}

var diffOpts = cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{})

func TestIterxUDT(t *testing.T) {
session := CreateSession(t)
t.Cleanup(func() {
session.Close()
})

if err := session.ExecStmt(`CREATE TYPE gocqlx_test.UDTTest_Full (first text, second text)`); err != nil {
t.Fatal("create type:", err)
}

if err := session.ExecStmt(`CREATE TABLE gocqlx_test.udt_table (
testuuid timeuuid PRIMARY KEY,
testudt gocqlx_test.UDTTest_Full
)`); err != nil {
t.Fatal("create table:", err)
}

type Full struct {
First string
Second string
}

type Part struct {
First string
}

type Extra struct {
First string
Second string
Third string
}

type FullUDT struct {
gocqlx.UDT
Full
}

type PartUDT struct {
gocqlx.UDT
Part
}

type ExtraUDT struct {
gocqlx.UDT
Extra
}

type FullUDTPtr struct {
gocqlx.UDT
*Full
}

type PartUDTPtr struct {
gocqlx.UDT
*Part
}

type ExtraUDTPtr struct {
gocqlx.UDT
*Extra
}

full := FullUDT{
Full: Full{
First: "John",
Second: "Doe",
},
}

makeStruct := func(testuuid gocql.UUID, insert interface{}) interface{} {
b := reflect.New(reflect.StructOf([]reflect.StructField{
{
Name: "TestUUID",
Type: reflect.TypeOf(gocql.UUID{}),
},
{
Name: "TestUDT",
Type: reflect.TypeOf(insert),
},
})).Interface()
reflect.ValueOf(b).Elem().FieldByName("TestUUID").Set(reflect.ValueOf(testuuid))
reflect.ValueOf(b).Elem().FieldByName("TestUDT").Set(reflect.ValueOf(insert))
return b
}

tcases := []struct {
name string
insert interface{}
expectedOnDB FullUDT
expected interface{}
}{
{
name: "exact-match",
insert: full,
expectedOnDB: full,
expected: full,
},
{
name: "exact-match-ptr",
insert: FullUDTPtr{
Full: &Full{
First: "John",
Second: "Doe",
},
},
expectedOnDB: full,
expected: FullUDTPtr{
Full: &Full{
First: "John",
Second: "Doe",
},
},
},
{
name: "extra-field",
insert: ExtraUDT{
Extra: Extra{
First: "John",
Second: "Doe",
Third: "Smith",
},
},
expectedOnDB: full,
expected: ExtraUDT{
Extra: Extra{
First: "John",
Second: "Doe",
Third: "", //Since the UDT has only 2 fields, the third field should be empty
},
},
},
{
name: "extra-field-ptr",
insert: ExtraUDTPtr{
Extra: &Extra{
First: "John",
Second: "Doe",
Third: "Smith",
},
},
expectedOnDB: full,
expected: ExtraUDTPtr{
Extra: &Extra{
First: "John",
Second: "Doe",
Third: "", //Since the UDT has only 2 fields, the third field should be empty
},
},
},
{
name: "absent-field",
insert: PartUDT{
Part: Part{
First: "John",
},
},
expectedOnDB: FullUDT{
Full: Full{
First: "John",
Second: "",
},
},
expected: PartUDT{
Part: Part{
First: "John",
},
},
},
{
name: "absent-field-ptr",
insert: PartUDTPtr{
Part: &Part{
First: "John",
},
},
expectedOnDB: FullUDT{
Full: Full{
First: "John",
Second: "",
},
},
expected: PartUDTPtr{
Part: &Part{
First: "John",
},
},
},
}

const insertStmt = `INSERT INTO udt_table (testuuid, testudt) VALUES (?, ?)`
const deleteStmt = `DELETE FROM udt_table WHERE testuuid = ?`

for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
testuuid := gocql.TimeUUID()

if reflect.TypeOf(tc.insert) != reflect.TypeOf(tc.expected) {
t.Fatalf("insert and expectedOnDB must have the same type")
}

t.Cleanup(func() {
session.Query(deleteStmt, nil).Bind(testuuid).ExecRelease()
})

t.Run("insert-bind", func(t *testing.T) {
if err := session.Query(insertStmt, nil).Unsafe().Bind(
testuuid,
tc.insert,
).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := FullUDT{}
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expectedOnDB, v)
})

t.Run("scan", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
})

t.Run("get", func(t *testing.T) {
v := reflect.New(reflect.TypeOf(tc.expected)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface())
})

t.Run("delete", func(t *testing.T) {
if err := session.Query(deleteStmt, nil).Bind(
testuuid,
).ExecRelease(); err != nil {
t.Fatal(err.Error())
}
})

t.Run("insert-bind-struct", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})

t.Run("insert-bind-struct-map", func(t *testing.T) {
t.Run("empty-map", func(t *testing.T) {
b := makeStruct(testuuid, tc.insert)
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindStructMap(b, nil).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})

t.Run("empty-struct", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindStructMap(struct{}{}, map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
}).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})
})

t.Run("insert-bind-map", func(t *testing.T) {
if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().
BindMap(map[string]interface{}{
"test_uuid": testuuid,
"test_udt": tc.insert,
}).ExecRelease(); err != nil {
t.Fatal(err.Error())
}

// Make sure the UDT was inserted correctly
v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface()
if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil {
t.Fatal(err.Error())
}
diff(t, &tc.expectedOnDB, v)
})
})
}
}

func TestIterxStruct(t *testing.T) {
session := CreateSession(t)
defer session.Close()
Expand Down Expand Up @@ -149,8 +472,6 @@ func TestIterxStruct(t *testing.T) {
t.Fatal("insert:", err)
}

diffOpts := cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{})

const stmt = `SELECT * FROM struct_table`

t.Run("get", func(t *testing.T) {
Expand Down
4 changes: 4 additions & 0 deletions queryx.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ func (q *Queryx) Bind(v ...interface{}) *Queryx {
return q
}

func (q *Queryx) Scan(v ...interface{}) error {
return q.Query.Scan(udtWrapSlice(q.Mapper, q.unsafe, v)...)
}

// Err returns any binding errors.
func (q *Queryx) Err() error {
return q.err
Expand Down
Loading

0 comments on commit 4d8c8ba

Please sign in to comment.