summaryrefslogtreecommitdiff
path: root/core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java')
-rw-r--r--core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java104
1 files changed, 100 insertions, 4 deletions
diff --git a/core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java b/core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java
index 4c02aab1fe..b5417aa238 100644
--- a/core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java
+++ b/core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java
@@ -30,6 +30,7 @@ import java.util.concurrent.TimeUnit;
*/
public class EsThreadPoolExecutor extends ThreadPoolExecutor {
+ private final ThreadContext contextHolder;
private volatile ShutdownListener listener;
private final Object monitor = new Object();
@@ -38,13 +39,14 @@ public class EsThreadPoolExecutor extends ThreadPoolExecutor {
*/
private final String name;
- EsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory) {
- this(name, corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, new EsAbortPolicy());
+ EsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, ThreadContext contextHolder) {
+ this(name, corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, new EsAbortPolicy(), contextHolder);
}
- EsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, XRejectedExecutionHandler handler) {
+ EsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, XRejectedExecutionHandler handler, ThreadContext contextHolder) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
this.name = name;
+ this.contextHolder = contextHolder;
}
public void shutdown(ShutdownListener listener) {
@@ -80,7 +82,11 @@ public class EsThreadPoolExecutor extends ThreadPoolExecutor {
}
@Override
- public void execute(Runnable command) {
+ public void execute(final Runnable command) {
+ doExecute(wrapRunnable(command));
+ }
+
+ protected void doExecute(final Runnable command) {
try {
super.execute(command);
} catch (EsRejectedExecutionException ex) {
@@ -116,4 +122,94 @@ public class EsThreadPoolExecutor extends ThreadPoolExecutor {
b.append(super.toString()).append(']');
return b.toString();
}
+
+ protected Runnable wrapRunnable(Runnable command) {
+ final Runnable wrappedCommand;
+ if (command instanceof AbstractRunnable) {
+ wrappedCommand = new FilterAbstractRunnable(contextHolder, (AbstractRunnable) command);
+ } else {
+ wrappedCommand = new FilterRunnable(contextHolder, command);
+ }
+ return wrappedCommand;
+ }
+
+ protected Runnable unwrap(Runnable runnable) {
+ if (runnable instanceof FilterAbstractRunnable) {
+ return ((FilterAbstractRunnable) runnable).in;
+ } else if (runnable instanceof FilterRunnable) {
+ return ((FilterRunnable) runnable).in;
+ }
+ return runnable;
+ }
+
+ private static class FilterAbstractRunnable extends AbstractRunnable {
+ private final ThreadContext contextHolder;
+ private final AbstractRunnable in;
+ private final ThreadContext.StoredContext ctx;
+
+ FilterAbstractRunnable(ThreadContext contextHolder, AbstractRunnable in) {
+ this.contextHolder = contextHolder;
+ ctx = contextHolder.newStoredContext();
+ this.in = in;
+ }
+
+ @Override
+ public boolean isForceExecution() {
+ return in.isForceExecution();
+ }
+
+ @Override
+ public void onAfter() {
+ in.onAfter();
+ }
+
+ @Override
+ public void onFailure(Throwable t) {
+ in.onFailure(t);
+ }
+
+ @Override
+ public void onRejection(Throwable t) {
+ in.onRejection(t);
+ }
+
+ @Override
+ protected void doRun() throws Exception {
+ try (ThreadContext.StoredContext ingore = contextHolder.stashContext()){
+ ctx.restore();
+ in.doRun();
+ }
+ }
+
+ @Override
+ public String toString() {
+ return in.toString();
+ }
+
+ }
+
+ private static class FilterRunnable implements Runnable {
+ private final ThreadContext contextHolder;
+ private final Runnable in;
+ private final ThreadContext.StoredContext ctx;
+
+ FilterRunnable(ThreadContext contextHolder, Runnable in) {
+ this.contextHolder = contextHolder;
+ ctx = contextHolder.newStoredContext();
+ this.in = in;
+ }
+
+ @Override
+ public void run() {
+ try (ThreadContext.StoredContext ingore = contextHolder.stashContext()){
+ ctx.restore();
+ in.run();
+ }
+ }
+ @Override
+ public String toString() {
+ return in.toString();
+ }
+ }
+
}