From 4d8c8ba37754cc929c65e88a9125dbb58456a64d Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Fri, 14 Jun 2024 11:28:30 -0400 Subject: [PATCH] Marshal/Unmarshal missing UDT fields as null instead of failing in unsafe mode We can't return an error in case a field is added to the UDT, otherwise existing code would break by simply altering the UDT in the database. For extra fields at the end of the UDT put nulls to be in line with gocql, but also python-driver and java-driver. In gocql it was fixed in https://github.com/scylladb/gocql/commit/d2ed1bb74f3118a83a352e9ce912be765001efa4 --- go.sum | 5 - iterx_test.go | 325 +++++++++++++++++++++++++++++++++++++++++++++++++- queryx.go | 4 + udt.go | 24 ++-- 4 files changed, 338 insertions(+), 20 deletions(-) diff --git a/go.sum b/go.sum index 00894b5..13ad649 100644 --- a/go.sum +++ b/go.sum @@ -3,13 +3,8 @@ github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCS github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gocql/gocql v0.0.0-20200131111108-92af2e088537 h1:NaMut1fdw76YYX/TPinSAbai4DShF5tPort3bHpET6g= -github.com/gocql/gocql v0.0.0-20200131111108-92af2e088537/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= github.com/gocql/gocql v0.0.0-20211015133455-b225f9b53fa1 h1:px9qUCy/RNJNsfCam4m2IxWGxNuimkrioEF0vrrbPsg= github.com/gocql/gocql v0.0.0-20211015133455-b225f9b53fa1/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= -github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= diff --git a/iterx_test.go b/iterx_test.go index 757f20c..c8f74ef 100644 --- a/iterx_test.go +++ b/iterx_test.go @@ -9,6 +9,7 @@ package gocqlx_test import ( "math/big" + "reflect" "strings" "testing" "time" @@ -47,6 +48,328 @@ type FullNamePtrUDT struct { *FullName } +func diff(t *testing.T, expected, got interface{}) { + t.Helper() + + if d := cmp.Diff(expected, got, diffOpts); d != "" { + t.Errorf("got %+v expected %+v, diff: %s", got, expected, d) + } +} + +var diffOpts = cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{}) + +func TestIterxUDT(t *testing.T) { + session := CreateSession(t) + t.Cleanup(func() { + session.Close() + }) + + if err := session.ExecStmt(`CREATE TYPE gocqlx_test.UDTTest_Full (first text, second text)`); err != nil { + t.Fatal("create type:", err) + } + + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.udt_table ( + testuuid timeuuid PRIMARY KEY, + testudt gocqlx_test.UDTTest_Full + )`); err != nil { + t.Fatal("create table:", err) + } + + type Full struct { + First string + Second string + } + + type Part struct { + First string + } + + type Extra struct { + First string + Second string + Third string + } + + type FullUDT struct { + gocqlx.UDT + Full + } + + type PartUDT struct { + gocqlx.UDT + Part + } + + type ExtraUDT struct { + gocqlx.UDT + Extra + } + + type FullUDTPtr struct { + gocqlx.UDT + *Full + } + + type PartUDTPtr struct { + gocqlx.UDT + *Part + } + + type ExtraUDTPtr struct { + gocqlx.UDT + *Extra + } + + full := FullUDT{ + Full: Full{ + First: "John", + Second: "Doe", + }, + } + + makeStruct := func(testuuid gocql.UUID, insert interface{}) interface{} { + b := reflect.New(reflect.StructOf([]reflect.StructField{ + { + Name: "TestUUID", + Type: reflect.TypeOf(gocql.UUID{}), + }, + { + Name: "TestUDT", + Type: reflect.TypeOf(insert), + }, + })).Interface() + reflect.ValueOf(b).Elem().FieldByName("TestUUID").Set(reflect.ValueOf(testuuid)) + reflect.ValueOf(b).Elem().FieldByName("TestUDT").Set(reflect.ValueOf(insert)) + return b + } + + tcases := []struct { + name string + insert interface{} + expectedOnDB FullUDT + expected interface{} + }{ + { + name: "exact-match", + insert: full, + expectedOnDB: full, + expected: full, + }, + { + name: "exact-match-ptr", + insert: FullUDTPtr{ + Full: &Full{ + First: "John", + Second: "Doe", + }, + }, + expectedOnDB: full, + expected: FullUDTPtr{ + Full: &Full{ + First: "John", + Second: "Doe", + }, + }, + }, + { + name: "extra-field", + insert: ExtraUDT{ + Extra: Extra{ + First: "John", + Second: "Doe", + Third: "Smith", + }, + }, + expectedOnDB: full, + expected: ExtraUDT{ + Extra: Extra{ + First: "John", + Second: "Doe", + Third: "", //Since the UDT has only 2 fields, the third field should be empty + }, + }, + }, + { + name: "extra-field-ptr", + insert: ExtraUDTPtr{ + Extra: &Extra{ + First: "John", + Second: "Doe", + Third: "Smith", + }, + }, + expectedOnDB: full, + expected: ExtraUDTPtr{ + Extra: &Extra{ + First: "John", + Second: "Doe", + Third: "", //Since the UDT has only 2 fields, the third field should be empty + }, + }, + }, + { + name: "absent-field", + insert: PartUDT{ + Part: Part{ + First: "John", + }, + }, + expectedOnDB: FullUDT{ + Full: Full{ + First: "John", + Second: "", + }, + }, + expected: PartUDT{ + Part: Part{ + First: "John", + }, + }, + }, + { + name: "absent-field-ptr", + insert: PartUDTPtr{ + Part: &Part{ + First: "John", + }, + }, + expectedOnDB: FullUDT{ + Full: Full{ + First: "John", + Second: "", + }, + }, + expected: PartUDTPtr{ + Part: &Part{ + First: "John", + }, + }, + }, + } + + const insertStmt = `INSERT INTO udt_table (testuuid, testudt) VALUES (?, ?)` + const deleteStmt = `DELETE FROM udt_table WHERE testuuid = ?` + + for _, tc := range tcases { + t.Run(tc.name, func(t *testing.T) { + testuuid := gocql.TimeUUID() + + if reflect.TypeOf(tc.insert) != reflect.TypeOf(tc.expected) { + t.Fatalf("insert and expectedOnDB must have the same type") + } + + t.Cleanup(func() { + session.Query(deleteStmt, nil).Bind(testuuid).ExecRelease() + }) + + t.Run("insert-bind", func(t *testing.T) { + if err := session.Query(insertStmt, nil).Unsafe().Bind( + testuuid, + tc.insert, + ).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := FullUDT{} + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expectedOnDB, v) + }) + + t.Run("scan", func(t *testing.T) { + v := reflect.New(reflect.TypeOf(tc.expected)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface()) + }) + + t.Run("get", func(t *testing.T) { + v := reflect.New(reflect.TypeOf(tc.expected)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface()) + }) + + t.Run("delete", func(t *testing.T) { + if err := session.Query(deleteStmt, nil).Bind( + testuuid, + ).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + }) + + t.Run("insert-bind-struct", func(t *testing.T) { + b := makeStruct(testuuid, tc.insert) + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + + t.Run("insert-bind-struct-map", func(t *testing.T) { + t.Run("empty-map", func(t *testing.T) { + b := makeStruct(testuuid, tc.insert) + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindStructMap(b, nil).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + + t.Run("empty-struct", func(t *testing.T) { + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindStructMap(struct{}{}, map[string]interface{}{ + "test_uuid": testuuid, + "test_udt": tc.insert, + }).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + }) + + t.Run("insert-bind-map", func(t *testing.T) { + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindMap(map[string]interface{}{ + "test_uuid": testuuid, + "test_udt": tc.insert, + }).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + }) + } +} + func TestIterxStruct(t *testing.T) { session := CreateSession(t) defer session.Close() @@ -149,8 +472,6 @@ func TestIterxStruct(t *testing.T) { t.Fatal("insert:", err) } - diffOpts := cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{}) - const stmt = `SELECT * FROM struct_table` t.Run("get", func(t *testing.T) { diff --git a/queryx.go b/queryx.go index f5e8736..dec924b 100644 --- a/queryx.go +++ b/queryx.go @@ -216,6 +216,10 @@ func (q *Queryx) Bind(v ...interface{}) *Queryx { return q } +func (q *Queryx) Scan(v ...interface{}) error { + return q.Query.Scan(udtWrapSlice(q.Mapper, q.unsafe, v)...) +} + // Err returns any binding errors. func (q *Queryx) Err() error { return q.err diff --git a/udt.go b/udt.go index a874989..9bd57f9 100644 --- a/udt.go +++ b/udt.go @@ -39,26 +39,24 @@ func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt { func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) { value, ok := u.field[name] - - var data []byte - var err error if ok { - data, err = gocql.Marshal(info, value.Interface()) - if err != nil { - return nil, err - } + return gocql.Marshal(info, value.Interface()) } - - return data, err + if u.unsafe { + return nil, nil + } + return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type()) } func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error { value, ok := u.field[name] - if !ok && !u.unsafe { - return fmt.Errorf("missing name %q in %s", name, u.value.Type()) + if ok { + return gocql.Unmarshal(info, data, value.Addr().Interface()) } - - return gocql.Unmarshal(info, data, value.Addr().Interface()) + if u.unsafe { + return nil + } + return fmt.Errorf("missing name %q in %s", name, u.value.Type()) } // udtWrapValue adds UDT wrapper if needed.