Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: rollback disallowing statements following schema change #15511

Merged
merged 2 commits into from
Apr 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions pkg/sql/drop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,42 @@ func TestDropTableInterleaved(t *testing.T) {
}
}

func TestDropTableInTxn(t *testing.T) {
defer leaktest.AfterTest(t)()
params, _ := createTestServerParams()
s, sqlDB, _ := serverutils.StartServer(t, params)
defer s.Stopper().Stop(context.TODO())

if _, err := sqlDB.Exec(`
CREATE DATABASE t;
CREATE TABLE t.kv (k CHAR PRIMARY KEY, v CHAR);
`); err != nil {
t.Fatal(err)
}

tx, err := sqlDB.Begin()
if err != nil {
t.Fatal(err)
}

if _, err := tx.Exec(`DROP TABLE t.kv`); err != nil {
t.Fatal(err)
}

// We might still be able to read/write in the table inside this transaction
// until the schema changer runs, but we shouldn't be able to ALTER it.
if _, err := tx.Exec(`ALTER TABLE t.kv ADD COLUMN w CHAR`); !testutils.IsError(err,
`table "kv" is being dropped`) {
t.Fatalf("different error than expected: %v", err)
}

// Can't commit after ALTER errored, so we ROLLBACK.
if err := tx.Rollback(); err != nil {
t.Fatal(err)
}

}

