Skip to content

Commit

Permalink
Merge pull request #15511 from vivekmenezes/nakama
Browse files Browse the repository at this point in the history
sql: rollback disallowing statements following schema change
  • Loading branch information
vivekmenezes authored Apr 30, 2017
2 parents 12adb26 + 7d11164 commit 3af1ee6
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 109 deletions.
36 changes: 36 additions & 0 deletions pkg/sql/drop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,42 @@ func TestDropTableInterleaved(t *testing.T) {
}
}

func TestDropTableInTxn(t *testing.T) {
defer leaktest.AfterTest(t)()
params, _ := createTestServerParams()
s, sqlDB, _ := serverutils.StartServer(t, params)
defer s.Stopper().Stop(context.TODO())

if _, err := sqlDB.Exec(`
CREATE DATABASE t;
CREATE TABLE t.kv (k CHAR PRIMARY KEY, v CHAR);
`); err != nil {
t.Fatal(err)
}

tx, err := sqlDB.Begin()
if err != nil {
t.Fatal(err)
}

if _, err := tx.Exec(`DROP TABLE t.kv`); err != nil {
t.Fatal(err)
}

// We might still be able to read/write in the table inside this transaction
// until the schema changer runs, but we shouldn't be able to ALTER it.
if _, err := tx.Exec(`ALTER TABLE t.kv ADD COLUMN w CHAR`); !testutils.IsError(err,
`table "kv" is being dropped`) {
t.Fatalf("different error than expected: %v", err)
}

// Can't commit after ALTER errored, so we ROLLBACK.
if err := tx.Rollback(); err != nil {
t.Fatal(err)
}

}

