需求
需要根据优先级执行任务,有任务不是特别重要,可以稍后执行;需要对正在执行的线程做超时监控;有的API依赖任务返回结果,线程池执行的时候任务也支持同步任务;
简单测试
创建一个使用支持优先级队列(new PriorityBlockingQueue<>() )的线程,然后每个任务实现 java.lang.Comparable 接口
new ThreadPoolExecutor(MAX_THREAD, MAX_THREAD * 2, 0L, TimeUnit.MILLISECONDS, new PriorityBlockingQueue());
简单测试一下基本可行,但是发现使用优先级列队线程池发现使用FutureTask 会抛异常 java.lang.ClassCastException: java.util.concurrent.FutureTask cannot be cast to java.lang.Comparable 所以这个方法不能再用了
Future<?> future = executor.submit(new Task()); 1
设计思路
- 定义一个任务接口,继承 Runnable, Comparable,并且获取超时时间方法、优先级方法、执行时间方法、超时中止任务方法;
- 有一个管理线程池的类,用于执行任务
- 有一个超时检测的类
- 有一个基本执行类,用于每次任务执行时把任务加入监控列表,有一个同步任务类,管理执行同步任务
代码实现
任务接口
/** * Created by zengrenyuan on 18/6/11. */ public interface ExecuteTask extends Runnable, Comparable { /** * @return 任务超时时间 */ long getTimeout(); /** * @return 任务的优先级 */ int getPriority(); /** * 任务超时时调用 * 用于结束任务运行 */ void destroy(); /** * @return 任务当前的执行时间 */ long elapsed(); } 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
线程池管理
import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.PreDestroy; import java.util.concurrent.*; /** * Created by zengrenyuan on 18/06/11. */ public class ExecuteThreadManager { private static final Logger LOG = LoggerFactory.getLogger(ExecuteThreadManager.class); /** * 线程池线程名称 */ public static final String SCRIPT_TASK_THREAD_POOL = "TaskThreadPool"; /** * 最大线程数 */ private final int MAX_THREAD = 10; /** * 一个按优先级执行的线程池 */ private final ThreadPoolExecutor executor = ThreadPoolUtils.newThreadPool(MAX_THREAD, SCRIPT_TASK_THREAD_POOL, new PriorityBlockingQueue<>()); private static class LazyHolder { private static final ExecuteThreadManager INSTANCE = new ExecuteThreadManager(); } private ExecuteThreadManager() { } /** * 单例 */ public static final ExecuteThreadManager getInstance() { return LazyHolder.INSTANCE; } /** * 把任务提交到线程池执行 */ public void execute(ExecuteTask task) { executor.execute(new BaseExecuteTask(task)); } /** * 同步执行任务,等任务执行完才会往下走 */ public void syncExecute(ExecuteTask task) { SynchronizedExecuteTask syncTask = new SynchronizedExecuteTask(task); executor.execute(syncTask); syncTask.await(); } @PreDestroy private void destroy() { try { executor.shutdownNow(); } catch (Exception e) { LOG.error(null, e); } } } 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
线程池生成工具类
import java.util.concurrent.*; /** * Created by zengrenyuan on 18/06/11. */ public class ThreadPoolUtils { /** * 生成一个线程池 * * @param maxThread 线程池最大数 * @param threadPoolName 线程池名称 * @return */ public static ThreadPoolExecutor newThreadPool(int maxThread, final String threadPoolName) { return newThreadPool(maxThread, threadPoolName, new LinkedBlockingDeque<>()); } /** * 生成一个线程池 * * @param maxThread 线程池最大数 * @param threadPoolName 线程池名称 * @param workQueue 队列 * @return 线程池 */ public static ThreadPoolExecutor newThreadPool(int maxThread, final String threadPoolName, BlockingQueue workQueue) { return new ThreadPoolExecutor( maxThread, maxThread, 0L, TimeUnit.MILLISECONDS, workQueue, getThreadFactory(threadPoolName) ); } /** * @param threadPoolName 线程名称 * @return 线程生成器 */ public static ThreadFactory getThreadFactory(final String threadPoolName) { return new ThreadFactory() { private int threadNumber = 1; @Override public Thread newThread(Runnable r) { return new Thread(r, threadPoolName + "-" + threadNumber); } }; } } 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
基础任务类
import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * 简单包装一下 * 在任务执行前把任务加入到运行任务监控池 * Created by zengrenyuan on 18/06/11. */ public class BaseExecuteTask implements ExecuteTask { private static final Logger LOG = LoggerFactory.getLogger(BaseExecuteTask.class); protected ExecuteTask task; public BaseExecuteTask(ExecuteTask task) { this.task = task; } @Override public long getTimeout() { return task.getTimeout(); } @Override public void destroy() { task.destroy(); } @Override public int getPriority() { return task.getPriority(); } @Override public long elapsed() { return task.elapsed(); } @Override public int compareTo(Object o) { return task.compareTo(o); } @Override public void run() { try { //任务开始前把任务加入到正在运行的任务池 TaskPoolListener.addTask(this); task.run(); } finally { LOG.info("运行结束"); TaskPoolListener.removeTask(this); //任务运行后,把任务从正在运行的任务池中移除 } } } 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
同步线程工具类
import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; /** * 同步执行任务 * Created by zengrenyuan on 18/06/11. */ public class SynchronizedExecuteTask extends BaseExecuteTask { private static final Logger LOG = LoggerFactory.getLogger(SynchronizedExecuteTask.class); //用于记录任务是否执行完成 private CountDownLatch latch = new CountDownLatch(1); public SynchronizedExecuteTask(ExecuteTask task) { super(task); } @Override public void run() { try { super.run(); } finally { LOG.info("运行结束"); latch.countDown(); } } public void await() { //任务是否在超时时间之前执行完成 boolean finished = false; try { finished = latch.await(getTimeout(), TimeUnit.MILLISECONDS); } catch (InterruptedException e) { LOG.error(null, e); } finally { if (!finished) { destroy(); throw new RuntimeException("任务执行超时"); } } } } 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
任务超时监控
import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Service; import javax.annotation.PostConstruct; import java.util.ArrayList; import java.util.List; /** * Created by zengrenyuan on 18/06/11. */ @Service public class TaskPoolListener { private static final Logger LOG = LoggerFactory.getLogger(TaskPoolListener.class); private final static List<ExecuteTask> scriptTasks = new ArrayList<>(); public static final int SLEEP_INTERVAL = 10_000; private Thread thread; @PostConstruct public void init() { thread = new Thread(new Runnable() { @Override public void run() { while (true) { try { doTask(); } catch (Exception e) { LOG.info(null, e); } finally { this.sleep(SLEEP_INTERVAL); } } } private void sleep(long interval) { try { Thread.sleep(interval); } catch (InterruptedException var4) { LOG.info(var4.getMessage()); } } }); thread.setName("任务运行超时检测线程"); thread.start(); } /** * 用于监测超时任务 * 如果任务超时调用destroy方法 */ public void doTask() { List<ExecuteTask> scriptTasks = new ArrayList<>(this.scriptTasks); for (ExecuteTask scriptTask : scriptTasks) { if (scriptTask.elapsed() > scriptTask.getTimeout()) { scriptTask.destroy(); } } } public static synchronized void addTask(ExecuteTask scriptTask) { scriptTasks.add(scriptTask); } public static synchronized void removeTask(ExecuteTask scriptTask) { scriptTasks.remove(scriptTask); } public final int size() { return scriptTasks.size(); } }