diff --git a/sql/driver/conn.go b/sql/driver/conn.go index 6755de196020..c15b5d5ae044 100644 --- a/sql/driver/conn.go +++ b/sql/driver/conn.go @@ -17,11 +17,11 @@ package driver -import ( - "database/sql/driver" - "fmt" - "time" -) +import "database/sql/driver" + +var _ driver.Conn = &conn{} +var _ driver.Queryer = &conn{} +var _ driver.Execer = &conn{} // conn implements the sql/driver.Conn interface. Note that conn is assumed to // be stateful and is not used concurrently by multiple goroutines; See @@ -46,52 +46,66 @@ func (c *conn) Begin() (driver.Tx, error) { } func (c *conn) Exec(stmt string, args []driver.Value) (driver.Result, error) { - rows, err := c.Query(stmt, args) + result, err := c.internalQuery(stmt, args) if err != nil { return nil, err } - return driver.RowsAffected(len(rows.rows)), nil + return driver.RowsAffected(len(result.Rows)), nil +} + +func (c *conn) Query(stmt string, args []driver.Value) (driver.Rows, error) { + result, err := c.internalQuery(stmt, args) + if err != nil { + return nil, err + } + + resultRows := &rows{ + columns: result.Columns, + rows: make([][]driver.Value, 0, len(result.Rows)), + } + for _, row := range result.Rows { + values := make([]driver.Value, 0, len(row.Values)) + for _, datum := range row.Values { + val, err := datum.Value() + if err != nil { + return nil, err + } + values = append(values, val) + } + resultRows.rows = append(resultRows.rows, values) + } + + return resultRows, nil } -func (c *conn) Query(stmt string, args []driver.Value) (*rows, error) { +func (c *conn) internalQuery(stmt string, args []driver.Value) (*Result, error) { if c.beginTransaction { stmt = "BEGIN TRANSACTION; " + stmt c.beginTransaction = false } - params := make([]Datum, 0, len(args)) + dArgs := make([]Datum, 0, len(args)) for _, arg := range args { - var param Datum - switch value := arg.(type) { - case int64: - param.IntVal = &value - case float64: - param.FloatVal = &value - case bool: - param.BoolVal = &value - case []byte: - param.BytesVal = value - case string: - param.StringVal = &value - case time.Time: - // Send absolute time devoid of time-zone. - t := Datum_Timestamp{Sec: value.Unix(), Nsec: uint32(value.Nanosecond())} - param.TimeVal = &t + datum, err := makeDatum(arg) + if err != nil { + return nil, err } - params = append(params, param) + dArgs = append(dArgs, datum) + } + + return c.send(stmt, dArgs) +} + +// send sends the statement to the server. +func (c *conn) send(stmt string, dArgs []Datum) (*Result, error) { + args := Request{ + Session: c.session, + Sql: stmt, + Params: dArgs, } // Forget the session state, and use the one provided in the server // response for the next request. - session := c.session c.session = nil - return c.send(Request{ - Session: session, - Sql: stmt, - Params: params, - }) -} -// send sends the call to the server. -func (c *conn) send(args Request) (*rows, error) { resp, err := c.sender.Send(args) if err != nil { return nil, err @@ -99,17 +113,8 @@ func (c *conn) send(args Request) (*rows, error) { // Set the session state even if the server returns an application error. // The server is responsible for constructing the correct session state // and sending it back. - if c.session != nil { - panic("connection has lingering session state") - } c.session = resp.Session - // Translate into rows - r := &rows{} - // Only use the last result to populate the response - index := len(resp.Results) - 1 - if index < 0 { - return r, nil - } + // Check for any application errors. // TODO(vivek): We might want to bunch all errors found here into // a single error. @@ -118,35 +123,12 @@ func (c *conn) send(args Request) (*rows, error) { return nil, result.Error } } - result := resp.Results[index] - r.columns = make([]string, len(result.Columns)) - for i, column := range result.Columns { - r.columns[i] = column - } - r.rows = make([]row, len(result.Rows)) - for i, p := range result.Rows { - t := make(row, len(p.Values)) - for j, datum := range p.Values { - if datum.BoolVal != nil { - t[j] = *datum.BoolVal - } else if datum.IntVal != nil { - t[j] = *datum.IntVal - } else if datum.FloatVal != nil { - t[j] = *datum.FloatVal - } else if datum.BytesVal != nil { - t[j] = datum.BytesVal - } else if datum.StringVal != nil { - t[j] = []byte(*datum.StringVal) - } else if datum.TimeVal != nil { - t[j] = time.Unix((*datum.TimeVal).Sec, int64((*datum.TimeVal).Nsec)).UTC() - } - if !driver.IsScanValue(t[j]) { - panic(fmt.Sprintf("unsupported type %T returned by database", t[j])) - } - } - r.rows[i] = t + + // Only use the last result. + if index := len(resp.Results); index != 0 { + return &resp.Results[index-1], nil } - return r, nil + return nil, nil } // Execute all the URL settings against the db to create diff --git a/sql/driver/driver.go b/sql/driver/driver.go index 45bf340acf82..dc75509681cc 100644 --- a/sql/driver/driver.go +++ b/sql/driver/driver.go @@ -26,17 +26,17 @@ import ( "github.com/cockroachdb/cockroach/util" ) +var _ driver.Driver = roachDriver{} + func init() { - sql.Register("cockroach", &roachDriver{}) + sql.Register("cockroach", roachDriver{}) } // roachDriver implements the database/sql/driver.Driver interface. Named // roachDriver so as not to conflict with the "driver" package name. type roachDriver struct{} -var _ driver.Driver = &roachDriver{} - -func (d *roachDriver) Open(dsn string) (driver.Conn, error) { +func (roachDriver) Open(dsn string) (driver.Conn, error) { u, err := url.Parse(dsn) if err != nil { return nil, err diff --git a/sql/driver/result.go b/sql/driver/result.go index f9dcd9600142..14393ac003f7 100644 --- a/sql/driver/result.go +++ b/sql/driver/result.go @@ -17,6 +17,10 @@ package driver +import "database/sql/driver" + +var _ driver.Result = result{} + // TODO(pmattis): Currently unused, but will be needed when we support // LastInsertId. type result struct { @@ -24,10 +28,10 @@ type result struct { rowsAffected int64 } -func (r *result) LastInsertId() (int64, error) { +func (r result) LastInsertId() (int64, error) { return r.lastInsertID, nil } -func (r *result) RowsAffected() (int64, error) { +func (r result) RowsAffected() (int64, error) { return r.rowsAffected, nil } diff --git a/sql/driver/rows.go b/sql/driver/rows.go index a0289fba93d1..f220ffaf0acd 100644 --- a/sql/driver/rows.go +++ b/sql/driver/rows.go @@ -22,28 +22,14 @@ import ( "io" ) -type row []driver.Value +var _ driver.Rows = &rows{} type rows struct { columns []string - rows []row + rows [][]driver.Value pos int // Next iteration index into rows. } -// newSingleColumnRows returns a rows structure initialized with a single -// column of values using the specified column name and values. This is a -// convenience routine used by operations which return only a single column. -func newSingleColumnRows(column string, vals []string) *rows { - r := make([]row, len(vals)) - for i, v := range vals { - r[i] = row{v} - } - return &rows{ - columns: []string{column}, - rows: r, - } -} - func (r *rows) Columns() []string { return r.columns } diff --git a/sql/driver/stmt.go b/sql/driver/stmt.go index 6e5a04218d96..f5a8340a0c28 100644 --- a/sql/driver/stmt.go +++ b/sql/driver/stmt.go @@ -19,24 +19,26 @@ package driver import "database/sql/driver" +var _ driver.Stmt = stmt{} + type stmt struct { conn *conn stmt string } -func (s *stmt) Close() error { +func (stmt) Close() error { return nil } -func (s *stmt) NumInput() int { +func (stmt) NumInput() int { // TODO(pmattis): Count the number of parameters. return -1 } -func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { +func (s stmt) Exec(args []driver.Value) (driver.Result, error) { return s.conn.Exec(s.stmt, args) } -func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { +func (s stmt) Query(args []driver.Value) (driver.Rows, error) { return s.conn.Query(s.stmt, args) } diff --git a/sql/driver/tx.go b/sql/driver/tx.go index a9ea2415fa8c..911e01df1444 100644 --- a/sql/driver/tx.go +++ b/sql/driver/tx.go @@ -17,16 +17,20 @@ package driver +import "database/sql/driver" + +var _ driver.Tx = tx{} + type tx struct { conn *conn } -func (t *tx) Commit() error { +func (t tx) Commit() error { _, err := t.conn.Exec("COMMIT TRANSACTION", nil) return err } -func (t *tx) Rollback() error { +func (t tx) Rollback() error { _, err := t.conn.Exec("ROLLBACK TRANSACTION", nil) return err } diff --git a/sql/driver/wire.go b/sql/driver/wire.go index 6ce028dcad89..cace38cbaf1d 100644 --- a/sql/driver/wire.go +++ b/sql/driver/wire.go @@ -18,42 +18,92 @@ package driver import ( + "database/sql/driver" "fmt" - "strconv" "time" + + "github.com/cockroachdb/cockroach/util" ) +var _ driver.Valuer = Datum{} + const ( // Endpoint is the URL path prefix which accepts incoming // HTTP requests for the SQL API. Endpoint = "/sql/" - - timestampWithOffsetZoneFormat = "2006-01-02 15:04:05.999999999-07:00" ) -func (d Datum) String() string { - val := d.GetValue() +func makeDatum(val driver.Value) (Datum, error) { + var datum Datum if val == nil { - return "NULL" + return datum, nil + } + + switch t := val.(type) { + case int64: + datum.IntVal = &t + case float64: + datum.FloatVal = &t + case bool: + datum.BoolVal = &t + case []byte: + datum.BytesVal = t + case string: + datum.StringVal = &t + case time.Time: + // Send absolute time devoid of time-zone. + datum.TimeVal = &Datum_Timestamp{ + Sec: t.Unix(), + Nsec: uint32(t.Nanosecond()), + } + default: + return datum, util.Errorf("unsupported type %T", t) } + return datum, nil +} + +// Value implements the driver.Valuer interface. +func (d Datum) Value() (driver.Value, error) { + val := d.GetValue() + switch t := val.(type) { case *bool: - return strconv.FormatBool(*t) + val = *t case *int64: - return strconv.FormatInt(*t, 10) + val = *t case *float64: - return strconv.FormatFloat(*t, 'g', -1, 64) + val = *t case []byte: - return string(t) + val = t case *string: - return *t + val = *t case *Datum_Timestamp: - return time.Unix((*t).Sec, int64((*t).Nsec)).UTC().Format(timestampWithOffsetZoneFormat) - default: - panic(fmt.Sprintf("unexpected type %T", t)) + val = time.Unix((*t).Sec, int64((*t).Nsec)).UTC() + } + + if driver.IsValue(val) { + return val, nil + } + return nil, util.Errorf("unsupported type %T", val) +} + +func (d Datum) String() string { + v, err := d.Value() + if err != nil { + panic(err) + } + + if v == nil { + return "NULL" } + + if bytes, ok := v.([]byte); ok { + return string(bytes) + } + + return fmt.Sprint(v) } // Method returns the method. diff --git a/sql/driver/wire_test.go b/sql/driver/wire_test.go index de8547d379f5..7186c2bcd86a 100644 --- a/sql/driver/wire_test.go +++ b/sql/driver/wire_test.go @@ -19,47 +19,33 @@ package driver import ( "testing" + "time" "github.com/cockroachdb/cockroach/util/leaktest" ) -func dBool(v bool) Datum { - return Datum{BoolVal: &v} -} - -func dInt(v int64) Datum { - return Datum{IntVal: &v} -} - -func dFloat(v float64) Datum { - return Datum{FloatVal: &v} -} - -func dBytes(v []byte) Datum { - return Datum{BytesVal: v} -} - -func dString(v string) Datum { - return Datum{StringVal: &v} -} - func TestDatumString(t *testing.T) { defer leaktest.AfterTest(t) testData := []struct { - datum Datum + value interface{} expected string }{ - {Datum{}, "NULL"}, - {dBool(false), "false"}, - {dBool(true), "true"}, - {dInt(-2), "-2"}, - {dFloat(4.5), "4.5"}, - {dBytes([]byte("6")), "6"}, - {dString("hello"), "hello"}, + {nil, "NULL"}, + {false, "false"}, + {true, "true"}, + {int64(-2), "-2"}, + {float64(4.5), "4.5"}, + {[]byte("6"), "6"}, + {"hello", "hello"}, + {time.Date(2015, 9, 6, 2, 19, 36, 342, time.UTC), "2015-09-06 02:19:36.000000342 +0000 UTC"}, } for i, d := range testData { - s := d.datum.String() + datum, err := makeDatum(d.value) + if err != nil { + t.Fatal(err) + } + s := datum.String() if d.expected != s { t.Errorf("%d: expected %s, but got %s", i, d.expected, s) }