diff --git a/diesel/src/backend.rs b/diesel/src/backend.rs index b5f6a02d28ad..b5595e501fc5 100644 --- a/diesel/src/backend.rs +++ b/diesel/src/backend.rs @@ -27,6 +27,7 @@ pub trait TypeMetadata { pub trait SupportsReturningClause {} pub trait SupportsDefaultKeyword {} +pub trait UsesAnsiSavepointSyntax {} #[derive(Debug, Copy, Clone)] pub struct Debug; @@ -50,3 +51,4 @@ impl TypeMetadata for Debug { impl SupportsReturningClause for Debug {} impl SupportsDefaultKeyword for Debug {} +impl UsesAnsiSavepointSyntax for Debug {} diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index 9fcd3ce7c9a1..42c48625501c 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -1,9 +1,13 @@ +mod transaction_manager; + use backend::Backend; use query_builder::{AsQuery, QueryFragment, QueryId}; use query_source::Queryable; use result::*; use types::HasSqlType; +pub use self::transaction_manager::{TransactionManager, AnsiTransactionManager}; + pub trait SimpleConnection { #[doc(hidden)] fn batch_execute(&self, query: &str) -> QueryResult<()>; @@ -11,6 +15,8 @@ pub trait SimpleConnection { pub trait Connection: SimpleConnection + Sized { type Backend: Backend; + #[doc(hidden)] + type TransactionManager: TransactionManager; /// Establishes a new connection to the database at the given URL. The URL /// should be a valid connection string for a given backend. See the @@ -28,14 +34,15 @@ pub trait Connection: SimpleConnection + Sized { fn transaction(&self, f: F) -> TransactionResult where F: FnOnce() -> Result, { - try!(self.begin_transaction()); + let transaction_manager = self.transaction_manager(); + try!(transaction_manager.begin_transaction(self)); match f() { Ok(value) => { - try!(self.commit_transaction()); + try!(transaction_manager.commit_transaction(self)); Ok(value) }, Err(e) => { - try!(self.rollback_transaction()); + try!(transaction_manager.rollback_transaction(self)); Err(TransactionError::UserReturnedError(e)) }, } @@ -44,8 +51,9 @@ pub trait Connection: SimpleConnection + Sized { /// Creates a transaction that will never be committed. This is useful for /// tests. Panics if called while inside of a transaction. fn begin_test_transaction(&self) -> QueryResult<()> { - assert_eq!(self.get_transaction_depth(), 0); - self.begin_transaction() + let transaction_manager = self.transaction_manager(); + assert_eq!(transaction_manager.get_transaction_depth(), 0); + transaction_manager.begin_transaction(self) } /// Executes the given function inside a transaction, but does not commit @@ -87,10 +95,6 @@ pub trait Connection: SimpleConnection + Sized { T: QueryFragment + QueryId; #[doc(hidden)] fn silence_notices T, T>(&self, f: F) -> T; - #[doc(hidden)] fn begin_transaction(&self) -> QueryResult<()>; - #[doc(hidden)] fn rollback_transaction(&self) -> QueryResult<()>; - #[doc(hidden)] fn commit_transaction(&self) -> QueryResult<()>; - #[doc(hidden)] fn get_transaction_depth(&self) -> i32; - + #[doc(hidden)] fn transaction_manager(&self) -> &Self::TransactionManager; #[doc(hidden)] fn setup_helper_functions(&self); } diff --git a/diesel/src/connection/transaction_manager.rs b/diesel/src/connection/transaction_manager.rs new file mode 100644 index 000000000000..9db94e0addfa --- /dev/null +++ b/diesel/src/connection/transaction_manager.rs @@ -0,0 +1,92 @@ +use backend::UsesAnsiSavepointSyntax; +use connection::Connection; +use result::QueryResult; + +/// Manages the internal transaction state for a connection. You should not +/// interface with this trait unless you are implementing a new connection +/// adapter. You should use [`Connection::transaction`][transaction], +/// [`Connection::test_transaction`][test_transaction], or +/// [`Connection::begin_test_transaction`][begin_test_transaction] instead. +pub trait TransactionManager { + /// Begin a new transaction. If the transaction depth is greater than 0, + /// this should create a savepoint instead. This function is expected to + /// increment the transaction depth by 1. + fn begin_transaction(&self, conn: &Conn) -> QueryResult<()>; + + /// Rollback the inner-most transcation. If the transaction depth is greater + /// than 1, this should rollback to the most recent savepoint. This function + /// is expected to decrement the transaction depth by 1. + fn rollback_transaction(&self, conn: &Conn) -> QueryResult<()>; + + /// Commit the inner-most transcation. If the transaction depth is greater + /// than 1, this should release the most recent savepoint. This function is + /// expected to decrement the transaction depth by 1. + fn commit_transaction(&self, conn: &Conn) -> QueryResult<()>; + + /// Fetch the current transaction depth. Used to ensure that + /// `begin_test_transaction` is not called when already inside of a + /// transaction. + fn get_transaction_depth(&self) -> u32; +} + +use std::cell::Cell; + +/// An implementation of TransactionManager which can be used for backends +/// which use ANSI standard syntax for savepoints such as SQLite and PostgreSQL. +#[allow(missing_debug_implementations)] +pub struct AnsiTransactionManager { + transaction_depth: Cell, +} + +impl AnsiTransactionManager { + pub fn new() -> Self { + AnsiTransactionManager { + transaction_depth: Cell::new(0), + } + } + + fn change_transaction_depth(&self, by: i32, query: QueryResult<()>) -> QueryResult<()> { + if query.is_ok() { + self.transaction_depth.set(self.transaction_depth.get() + by) + } + query + } +} + +impl TransactionManager for AnsiTransactionManager where + Conn: Connection, + Conn::Backend: UsesAnsiSavepointSyntax, +{ + fn begin_transaction(&self, conn: &Conn) -> QueryResult<()> { + let transaction_depth = self.transaction_depth.get(); + self.change_transaction_depth(1, if transaction_depth == 0 { + conn.batch_execute("BEGIN") + } else { + conn.batch_execute(&format!("SAVEPOINT diesel_savepoint_{}", transaction_depth)) + }) + } + + fn rollback_transaction(&self, conn: &Conn) -> QueryResult<()> { + let transaction_depth = self.transaction_depth.get(); + self.change_transaction_depth(-1, if transaction_depth == 1 { + conn.batch_execute("ROLLBACK") + } else { + conn.batch_execute(&format!("ROLLBACK TO SAVEPOINT diesel_savepoint_{}", + transaction_depth - 1)) + }) + } + + fn commit_transaction(&self, conn: &Conn) -> QueryResult<()> { + let transaction_depth = self.transaction_depth.get(); + self.change_transaction_depth(-1, if transaction_depth <= 1 { + conn.batch_execute("COMMIT") + } else { + conn.batch_execute(&format!("RELEASE SAVEPOINT diesel_savepoint_{}", + transaction_depth - 1)) + }) + } + + fn get_transaction_depth(&self) -> u32 { + self.transaction_depth.get() as u32 + } +} diff --git a/diesel/src/mysql/backend.rs b/diesel/src/mysql/backend.rs index e4a368eb75ee..2068ae5ddb45 100644 --- a/diesel/src/mysql/backend.rs +++ b/diesel/src/mysql/backend.rs @@ -39,6 +39,7 @@ impl TypeMetadata for Mysql { impl SupportsReturningClause for Mysql {} impl SupportsDefaultKeyword for Mysql {} +impl UsesAnsiSavepointSyntax for Mysql {} // FIXME: Move this out of this module use types::HasSqlType; diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index e2f5f4a47f47..a57927acc799 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -1,7 +1,7 @@ mod raw; mod url; -use connection::{Connection, SimpleConnection}; +use connection::{Connection, SimpleConnection, AnsiTransactionManager}; use query_builder::*; use query_source::Queryable; use result::*; @@ -13,6 +13,7 @@ use types::HasSqlType; #[allow(missing_debug_implementations, missing_copy_implementations)] pub struct MysqlConnection { _raw_connection: RawConnection, + transaction_manager: AnsiTransactionManager, } impl SimpleConnection for MysqlConnection { @@ -23,6 +24,7 @@ impl SimpleConnection for MysqlConnection { impl Connection for MysqlConnection { type Backend = Mysql; + type TransactionManager = AnsiTransactionManager; fn establish(database_url: &str) -> ConnectionResult { let raw_connection = RawConnection::new(); @@ -30,13 +32,16 @@ impl Connection for MysqlConnection { try!(raw_connection.connect(connection_options)); Ok(MysqlConnection { _raw_connection: raw_connection, + transaction_manager: AnsiTransactionManager::new(), }) } + #[doc(hidden)] fn execute(&self, _query: &str) -> QueryResult { unimplemented!() } + #[doc(hidden)] fn query_all(&self, _source: T) -> QueryResult> where T: AsQuery, T::Query: QueryFragment + QueryId, @@ -46,30 +51,22 @@ impl Connection for MysqlConnection { unimplemented!() } + #[doc(hidden)] fn silence_notices T, T>(&self, _f: F) -> T { unimplemented!() } + #[doc(hidden)] fn execute_returning_count(&self, _source: &T) -> QueryResult { unimplemented!() } - fn begin_transaction(&self) -> QueryResult<()> { - unimplemented!() - } - - fn rollback_transaction(&self) -> QueryResult<()> { - unimplemented!() - } - - fn commit_transaction(&self) -> QueryResult<()> { - unimplemented!() - } - - fn get_transaction_depth(&self) -> i32 { - unimplemented!() + #[doc(hidden)] + fn transaction_manager(&self) -> &Self::TransactionManager { + &self.transaction_manager } + #[doc(hidden)] fn setup_helper_functions(&self) { unimplemented!() } diff --git a/diesel/src/pg/backend.rs b/diesel/src/pg/backend.rs index 69b8030e0deb..9599fd99cbd0 100644 --- a/diesel/src/pg/backend.rs +++ b/diesel/src/pg/backend.rs @@ -23,3 +23,4 @@ impl TypeMetadata for Pg { impl SupportsReturningClause for Pg {} impl SupportsDefaultKeyword for Pg {} +impl UsesAnsiSavepointSyntax for Pg {} diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index 4a73ae5153f6..9c3570846317 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -7,11 +7,10 @@ mod row; pub mod result; mod stmt; -use std::cell::Cell; use std::ffi::{CString, CStr}; use std::rc::Rc; -use connection::{SimpleConnection, Connection}; +use connection::{SimpleConnection, Connection, AnsiTransactionManager}; use pg::{Pg, PgQueryBuilder}; use query_builder::{AsQuery, QueryFragment, QueryId}; use query_builder::bind_collector::RawBytesBindCollector; @@ -29,7 +28,7 @@ use types::HasSqlType; #[allow(missing_debug_implementations)] pub struct PgConnection { raw_connection: RawConnection, - transaction_depth: Cell, + transaction_manager: AnsiTransactionManager, statement_cache: StatementCache, } @@ -48,12 +47,13 @@ impl SimpleConnection for PgConnection { impl Connection for PgConnection { type Backend = Pg; + type TransactionManager = AnsiTransactionManager; fn establish(database_url: &str) -> ConnectionResult { RawConnection::establish(database_url).map(|raw_conn| { PgConnection { raw_connection: raw_conn, - transaction_depth: Cell::new(0), + transaction_manager: AnsiTransactionManager::new(), statement_cache: StatementCache::new(), } }) @@ -94,40 +94,8 @@ impl Connection for PgConnection { } #[doc(hidden)] - fn begin_transaction(&self) -> QueryResult<()> { - let transaction_depth = self.transaction_depth.get(); - self.change_transaction_depth(1, if transaction_depth == 0 { - self.execute("BEGIN") - } else { - self.execute(&format!("SAVEPOINT diesel_savepoint_{}", transaction_depth)) - }) - } - - #[doc(hidden)] - fn rollback_transaction(&self) -> QueryResult<()> { - let transaction_depth = self.transaction_depth.get(); - self.change_transaction_depth(-1, if transaction_depth == 1 { - self.execute("ROLLBACK") - } else { - self.execute(&format!("ROLLBACK TO SAVEPOINT diesel_savepoint_{}", - transaction_depth - 1)) - }) - } - - #[doc(hidden)] - fn commit_transaction(&self) -> QueryResult<()> { - let transaction_depth = self.transaction_depth.get(); - self.change_transaction_depth(-1, if transaction_depth <= 1 { - self.execute("COMMIT") - } else { - self.execute(&format!("RELEASE SAVEPOINT diesel_savepoint_{}", - transaction_depth - 1)) - }) - } - - #[doc(hidden)] - fn get_transaction_depth(&self) -> i32 { - self.transaction_depth.get() + fn transaction_manager(&self) -> &Self::TransactionManager { + &self.transaction_manager } #[doc(hidden)] @@ -166,13 +134,6 @@ impl PgConnection { let query = try!(Query::sql(query, None)); query.execute(&self.raw_connection, &Vec::new()) } - - fn change_transaction_depth(&self, by: i32, query: QueryResult) -> QueryResult<()> { - if query.is_ok() { - self.transaction_depth.set(self.transaction_depth.get() + by); - } - query.map(|_| ()) - } } extern "C" fn noop_notice_processor(_: *mut libc::c_void, _message: *const libc::c_char) { diff --git a/diesel/src/sqlite/backend.rs b/diesel/src/sqlite/backend.rs index 72d776e7f8a9..e58a7699932b 100644 --- a/diesel/src/sqlite/backend.rs +++ b/diesel/src/sqlite/backend.rs @@ -26,3 +26,5 @@ impl Backend for Sqlite { impl TypeMetadata for Sqlite { type TypeMetadata = SqliteType; } + +impl UsesAnsiSavepointSyntax for Sqlite {} diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 4ca3de47368f..47a76fa189ad 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -11,11 +11,11 @@ pub use self::sqlite_value::SqliteValue; use std::any::TypeId; use std::borrow::Cow; -use std::cell::{Cell, RefCell}; +use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; -use connection::{SimpleConnection, Connection}; +use connection::{SimpleConnection, Connection, AnsiTransactionManager}; use query_builder::*; use query_builder::bind_collector::RawBytesBindCollector; use query_source::*; @@ -32,7 +32,7 @@ use types::HasSqlType; pub struct SqliteConnection { statement_cache: RefCell>, raw_connection: Rc, - transaction_depth: Cell, + transaction_manager: AnsiTransactionManager, } #[derive(Hash, PartialEq, Eq)] @@ -54,13 +54,14 @@ impl SimpleConnection for SqliteConnection { impl Connection for SqliteConnection { type Backend = Sqlite; + type TransactionManager = AnsiTransactionManager; fn establish(database_url: &str) -> ConnectionResult { RawConnection::establish(database_url).map(|conn| { SqliteConnection { statement_cache: RefCell::new(HashMap::new()), raw_connection: Rc::new(conn), - transaction_depth: Cell::new(0), + transaction_manager: AnsiTransactionManager::new(), } }) } @@ -98,40 +99,8 @@ impl Connection for SqliteConnection { } #[doc(hidden)] - fn begin_transaction(&self) -> QueryResult<()> { - let transaction_depth = self.transaction_depth.get(); - self.change_transaction_depth(1, if transaction_depth == 0 { - self.execute("BEGIN") - } else { - self.execute(&format!("SAVEPOINT diesel_savepoint_{}", transaction_depth)) - }) - } - - #[doc(hidden)] - fn rollback_transaction(&self) -> QueryResult<()> { - let transaction_depth = self.transaction_depth.get(); - self.change_transaction_depth(-1, if transaction_depth == 1 { - self.execute("ROLLBACK") - } else { - self.execute(&format!("ROLLBACK TO SAVEPOINT diesel_savepoint_{}", - transaction_depth - 1)) - }) - } - - #[doc(hidden)] - fn commit_transaction(&self) -> QueryResult<()> { - let transaction_depth = self.transaction_depth.get(); - self.change_transaction_depth(-1, if transaction_depth <= 1 { - self.execute("COMMIT") - } else { - self.execute(&format!("RELEASE SAVEPOINT diesel_savepoint_{}", - transaction_depth - 1)) - }) - } - - #[doc(hidden)] - fn get_transaction_depth(&self) -> i32 { - self.transaction_depth.get() + fn transaction_manager(&self) -> &Self::TransactionManager { + &self.transaction_manager } #[doc(hidden)] @@ -156,13 +125,6 @@ impl SqliteConnection { Ok(result) } - fn change_transaction_depth(&self, by: i32, query: QueryResult) -> QueryResult<()> { - if query.is_ok() { - self.transaction_depth.set(self.transaction_depth.get() + by); - } - query.map(|_| ()) - } - fn cached_prepared_statement + QueryId>(&self, source: &T) -> QueryResult {