func TestDropAndCreateTable(t *testing.T) {
defer leaktest.AfterTest(t)()
params, _ := createTestServerParams()
Expand Down
20 changes: 0 additions & 20 deletions pkg/sql/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ import (
var errNoTransactionInProgress = errors.New("there is no transaction in progress")
var errStaleMetadata = errors.New("metadata is still stale")
var errTransactionInProgress = errors.New("there is already a transaction in progress")
var errStmtFollowsSchemaChange = errors.New("statement cannot follow a schema change in a transaction")

func errWrongNumberOfPreparedStatements(n int) error {
return pgerror.NewErrorf(pgerror.CodeInvalidPreparedStatementDefinitionError,
Expand Down Expand Up @@ -453,12 +452,6 @@ func (e *Executor) Prepare(
txn.Proto().OrigTimestamp = e.cfg.Clock.Now()
}

if len(session.TxnState.schemaChangers.schemaChangers) > 0 {
if _, ok := stmt.(parser.ValidAfterSchemaUpdateStatement); !ok {
return nil, errStmtFollowsSchemaChange
}
}

planner := session.newPlanner(e, txn)
planner.semaCtx.Placeholders.SetTypes(pinfo)
planner.evalCtx.PrepareOnly = true
Expand Down Expand Up @@ -672,7 +665,6 @@ func (e *Executor) execRequest(

// Track if we are retrying this query, so that we do not double count.
automaticRetryCount := 0
schemaChangerCount := len(txnState.schemaChangers.schemaChangers)
txnClosure := func(ctx context.Context, txn *client.Txn, opt *client.TxnExecOptions) error {
defer func() { automaticRetryCount++ }()
if txnState.State == Open && txnState.txn != txn {
Expand All @@ -681,12 +673,6 @@ func (e *Executor) execRequest(
}
txnState.txn = txn

// Remove all schema changers added by the closure.
if automaticRetryCount > 0 && len(txnState.schemaChangers.schemaChangers) > 0 {
txnState.schemaChangers.schemaChangers =
txnState.schemaChangers.schemaChangers[:schemaChangerCount]
}

if protoTS != nil {
SetTxnTimestamps(txnState.txn, *protoTS)
}
Expand Down Expand Up @@ -1288,12 +1274,6 @@ func (e *Executor) execStmtInOpenTxn(
return Result{PGTag: s.StatementTag()}, nil
}

if len(txnState.schemaChangers.schemaChangers) > 0 {
if _, ok := stmt.(parser.ValidAfterSchemaUpdateStatement); !ok {
return Result{}, errStmtFollowsSchemaChange
}
}

// Create a new planner from the Session to execute the statement.
planner := session.newPlanner(e, txnState.txn)
planner.evalCtx.SetTxnTimestamp(txnState.sqlTimestamp)
Expand Down
45 changes: 0 additions & 45 deletions pkg/sql/parser/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,6 @@ type IndependentFromParallelizedPriors interface {
independentFromParallelizedPriors()
}

// ValidAfterSchemaUpdateStatement is a pseudo-interface to be implemented by
// statements which do not risk conflicting with an earlier schema change
// operation issued previously within the same transaction.
type ValidAfterSchemaUpdateStatement interface {
validAfterSchemaUpdateStatement()
}

// StatementType implements the Statement interface.
func (*AlterTable) StatementType() StatementType { return DDL }

Expand All @@ -114,8 +107,6 @@ func (*BeginTransaction) StatementTag() string { return "BEGIN" }

func (*BeginTransaction) hiddenFromStats() {}

func (*BeginTransaction) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*CommitTransaction) StatementType() StatementType { return Ack }

Expand All @@ -124,8 +115,6 @@ func (*CommitTransaction) StatementTag() string { return "COMMIT" }

func (*CommitTransaction) hiddenFromStats() {}

func (*CommitTransaction) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*CopyFrom) StatementType() StatementType { return CopyIn }

Expand Down Expand Up @@ -161,8 +150,6 @@ func (*CreateUser) StatementType() StatementType { return Ack }
// StatementTag returns a short string identifying the type of statement.
func (*CreateUser) StatementTag() string { return "CREATE USER" }

func (*CreateUser) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*CreateView) StatementType() StatementType { return DDL }

Expand All @@ -183,8 +170,6 @@ func (n *Deallocate) StatementTag() string {

func (*Deallocate) hiddenFromStats() {}

func (*Deallocate) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (n *Delete) StatementType() StatementType { return n.Returning.statementType() }

Expand Down Expand Up @@ -229,8 +214,6 @@ func (*Explain) StatementTag() string { return "EXPLAIN" }

func (*Explain) hiddenFromStats() {}

func (*Explain) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*Grant) StatementType() StatementType { return DDL }

Expand All @@ -239,8 +222,6 @@ func (*Grant) StatementTag() string { return "GRANT" }

func (*Grant) hiddenFromStats() {}

func (*Grant) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (n *Insert) StatementType() StatementType { return n.Returning.statementType() }

Expand All @@ -261,8 +242,6 @@ func (*Prepare) StatementTag() string { return "PREPARE" }

func (*Prepare) hiddenFromStats() {}

func (*Prepare) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ReleaseSavepoint) StatementType() StatementType { return Ack }

Expand All @@ -271,8 +250,6 @@ func (*ReleaseSavepoint) StatementTag() string { return "RELEASE" }

func (*ReleaseSavepoint) hiddenFromStats() {}

func (*ReleaseSavepoint) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*RenameColumn) StatementType() StatementType { return DDL }

Expand Down Expand Up @@ -322,8 +299,6 @@ func (*Revoke) StatementTag() string { return "REVOKE" }

func (*Revoke) hiddenFromStats() {}

func (*Revoke) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*RollbackToSavepoint) StatementType() StatementType { return Ack }

Expand All @@ -332,8 +307,6 @@ func (*RollbackToSavepoint) StatementTag() string { return "ROLLBACK" }

func (*RollbackToSavepoint) hiddenFromStats() {}

func (*RollbackToSavepoint) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*RollbackTransaction) StatementType() StatementType { return Ack }

Expand All @@ -342,8 +315,6 @@ func (*RollbackTransaction) StatementTag() string { return "ROLLBACK" }

func (*RollbackTransaction) hiddenFromStats() {}

func (*RollbackTransaction) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*Savepoint) StatementType() StatementType { return Ack }

Expand Down Expand Up @@ -376,8 +347,6 @@ func (*Set) StatementTag() string { return "SET" }

func (*Set) hiddenFromStats() {}

func (*Set) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*SetTransaction) StatementType() StatementType { return Ack }

Expand All @@ -394,8 +363,6 @@ func (*SetTimeZone) StatementTag() string { return "SET TIME ZONE" }

func (*SetTimeZone) hiddenFromStats() {}

func (*SetTimeZone) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*SetDefaultIsolation) StatementType() StatementType { return Ack }

Expand All @@ -404,8 +371,6 @@ func (*SetDefaultIsolation) StatementTag() string { return "SET" }

func (*SetDefaultIsolation) hiddenFromStats() {}

func (*SetDefaultIsolation) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*Show) StatementType() StatementType { return Rows }

