diff options
Diffstat (limited to 'core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java')
-rw-r--r-- | core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java | 104 |
1 files changed, 100 insertions, 4 deletions
diff --git a/core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java b/core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java index 4c02aab1fe..b5417aa238 100644 --- a/core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java +++ b/core/src/main/java/org/elasticsearch/common/util/concurrent/EsThreadPoolExecutor.java @@ -30,6 +30,7 @@ import java.util.concurrent.TimeUnit; */ public class EsThreadPoolExecutor extends ThreadPoolExecutor { + private final ThreadContext contextHolder; private volatile ShutdownListener listener; private final Object monitor = new Object(); @@ -38,13 +39,14 @@ public class EsThreadPoolExecutor extends ThreadPoolExecutor { */ private final String name; - EsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory) { - this(name, corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, new EsAbortPolicy()); + EsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, ThreadContext contextHolder) { + this(name, corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, new EsAbortPolicy(), contextHolder); } - EsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, XRejectedExecutionHandler handler) { + EsThreadPoolExecutor(String name, int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, XRejectedExecutionHandler handler, ThreadContext contextHolder) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler); this.name = name; + this.contextHolder = contextHolder; } public void shutdown(ShutdownListener listener) { @@ -80,7 +82,11 @@ public class EsThreadPoolExecutor extends ThreadPoolExecutor { } @Override - public void execute(Runnable command) { + public void execute(final Runnable command) { + doExecute(wrapRunnable(command)); + } + + protected void doExecute(final Runnable command) { try { super.execute(command); } catch (EsRejectedExecutionException ex) { @@ -116,4 +122,94 @@ public class EsThreadPoolExecutor extends ThreadPoolExecutor { b.append(super.toString()).append(']'); return b.toString(); } + + protected Runnable wrapRunnable(Runnable command) { + final Runnable wrappedCommand; + if (command instanceof AbstractRunnable) { + wrappedCommand = new FilterAbstractRunnable(contextHolder, (AbstractRunnable) command); + } else { + wrappedCommand = new FilterRunnable(contextHolder, command); + } + return wrappedCommand; + } + + protected Runnable unwrap(Runnable runnable) { + if (runnable instanceof FilterAbstractRunnable) { + return ((FilterAbstractRunnable) runnable).in; + } else if (runnable instanceof FilterRunnable) { + return ((FilterRunnable) runnable).in; + } + return runnable; + } + + private static class FilterAbstractRunnable extends AbstractRunnable { + private final ThreadContext contextHolder; + private final AbstractRunnable in; + private final ThreadContext.StoredContext ctx; + + FilterAbstractRunnable(ThreadContext contextHolder, AbstractRunnable in) { + this.contextHolder = contextHolder; + ctx = contextHolder.newStoredContext(); + this.in = in; + } + + @Override + public boolean isForceExecution() { + return in.isForceExecution(); + } + + @Override + public void onAfter() { + in.onAfter(); + } + + @Override + public void onFailure(Throwable t) { + in.onFailure(t); + } + + @Override + public void onRejection(Throwable t) { + in.onRejection(t); + } + + @Override + protected void doRun() throws Exception { + try (ThreadContext.StoredContext ingore = contextHolder.stashContext()){ + ctx.restore(); + in.doRun(); + } + } + + @Override + public String toString() { + return in.toString(); + } + + } + + private static class FilterRunnable implements Runnable { + private final ThreadContext contextHolder; + private final Runnable in; + private final ThreadContext.StoredContext ctx; + + FilterRunnable(ThreadContext contextHolder, Runnable in) { + this.contextHolder = contextHolder; + ctx = contextHolder.newStoredContext(); + this.in = in; + } + + @Override + public void run() { + try (ThreadContext.StoredContext ingore = contextHolder.stashContext()){ + ctx.restore(); + in.run(); + } + } + @Override + public String toString() { + return in.toString(); + } + } + } |