Skip to content

Commit

Permalink
Restore thread's original context before returning to the ThreadPool
Browse files Browse the repository at this point in the history
This commit ensures that we always restore the thread's original context after execution of
a context preserving runnable. We always wrap runnables in a wrapper that restores the context
at the time it was submitted to the execute method. The ContextPreservingAbstractRunnable
would restore the calling context in the doRun method and then in a try with resources
block would restore the thread's original context. However, the onFailure and onAfter methods
of a AbstractRunnable could modify the thread context and this modified thread context would
continue on as the thread's context after it was returned to the pool and potentially used
for a different purpose.
  • Loading branch information
jaymode authored Nov 8, 2016
1 parent b326f0b commit 6ecb023
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ protected void doExecute(final Runnable command) {
}
}

@Override
protected void afterExecute(Runnable r, Throwable t) {
super.afterExecute(r, t);
assert contextHolder.isDefaultContext() : "the thread context is not the default context and the thread [" +
Thread.currentThread().getName() + "] is being returned to the pool after executing [" + r + "]";
}

/**
* Returns a stream of all pending tasks. This is similar to {@link #getQueue()} but will expose the originally submitted
* {@link Runnable} instances rather than potentially wrapped ones.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ protected void beforeExecute(Thread t, Runnable r) {

@Override
protected void afterExecute(Runnable r, Throwable t) {
super.afterExecute(r, t);
current.remove(r);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ public Runnable unwrap(Runnable command) {
return command;
}

/**
* Returns true if the current context is the default context.
*/
boolean isDefaultContext() {
return threadLocal.get() == DEFAULT_CONTEXT;
}

@FunctionalInterface
public interface StoredContext extends AutoCloseable {
@Override
Expand Down Expand Up @@ -468,10 +475,12 @@ public Runnable unwrap() {
*/
private class ContextPreservingAbstractRunnable extends AbstractRunnable {
private final AbstractRunnable in;
private final ThreadContext.StoredContext ctx;
private final ThreadContext.StoredContext creatorsContext;

private ThreadContext.StoredContext threadsOriginalContext = null;

private ContextPreservingAbstractRunnable(AbstractRunnable in) {
ctx = newStoredContext();
creatorsContext = newStoredContext();
this.in = in;
}

Expand All @@ -482,7 +491,13 @@ public boolean isForceExecution() {

@Override
public void onAfter() {
in.onAfter();
try {
in.onAfter();
} finally {
if (threadsOriginalContext != null) {
threadsOriginalContext.restore();
}
}
}

@Override
Expand All @@ -498,8 +513,9 @@ public void onRejection(Exception e) {
@Override
protected void doRun() throws Exception {
boolean whileRunning = false;
try (ThreadContext.StoredContext ignore = stashContext()){
ctx.restore();
threadsOriginalContext = stashContext();
try {
creatorsContext.restore();
whileRunning = true;
in.doRun();
whileRunning = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ public void testPreserveContext() throws IOException {

// But we do inside of it
withContext.run();

// but not after
assertNull(threadContext.getHeader("foo"));
}
}

Expand Down Expand Up @@ -350,6 +353,177 @@ public void testPreserveContextKeepsOriginalContextWhenCalledTwice() throws IOEx
}
}

public void testPreservesThreadsOriginalContextOnRunException() throws IOException {
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
Runnable withContext;

// create a abstract runnable, add headers and transient objects and verify in the methods
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("foo", "bar_transient");
withContext = threadContext.preserveContext(new AbstractRunnable() {

@Override
public void onAfter() {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertNotNull(threadContext.getTransient("failure"));
assertEquals("exception from doRun", ((RuntimeException)threadContext.getTransient("failure")).getMessage());
assertFalse(threadContext.isDefaultContext());
threadContext.putTransient("after", "after");
}

@Override
public void onFailure(Exception e) {
assertEquals("exception from doRun", e.getMessage());
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertFalse(threadContext.isDefaultContext());
threadContext.putTransient("failure", e);
}

@Override
protected void doRun() throws Exception {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertFalse(threadContext.isDefaultContext());
throw new RuntimeException("exception from doRun");
}
});
}

// We don't see the header outside of the runnable
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertNull(threadContext.getTransient("failure"));
assertNull(threadContext.getTransient("after"));
assertTrue(threadContext.isDefaultContext());

// But we do inside of it
withContext.run();

// verify not seen after
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertNull(threadContext.getTransient("failure"));
assertNull(threadContext.getTransient("after"));
assertTrue(threadContext.isDefaultContext());

// repeat with regular runnable
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("foo", "bar_transient");
withContext = threadContext.preserveContext(() -> {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertFalse(threadContext.isDefaultContext());
threadContext.putTransient("run", true);
throw new RuntimeException("exception from run");
});
}

assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertNull(threadContext.getTransient("run"));
assertTrue(threadContext.isDefaultContext());

final Runnable runnable = withContext;
RuntimeException e = expectThrows(RuntimeException.class, runnable::run);
assertEquals("exception from run", e.getMessage());
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertNull(threadContext.getTransient("run"));
assertTrue(threadContext.isDefaultContext());
}
}

public void testPreservesThreadsOriginalContextOnFailureException() throws IOException {
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
Runnable withContext;

// a runnable that throws from onFailure
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("foo", "bar_transient");
withContext = threadContext.preserveContext(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
throw new RuntimeException("from onFailure", e);
}

@Override
protected void doRun() throws Exception {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertFalse(threadContext.isDefaultContext());
throw new RuntimeException("from doRun");
}
});
}

// We don't see the header outside of the runnable
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertTrue(threadContext.isDefaultContext());

// But we do inside of it
RuntimeException e = expectThrows(RuntimeException.class, withContext::run);
assertEquals("from onFailure", e.getMessage());
assertEquals("from doRun", e.getCause().getMessage());

// but not after
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertTrue(threadContext.isDefaultContext());
}
}

public void testPreservesThreadsOriginalContextOnAfterException() throws IOException {
try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) {
Runnable withContext;

// a runnable that throws from onAfter
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("foo", "bar_transient");
withContext = threadContext.preserveContext(new AbstractRunnable() {

@Override
public void onAfter() {
throw new RuntimeException("from onAfter");
}

@Override
public void onFailure(Exception e) {
throw new RuntimeException("from onFailure", e);
}

@Override
protected void doRun() throws Exception {
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("bar_transient", threadContext.getTransient("foo"));
assertFalse(threadContext.isDefaultContext());
}
});
}

// We don't see the header outside of the runnable
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertTrue(threadContext.isDefaultContext());

// But we do inside of it
RuntimeException e = expectThrows(RuntimeException.class, withContext::run);
assertEquals("from onAfter", e.getMessage());
assertNull(e.getCause());

// but not after
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("foo"));
assertTrue(threadContext.isDefaultContext());
}
}

/**
* Sometimes wraps a Runnable in an AbstractRunnable.
*/
Expand Down

0 comments on commit 6ecb023

Please sign in to comment.