Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql/driver: misc cleanup #2375

Merged
merged 1 commit into from
Sep 6, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 55 additions & 73 deletions sql/driver/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package driver

import (
"database/sql/driver"
"fmt"
"time"
)
import "database/sql/driver"

var _ driver.Conn = &conn{}
var _ driver.Queryer = &conn{}
var _ driver.Execer = &conn{}

// conn implements the sql/driver.Conn interface. Note that conn is assumed to
// be stateful and is not used concurrently by multiple goroutines; See
Expand All @@ -46,70 +46,75 @@ func (c *conn) Begin() (driver.Tx, error) {
}

func (c *conn) Exec(stmt string, args []driver.Value) (driver.Result, error) {
rows, err := c.Query(stmt, args)
result, err := c.internalQuery(stmt, args)
if err != nil {
return nil, err
}
return driver.RowsAffected(len(rows.rows)), nil
return driver.RowsAffected(len(result.Rows)), nil
}

func (c *conn) Query(stmt string, args []driver.Value) (driver.Rows, error) {
result, err := c.internalQuery(stmt, args)
if err != nil {
return nil, err
}

resultRows := &rows{
columns: result.Columns,
rows: make([][]driver.Value, 0, len(result.Rows)),
}
for _, row := range result.Rows {
values := make([]driver.Value, 0, len(row.Values))
for _, datum := range row.Values {
val, err := datum.Value()
if err != nil {
return nil, err
}
values = append(values, val)
}
resultRows.rows = append(resultRows.rows, values)
}

return resultRows, nil
}