func TestDropAndCreateTable(t *testing.T) {
defer leaktest.AfterTest(t)()
params, _ := createTestServerParams()
Expand Down
20 changes: 0 additions & 20 deletions pkg/sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ import (
var errNoTransactionInProgress = errors.New("there is no transaction in progress")
var errStaleMetadata = errors.New("metadata is still stale")
var errTransactionInProgress = errors.New("there is already a transaction in progress")
var errStmtFollowsSchemaChange = errors.New("statement cannot follow a schema change in a transaction")

func errWrongNumberOfPreparedStatements(n int) error {
return pgerror.NewErrorf(pgerror.CodeInvalidPreparedStatementDefinitionError,
Expand Down Expand Up @@ -453,12 +452,6 @@ func (e *Executor) Prepare(
txn.Proto().OrigTimestamp = e.cfg.Clock.Now()
}

if len(session.TxnState.schemaChangers.schemaChangers) > 0 {
if _, ok := stmt.(parser.ValidAfterSchemaUpdateStatement); !ok {
return nil, errStmtFollowsSchemaChange
}
}

planner := session.newPlanner(e, txn)
planner.semaCtx.Placeholders.SetTypes(pinfo)
planner.evalCtx.PrepareOnly = true
Expand Down Expand Up @@ -672,7 +665,6 @@ func (e *Executor) execRequest(

// Track if we are retrying this query, so that we do not double count.
automaticRetryCount := 0
schemaChangerCount := len(txnState.schemaChangers.schemaChangers)
txnClosure := func(ctx context.Context, txn *client.Txn, opt *client.TxnExecOptions) error {
defer func() { automaticRetryCount++ }()
if txnState.State == Open && txnState.txn != txn {
Expand All @@ -681,12 +673,6 @@ func (e *Executor) execRequest(
}
txnState.txn = txn

// Remove all schema changers added by the closure.
if automaticRetryCount > 0 && len(txnState.schemaChangers.schemaChangers) > 0 {
txnState.schemaChangers.schemaChangers =
txnState.schemaChangers.schemaChangers[:schemaChangerCount]
}

if protoTS != nil {
SetTxnTimestamps(txnState.txn, *protoTS)
}
Expand Down Expand Up @@ -1288,12 +1274,6 @@ func (e *Executor) execStmtInOpenTxn(
return Result{PGTag: s.StatementTag()}, nil
}

if len(txnState.schemaChangers.schemaChangers) > 0 {
if _, ok := stmt.(parser.ValidAfterSchemaUpdateStatement); !ok {
return Result{}, errStmtFollowsSchemaChange
}
}

// Create a new planner from the Session to execute the statement.
planner := session.newPlanner(e, txnState.txn)
planner.evalCtx.SetTxnTimestamp(txnState.sqlTimestamp)
Expand Down
45 changes: 0 additions & 45 deletions pkg/sql/parser/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,6 @@ type IndependentFromParallelizedPriors interface {
independentFromParallelizedPriors()
}

// ValidAfterSchemaUpdateStatement is a pseudo-interface to be implemented by
// statements which do not risk conflicting with an earlier schema change
// operation issued previously within the same transaction.
type ValidAfterSchemaUpdateStatement interface {
validAfterSchemaUpdateStatement()
}

// StatementType implements the Statement interface.
func (*AlterTable) StatementType() StatementType { return DDL }

Expand All @@ -114,8 +107,6 @@ func (*BeginTransaction) StatementTag() string { return "BEGIN" }

func (*BeginTransaction) hiddenFromStats() {}

func (*BeginTransaction) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*CommitTransaction) StatementType() StatementType { return Ack }

Expand All @@ -124,8 +115,6 @@ func (*CommitTransaction) StatementTag() string { return "COMMIT" }

func (*CommitTransaction) hiddenFromStats() {}

func (*CommitTransaction) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*CopyFrom) StatementType() StatementType { return CopyIn }

Expand Down Expand Up @@ -161,8 +150,6 @@ func (*CreateUser) StatementType() StatementType { return Ack }
// StatementTag returns a short string identifying the type of statement.
func (*CreateUser) StatementTag() string { return "CREATE USER" }

func (*CreateUser) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*CreateView) StatementType() StatementType { return DDL }

Expand All @@ -183,8 +170,6 @@ func (n *Deallocate) StatementTag() string {

func (*Deallocate) hiddenFromStats() {}

func (*Deallocate) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (n *Delete) StatementType() StatementType { return n.Returning.statementType() }

Expand Down Expand Up @@ -229,8 +214,6 @@ func (*Explain) StatementTag() string { return "EXPLAIN" }

func (*Explain) hiddenFromStats() {}

func (*Explain) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*Grant) StatementType() StatementType { return DDL }

Expand All @@ -239,8 +222,6 @@ func (*Grant) StatementTag() string { return "GRANT" }

func (*Grant) hiddenFromStats() {}

func (*Grant) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (n *Insert) StatementType() StatementType { return n.Returning.statementType() }

Expand All @@ -261,8 +242,6 @@ func (*Prepare) StatementTag() string { return "PREPARE" }

func (*Prepare) hiddenFromStats() {}

func (*Prepare) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ReleaseSavepoint) StatementType() StatementType { return Ack }

Expand All @@ -271,8 +250,6 @@ func (*ReleaseSavepoint) StatementTag() string { return "RELEASE" }

func (*ReleaseSavepoint) hiddenFromStats() {}

func (*ReleaseSavepoint) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*RenameColumn) StatementType() StatementType { return DDL }

Expand Down Expand Up @@ -322,8 +299,6 @@ func (*Revoke) StatementTag() string { return "REVOKE" }

func (*Revoke) hiddenFromStats() {}

func (*Revoke) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*RollbackToSavepoint) StatementType() StatementType { return Ack }

Expand All @@ -332,8 +307,6 @@ func (*RollbackToSavepoint) StatementTag() string { return "ROLLBACK" }

func (*RollbackToSavepoint) hiddenFromStats() {}

func (*RollbackToSavepoint) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*RollbackTransaction) StatementType() StatementType { return Ack }

Expand All @@ -342,8 +315,6 @@ func (*RollbackTransaction) StatementTag() string { return "ROLLBACK" }

func (*RollbackTransaction) hiddenFromStats() {}

func (*RollbackTransaction) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*Savepoint) StatementType() StatementType { return Ack }

Expand Down Expand Up @@ -376,8 +347,6 @@ func (*Set) StatementTag() string { return "SET" }

func (*Set) hiddenFromStats() {}

func (*Set) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*SetTransaction) StatementType() StatementType { return Ack }

Expand All @@ -394,8 +363,6 @@ func (*SetTimeZone) StatementTag() string { return "SET TIME ZONE" }

func (*SetTimeZone) hiddenFromStats() {}

func (*SetTimeZone) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*SetDefaultIsolation) StatementType() StatementType { return Ack }

Expand All @@ -404,8 +371,6 @@ func (*SetDefaultIsolation) StatementTag() string { return "SET" }

func (*SetDefaultIsolation) hiddenFromStats() {}

