diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java index 882f709966..0057bb15be 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AsyncTransactionManagerImpl.java @@ -28,6 +28,7 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; +import com.google.protobuf.ByteString; /** Implementation of {@link AsyncTransactionManager}. */ final class AsyncTransactionManagerImpl @@ -80,7 +81,19 @@ public TransactionContextFutureImpl beginAsync() { private ApiFuture internalBeginAsync(boolean firstAttempt) { txnState = TransactionState.STARTED; - txn = session.newTransaction(options); + + // Determine the latest transactionId when using a multiplexed session. + ByteString multiplexedSessionPreviousTransactionId = ByteString.EMPTY; + if (txn != null && session.getIsMultiplexed() && !firstAttempt) { + // Use the current transactionId if available, otherwise fallback to the previous aborted + // transactionId. + multiplexedSessionPreviousTransactionId = + txn.transactionId != null ? txn.transactionId : txn.getPreviousTransactionId(); + } + + txn = + session.newTransaction( + options, /* previousTransactionId = */ multiplexedSessionPreviousTransactionId); if (firstAttempt) { session.setActive(this); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java index b9d1ce054d..91edce7932 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java @@ -59,7 +59,7 @@ class DatabaseClientImpl implements DatabaseClient { /* useMultiplexedSessionBlindWrite = */ false, /* multiplexedSessionDatabaseClient = */ null, tracer, - false); + /* useMultiplexedSessionForRW = */ false); } DatabaseClientImpl( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index 7b9abc71a8..60c9d45d18 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -69,7 +69,8 @@ static void throwIfTransactionsPending() { } } - static TransactionOptions createReadWriteTransactionOptions(Options options) { + static TransactionOptions createReadWriteTransactionOptions( + Options options, ByteString previousTransactionId) { TransactionOptions.Builder transactionOptions = TransactionOptions.newBuilder(); if (options.withExcludeTxnFromChangeStreams() == Boolean.TRUE) { transactionOptions.setExcludeTxnFromChangeStreams(true); @@ -78,6 +79,10 @@ static TransactionOptions createReadWriteTransactionOptions(Options options) { if (options.withOptimisticLock() == Boolean.TRUE) { readWrite.setReadLockMode(TransactionOptions.ReadWrite.ReadLockMode.OPTIMISTIC); } + if (previousTransactionId != null + && previousTransactionId != com.google.protobuf.ByteString.EMPTY) { + readWrite.setMultiplexedSessionPreviousTransactionId(previousTransactionId); + } transactionOptions.setReadWrite(readWrite); return transactionOptions.build(); } @@ -427,13 +432,17 @@ public void close() { } ApiFuture beginTransactionAsync( - Options transactionOptions, boolean routeToLeader, Map channelHint) { + Options transactionOptions, + boolean routeToLeader, + Map channelHint, + ByteString previousTransactionId) { final SettableApiFuture res = SettableApiFuture.create(); final ISpan span = tracer.spanBuilder(SpannerImpl.BEGIN_TRANSACTION); final BeginTransactionRequest request = BeginTransactionRequest.newBuilder() .setSession(getName()) - .setOptions(createReadWriteTransactionOptions(transactionOptions)) + .setOptions( + createReadWriteTransactionOptions(transactionOptions, previousTransactionId)) .build(); final ApiFuture requestFuture; try (IScope ignore = tracer.withSpan(span)) { @@ -469,11 +478,12 @@ ApiFuture beginTransactionAsync( return res; } - TransactionContextImpl newTransaction(Options options) { + TransactionContextImpl newTransaction(Options options, ByteString previousTransactionId) { return TransactionContextImpl.newBuilder() .setSession(this) .setOptions(options) .setTransactionId(null) + .setPreviousTransactionId(previousTransactionId) .setOptions(options) .setTrackTransactionStarter(spanner.getOptions().isTrackTransactionStarter()) .setRpc(spanner.getRpc()) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java index bbbc9aeb44..cafb27ba6b 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionManagerImpl.java @@ -20,6 +20,7 @@ import com.google.cloud.spanner.Options.TransactionOption; import com.google.cloud.spanner.SessionImpl.SessionTransaction; import com.google.common.base.Preconditions; +import com.google.protobuf.ByteString; /** Implementation of {@link TransactionManager}. */ final class TransactionManagerImpl implements TransactionManager, SessionTransaction { @@ -53,7 +54,7 @@ public void setSpan(ISpan span) { public TransactionContext begin() { Preconditions.checkState(txn == null, "begin can only be called once"); try (IScope s = tracer.withSpan(span)) { - txn = session.newTransaction(options); + txn = session.newTransaction(options, /* previousTransactionId = */ ByteString.EMPTY); session.setActive(this); txnState = TransactionState.STARTED; return txn; @@ -102,7 +103,18 @@ public TransactionContext resetForRetry() { } try (IScope s = tracer.withSpan(span)) { boolean useInlinedBegin = txn.transactionId != null; - txn = session.newTransaction(options); + + // Determine the latest transactionId when using a multiplexed session. + ByteString multiplexedSessionPreviousTransactionId = ByteString.EMPTY; + if (session.getIsMultiplexed()) { + // Use the current transactionId if available, otherwise fallback to the previous aborted + // transactionId. + multiplexedSessionPreviousTransactionId = + txn.transactionId != null ? txn.transactionId : txn.getPreviousTransactionId(); + } + txn = + session.newTransaction( + options, /* previousTransactionId = */ multiplexedSessionPreviousTransactionId); if (!useInlinedBegin) { txn.ensureTxn(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java index a7250e1ef7..48affde355 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionRunnerImpl.java @@ -93,6 +93,9 @@ static class Builder extends AbstractReadContext.Builder ensureTxnAsync() { private void createTxnAsync(final SettableApiFuture res) { span.addAnnotation("Creating Transaction"); final ApiFuture fut = - session.beginTransactionAsync(options, isRouteToLeader(), getTransactionChannelHint()); + session.beginTransactionAsync( + options, isRouteToLeader(), getTransactionChannelHint(), getPreviousTransactionId()); fut.addListener( () -> { try { @@ -558,7 +574,9 @@ TransactionSelector getTransactionSelector() { } if (tx == null) { return TransactionSelector.newBuilder() - .setBegin(SessionImpl.createReadWriteTransactionOptions(options)) + .setBegin( + SessionImpl.createReadWriteTransactionOptions( + options, getPreviousTransactionId())) .build(); } else { // Wait for the transaction to come available. The tx.get() call will fail with an @@ -1079,7 +1097,7 @@ public TransactionRunner allowNestedTransaction() { TransactionRunnerImpl(SessionImpl session, TransactionOption... options) { this.session = session; this.options = Options.fromTransactionOptions(options); - this.txn = session.newTransaction(this.options); + this.txn = session.newTransaction(this.options, /* previousTransactionId = */ ByteString.EMPTY); this.tracer = session.getTracer(); } @@ -1118,7 +1136,19 @@ private T runInternal(final TransactionCallable txCallable) { // Do not inline the BeginTransaction during a retry if the initial attempt did not // actually start a transaction. useInlinedBegin = txn.transactionId != null; - txn = session.newTransaction(options); + + // Determine the latest transactionId when using a multiplexed session. + ByteString multiplexedSessionPreviousTransactionId = ByteString.EMPTY; + if (session.getIsMultiplexed()) { + // Use the current transactionId if available, otherwise fallback to the previous + // transactionId. + multiplexedSessionPreviousTransactionId = + txn.transactionId != null ? txn.transactionId : txn.getPreviousTransactionId(); + } + + txn = + session.newTransaction( + options, /* previousTransactionId = */ multiplexedSessionPreviousTransactionId); } checkState( isValid, "TransactionRunner has been invalidated by a new operation on the session"); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerImplTest.java index 08d22dd2d6..006a926e90 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/AsyncTransactionManagerImplTest.java @@ -16,12 +16,18 @@ package com.google.cloud.spanner; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.api.core.ApiFutures; import com.google.cloud.Timestamp; +import com.google.protobuf.ByteString; import io.opentelemetry.api.trace.Span; import io.opentelemetry.context.Scope; import org.junit.Test; @@ -42,7 +48,7 @@ public void testCommitReturnsCommitStats() { when(oTspan.makeCurrent()).thenReturn(mock(Scope.class)); try (AsyncTransactionManagerImpl manager = new AsyncTransactionManagerImpl(session, span, Options.commitStats())) { - when(session.newTransaction(Options.fromTransactionOptions(Options.commitStats()))) + when(session.newTransaction(eq(Options.fromTransactionOptions(Options.commitStats())), any())) .thenReturn(transaction); when(transaction.ensureTxnAsync()).thenReturn(ApiFutures.immediateFuture(null)); Timestamp commitTimestamp = Timestamp.ofTimeMicroseconds(1); @@ -54,4 +60,67 @@ public void testCommitReturnsCommitStats() { verify(transaction).commitAsync(); } } + + @Test + public void testRetryUsesPreviousTransactionIdOnMultiplexedSession() { + // Set up mock transaction IDs + final ByteString mockTransactionId = ByteString.copyFromUtf8("mockTransactionId"); + final ByteString mockPreviousTransactionId = + ByteString.copyFromUtf8("mockPreviousTransactionId"); + + Span oTspan = mock(Span.class); + ISpan span = new OpenTelemetrySpan(oTspan); + when(oTspan.makeCurrent()).thenReturn(mock(Scope.class)); + // Mark the session as multiplexed. + when(session.getIsMultiplexed()).thenReturn(true); + + // Initialize a mock transaction with transactionId = null, previousTransactionId = null. + transaction = mock(TransactionRunnerImpl.TransactionContextImpl.class); + when(transaction.ensureTxnAsync()).thenReturn(ApiFutures.immediateFuture(null)); + when(session.newTransaction(eq(Options.fromTransactionOptions(Options.commitStats())), any())) + .thenReturn(transaction); + + // Simulate an ABORTED error being thrown when `commitAsync()` is called. + doThrow(SpannerExceptionFactory.newSpannerException(ErrorCode.ABORTED, "")) + .when(transaction) + .commitAsync(); + + try (AsyncTransactionManagerImpl manager = + new AsyncTransactionManagerImpl(session, span, Options.commitStats())) { + manager.beginAsync(); + + // Verify that for the first transaction attempt, the `previousTransactionId` is + // ByteString.EMPTY. + // This is because no transaction has been previously aborted at this point. + verify(session) + .newTransaction(Options.fromTransactionOptions(Options.commitStats()), ByteString.EMPTY); + assertThrows(AbortedException.class, manager::commitAsync); + clearInvocations(session); + + // Mock the transaction object to contain transactionID=null and + // previousTransactionId=mockPreviousTransactionId + when(transaction.getPreviousTransactionId()).thenReturn(mockPreviousTransactionId); + manager.resetForRetryAsync(); + // Verify that in the first retry attempt, the `previousTransactionId` + // (mockPreviousTransactionId) is passed to the new transaction. + // This allows Spanner to retry the transaction using the ID of the aborted transaction. + verify(session) + .newTransaction( + Options.fromTransactionOptions(Options.commitStats()), mockPreviousTransactionId); + assertThrows(AbortedException.class, manager::commitAsync); + clearInvocations(session); + + // Mock the transaction object to contain transactionID=mockTransactionId and + // previousTransactionId=mockPreviousTransactionId and transactionID = null + transaction.transactionId = mockTransactionId; + manager.resetForRetryAsync(); + // Verify that the latest `transactionId` (mockTransactionId) is used in the retry. + // This ensures the retry logic is working as expected with the latest transaction ID. + verify(session) + .newTransaction(Options.fromTransactionOptions(Options.commitStats()), mockTransactionId); + + when(transaction.rollbackAsync()).thenReturn(ApiFutures.immediateFuture(null)); + manager.closeAsync(); + } + } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionDatabaseClientMockServerTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionDatabaseClientMockServerTest.java index 2e41253788..adf7ed2a40 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionDatabaseClientMockServerTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MultiplexedSessionDatabaseClientMockServerTest.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner; +import static com.google.cloud.spanner.MockSpannerTestUtil.INVALID_UPDATE_STATEMENT; import static com.google.cloud.spanner.MockSpannerTestUtil.UPDATE_COUNT; import static com.google.cloud.spanner.MockSpannerTestUtil.UPDATE_STATEMENT; import static com.google.common.truth.Truth.assertThat; @@ -36,11 +37,13 @@ import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.Options.RpcPriority; +import com.google.cloud.spanner.TransactionRunnerImpl.TransactionContextImpl; import com.google.cloud.spanner.connection.RandomResultSetGenerator; import com.google.common.base.Stopwatch; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.ByteString; +import com.google.spanner.v1.BeginTransactionRequest; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.RequestOptions.Priority; @@ -54,6 +57,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import org.junit.Before; import org.junit.BeforeClass; @@ -70,6 +74,10 @@ public static void setupResults() { mockSpanner.putStatementResults( StatementResult.query(STATEMENT, new RandomResultSetGenerator(1).generate())); mockSpanner.putStatementResult(StatementResult.update(UPDATE_STATEMENT, UPDATE_COUNT)); + mockSpanner.putStatementResult( + StatementResult.exception( + INVALID_UPDATE_STATEMENT, + Status.INVALID_ARGUMENT.withDescription("invalid statement").asRuntimeException())); } @Before @@ -745,6 +753,179 @@ public void testAsyncRunnerIsNonBlockingWithMultiplexedSession() throws Exceptio assertEquals(1L, client.multiplexedSessionDatabaseClient.getNumSessionsReleased().get()); } + @Test + public void testAbortedReadWriteTxnUsesPreviousTxnIdOnRetryWithInlineBegin() { + DatabaseClientImpl client = + (DatabaseClientImpl) spanner.getDatabaseClient(DatabaseId.of("p", "i", "d")); + // Force the Commit RPC to return Aborted the first time it is called. The exception is cleared + // after the first call, so the retry should succeed. + mockSpanner.setCommitExecutionTime( + SimulatedExecutionTime.ofException( + mockSpanner.createAbortedException(ByteString.copyFromUtf8("test")))); + TransactionRunner runner = client.readWriteTransaction(); + AtomicReference validTransactionId = new AtomicReference<>(); + runner.run( + transaction -> { + try (ResultSet resultSet = transaction.executeQuery(STATEMENT)) { + while (resultSet.next()) {} + } + + TransactionContextImpl impl = (TransactionContextImpl) transaction; + if (validTransactionId.get() == null) { + // Track the first not-null transactionId. This transaction gets ABORTED during commit + // operation and gets retried. + validTransactionId.set(impl.transactionId); + } + return null; + }); + + List executeSqlRequests = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class); + assertEquals(2, executeSqlRequests.size()); + + // Verify the requests are executed using multiplexed sessions + for (ExecuteSqlRequest request : executeSqlRequests) { + assertTrue(mockSpanner.getSession(request.getSession()).getMultiplexed()); + } + + // Verify that the first request uses inline begin, and the previous transaction ID is set to + // ByteString.EMPTY + assertTrue(executeSqlRequests.get(0).hasTransaction()); + assertTrue(executeSqlRequests.get(0).getTransaction().hasBegin()); + assertTrue(executeSqlRequests.get(0).getTransaction().getBegin().hasReadWrite()); + assertNotNull( + executeSqlRequests + .get(0) + .getTransaction() + .getBegin() + .getReadWrite() + .getMultiplexedSessionPreviousTransactionId()); + assertEquals( + ByteString.EMPTY, + executeSqlRequests + .get(0) + .getTransaction() + .getBegin() + .getReadWrite() + .getMultiplexedSessionPreviousTransactionId()); + + // Verify that the second request uses inline begin, and the previous transaction ID is set + // appropriately + assertTrue(executeSqlRequests.get(1).hasTransaction()); + assertTrue(executeSqlRequests.get(1).getTransaction().hasBegin()); + assertTrue(executeSqlRequests.get(1).getTransaction().getBegin().hasReadWrite()); + assertNotNull( + executeSqlRequests + .get(1) + .getTransaction() + .getBegin() + .getReadWrite() + .getMultiplexedSessionPreviousTransactionId()); + assertNotEquals( + ByteString.EMPTY, + executeSqlRequests + .get(1) + .getTransaction() + .getBegin() + .getReadWrite() + .getMultiplexedSessionPreviousTransactionId()); + assertEquals( + validTransactionId.get(), + executeSqlRequests + .get(1) + .getTransaction() + .getBegin() + .getReadWrite() + .getMultiplexedSessionPreviousTransactionId()); + } + + @Test + public void testAbortedReadWriteTxnUsesPreviousTxnIdOnRetryWithExplicitBegin() { + DatabaseClientImpl client = + (DatabaseClientImpl) spanner.getDatabaseClient(DatabaseId.of("p", "i", "d")); + // Force the Commit RPC to return Aborted the first time it is called. The exception is cleared + // after the first call, so the retry should succeed. + mockSpanner.setCommitExecutionTime( + SimulatedExecutionTime.ofException( + mockSpanner.createAbortedException(ByteString.copyFromUtf8("test")))); + TransactionRunner runner = client.readWriteTransaction(); + AtomicReference validTransactionId = new AtomicReference<>(); + Long updateCount = + runner.run( + transaction -> { + // This update statement carries the BeginTransaction, but fails. This will + // cause the entire transaction to be retried with an explicit + // BeginTransaction RPC to ensure all statements in the transaction are + // actually executed against the same transaction. + TransactionContextImpl impl = (TransactionContextImpl) transaction; + if (validTransactionId.get() == null) { + // Track the first not-null transactionId. This transaction gets ABORTED during + // commit operation and gets retried. + validTransactionId.set(impl.transactionId); + } + SpannerException e = + assertThrows( + SpannerException.class, + () -> transaction.executeUpdate(INVALID_UPDATE_STATEMENT)); + assertEquals(ErrorCode.INVALID_ARGUMENT, e.getErrorCode()); + return transaction.executeUpdate(UPDATE_STATEMENT); + }); + + assertThat(updateCount).isEqualTo(1L); + List beginTransactionRequests = + mockSpanner.getRequestsOfType(BeginTransactionRequest.class); + assertEquals(2, beginTransactionRequests.size()); + + // Verify the requests are executed using multiplexed sessions + for (BeginTransactionRequest request : beginTransactionRequests) { + assertTrue(mockSpanner.getSession(request.getSession()).getMultiplexed()); + } + + // Verify that explicit begin transaction is called during retry, and the previous transaction + // ID is set to ByteString.EMPTY + assertTrue(beginTransactionRequests.get(0).hasOptions()); + assertTrue(beginTransactionRequests.get(0).getOptions().hasReadWrite()); + assertNotNull( + beginTransactionRequests + .get(0) + .getOptions() + .getReadWrite() + .getMultiplexedSessionPreviousTransactionId()); + assertEquals( + ByteString.EMPTY, + beginTransactionRequests + .get(0) + .getOptions() + .getReadWrite() + .getMultiplexedSessionPreviousTransactionId()); + + // The previous transaction with id (txn1) fails during commit operation with ABORTED error. + // Verify that explicit begin transaction is called during retry, and the previous transaction + // ID is not ByteString.EMPTY (should be set to txn1) + assertTrue(beginTransactionRequests.get(1).hasOptions()); + assertTrue(beginTransactionRequests.get(1).getOptions().hasReadWrite()); + assertNotNull( + beginTransactionRequests + .get(1) + .getOptions() + .getReadWrite() + .getMultiplexedSessionPreviousTransactionId()); + assertNotEquals( + ByteString.EMPTY, + beginTransactionRequests + .get(1) + .getOptions() + .getReadWrite() + .getMultiplexedSessionPreviousTransactionId()); + assertEquals( + validTransactionId.get(), + beginTransactionRequests + .get(1) + .getOptions() + .getReadWrite() + .getMultiplexedSessionPreviousTransactionId()); + } + private void waitForSessionToBeReplaced(DatabaseClientImpl client) { assertNotNull(client.multiplexedSessionDatabaseClient); SessionReference sessionReference = diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java index 998678e429..c3e8d887de 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java @@ -1495,9 +1495,10 @@ public void testSessionNotFoundReadWriteTransaction() { .build(); when(closedSession.asyncClose()) .thenReturn(ApiFutures.immediateFuture(Empty.getDefaultInstance())); - when(closedSession.newTransaction(Options.fromTransactionOptions())) + when(closedSession.newTransaction(eq(Options.fromTransactionOptions()), any())) .thenReturn(closedTransactionContext); - when(closedSession.beginTransactionAsync(any(), eq(true), any())).thenThrow(sessionNotFound); + when(closedSession.beginTransactionAsync(any(), eq(true), any(), any())) + .thenThrow(sessionNotFound); when(closedSession.getTracer()).thenReturn(tracer); TransactionRunnerImpl closedTransactionRunner = new TransactionRunnerImpl(closedSession); closedTransactionRunner.setSpan(span); @@ -1510,9 +1511,9 @@ public void testSessionNotFoundReadWriteTransaction() { when(openSession.getName()) .thenReturn("projects/dummy/instances/dummy/database/dummy/sessions/session-open"); final TransactionContextImpl openTransactionContext = mock(TransactionContextImpl.class); - when(openSession.newTransaction(Options.fromTransactionOptions())) + when(openSession.newTransaction(eq(Options.fromTransactionOptions()), any())) .thenReturn(openTransactionContext); - when(openSession.beginTransactionAsync(any(), eq(true), any())) + when(openSession.beginTransactionAsync(any(), eq(true), any(), any())) .thenReturn(ApiFutures.immediateFuture(ByteString.copyFromUtf8("open-txn"))); when(openSession.getTracer()).thenReturn(tracer); TransactionRunnerImpl openTransactionRunner = new TransactionRunnerImpl(openSession); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionManagerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionManagerImplTest.java index c3fcf1c748..10b1312515 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionManagerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionManagerImplTest.java @@ -20,7 +20,9 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; @@ -98,7 +100,7 @@ public void setUp() { @Test public void beginCalledTwiceFails() { - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(txn); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())).thenReturn(txn); assertThat(manager.begin()).isEqualTo(txn); assertThat(manager.getState()).isEqualTo(TransactionState.STARTED); IllegalStateException e = assertThrows(IllegalStateException.class, () -> manager.begin()); @@ -126,7 +128,7 @@ public void resetBeforeBeginFails() { @Test public void transactionRolledBackOnClose() { - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(txn); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())).thenReturn(txn); when(txn.isAborted()).thenReturn(false); manager.begin(); manager.close(); @@ -135,7 +137,7 @@ public void transactionRolledBackOnClose() { @Test public void commitSucceeds() { - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(txn); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())).thenReturn(txn); Timestamp commitTimestamp = Timestamp.ofTimeMicroseconds(1); CommitResponse response = new CommitResponse(commitTimestamp); when(txn.getCommitResponse()).thenReturn(response); @@ -147,7 +149,7 @@ public void commitSucceeds() { @Test public void resetAfterSuccessfulCommitFails() { - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(txn); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())).thenReturn(txn); manager.begin(); manager.commit(); IllegalStateException e = @@ -157,21 +159,21 @@ public void resetAfterSuccessfulCommitFails() { @Test public void resetAfterAbortSucceeds() { - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(txn); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())).thenReturn(txn); manager.begin(); doThrow(SpannerExceptionFactory.newSpannerException(ErrorCode.ABORTED, "")).when(txn).commit(); assertThrows(AbortedException.class, () -> manager.commit()); assertEquals(TransactionState.ABORTED, manager.getState()); txn = Mockito.mock(TransactionRunnerImpl.TransactionContextImpl.class); - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(txn); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())).thenReturn(txn); assertThat(manager.resetForRetry()).isEqualTo(txn); assertThat(manager.getState()).isEqualTo(TransactionState.STARTED); } @Test public void resetAfterErrorFails() { - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(txn); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())).thenReturn(txn); manager.begin(); doThrow(SpannerExceptionFactory.newSpannerException(ErrorCode.UNKNOWN, "")).when(txn).commit(); SpannerException e = assertThrows(SpannerException.class, () -> manager.commit()); @@ -184,7 +186,7 @@ public void resetAfterErrorFails() { @Test public void rollbackAfterCommitFails() { - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(txn); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())).thenReturn(txn); manager.begin(); manager.commit(); IllegalStateException e = assertThrows(IllegalStateException.class, () -> manager.rollback()); @@ -193,7 +195,7 @@ public void rollbackAfterCommitFails() { @Test public void commitAfterRollbackFails() { - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(txn); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())).thenReturn(txn); manager.begin(); manager.rollback(); IllegalStateException e = assertThrows(IllegalStateException.class, () -> manager.commit()); @@ -363,4 +365,61 @@ public void inlineBegin() { assertThat(transactionsStarted.get()).isEqualTo(1); } } + + // This test ensures that when a transaction is aborted in a multiplexed session, + // the transaction ID of the aborted transaction is saved during the retry when a new transaction + // is created. + @Test + public void storePreviousTxnIdOnAbortForMultiplexedSession() { + txn = Mockito.mock(TransactionRunnerImpl.TransactionContextImpl.class); + final ByteString mockTransactionId = ByteString.copyFromUtf8("mockTransactionId"); + txn.transactionId = mockTransactionId; + when(session.newTransaction(Options.fromTransactionOptions(), ByteString.EMPTY)) + .thenReturn(txn); + manager.begin(); + // Verify that for the first transaction attempt, the `previousTransactionId` is + // ByteString.EMPTY. + // This is because no transaction has been previously aborted at this point. + verify(session).newTransaction(Options.fromTransactionOptions(), ByteString.EMPTY); + doThrow(SpannerExceptionFactory.newSpannerException(ErrorCode.ABORTED, "")).when(txn).commit(); + assertThrows(AbortedException.class, () -> manager.commit()); + + txn = Mockito.mock(TransactionRunnerImpl.TransactionContextImpl.class); + when(txn.getPreviousTransactionId()).thenReturn(mockTransactionId); + when(session.newTransaction(Options.fromTransactionOptions(), mockTransactionId)) + .thenReturn(txn); + when(session.getIsMultiplexed()).thenReturn(true); + assertThat(manager.resetForRetry()).isEqualTo(txn); + // Verify that in the first retry attempt, the `previousTransactionId` is passed to the new + // transaction. + // This allows Spanner to retry the transaction using the ID of the aborted transaction. + verify(session).newTransaction(Options.fromTransactionOptions(), mockTransactionId); + } + + // This test ensures that when a transaction is aborted in a regular session, + // the transaction ID of the aborted transaction is not saved during the retry when a new + // transaction is created. + @Test + public void skipTxnIdStorageOnAbortForRegularSession() { + txn = Mockito.mock(TransactionRunnerImpl.TransactionContextImpl.class); + final ByteString mockTransactionId = ByteString.copyFromUtf8("mockTransactionId"); + txn.transactionId = mockTransactionId; + when(session.newTransaction(Options.fromTransactionOptions(), ByteString.EMPTY)) + .thenReturn(txn); + manager.begin(); + verify(session).newTransaction(Options.fromTransactionOptions(), ByteString.EMPTY); + doThrow(SpannerExceptionFactory.newSpannerException(ErrorCode.ABORTED, "")).when(txn).commit(); + assertThrows(AbortedException.class, () -> manager.commit()); + clearInvocations(session); + + txn = Mockito.mock(TransactionRunnerImpl.TransactionContextImpl.class); + when(session.newTransaction(Options.fromTransactionOptions(), ByteString.EMPTY)) + .thenReturn(txn); + when(session.getIsMultiplexed()).thenReturn(false); + assertThat(manager.resetForRetry()).isEqualTo(txn); + // Verify that in the first retry attempt, the `previousTransactionId` is not passed to the new + // transaction + // in case of regular sessions. + verify(session).newTransaction(Options.fromTransactionOptions(), ByteString.EMPTY); + } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java index c647bb3642..1fd6817ea9 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/TransactionRunnerImplTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; @@ -117,7 +118,7 @@ public void setUp() { tracer = new TraceWrapper(Tracing.getTracer(), OpenTelemetry.noop().getTracer(""), false); firstRun = true; when(session.getErrorHandler()).thenReturn(DefaultErrorHandler.INSTANCE); - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(txn); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())).thenReturn(txn); when(session.getTracer()).thenReturn(tracer); when(rpc.executeQuery(Mockito.any(ExecuteSqlRequest.class), Mockito.anyMap(), eq(true))) .thenAnswer( @@ -343,7 +344,8 @@ private long[] batchDmlException(int status) { .setTracer(session.getTracer()) .setSpan(session.getTracer().getCurrentSpan()) .build(); - when(session.newTransaction(Options.fromTransactionOptions())).thenReturn(transaction); + when(session.newTransaction(eq(Options.fromTransactionOptions()), any())) + .thenReturn(transaction); when(session.getName()).thenReturn(SessionId.of("p", "i", "d", "test").getName()); TransactionRunnerImpl runner = new TransactionRunnerImpl(session); runner.setSpan(span);