Expand All @@ -415,8 +380,6 @@ func (*Show) StatementTag() string { return "SHOW" }
func (*Show) hiddenFromStats() {}
func (*Show) independentFromParallelizedPriors() {}

func (*Show) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ShowColumns) StatementType() StatementType { return Rows }

Expand Down Expand Up @@ -462,8 +425,6 @@ func (*ShowGrants) StatementTag() string { return "SHOW GRANTS" }
func (*ShowGrants) hiddenFromStats() {}
func (*ShowGrants) independentFromParallelizedPriors() {}

func (*ShowGrants) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ShowIndex) StatementType() StatementType { return Rows }

Expand All @@ -482,8 +443,6 @@ func (*ShowTransactionStatus) StatementTag() string { return "SHOW TRANSACTION S
func (*ShowTransactionStatus) hiddenFromStats() {}
func (*ShowTransactionStatus) independentFromParallelizedPriors() {}

func (*ShowTransactionStatus) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ShowUsers) StatementType() StatementType { return Rows }

Expand All @@ -493,8 +452,6 @@ func (*ShowUsers) StatementTag() string { return "SHOW USERS" }
func (*ShowUsers) hiddenFromStats() {}
func (*ShowUsers) independentFromParallelizedPriors() {}

func (*ShowUsers) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ShowRanges) StatementType() StatementType { return Rows }

Expand All @@ -512,8 +469,6 @@ func (*Help) StatementTag() string { return "HELP" }
func (*Help) hiddenFromStats() {}
func (*Help) independentFromParallelizedPriors() {}

func (*Help) validAfterSchemaUpdateStatement() {}

// StatementType implements the Statement interface.
func (*ShowConstraints) StatementType() StatementType { return Rows }

Expand Down
95 changes: 56 additions & 39 deletions pkg/sql/pgwire/pgwire_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,59 +450,76 @@ func TestPGPrepareWithCreateDropInTxn(t *testing.T) {
}
defer db.Close()

tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}
{
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}

if _, err := tx.Exec(`
if _, err := tx.Exec(`
CREATE DATABASE d;
CREATE TABLE d.kv (k CHAR PRIMARY KEY, v CHAR);
`); err != nil {
t.Fatal(err)
}
t.Fatal(err)
}

stmt, err := tx.Prepare(`INSERT INTO d.kv (k,v) VALUES ($1, $2);`)
if err != nil {
t.Fatal(err)
}
stmt, err := tx.Prepare(`INSERT INTO d.kv (k,v) VALUES ($1, $2);`)
if err != nil {
t.Fatal(err)
}

res, err := stmt.Exec('a', 'b')
if err != nil {
t.Fatal(err)
}
stmt.Close()
affected, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
}
if affected != 1 {
t.Fatalf("unexpected number of rows affected: %d", affected)
}
res, err := stmt.Exec('a', 'b')
if err != nil {
t.Fatal(err)
}
stmt.Close()
affected, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
}
if affected != 1 {
t.Fatalf("unexpected number of rows affected: %d", affected)
}

if err := tx.Commit(); err != nil {
t.Fatal(err)
if err := tx.Commit(); err != nil {
t.Fatal(err)
}
}

tx, err = db.Begin()
if err != nil {
t.Fatal(err)
}
{
tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}

if _, err := tx.Exec(`
if _, err := tx.Exec(`
DROP TABLE d.kv;
`); err != nil {
t.Fatal(err)
}
t.Fatal(err)
}

if _, err := tx.Prepare(`
INSERT INTO d.kv (k,v) VALUES ($1, $2);
`); !testutils.IsError(err, "statement cannot follow a schema change in a transaction") {
t.Fatalf("got err: %s", err)
}
stmt, err := tx.Prepare(`INSERT INTO d.kv (k,v) VALUES ($1, $2);`)
if err != nil {
t.Fatal(err)
}

if err := tx.Rollback(); err != nil {
t.Fatal(err)
// INSERT works because it is using a cached descriptor that is leased.
res, err := stmt.Exec('c', 'd')
if err != nil {
t.Fatal(err)
}
stmt.Close()
affected, err := res.RowsAffected()
if err != nil {
t.Fatal(err)
}
if affected != 1 {
t.Fatalf("unexpected number of rows affected: %d", affected)
}

if err := tx.Commit(); err != nil {
t.Fatal(err)
}
}
}

Expand Down
Loading