func (*SetDefaultIsolation) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*Show) StatementType() StatementType { return Rows }

Expand All @@ -415,8 +380,6 @@ func (*Show) StatementTag() string { return "SHOW" }
func (*Show) hiddenFromStats() {}
func (*Show) independentFromParallelizedPriors() {}

func (*Show) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ShowColumns) StatementType() StatementType { return Rows }

Expand Down Expand Up @@ -462,8 +425,6 @@ func (*ShowGrants) StatementTag() string { return "SHOW GRANTS" }
func (*ShowGrants) hiddenFromStats() {}
func (*ShowGrants) independentFromParallelizedPriors() {}

func (*ShowGrants) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ShowIndex) StatementType() StatementType { return Rows }

Expand All @@ -482,8 +443,6 @@ func (*ShowTransactionStatus) StatementTag() string { return "SHOW TRANSACTION S
func (*ShowTransactionStatus) hiddenFromStats() {}
func (*ShowTransactionStatus) independentFromParallelizedPriors() {}

func (*ShowTransactionStatus) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ShowUsers) StatementType() StatementType { return Rows }

Expand All @@ -493,8 +452,6 @@ func (*ShowUsers) StatementTag() string { return "SHOW USERS" }
func (*ShowUsers) hiddenFromStats() {}
func (*ShowUsers) independentFromParallelizedPriors() {}

func (*ShowUsers) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ShowRanges) StatementType() StatementType { return Rows }

Expand All @@ -512,8 +469,6 @@ func (*Help) StatementTag() string { return "HELP" }
func (*Help) hiddenFromStats() {}
func (*Help) independentFromParallelizedPriors() {}

func (*Help) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ShowConstraints) StatementType() StatementType { return Rows }

Expand Down
95 changes: 56 additions & 39 deletions pkg/sql/pgwire/pgwire_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,59 +450,76 @@ func TestPGPrepareWithCreateDropInTxn(t *testing.T) {
}
defer db.Close()

tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
{
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}

if _, err := tx.Exec(`
if _, err := tx.Exec(`
CREATE DATABASE d;
CREATE TABLE d.kv (k CHAR PRIMARY KEY, v CHAR);
`); err != nil {
t.Fatal(err)
}
t.Fatal(err)
}

stmt, err := tx.Prepare(`INSERT INTO d.kv (k,v) VALUES ($1, $2);`)
if err != nil {
t.Fatal(err)
}
stmt, err := tx.Prepare(`INSERT INTO d.kv (k,v) VALUES ($1, $2);`)
if err != nil {
t.Fatal(err)
}

res, err := stmt.Exec('a', 'b')
if err != nil {
t.Fatal(err)
}
stmt.Close()
affected, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
}
if affected != 1 {
t.Fatalf("unexpected number of rows affected: %d", affected)
}
res, err := stmt.Exec('a', 'b')
if err != nil {
t.Fatal(err)
}
stmt.Close()
affected, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
}
if affected != 1 {
t.Fatalf("unexpected number of rows affected: %d", affected)
}

if err := tx.Commit(); err != nil {
t.Fatal(err)
if err := tx.Commit(); err != nil {
t.Fatal(err)
}
}

tx, err = db.Begin()
if err != nil {
t.Fatal(err)
}
{
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}

if _, err := tx.Exec(`
if _, err := tx.Exec(`
DROP TABLE d.kv;
`); err != nil {
t.Fatal(err)
}
t.Fatal(err)
}

if _, err := tx.Prepare(`
INSERT INTO d.kv (k,v) VALUES ($1, $2);
`); !testutils.IsError(err, "statement cannot follow a schema change in a transaction") {
t.Fatalf("got err: %s", err)
}
stmt, err := tx.Prepare(`INSERT INTO d.kv (k,v) VALUES ($1, $2);`)
if err != nil {
t.Fatal(err)
}

if err := tx.Rollback(); err != nil {
t.Fatal(err)
// INSERT works because it is using a cached descriptor that is leased.
res, err := stmt.Exec('c', 'd')
if err != nil {
t.Fatal(err)
}
stmt.Close()
affected, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
}
if affected != 1 {
t.Fatalf("unexpected number of rows affected: %d", affected)
}

if err := tx.Commit(); err != nil {
t.Fatal(err)
}
}
}

Expand Down
Loading

0 comments on commit 3af1ee6

Please sign in to comment.