From f443712bbc7a822eacd4e3d7568c38fa3f0f02f3 Mon Sep 17 00:00:00 2001 From: Vivek Menezes Date: Thu, 3 Sep 2015 22:42:39 -0400 Subject: [PATCH] Add Timestamp sql parameter. Allow casting of integer to interval type. Get rid of sql/driver -> sql/parser dependency. --- sql/driver/conn.go | 12 +- sql/driver/driver_test.go | 86 ++++++++------ sql/driver/wire.go | 54 +-------- sql/driver/wire.pb.go | 207 ++++++++++++++++++++++++++++++++-- sql/driver/wire.proto | 13 +++ sql/parser/eval.go | 7 ++ sql/parser/eval_test.go | 4 + sql/parser/type_check.go | 2 +- sql/parser/type_check_test.go | 1 - sql/server.go | 53 ++++++++- 10 files changed, 336 insertions(+), 103 deletions(-) diff --git a/sql/driver/conn.go b/sql/driver/conn.go index aeaea3ec3f53..6755de196020 100644 --- a/sql/driver/conn.go +++ b/sql/driver/conn.go @@ -73,13 +73,9 @@ func (c *conn) Query(stmt string, args []driver.Value) (*rows, error) { case string: param.StringVal = &value case time.Time: - // TODO(vivek): pass in time as an input that can be interpreted - // by the server. - time, err := value.MarshalBinary() - if err != nil { - return nil, err - } - param.BytesVal = time + // Send absolute time devoid of time-zone. + t := Datum_Timestamp{Sec: value.Unix(), Nsec: uint32(value.Nanosecond())} + param.TimeVal = &t } params = append(params, param) } @@ -141,6 +137,8 @@ func (c *conn) send(args Request) (*rows, error) { t[j] = datum.BytesVal } else if datum.StringVal != nil { t[j] = []byte(*datum.StringVal) + } else if datum.TimeVal != nil { + t[j] = time.Unix((*datum.TimeVal).Sec, int64((*datum.TimeVal).Nsec)).UTC() } if !driver.IsScanValue(t[j]) { panic(fmt.Sprintf("unsupported type %T returned by database", t[j])) diff --git a/sql/driver/driver_test.go b/sql/driver/driver_test.go index 0d8ff32fc719..64510af2889d 100644 --- a/sql/driver/driver_test.go +++ b/sql/driver/driver_test.go @@ -20,7 +20,9 @@ package driver_test import ( "database/sql" "fmt" + "log" "testing" + "time" "github.com/cockroachdb/cockroach/server" "github.com/cockroachdb/cockroach/util/leaktest" @@ -64,7 +66,11 @@ func asResultSlice(src [][]string) resultSlice { for i, subSlice := range src { result[i] = make([]*string, len(subSlice)) for j := range subSlice { - result[i][j] = &subSlice[j] + if subSlice[j] == "" { + result[i][j] = nil + } else { + result[i][j] = &subSlice[j] + } } } return result @@ -137,69 +143,79 @@ func TestPlaceholders(t *testing.T) { s, db := setup(t) defer cleanup(s, db) + timeVal := time.Date(2015, time.August, 30, 3, 34, 45, 345670000, time.UTC) + intervalVal, err := time.ParseDuration("34h2s") + if err != nil { + log.Fatal(err) + } + if _, err := db.Exec(`CREATE DATABASE t`); err != nil { t.Fatal(err) } - if _, err := db.Exec(`CREATE TABLE t.kv (k CHAR PRIMARY KEY, v CHAR)`); err != nil { + if _, err := db.Exec(`CREATE TABLE t.alltypes (a BIGINT PRIMARY KEY, b FLOAT, c TEXT, d BOOLEAN, e TIMESTAMP, f DATE, g INTERVAL)`); err != nil { t.Fatal(err) } - if _, err := db.Exec(`INSERT INTO t.kv VALUES ($1, $2), ($3, $4)`, "a", "b", "c", nil); err != nil { + // Insert values for all the different types. + if _, err := db.Exec(`INSERT INTO t.alltypes (a, b, c, d, e, f, g) VALUES ($1, $2, $3, $4, $5, $5::DATE, $6::INTERVAL)`, 123, 3.4, "blah", true, timeVal, intervalVal); err != nil { t.Fatal(err) } - - if rows, err := db.Query("SELECT * FROM t.kv"); err != nil { + // Insert a row with NULL values + if _, err := db.Exec(`INSERT INTO t.alltypes (a, b, c, d, e, f, g) VALUES ($1, $2, $3, $4, $5, $6, $7)`, 456, nil, nil, nil, nil, nil, nil); err != nil { + t.Fatal(err) + } + if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE a IN ($1)", 123); err != nil { + t.Fatal(err) + } + if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE b IN ($1)", 3.4); err != nil { + t.Fatal(err) + } + if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE c IN ($1)", "blah"); err != nil { + t.Fatal(err) + } + if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE d IN ($1)", true); err != nil { + t.Fatal(err) + } + if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE e IN ($1)", timeVal); err != nil { + t.Fatal(err) + } + if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE f IN ($1::DATE)", timeVal); err != nil { + t.Fatal(err) + } + if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE g IN ($1::INTERVAL)", intervalVal); err != nil { + t.Fatal(err) + } + if rows, err := db.Query("SELECT * FROM t.alltypes"); err != nil { t.Fatal(err) } else { results := readAll(t, rows) expectedResults := asResultSlice([][]string{ - {"k", "v"}, - {"a", "b"}, - {"c", ""}, + {"a", "b", "c", "d", "e", "f", "g"}, + {"123", "3.4", "blah", "true", "2015-08-30 03:34:45.34567+00:00", "2015-08-30", "34h0m2s"}, + {"456", "", "", "", "", "", ""}, }) - expectedResults[2][1] = nil if err := verifyResults(expectedResults, results); err != nil { t.Fatal(err) } } - - if _, err := db.Exec(`DELETE FROM t.kv WHERE k IN ($1)`, "c"); err != nil { + // Delete a row using a placeholder param. + if _, err := db.Exec(`DELETE FROM t.alltypes WHERE a IN ($1)`, 123); err != nil { t.Fatal(err) } - - if rows, err := db.Query("SELECT * FROM t.kv"); err != nil { + if rows, err := db.Query("SELECT * FROM t.alltypes"); err != nil { t.Fatal(err) } else { results := readAll(t, rows) expectedResults := asResultSlice([][]string{ - {"k", "v"}, - {"a", "b"}, + {"a", "b", "c", "d", "e", "f", "g"}, + {"456", "", "", "", "", "", ""}, }) if err := verifyResults(expectedResults, results); err != nil { t.Fatal(err) } } - - if _, err := db.Exec(`CREATE TABLE t.alltypes (a BIGINT PRIMARY KEY, b FLOAT, c TEXT, d BOOLEAN)`); err != nil { - t.Fatal(err) - } - if _, err := db.Exec(`INSERT INTO t.alltypes (a, b, c, d) VALUES ($1, $2, $3, $4)`, 123, 3.4, "blah", true); err != nil { - t.Fatal(err) - } - if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE a IN ($1)", 123); err != nil { - t.Fatal(err) - } - if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE b IN ($1)", 3.4); err != nil { - t.Fatal(err) - } - if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE c IN ($1)", "blah"); err != nil { - t.Fatal(err) - } - if _, err := db.Query("SELECT a, b FROM t.alltypes WHERE d IN ($1)", true); err != nil { - t.Fatal(err) - } } -func TestinConnectionSettings(t *testing.T) { +func TestConnectionSettings(t *testing.T) { defer leaktest.AfterTest(t) s := server.StartTestServer(nil) url := "https://root@" + s.ServingAddr() + "?certs=test_certs" diff --git a/sql/driver/wire.go b/sql/driver/wire.go index 5589fcf361a7..6ce028dcad89 100644 --- a/sql/driver/wire.go +++ b/sql/driver/wire.go @@ -20,14 +20,15 @@ package driver import ( "fmt" "strconv" - - "github.com/cockroachdb/cockroach/sql/parser" + "time" ) const ( // Endpoint is the URL path prefix which accepts incoming // HTTP requests for the SQL API. Endpoint = "/sql/" + + timestampWithOffsetZoneFormat = "2006-01-02 15:04:05.999999999-07:00" ) func (d Datum) String() string { @@ -48,6 +49,8 @@ func (d Datum) String() string { return string(t) case *string: return *t + case *Datum_Timestamp: + return time.Unix((*t).Sec, int64((*t).Nsec)).UTC().Format(timestampWithOffsetZoneFormat) default: panic(fmt.Sprintf("unexpected type %T", t)) } @@ -62,50 +65,3 @@ func (Request) Method() Method { func (Request) CreateReply() Response { return Response{} } - -// GetParameters returns the Params slice as a `parameters`. -func (r Request) GetParameters() Parameters { - return Parameters(r.Params) -} - -// Parameters implements the parser.Args interface. -type Parameters []Datum - -// Arg implements the parser.Args interface. -func (p Parameters) Arg(name string) (parser.Datum, bool) { - if len(name) == 0 { - // This shouldn't happen unless the parser let through an invalid parameter - // specification. - panic(fmt.Sprintf("invalid empty parameter name")) - } - if ch := name[0]; ch < '0' || ch > '9' { - // TODO(pmattis): Add support for named parameters (vs the numbered - // parameter support below). - return nil, false - } - i, err := strconv.ParseInt(name, 10, 0) - if err != nil { - return nil, false - } - if i < 1 || int(i) > len(p) { - return nil, false - } - arg := p[i-1].GetValue() - if arg == nil { - return parser.DNull, true - } - switch t := arg.(type) { - case *bool: - return parser.DBool(*t), true - case *int64: - return parser.DInt(*t), true - case *float64: - return parser.DFloat(*t), true - case []byte: - return parser.DString(t), true - case *string: - return parser.DString(*t), true - default: - panic(fmt.Sprintf("unexpected type %T", t)) - } -} diff --git a/sql/driver/wire.pb.go b/sql/driver/wire.pb.go index 792ec42d5c54..5d2ad3ae44a8 100644 --- a/sql/driver/wire.pb.go +++ b/sql/driver/wire.pb.go @@ -30,11 +30,12 @@ var _ = proto.Marshal var _ = math.Inf type Datum struct { - BoolVal *bool `protobuf:"varint,1,opt,name=bool_val" json:"bool_val,omitempty"` - IntVal *int64 `protobuf:"varint,2,opt,name=int_val" json:"int_val,omitempty"` - FloatVal *float64 `protobuf:"fixed64,3,opt,name=float_val" json:"float_val,omitempty"` - BytesVal []byte `protobuf:"bytes,4,opt,name=bytes_val" json:"bytes_val,omitempty"` - StringVal *string `protobuf:"bytes,5,opt,name=string_val" json:"string_val,omitempty"` + BoolVal *bool `protobuf:"varint,1,opt,name=bool_val" json:"bool_val,omitempty"` + IntVal *int64 `protobuf:"varint,2,opt,name=int_val" json:"int_val,omitempty"` + FloatVal *float64 `protobuf:"fixed64,3,opt,name=float_val" json:"float_val,omitempty"` + BytesVal []byte `protobuf:"bytes,4,opt,name=bytes_val" json:"bytes_val,omitempty"` + StringVal *string `protobuf:"bytes,5,opt,name=string_val" json:"string_val,omitempty"` + TimeVal *Datum_Timestamp `protobuf:"bytes,6,opt,name=time_val" json:"time_val,omitempty"` } func (m *Datum) Reset() { *m = Datum{} } @@ -75,6 +76,40 @@ func (m *Datum) GetStringVal() string { return "" } +func (m *Datum) GetTimeVal() *Datum_Timestamp { + if m != nil { + return m.TimeVal + } + return nil +} + +// Timestamp represents an absolute timestamp devoid of time-zone. +type Datum_Timestamp struct { + // The time in seconds since, January 1, 1970 UTC (Unix time). + Sec int64 `protobuf:"varint,1,opt,name=sec" json:"sec"` + // nsec specifies a non-negative nanosecond offset within sec. + // It must be in the range [0, 999999999]. + Nsec uint32 `protobuf:"varint,2,opt,name=nsec" json:"nsec"` +} + +func (m *Datum_Timestamp) Reset() { *m = Datum_Timestamp{} } +func (m *Datum_Timestamp) String() string { return proto.CompactTextString(m) } +func (*Datum_Timestamp) ProtoMessage() {} + +func (m *Datum_Timestamp) GetSec() int64 { + if m != nil { + return m.Sec + } + return 0 +} + +func (m *Datum_Timestamp) GetNsec() uint32 { + if m != nil { + return m.Nsec + } + return 0 +} + // A Result is a collection of rows. type Result struct { // Error is non-nil if an error occurred while executing the statement. @@ -249,6 +284,40 @@ func (m *Datum) MarshalTo(data []byte) (int, error) { i = encodeVarintWire(data, i, uint64(len(*m.StringVal))) i += copy(data[i:], *m.StringVal) } + if m.TimeVal != nil { + data[i] = 0x32 + i++ + i = encodeVarintWire(data, i, uint64(m.TimeVal.Size())) + n1, err := m.TimeVal.MarshalTo(data[i:]) + if err != nil { + return 0, err + } + i += n1 + } + return i, nil +} + +func (m *Datum_Timestamp) Marshal() (data []byte, err error) { + size := m.Size() + data = make([]byte, size) + n, err := m.MarshalTo(data) + if err != nil { + return nil, err + } + return data[:n], nil +} + +func (m *Datum_Timestamp) MarshalTo(data []byte) (int, error) { + var i int + _ = i + var l int + _ = l + data[i] = 0x8 + i++ + i = encodeVarintWire(data, i, uint64(m.Sec)) + data[i] = 0x10 + i++ + i = encodeVarintWire(data, i, uint64(m.Nsec)) return i, nil } @@ -271,11 +340,11 @@ func (m *Result) MarshalTo(data []byte) (int, error) { data[i] = 0xa i++ i = encodeVarintWire(data, i, uint64(m.Error.Size())) - n1, err := m.Error.MarshalTo(data[i:]) + n2, err := m.Error.MarshalTo(data[i:]) if err != nil { return 0, err } - i += n1 + i += n2 } if len(m.Columns) > 0 { for _, s := range m.Columns { @@ -464,6 +533,18 @@ func (m *Datum) Size() (n int) { l = len(*m.StringVal) n += 1 + l + sovWire(uint64(l)) } + if m.TimeVal != nil { + l = m.TimeVal.Size() + n += 1 + l + sovWire(uint64(l)) + } + return n +} + +func (m *Datum_Timestamp) Size() (n int) { + var l int + _ = l + n += 1 + sovWire(uint64(m.Sec)) + n += 1 + sovWire(uint64(m.Nsec)) return n } @@ -566,6 +647,9 @@ func (this *Datum) GetValue() interface{} { if this.StringVal != nil { return this.StringVal } + if this.TimeVal != nil { + return this.TimeVal + } return nil } @@ -581,6 +665,8 @@ func (this *Datum) SetValue(value interface{}) bool { this.BytesVal = vt case *string: this.StringVal = vt + case *Datum_Timestamp: + this.TimeVal = vt default: return false } @@ -711,6 +797,113 @@ func (m *Datum) Unmarshal(data []byte) error { s := string(data[iNdEx:postIndex]) m.StringVal = &s iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field TimeVal", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + msglen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthWire + } + postIndex := iNdEx + msglen + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.TimeVal == nil { + m.TimeVal = &Datum_Timestamp{} + } + if err := m.TimeVal.Unmarshal(data[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + var sizeOfWire int + for { + sizeOfWire++ + wire >>= 7 + if wire == 0 { + break + } + } + iNdEx -= sizeOfWire + skippy, err := skipWire(data[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthWire + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + return nil +} +func (m *Datum_Timestamp) Unmarshal(data []byte) error { + l := len(data) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Sec", wireType) + } + m.Sec = 0 + for shift := uint(0); ; shift += 7 { + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + m.Sec |= (int64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Nsec", wireType) + } + m.Nsec = 0 + for shift := uint(0); ; shift += 7 { + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := data[iNdEx] + iNdEx++ + m.Nsec |= (uint32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } default: var sizeOfWire int for { diff --git a/sql/driver/wire.proto b/sql/driver/wire.proto index bdcdebd8d3ca..9cfa905acb7d 100644 --- a/sql/driver/wire.proto +++ b/sql/driver/wire.proto @@ -29,6 +29,15 @@ option (gogoproto.goproto_unrecognized_all) = false; message Datum { option (gogoproto.goproto_stringer) = false; + + // Timestamp represents an absolute timestamp devoid of time-zone. + message Timestamp { + // The time in seconds since, January 1, 1970 UTC (Unix time). + optional int64 sec = 1 [(gogoproto.nullable) = false]; + // nsec specifies a non-negative nanosecond offset within sec. + // It must be in the range [0, 999999999]. + optional uint32 nsec = 2 [(gogoproto.nullable) = false]; + } // Using explicit proto types provides convenient access when using json. If // we used a Kind+Bytes approach the json interface would involve base64 @@ -40,6 +49,10 @@ message Datum { double float_val = 3; bytes bytes_val = 4; string string_val = 5; + Timestamp time_val = 6; + // TODO(vivek): Add additional types like Date and Interval that are + // supported by Cockroach but not supported by the Go sql driver. These + // will be useful for drivers in other languages. } // TODO(pmattis): How to add end-to-end checksumming? Just adding a checksum diff --git a/sql/parser/eval.go b/sql/parser/eval.go index 7b0da05a204d..ab94a48f2b8d 100644 --- a/sql/parser/eval.go +++ b/sql/parser/eval.go @@ -929,6 +929,9 @@ func init() { cmpOps[cmpArgs{In, intType, tupleType}] = evalTupleIN cmpOps[cmpArgs{In, floatType, tupleType}] = evalTupleIN cmpOps[cmpArgs{In, stringType, tupleType}] = evalTupleIN + cmpOps[cmpArgs{In, dateType, tupleType}] = evalTupleIN + cmpOps[cmpArgs{In, timestampType, tupleType}] = evalTupleIN + cmpOps[cmpArgs{In, intervalType, tupleType}] = evalTupleIN cmpOps[cmpArgs{In, tupleType, tupleType}] = evalTupleIN } @@ -1475,6 +1478,10 @@ func evalCastExpr(expr *CastExpr) (Datum, error) { // TODO(vivek): we might consider using the postgres format as well. d, err := time.ParseDuration(string(d.(DString))) return DInterval{Duration: d}, err + + case DInt: + // An integer duration represents a duration in nanoseconds. + return DInterval{Duration: time.Duration(d.(DInt))}, nil } // TODO(pmattis): unimplemented. // case *DecimalType: diff --git a/sql/parser/eval_test.go b/sql/parser/eval_test.go index 3e815e00a69e..b55b322c6b2c 100644 --- a/sql/parser/eval_test.go +++ b/sql/parser/eval_test.go @@ -136,6 +136,9 @@ func TestEvalExpr(t *testing.T) { {`1 NOT IN (2, 3, 4)`, `true`}, {`1+1 IN (2, 3, 4)`, `true`}, {`'a0' IN ('a'||0::char, 'b'||1::char, 'c'||2::char)`, `true`}, + {`'2012-09-21'::date IN ('2012-09-21'::date)`, `true`}, + {`'2010-09-28 12:00:00.1'::timestamp IN ('2010-09-28 12:00:00.1'::timestamp)`, `true`}, + {`'34h'::interval IN ('34h'::interval)`, `true`}, {`(1,2) IN ((0+1,1+1), (3,4), (5,6))`, `true`}, // Func expressions. {`length('hel'||'lo')`, `5`}, @@ -187,6 +190,7 @@ func TestEvalExpr(t *testing.T) { {`'2010-09-28 12:00:00.1-07:00'::timestamp`, `2010-09-28 19:00:00.1+00:00`}, {`('2010-09-28'::date)::timestamp`, `2010-09-28 00:00:00+00:00`}, {`'12h2m1s23ms'::interval`, `12h2m1.023s`}, + {`1::interval`, `1ns`}, {`'2010-09-28'::date + '12h2m'::interval`, `2010-09-28 12:02:00+00:00`}, {`'12h2m'::interval + '2010-09-28'::date`, `2010-09-28 12:02:00+00:00`}, {`'2010-09-28'::date - '12h2m'::interval`, `2010-09-27 11:58:00+00:00`}, diff --git a/sql/parser/type_check.go b/sql/parser/type_check.go index 43225e75b0c9..88f3276fa64f 100644 --- a/sql/parser/type_check.go +++ b/sql/parser/type_check.go @@ -458,7 +458,7 @@ func typeCheckCastExpr(expr *CastExpr) (Datum, error) { case *IntervalType: switch dummyExpr { - case DummyString: + case DummyString, DummyInt: return DummyInterval, nil } diff --git a/sql/parser/type_check_test.go b/sql/parser/type_check_test.go index 870ca60926b1..398cb2645dfa 100644 --- a/sql/parser/type_check_test.go +++ b/sql/parser/type_check_test.go @@ -65,7 +65,6 @@ func TestTypeCheckExprError(t *testing.T) { {`1::decimal`, `invalid cast: int -> DECIMAL`}, {`1::date`, `invalid cast: int -> DATE`}, {`1::timestamp`, `invalid cast: int -> TIMESTAMP`}, - {`1::interval`, `invalid cast: int -> INTERVAL`}, {`CASE 'one' WHEN 1 THEN 1 WHEN 'two' THEN 2 END`, `incompatible condition type`}, {`CASE 1 WHEN 1 THEN 'one' WHEN 2 THEN 2 END`, `incompatible value type`}, {`CASE 1 WHEN 1 THEN 'one' ELSE 2 END`, `incompatible value type`}, diff --git a/sql/server.go b/sql/server.go index d9f3fef19dd6..0f717475dc60 100644 --- a/sql/server.go +++ b/sql/server.go @@ -19,7 +19,10 @@ package sql import ( "errors" + "fmt" "net/http" + "strconv" + "time" "github.com/cockroachdb/cockroach/client" "github.com/cockroachdb/cockroach/proto" @@ -53,7 +56,7 @@ func (s server) execute(args driver.Request) (driver.Response, int, error) { // Send the Request for SQL execution and set the application-level error // for each result in the reply. - reply := s.execStmts(args.Sql, args.GetParameters(), &planMaker) + reply := s.execStmts(args.Sql, parameters(args.Params), &planMaker) // Send back the session state even if there were application-level errors. // Add transaction to session state. @@ -75,7 +78,7 @@ func (s server) execute(args driver.Request) (driver.Response, int, error) { // exec executes the request. Any error encountered is returned; it is // the caller's responsibility to update the response. -func (s server) execStmts(sql string, params driver.Parameters, planMaker *planner) driver.Response { +func (s server) execStmts(sql string, params parameters, planMaker *planner) driver.Response { var resp driver.Response stmts, err := parser.Parse(sql, parser.Syntax(planMaker.session.Syntax)) if err != nil { @@ -94,7 +97,7 @@ func (s server) execStmts(sql string, params driver.Parameters, planMaker *plann return resp } -func (s server) execStmt(stmt parser.Statement, params driver.Parameters, planMaker *planner) (driver.Result, error) { +func (s server) execStmt(stmt parser.Statement, params parameters, planMaker *planner) (driver.Result, error) { var result driver.Result if planMaker.txn == nil { if _, ok := stmt.(*parser.BeginTransaction); ok { @@ -192,3 +195,47 @@ func rollbackTxnAndReturnResultWithError(planMaker *planner, err error) driver.R errProto.SetResponseGoError(err) return driver.Result{Error: &errProto} } + +// parameters implements the parser.Args interface. +type parameters []driver.Datum + +// Arg implements the parser.Args interface. +func (p parameters) Arg(name string) (parser.Datum, bool) { + if len(name) == 0 { + // This shouldn't happen unless the parser let through an invalid parameter + // specification. + panic(fmt.Sprintf("invalid empty parameter name")) + } + if ch := name[0]; ch < '0' || ch > '9' { + // TODO(pmattis): Add support for named parameters (vs the numbered + // parameter support below). + return nil, false + } + i, err := strconv.ParseInt(name, 10, 0) + if err != nil { + return nil, false + } + if i < 1 || int(i) > len(p) { + return nil, false + } + arg := p[i-1].GetValue() + if arg == nil { + return parser.DNull, true + } + switch t := arg.(type) { + case *bool: + return parser.DBool(*t), true + case *int64: + return parser.DInt(*t), true + case *float64: + return parser.DFloat(*t), true + case []byte: + return parser.DString(t), true + case *string: + return parser.DString(*t), true + case *driver.Datum_Timestamp: + return parser.DTimestamp{Time: time.Unix((*t).Sec, int64((*t).Nsec)).UTC()}, true + default: + panic(fmt.Sprintf("unexpected type %T", t)) + } +}