From 12ad5ccc490b706b8b2493fda55d61b82dc5b3e3 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Wed, 5 Jun 2024 13:57:55 +0200 Subject: [PATCH 1/3] Marshal missing UDT fields as null instead of failing We can't return an error in case a field is added to the UDT, otherwise existing code would break by simply altering the UDT in the database. For extra fields at the end of the UDT put nulls to be in line with gocql, but also python-driver and java-driver. In gocql it was fixed in https://github.com/scylladb/gocql/commit/d2ed1bb74f3118a83a352e9ce912be765001efa4 --- udt.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/udt.go b/udt.go index 1da938e..63551f7 100644 --- a/udt.go +++ b/udt.go @@ -39,11 +39,17 @@ func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt { func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) { value, ok := u.field[name] - if !ok { - return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type()) + + var data []byte + var err error + if ok { + data, err = gocql.Marshal(info, value.Interface()) + if err != nil { + return nil, err + } } - return gocql.Marshal(info, value.Interface()) + return data, err } func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error { From 8238a44e10be4ce7f1edae28bd3faa2a5a77cc97 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Fri, 14 Jun 2024 11:21:57 -0400 Subject: [PATCH 2/3] Introduce `Unsafe` method on `Queryx` It enables local control over `unsafe` mode for .Bind methods of `Queryx` and iterators spawn by it. --- iterx.go | 2 +- iterx_test.go | 6 +----- queryx.go | 16 +++++++++++++--- session.go | 2 ++ 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/iterx.go b/iterx.go index 5f99d06..9ebfedf 100644 --- a/iterx.go +++ b/iterx.go @@ -13,7 +13,7 @@ import ( "github.com/scylladb/go-reflectx" ) -// DefaultUnsafe enables the behavior of forcing the iterator to ignore +// DefaultUnsafe enables the behavior of forcing queries and iterators to ignore // missing fields for all queries. See Unsafe below for more information. var DefaultUnsafe bool diff --git a/iterx_test.go b/iterx_test.go index 59ecfa5..d8652a0 100644 --- a/iterx_test.go +++ b/iterx_test.go @@ -484,12 +484,8 @@ func TestIterxUnsafe(t *testing.T) { }) t.Run("select default unsafe", func(t *testing.T) { - gocqlx.DefaultUnsafe = true - defer func() { - gocqlx.DefaultUnsafe = false - }() var v []UnsafeTable - err := session.Query(stmt, nil).Iter().Select(&v) + err := session.Query(stmt, nil).Unsafe().Iter().Select(&v) if err != nil { t.Fatal("Select() failed:", err) } diff --git a/queryx.go b/queryx.go index 7d8809f..331512c 100644 --- a/queryx.go +++ b/queryx.go @@ -94,7 +94,8 @@ type Queryx struct { tr Transformer Mapper *reflectx.Mapper *gocql.Query - Names []string + Names []string + unsafe bool } // Query creates a new Queryx from gocql.Query using a default mapper. @@ -106,6 +107,7 @@ func Query(q *gocql.Query, names []string) *Queryx { Names: names, Mapper: DefaultMapper, tr: DefaultBindTransformer, + unsafe: DefaultUnsafe, } } @@ -209,7 +211,7 @@ func (q *Queryx) bindMapArgs(arg map[string]interface{}) ([]interface{}, error) // Bind sets query arguments of query. This can also be used to rebind new query arguments // to an existing query instance. func (q *Queryx) Bind(v ...interface{}) *Queryx { - q.Query.Bind(udtWrapSlice(q.Mapper, DefaultUnsafe, v)...) + q.Query.Bind(udtWrapSlice(q.Mapper, q.unsafe, v)...) return q } @@ -342,6 +344,14 @@ func (q *Queryx) Iter() *Iterx { return &Iterx{ Iter: q.Query.Iter(), Mapper: q.Mapper, - unsafe: DefaultUnsafe, + unsafe: q.unsafe, } } + +// Unsafe forces the query and iterators to ignore missing fields. By default when scanning +// a struct if result row has a column that cannot be mapped to any destination +// field an error is reported. With unsafe such columns are ignored. +func (q *Queryx) Unsafe() *Queryx { + q.unsafe = true + return q +} diff --git a/session.go b/session.go index aef880a..cb82c3c 100644 --- a/session.go +++ b/session.go @@ -51,6 +51,7 @@ func (s Session) ContextQuery(ctx context.Context, stmt string, names []string) Names: names, Mapper: s.Mapper, tr: DefaultBindTransformer, + unsafe: DefaultUnsafe, } } @@ -65,6 +66,7 @@ func (s Session) Query(stmt string, names []string) *Queryx { Names: names, Mapper: s.Mapper, tr: DefaultBindTransformer, + unsafe: DefaultUnsafe, } } From a12aa3e7df4ae92b7e76a159b58cc7a43f3fcd16 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Fri, 14 Jun 2024 11:28:30 -0400 Subject: [PATCH 3/3] Marshal/Unmarshal missing UDT fields as null instead of failing in unsafe mode We can't return an error in case a field is added to the UDT, otherwise existing code would break by simply altering the UDT in the database. For extra fields at the end of the UDT put nulls to be in line with gocql, but also python-driver and java-driver. In gocql it was fixed in https://github.com/scylladb/gocql/commit/d2ed1bb74f3118a83a352e9ce912be765001efa4 --- iterx_test.go | 325 +++++++++++++++++++++++++++++++++++++++++++++++++- queryx.go | 7 ++ udt.go | 24 ++-- 3 files changed, 341 insertions(+), 15 deletions(-) diff --git a/iterx_test.go b/iterx_test.go index d8652a0..96de405 100644 --- a/iterx_test.go +++ b/iterx_test.go @@ -9,6 +9,7 @@ package gocqlx_test import ( "math/big" + "reflect" "strings" "testing" "time" @@ -48,6 +49,328 @@ type FullNamePtrUDT struct { *FullName } +func diff(t *testing.T, expected, got interface{}) { + t.Helper() + + if d := cmp.Diff(expected, got, diffOpts); d != "" { + t.Errorf("got %+v expected %+v, diff: %s", got, expected, d) + } +} + +var diffOpts = cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{}) + +func TestIterxUDT(t *testing.T) { + session := gocqlxtest.CreateSession(t) + t.Cleanup(func() { + session.Close() + }) + + if err := session.ExecStmt(`CREATE TYPE gocqlx_test.UDTTest_Full (first text, second text)`); err != nil { + t.Fatal("create type:", err) + } + + if err := session.ExecStmt(`CREATE TABLE gocqlx_test.udt_table ( + testuuid timeuuid PRIMARY KEY, + testudt gocqlx_test.UDTTest_Full + )`); err != nil { + t.Fatal("create table:", err) + } + + type Full struct { + First string + Second string + } + + type Part struct { + First string + } + + type Extra struct { + First string + Second string + Third string + } + + type FullUDT struct { + gocqlx.UDT + Full + } + + type PartUDT struct { + gocqlx.UDT + Part + } + + type ExtraUDT struct { + gocqlx.UDT + Extra + } + + type FullUDTPtr struct { + gocqlx.UDT + *Full + } + + type PartUDTPtr struct { + gocqlx.UDT + *Part + } + + type ExtraUDTPtr struct { + gocqlx.UDT + *Extra + } + + full := FullUDT{ + Full: Full{ + First: "John", + Second: "Doe", + }, + } + + makeStruct := func(testuuid gocql.UUID, insert interface{}) interface{} { + b := reflect.New(reflect.StructOf([]reflect.StructField{ + { + Name: "TestUUID", + Type: reflect.TypeOf(gocql.UUID{}), + }, + { + Name: "TestUDT", + Type: reflect.TypeOf(insert), + }, + })).Interface() + reflect.ValueOf(b).Elem().FieldByName("TestUUID").Set(reflect.ValueOf(testuuid)) + reflect.ValueOf(b).Elem().FieldByName("TestUDT").Set(reflect.ValueOf(insert)) + return b + } + + tcases := []struct { + name string + insert interface{} + expected interface{} + expectedOnDB FullUDT + }{ + { + name: "exact-match", + insert: full, + expectedOnDB: full, + expected: full, + }, + { + name: "exact-match-ptr", + insert: FullUDTPtr{ + Full: &Full{ + First: "John", + Second: "Doe", + }, + }, + expectedOnDB: full, + expected: FullUDTPtr{ + Full: &Full{ + First: "John", + Second: "Doe", + }, + }, + }, + { + name: "extra-field", + insert: ExtraUDT{ + Extra: Extra{ + First: "John", + Second: "Doe", + Third: "Smith", + }, + }, + expectedOnDB: full, + expected: ExtraUDT{ + Extra: Extra{ + First: "John", + Second: "Doe", + Third: "", // Since the UDT has only 2 fields, the third field should be empty + }, + }, + }, + { + name: "extra-field-ptr", + insert: ExtraUDTPtr{ + Extra: &Extra{ + First: "John", + Second: "Doe", + Third: "Smith", + }, + }, + expectedOnDB: full, + expected: ExtraUDTPtr{ + Extra: &Extra{ + First: "John", + Second: "Doe", + Third: "", // Since the UDT has only 2 fields, the third field should be empty + }, + }, + }, + { + name: "absent-field", + insert: PartUDT{ + Part: Part{ + First: "John", + }, + }, + expectedOnDB: FullUDT{ + Full: Full{ + First: "John", + Second: "", + }, + }, + expected: PartUDT{ + Part: Part{ + First: "John", + }, + }, + }, + { + name: "absent-field-ptr", + insert: PartUDTPtr{ + Part: &Part{ + First: "John", + }, + }, + expectedOnDB: FullUDT{ + Full: Full{ + First: "John", + Second: "", + }, + }, + expected: PartUDTPtr{ + Part: &Part{ + First: "John", + }, + }, + }, + } + + const insertStmt = `INSERT INTO udt_table (testuuid, testudt) VALUES (?, ?)` + const deleteStmt = `DELETE FROM udt_table WHERE testuuid = ?` + + for _, tc := range tcases { + t.Run(tc.name, func(t *testing.T) { + testuuid := gocql.TimeUUID() + + if reflect.TypeOf(tc.insert) != reflect.TypeOf(tc.expected) { + t.Fatalf("insert and expectedOnDB must have the same type") + } + + t.Cleanup(func() { + session.Query(deleteStmt, nil).Bind(testuuid).ExecRelease() // nolint:errcheck + }) + + t.Run("insert-bind", func(t *testing.T) { + if err := session.Query(insertStmt, nil).Unsafe().Bind( + testuuid, + tc.insert, + ).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := FullUDT{} + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(&v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expectedOnDB, v) + }) + + t.Run("scan", func(t *testing.T) { + v := reflect.New(reflect.TypeOf(tc.expected)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Scan(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface()) + }) + + t.Run("get", func(t *testing.T) { + v := reflect.New(reflect.TypeOf(tc.expected)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Unsafe().Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, tc.expected, reflect.ValueOf(v).Elem().Interface()) + }) + + t.Run("delete", func(t *testing.T) { + if err := session.Query(deleteStmt, nil).Bind( + testuuid, + ).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + }) + + t.Run("insert-bind-struct", func(t *testing.T) { + b := makeStruct(testuuid, tc.insert) + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe().BindStruct(b).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + + t.Run("insert-bind-struct-map", func(t *testing.T) { + t.Run("empty-map", func(t *testing.T) { + b := makeStruct(testuuid, tc.insert) + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindStructMap(b, nil).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + + t.Run("empty-struct", func(t *testing.T) { + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindStructMap(struct{}{}, map[string]interface{}{ + "test_uuid": testuuid, + "test_udt": tc.insert, + }).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + }) + + t.Run("insert-bind-map", func(t *testing.T) { + if err := session.Query(insertStmt, []string{"test_uuid", "test_udt"}).Unsafe(). + BindMap(map[string]interface{}{ + "test_uuid": testuuid, + "test_udt": tc.insert, + }).ExecRelease(); err != nil { + t.Fatal(err.Error()) + } + + // Make sure the UDT was inserted correctly + v := reflect.New(reflect.TypeOf(tc.expectedOnDB)).Interface() + if err := session.Query(`SELECT testudt FROM udt_table where testuuid = ?`, nil).Bind(testuuid).Get(v); err != nil { + t.Fatal(err.Error()) + } + diff(t, &tc.expectedOnDB, v) + }) + }) + } +} + func TestIterxStruct(t *testing.T) { session := gocqlxtest.CreateSession(t) defer session.Close() @@ -153,8 +476,6 @@ func TestIterxStruct(t *testing.T) { t.Fatal("insert:", err) } - diffOpts := cmpopts.IgnoreUnexported(big.Int{}, inf.Dec{}) - const stmt = `SELECT * FROM struct_table` t.Run("get", func(t *testing.T) { diff --git a/queryx.go b/queryx.go index 331512c..9b67f9b 100644 --- a/queryx.go +++ b/queryx.go @@ -215,6 +215,13 @@ func (q *Queryx) Bind(v ...interface{}) *Queryx { return q } +// Scan executes the query, copies the columns of the first selected +// row into the values pointed at by dest and discards the rest. If no rows +// were selected, ErrNotFound is returned. +func (q *Queryx) Scan(v ...interface{}) error { + return q.Query.Scan(udtWrapSlice(q.Mapper, q.unsafe, v)...) +} + // Err returns any binding errors. func (q *Queryx) Err() error { return q.err diff --git a/udt.go b/udt.go index 63551f7..2785151 100644 --- a/udt.go +++ b/udt.go @@ -39,26 +39,24 @@ func makeUDT(value reflect.Value, mapper *reflectx.Mapper, unsafe bool) udt { func (u udt) MarshalUDT(name string, info gocql.TypeInfo) ([]byte, error) { value, ok := u.field[name] - - var data []byte - var err error if ok { - data, err = gocql.Marshal(info, value.Interface()) - if err != nil { - return nil, err - } + return gocql.Marshal(info, value.Interface()) } - - return data, err + if u.unsafe { + return nil, nil + } + return nil, fmt.Errorf("missing name %q in %s", name, u.value.Type()) } func (u udt) UnmarshalUDT(name string, info gocql.TypeInfo, data []byte) error { value, ok := u.field[name] - if !ok && !u.unsafe { - return fmt.Errorf("missing name %q in %s", name, u.value.Type()) + if ok { + return gocql.Unmarshal(info, data, value.Addr().Interface()) } - - return gocql.Unmarshal(info, data, value.Addr().Interface()) + if u.unsafe { + return nil + } + return fmt.Errorf("missing name %q in %s", name, u.value.Type()) } // udtWrapValue adds UDT wrapper if needed.