Skip to content

Commit

Permalink
Merge pull request #2375 from tamird/sql-clean-up
Browse files Browse the repository at this point in the history
sql/driver: misc cleanup
  • Loading branch information
tamird committed Sep 6, 2015
2 parents 450daae + a6c8f09 commit d4d8ea9
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 144 deletions.
128 changes: 55 additions & 73 deletions sql/driver/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package driver

import (
"database/sql/driver"
"fmt"
"time"
)
import "database/sql/driver"

var _ driver.Conn = &conn{}
var _ driver.Queryer = &conn{}
var _ driver.Execer = &conn{}

// conn implements the sql/driver.Conn interface. Note that conn is assumed to
// be stateful and is not used concurrently by multiple goroutines; See
Expand All @@ -46,70 +46,75 @@ func (c *conn) Begin() (driver.Tx, error) {
}

func (c *conn) Exec(stmt string, args []driver.Value) (driver.Result, error) {
rows, err := c.Query(stmt, args)
result, err := c.internalQuery(stmt, args)
if err != nil {
return nil, err
}
return driver.RowsAffected(len(rows.rows)), nil
return driver.RowsAffected(len(result.Rows)), nil
}

func (c *conn) Query(stmt string, args []driver.Value) (driver.Rows, error) {
result, err := c.internalQuery(stmt, args)
if err != nil {
return nil, err
}

resultRows := &rows{
columns: result.Columns,
rows: make([][]driver.Value, 0, len(result.Rows)),
}
for _, row := range result.Rows {
values := make([]driver.Value, 0, len(row.Values))
for _, datum := range row.Values {
val, err := datum.Value()
if err != nil {
return nil, err
}
values = append(values, val)
}
resultRows.rows = append(resultRows.rows, values)
}

return resultRows, nil
}

