Skip to content

Commit

Permalink
Refactor duplicated transaction management code
Browse files Browse the repository at this point in the history
The code for transactions is duplicated between SQLite and PostgreSQL.
MySQL would have also used identical code. However, the SQL being
executed is not universal across all backends. Oracle appears to use the
same SQL, but SQL Server has its own special syntax for this. As such,
I'm not comfortable promoting this to a default impl on the trait.
Instead I've moved the code out into a shared trait/struct, and operate
on that instead.

I had wanted to make `TransactionManager` be generic over the backend,
not the connection itself, since constraints for it will always be about
the backend, but I ran into
rust-lang/rust#39532 when attempting to do so.
  • Loading branch information
sgrif committed Feb 4, 2017
1 parent 2c41bf9 commit 972fcb9
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 115 deletions.
2 changes: 2 additions & 0 deletions diesel/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub trait TypeMetadata {

pub trait SupportsReturningClause {}
pub trait SupportsDefaultKeyword {}
pub trait UsesAnsiSavepointSyntax {}

#[derive(Debug, Copy, Clone)]
pub struct Debug;
Expand All @@ -50,3 +51,4 @@ impl TypeMetadata for Debug {

impl SupportsReturningClause for Debug {}
impl SupportsDefaultKeyword for Debug {}
impl UsesAnsiSavepointSyntax for Debug {}
24 changes: 14 additions & 10 deletions diesel/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
mod transaction_manager;

use backend::Backend;
use query_builder::{AsQuery, QueryFragment, QueryId};
use query_source::Queryable;
use result::*;
use types::HasSqlType;

pub use self::transaction_manager::{TransactionManager, AnsiTransactionManager};

pub trait SimpleConnection {
#[doc(hidden)]
fn batch_execute(&self, query: &str) -> QueryResult<()>;
}

pub trait Connection: SimpleConnection + Sized {
type Backend: Backend;
#[doc(hidden)]
type TransactionManager: TransactionManager<Self>;

/// Establishes a new connection to the database at the given URL. The URL
/// should be a valid connection string for a given backend. See the
Expand All @@ -28,14 +34,15 @@ pub trait Connection: SimpleConnection + Sized {
fn transaction<T, E, F>(&self, f: F) -> TransactionResult<T, E> where
F: FnOnce() -> Result<T, E>,
{
try!(self.begin_transaction());
let transaction_manager = self.transaction_manager();
try!(transaction_manager.begin_transaction(self));
match f() {
Ok(value) => {
try!(self.commit_transaction());
try!(transaction_manager.commit_transaction(self));
Ok(value)
},
Err(e) => {
try!(self.rollback_transaction());
try!(transaction_manager.rollback_transaction(self));
Err(TransactionError::UserReturnedError(e))
},
}
Expand All @@ -44,8 +51,9 @@ pub trait Connection: SimpleConnection + Sized {
/// Creates a transaction that will never be committed. This is useful for
/// tests. Panics if called while inside of a transaction.
fn begin_test_transaction(&self) -> QueryResult<()> {
assert_eq!(self.get_transaction_depth(), 0);
self.begin_transaction()
let transaction_manager = self.transaction_manager();
assert_eq!(transaction_manager.get_transaction_depth(), 0);
transaction_manager.begin_transaction(self)
}

/// Executes the given function inside a transaction, but does not commit
Expand Down Expand Up @@ -87,10 +95,6 @@ pub trait Connection: SimpleConnection + Sized {
T: QueryFragment<Self::Backend> + QueryId;

#[doc(hidden)] fn silence_notices<F: FnOnce() -> T, T>(&self, f: F) -> T;
#[doc(hidden)] fn begin_transaction(&self) -> QueryResult<()>;
#[doc(hidden)] fn rollback_transaction(&self) -> QueryResult<()>;
#[doc(hidden)] fn commit_transaction(&self) -> QueryResult<()>;
#[doc(hidden)] fn get_transaction_depth(&self) -> i32;

#[doc(hidden)] fn transaction_manager(&self) -> &Self::TransactionManager;
#[doc(hidden)] fn setup_helper_functions(&self);
}
92 changes: 92 additions & 0 deletions diesel/src/connection/transaction_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use backend::UsesAnsiSavepointSyntax;
use connection::Connection;
use result::QueryResult;

/// Manages the internal transaction state for a connection. You should not
/// interface with this trait unless you are implementing a new connection
/// adapter. You should use [`Connection::transaction`][transaction],
/// [`Connection::test_transaction`][test_transaction], or
/// [`Connection::begin_test_transaction`][begin_test_transaction] instead.
pub trait TransactionManager<Conn: Connection> {
/// Begin a new transaction. If the transaction depth is greater than 0,
/// this should create a savepoint instead. This function is expected to
/// increment the transaction depth by 1.
fn begin_transaction(&self, conn: &Conn) -> QueryResult<()>;

/// Rollback the inner-most transcation. If the transaction depth is greater
/// than 1, this should rollback to the most recent savepoint. This function
/// is expected to decrement the transaction depth by 1.
fn rollback_transaction(&self, conn: &Conn) -> QueryResult<()>;

/// Commit the inner-most transcation. If the transaction depth is greater
/// than 1, this should release the most recent savepoint. This function is
/// expected to decrement the transaction depth by 1.
fn commit_transaction(&self, conn: &Conn) -> QueryResult<()>;

/// Fetch the current transaction depth. Used to ensure that
/// `begin_test_transaction` is not called when already inside of a
/// transaction.
fn get_transaction_depth(&self) -> u32;
}

use std::cell::Cell;

/// An implementation of TransactionManager which can be used for backends
/// which use ANSI standard syntax for savepoints such as SQLite and PostgreSQL.
#[allow(missing_debug_implementations)]
pub struct AnsiTransactionManager {
transaction_depth: Cell<i32>,
}

impl AnsiTransactionManager {
pub fn new() -> Self {
AnsiTransactionManager {
transaction_depth: Cell::new(0),
}
}

fn change_transaction_depth(&self, by: i32, query: QueryResult<()>) -> QueryResult<()> {
if query.is_ok() {
self.transaction_depth.set(self.transaction_depth.get() + by)
}
query
}
}

impl<Conn> TransactionManager<Conn> for AnsiTransactionManager where
Conn: Connection,
Conn::Backend: UsesAnsiSavepointSyntax,
{
fn begin_transaction(&self, conn: &Conn) -> QueryResult<()> {
let transaction_depth = self.transaction_depth.get();
self.change_transaction_depth(1, if transaction_depth == 0 {
conn.batch_execute("BEGIN")
} else {
conn.batch_execute(&format!("SAVEPOINT diesel_savepoint_{}", transaction_depth))
})
}

fn rollback_transaction(&self, conn: &Conn) -> QueryResult<()> {
let transaction_depth = self.transaction_depth.get();
self.change_transaction_depth(-1, if transaction_depth == 1 {
conn.batch_execute("ROLLBACK")
} else {
conn.batch_execute(&format!("ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
transaction_depth - 1))
})
}

fn commit_transaction(&self, conn: &Conn) -> QueryResult<()> {
let transaction_depth = self.transaction_depth.get();
self.change_transaction_depth(-1, if transaction_depth <= 1 {
conn.batch_execute("COMMIT")
} else {
conn.batch_execute(&format!("RELEASE SAVEPOINT diesel_savepoint_{}",
transaction_depth - 1))
})
}

fn get_transaction_depth(&self) -> u32 {
self.transaction_depth.get() as u32
}
}
1 change: 1 addition & 0 deletions diesel/src/mysql/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ impl TypeMetadata for Mysql {

impl SupportsReturningClause for Mysql {}
impl SupportsDefaultKeyword for Mysql {}
impl UsesAnsiSavepointSyntax for Mysql {}

// FIXME: Move this out of this module
use types::HasSqlType;
Expand Down
27 changes: 12 additions & 15 deletions diesel/src/mysql/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod raw;
mod url;

use connection::{Connection, SimpleConnection};
use connection::{Connection, SimpleConnection, AnsiTransactionManager};
use query_builder::*;
use query_source::Queryable;
use result::*;
Expand All @@ -13,6 +13,7 @@ use types::HasSqlType;
#[allow(missing_debug_implementations, missing_copy_implementations)]
pub struct MysqlConnection {
_raw_connection: RawConnection,
transaction_manager: AnsiTransactionManager,
}

impl SimpleConnection for MysqlConnection {
Expand All @@ -23,20 +24,24 @@ impl SimpleConnection for MysqlConnection {

impl Connection for MysqlConnection {
type Backend = Mysql;
type TransactionManager = AnsiTransactionManager;

fn establish(database_url: &str) -> ConnectionResult<Self> {
let raw_connection = RawConnection::new();
let connection_options = try!(ConnectionOptions::parse(database_url));
try!(raw_connection.connect(connection_options));
Ok(MysqlConnection {
_raw_connection: raw_connection,
transaction_manager: AnsiTransactionManager::new(),
})
}

#[doc(hidden)]
fn execute(&self, _query: &str) -> QueryResult<usize> {
unimplemented!()
}

#[doc(hidden)]
fn query_all<T, U>(&self, _source: T) -> QueryResult<Vec<U>> where
T: AsQuery,
T::Query: QueryFragment<Self::Backend> + QueryId,
Expand All @@ -46,30 +51,22 @@ impl Connection for MysqlConnection {
unimplemented!()
}

#[doc(hidden)]
fn silence_notices<F: FnOnce() -> T, T>(&self, _f: F) -> T {
unimplemented!()
}

#[doc(hidden)]
fn execute_returning_count<T>(&self, _source: &T) -> QueryResult<usize> {
unimplemented!()
}

fn begin_transaction(&self) -> QueryResult<()> {
unimplemented!()
}

fn rollback_transaction(&self) -> QueryResult<()> {
unimplemented!()
}

fn commit_transaction(&self) -> QueryResult<()> {
unimplemented!()
}

fn get_transaction_depth(&self) -> i32 {
unimplemented!()
#[doc(hidden)]
fn transaction_manager(&self) -> &Self::TransactionManager {
&self.transaction_manager
}

#[doc(hidden)]
fn setup_helper_functions(&self) {
unimplemented!()
}
Expand Down
1 change: 1 addition & 0 deletions diesel/src/pg/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ impl TypeMetadata for Pg {

impl SupportsReturningClause for Pg {}
impl SupportsDefaultKeyword for Pg {}
impl UsesAnsiSavepointSyntax for Pg {}
51 changes: 6 additions & 45 deletions diesel/src/pg/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ mod row;
pub mod result;
mod stmt;

use std::cell::Cell;
use std::ffi::{CString, CStr};
use std::rc::Rc;

use connection::{SimpleConnection, Connection};
use connection::{SimpleConnection, Connection, AnsiTransactionManager};
use pg::{Pg, PgQueryBuilder};
use query_builder::{AsQuery, QueryFragment, QueryId};
use query_builder::bind_collector::RawBytesBindCollector;
Expand All @@ -29,7 +28,7 @@ use types::HasSqlType;
#[allow(missing_debug_implementations)]
pub struct PgConnection {
raw_connection: RawConnection,
transaction_depth: Cell<i32>,
transaction_manager: AnsiTransactionManager,
statement_cache: StatementCache,
}

Expand All @@ -48,12 +47,13 @@ impl SimpleConnection for PgConnection {

impl Connection for PgConnection {
type Backend = Pg;
type TransactionManager = AnsiTransactionManager;

fn establish(database_url: &str) -> ConnectionResult<PgConnection> {
RawConnection::establish(database_url).map(|raw_conn| {
PgConnection {
raw_connection: raw_conn,
transaction_depth: Cell::new(0),
transaction_manager: AnsiTransactionManager::new(),
statement_cache: StatementCache::new(),
}
})
Expand Down Expand Up @@ -94,40 +94,8 @@ impl Connection for PgConnection {
}

#[doc(hidden)]
fn begin_transaction(&self) -> QueryResult<()> {
let transaction_depth = self.transaction_depth.get();
self.change_transaction_depth(1, if transaction_depth == 0 {
self.execute("BEGIN")
} else {
self.execute(&format!("SAVEPOINT diesel_savepoint_{}", transaction_depth))
})
}

#[doc(hidden)]
fn rollback_transaction(&self) -> QueryResult<()> {
let transaction_depth = self.transaction_depth.get();
self.change_transaction_depth(-1, if transaction_depth == 1 {
self.execute("ROLLBACK")
} else {
self.execute(&format!("ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
transaction_depth - 1))
})
}

#[doc(hidden)]
fn commit_transaction(&self) -> QueryResult<()> {
let transaction_depth = self.transaction_depth.get();
self.change_transaction_depth(-1, if transaction_depth <= 1 {
self.execute("COMMIT")
} else {
self.execute(&format!("RELEASE SAVEPOINT diesel_savepoint_{}",
transaction_depth - 1))
})
}

#[doc(hidden)]
fn get_transaction_depth(&self) -> i32 {
self.transaction_depth.get()
fn transaction_manager(&self) -> &Self::TransactionManager {
&self.transaction_manager
}

#[doc(hidden)]
Expand Down Expand Up @@ -166,13 +134,6 @@ impl PgConnection {
let query = try!(Query::sql(query, None));
query.execute(&self.raw_connection, &Vec::new())
}

fn change_transaction_depth(&self, by: i32, query: QueryResult<usize>) -> QueryResult<()> {
if query.is_ok() {
self.transaction_depth.set(self.transaction_depth.get() + by);
}
query.map(|_| ())
}
}

extern "C" fn noop_notice_processor(_: *mut libc::c_void, _message: *const libc::c_char) {
Expand Down
2 changes: 2 additions & 0 deletions diesel/src/sqlite/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ impl Backend for Sqlite {
impl TypeMetadata for Sqlite {
type TypeMetadata = SqliteType;
}

impl UsesAnsiSavepointSyntax for Sqlite {}
Loading

0 comments on commit 972fcb9

Please sign in to comment.