func (c *conn) Query(stmt string, args []driver.Value) (*rows, error) {
func (c *conn) internalQuery(stmt string, args []driver.Value) (*Result, error) {
if c.beginTransaction {
stmt = "BEGIN TRANSACTION; " + stmt
c.beginTransaction = false
}
params := make([]Datum, 0, len(args))
dArgs := make([]Datum, 0, len(args))
for _, arg := range args {
var param Datum
switch value := arg.(type) {
case int64:
param.IntVal = &value
case float64:
param.FloatVal = &value
case bool:
param.BoolVal = &value
case []byte:
param.BytesVal = value
case string:
param.StringVal = &value
case time.Time:
// Send absolute time devoid of time-zone.
t := Datum_Timestamp{Sec: value.Unix(), Nsec: uint32(value.Nanosecond())}
param.TimeVal = &t
datum, err := makeDatum(arg)
if err != nil {
return nil, err
}
params = append(params, param)
dArgs = append(dArgs, datum)
}

return c.send(stmt, dArgs)
}

// send sends the statement to the server.
func (c *conn) send(stmt string, dArgs []Datum) (*Result, error) {
args := Request{
Session: c.session,
Sql: stmt,
Params: dArgs,
}
// Forget the session state, and use the one provided in the server
// response for the next request.
session := c.session
c.session = nil
return c.send(Request{
Session: session,
Sql: stmt,
Params: params,
})
}

// send sends the call to the server.
func (c *conn) send(args Request) (*rows, error) {
resp, err := c.sender.Send(args)
if err != nil {
return nil, err
}
// Set the session state even if the server returns an application error.
// The server is responsible for constructing the correct session state
// and sending it back.
if c.session != nil {
panic("connection has lingering session state")
}
c.session = resp.Session
// Translate into rows
r := &rows{}
// Only use the last result to populate the response
index := len(resp.Results) - 1
if index < 0 {
return r, nil
}

// Check for any application errors.
// TODO(vivek): We might want to bunch all errors found here into
// a single error.
Expand All @@ -118,35 +123,12 @@ func (c *conn) send(args Request) (*rows, error) {
return nil, result.Error
}
}
result := resp.Results[index]
r.columns = make([]string, len(result.Columns))
for i, column := range result.Columns {
r.columns[i] = column
}
r.rows = make([]row, len(result.Rows))
for i, p := range result.Rows {
t := make(row, len(p.Values))
for j, datum := range p.Values {
if datum.BoolVal != nil {
t[j] = *datum.BoolVal
} else if datum.IntVal != nil {
t[j] = *datum.IntVal
} else if datum.FloatVal != nil {
t[j] = *datum.FloatVal
} else if datum.BytesVal != nil {
t[j] = datum.BytesVal
} else if datum.StringVal != nil {
t[j] = []byte(*datum.StringVal)
} else if datum.TimeVal != nil {
t[j] = time.Unix((*datum.TimeVal).Sec, int64((*datum.TimeVal).Nsec)).UTC()
}
if !driver.IsScanValue(t[j]) {
panic(fmt.Sprintf("unsupported type %T returned by database", t[j]))
}
}
r.rows[i] = t

// Only use the last result.
if index := len(resp.Results); index != 0 {
return &resp.Results[index-1], nil
}
return r, nil
return nil, nil
}

// Execute all the URL settings against the db to create
Expand Down
8 changes: 4 additions & 4 deletions sql/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ import (
"github.com/cockroachdb/cockroach/util"
)

var _ driver.Driver = roachDriver{}

func init() {
sql.Register("cockroach", &roachDriver{})
sql.Register("cockroach", roachDriver{})
}

// roachDriver implements the database/sql/driver.Driver interface. Named
// roachDriver so as not to conflict with the "driver" package name.
type roachDriver struct{}

var _ driver.Driver = &roachDriver{}

func (d *roachDriver) Open(dsn string) (driver.Conn, error) {
func (roachDriver) Open(dsn string) (driver.Conn, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, err
Expand Down
8 changes: 6 additions & 2 deletions sql/driver/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@

package driver

import "database/sql/driver"

var _ driver.Result = result{}

// TODO(pmattis): Currently unused, but will be needed when we support
// LastInsertId.
type result struct {
lastInsertID int64
rowsAffected int64
}

func (r *result) LastInsertId() (int64, error) {
func (r result) LastInsertId() (int64, error) {
return r.lastInsertID, nil
}

func (r *result) RowsAffected() (int64, error) {
func (r result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
18 changes: 2 additions & 16 deletions sql/driver/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,14 @@ import (
"io"
)

type row []driver.Value
var _ driver.Rows = &rows{}

type rows struct {
columns []string
rows []row
rows [][]driver.Value
pos int // Next iteration index into rows.
}

// newSingleColumnRows returns a rows structure initialized with a single
// column of values using the specified column name and values. This is a
// convenience routine used by operations which return only a single column.
func newSingleColumnRows(column string, vals []string) *rows {
r := make([]row, len(vals))
for i, v := range vals {
r[i] = row{v}
}
return &rows{
columns: []string{column},
rows: r,
}
}

func (r *rows) Columns() []string {
return r.columns
}
Expand Down
10 changes: 6 additions & 4 deletions sql/driver/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,26 @@ package driver

import "database/sql/driver"

var _ driver.Stmt = stmt{}

type stmt struct {
conn *conn
stmt string
}

func (s *stmt) Close() error {
func (stmt) Close() error {
return nil
}

func (s *stmt) NumInput() int {
func (stmt) NumInput() int {
// TODO(pmattis): Count the number of parameters.
return -1
}

func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
return s.conn.Exec(s.stmt, args)
}

func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.conn.Query(s.stmt, args)
}
8 changes: 6 additions & 2 deletions sql/driver/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

package driver

import "database/sql/driver"

var _ driver.Tx = tx{}

type tx struct {
conn *conn
}

func (t *tx) Commit() error {
func (t tx) Commit() error {
_, err := t.conn.Exec("COMMIT TRANSACTION", nil)
return err
}

func (t *tx) Rollback() error {
func (t tx) Rollback() error {
_, err := t.conn.Exec("ROLLBACK TRANSACTION", nil)
return err
}
Loading