func (c *conn) Query(stmt string, args []driver.Value) (*rows, error) {
func (c *conn) internalQuery(stmt string, args []driver.Value) (*Result, error) {
if c.beginTransaction {
stmt = "BEGIN TRANSACTION; " + stmt
c.beginTransaction = false
}
params := make([]Datum, 0, len(args))
dArgs := make([]Datum, 0, len(args))
for _, arg := range args {
var param Datum
switch value := arg.(type) {
case int64:
param.IntVal = &value
case float64:
param.FloatVal = &value
case bool:
param.BoolVal = &value
case []byte:
param.BytesVal = value
case string:
param.StringVal = &value
case time.Time:
// Send absolute time devoid of time-zone.
t := Datum_Timestamp{Sec: value.Unix(), Nsec: uint32(value.Nanosecond())}
param.TimeVal = &t
datum, err := makeDatum(arg)
if err != nil {
return nil, err
}
params = append(params, param)
dArgs = append(dArgs, datum)
}

return c.send(stmt, dArgs)
}

// send sends the statement to the server.
func (c *conn) send(stmt string, dArgs []Datum) (*Result, error) {
args := Request{
Session: c.session,
Sql: stmt,
Params: dArgs,
}
// Forget the session state, and use the one provided in the server
// response for the next request.
session := c.session
c.session = nil
return c.send(Request{
Session: session,
Sql: stmt,
Params: params,
})
}

// send sends the call to the server.
func (c *conn) send(args Request) (*rows, error) {
resp, err := c.sender.Send(args)
if err != nil {
return nil, err
}
// Set the session state even if the server returns an application error.
// The server is responsible for constructing the correct session state
// and sending it back.
if c.session != nil {
panic("connection has lingering session state")
}
c.session = resp.Session
// Translate into rows
r := &rows{}
// Only use the last result to populate the response
index := len(resp.Results) - 1
if index < 0 {
return r, nil
}

// Check for any application errors.
// TODO(vivek): We might want to bunch all errors found here into
// a single error.
Expand All @@ -118,35 +123,12 @@ func (c *conn) send(args Request) (*rows, error) {
return nil, result.Error
}
}
result := resp.Results[index]
r.columns = make([]string, len(result.Columns))
for i, column := range result.Columns {
r.columns[i] = column
}
r.rows = make([]row, len(result.Rows))
for i, p := range result.Rows {
t := make(row, len(p.Values))
for j, datum := range p.Values {
if datum.BoolVal != nil {
t[j] = *datum.BoolVal
} else if datum.IntVal != nil {
t[j] = *datum.IntVal
} else if datum.FloatVal != nil {
t[j] = *datum.FloatVal
} else if datum.BytesVal != nil {
t[j] = datum.BytesVal
} else if datum.StringVal != nil {
t[j] = []byte(*datum.StringVal)
} else if datum.TimeVal != nil {
t[j] = time.Unix((*datum.TimeVal).Sec, int64((*datum.TimeVal).Nsec)).UTC()
}
if !driver.IsScanValue(t[j]) {
panic(fmt.Sprintf("unsupported type %T returned by database", t[j]))
}
}
r.rows[i] = t

// Only use the last result.
if index := len(resp.Results); index != 0 {
return &resp.Results[index-1], nil
}
return r, nil
return nil, nil
}

// Execute all the URL settings against the db to create
Expand Down
8 changes: 4 additions & 4 deletions sql/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ import (
"github.com/cockroachdb/cockroach/util"
)

var _ driver.Driver = roachDriver{}

func init() {
sql.Register("cockroach", &roachDriver{})
sql.Register("cockroach", roachDriver{})
}

// roachDriver implements the database/sql/driver.Driver interface. Named
// roachDriver so as not to conflict with the "driver" package name.
type roachDriver struct{}

var _ driver.Driver = &roachDriver{}

func (d *roachDriver) Open(dsn string) (driver.Conn, error) {
func (roachDriver) Open(dsn string) (driver.Conn, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, err
Expand Down
8 changes: 6 additions & 2 deletions sql/driver/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@

package driver

import "database/sql/driver"

var _ driver.Result = result{}

// TODO(pmattis): Currently unused, but will be needed when we support
// LastInsertId.
type result struct {
lastInsertID int64
rowsAffected int64
}

func (r *result) LastInsertId() (int64, error) {
func (r result) LastInsertId() (int64, error) {
return r.lastInsertID, nil
}

func (r *result) RowsAffected() (int64, error) {
func (r result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
18 changes: 2 additions & 16 deletions sql/driver/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,14 @@ import (
"io"
)

type row []driver.Value
var _ driver.Rows = &rows{}

type rows struct {
columns []string
rows []row
rows [][]driver.Value
pos int // Next iteration index into rows.
}

// newSingleColumnRows returns a rows structure initialized with a single
// column of values using the specified column name and values. This is a
// convenience routine used by operations which return only a single column.
func newSingleColumnRows(column string, vals []string) *rows {
r := make([]row, len(vals))
for i, v := range vals {
r[i] = row{v}
}
return &rows{
columns: []string{column},
rows: r,
}
}

func (r *rows) Columns() []string {
return r.columns
}
Expand Down
10 changes: 6 additions & 4 deletions sql/driver/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,26 @@ package driver

import "database/sql/driver"

var _ driver.Stmt = stmt{}

type stmt struct {
conn *conn
stmt string
}

func (s *stmt) Close() error {
func (stmt) Close() error {
return nil
}

func (s *stmt) NumInput() int {
func (stmt) NumInput() int {
// TODO(pmattis): Count the number of parameters.
return -1
}

func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.conn.Exec(s.stmt, args)
}

func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.conn.Query(s.stmt, args)
}
8 changes: 6 additions & 2 deletions sql/driver/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

package driver

import "database/sql/driver"

var _ driver.Tx = tx{}

type tx struct {
conn *conn
}

func (t *tx) Commit() error {
func (t tx) Commit() error {
_, err := t.conn.Exec("COMMIT TRANSACTION", nil)
return err
}

func (t *tx) Rollback() error {
func (t tx) Rollback() error {
_, err := t.conn.Exec("ROLLBACK TRANSACTION", nil)
return err
}
Loading

0 comments on commit d4d8ea9

Please sign in to comment.