diff --git a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java index ab7f4a04..cf48aa01 100644 --- a/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java +++ b/src/main/java/com/pivovarit/collectors/AsyncParallelCollector.java @@ -17,6 +17,7 @@ import static com.pivovarit.collectors.BatchingSpliterator.batching; import static com.pivovarit.collectors.BatchingSpliterator.partitioned; +import static com.pivovarit.collectors.Dispatcher.Configuration.*; import static java.util.Objects.requireNonNull; import static java.util.concurrent.CompletableFuture.allOf; import static java.util.concurrent.CompletableFuture.supplyAsync; @@ -110,7 +111,7 @@ private static CompletableFuture> combine(List(mapper, Dispatcher.from(executor), Function.identity()); + return new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor)), Function.identity()); } static Collector>> collectingToStream(Function mapper, Executor executor, int parallelism) { @@ -120,7 +121,7 @@ private static CompletableFuture> combine(List i) - : new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), Function.identity()); + : new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)), Function.identity()); } static Collector> collectingWithCollector(Collector collector, Function mapper) { @@ -145,7 +146,7 @@ private static CompletableFuture> combine(List(mapper, Dispatcher.from(executor), s -> s.collect(collector)); + return new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor)), s -> s.collect(collector)); } static Collector> collectingWithCollector(Collector collector, Function mapper, Executor executor, int parallelism) { @@ -156,7 +157,7 @@ private static CompletableFuture> combine(List s.collect(collector)) - : new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), s -> s.collect(collector)); + : new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)), s -> s.collect(collector)); } static void requireValidParallelism(int parallelism) { @@ -212,13 +213,13 @@ private BatchingCollectors() { return list.stream() .collect(new AsyncParallelCollector<>( mapper, - Dispatcher.from(executor, parallelism), + Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)), finisher)); } else { return partitioned(list, parallelism) .collect(new AsyncParallelCollector<>( batching(mapper), - Dispatcher.from(executor, parallelism), + Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)), listStream -> finisher.apply(listStream.flatMap(Collection::stream)))); } }); diff --git a/src/main/java/com/pivovarit/collectors/Dispatcher.java b/src/main/java/com/pivovarit/collectors/Dispatcher.java index de4db63a..59f53659 100644 --- a/src/main/java/com/pivovarit/collectors/Dispatcher.java +++ b/src/main/java/com/pivovarit/collectors/Dispatcher.java @@ -1,5 +1,6 @@ package com.pivovarit.collectors; +import java.util.Objects; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; @@ -9,6 +10,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; @@ -31,41 +33,27 @@ final class Dispatcher { private volatile boolean shortCircuited = false; - private Dispatcher() { - this.executor = defaultExecutorService(); - this.limiter = null; - } - - private Dispatcher(Executor executor, int permits) { - requireValidExecutor(executor); - this.executor = executor; - this.limiter = new Semaphore(permits); - } - - private Dispatcher(int permits) { - this.executor = defaultExecutorService(); - this.limiter = new Semaphore(permits); - } - - private Dispatcher(Executor executor) { - this.executor = executor; - this.limiter = null; - } + private Dispatcher(Configuration configuration) { + if (configuration.maxParallelism == null) { + this.limiter = null; + } else { + requireValidMaxParallelism(configuration.maxParallelism); + this.limiter = new Semaphore(configuration.maxParallelism); + } - static Dispatcher from(Executor executor) { - return new Dispatcher<>(executor); + this.executor = Objects.requireNonNullElseGet(requireValidExecutor(configuration.executor), Dispatcher::defaultExecutorService); } - static Dispatcher from(Executor executor, int permits) { - return new Dispatcher<>(executor, permits); + static Dispatcher virtual() { + return new Dispatcher<>(Configuration.initial()); } - static Dispatcher virtual() { - return new Dispatcher<>(); + static Dispatcher virtual(int maxParallelism) { + return new Dispatcher<>(Configuration.initial().withMaxParallelism(maxParallelism)); } - static Dispatcher virtual(int permits) { - return new Dispatcher<>(permits); + static Dispatcher from(Configuration configuration) { + return new Dispatcher<>(configuration); } void start() { @@ -82,17 +70,15 @@ void start() { } Runnable task; if ((task = workingQueue.take()) != POISON_PILL) { - retry(() -> { - executor.execute(() -> { - try { - task.run(); - } finally { - if (limiter != null) { - limiter.release(); - } + retry(() -> executor.execute(() -> { + try { + task.run(); + } finally { + if (limiter != null) { + limiter.release(); } - }); - }); + } + })); } else { break; } @@ -150,6 +136,19 @@ private static Function shortcircuit(InterruptibleCompletableFu }; } + record Configuration(Executor executor, Integer maxParallelism, ThreadFactory dispatcherFactory) { + + public static Configuration initial() { + return new Configuration(null, null, null); + } + public Configuration withExecutor(Executor executor) { + return new Configuration(executor, this.maxParallelism, this.dispatcherFactory); + } + + public Configuration withMaxParallelism(int permits) { + return new Configuration(this.executor, permits, this.dispatcherFactory); + } + } static final class InterruptibleCompletableFuture extends CompletableFuture { private volatile FutureTask backingTask; @@ -165,13 +164,19 @@ public boolean cancel(boolean mayInterruptIfRunning) { } return super.cancel(mayInterruptIfRunning); } - } + } private static ExecutorService defaultExecutorService() { return Executors.newVirtualThreadPerTaskExecutor(); } - private static void requireValidExecutor(Executor executor) { + static void requireValidMaxParallelism(int maxParallelism) { + if (maxParallelism < 1) { + throw new IllegalArgumentException("Max parallelism can't be lower than 1"); + } + } + + private static Executor requireValidExecutor(Executor executor) { if (executor instanceof ThreadPoolExecutor tpe) { switch (tpe.getRejectedExecutionHandler()) { case ThreadPoolExecutor.DiscardPolicy __ -> @@ -183,6 +188,7 @@ private static void requireValidExecutor(Executor executor) { } } } + return executor; } private static void retry(Runnable runnable) { diff --git a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java index 64fe8a84..b1036118 100644 --- a/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java +++ b/src/main/java/com/pivovarit/collectors/ParallelStreamCollector.java @@ -19,6 +19,7 @@ import static com.pivovarit.collectors.BatchingSpliterator.partitioned; import static com.pivovarit.collectors.CompletionStrategy.ordered; import static com.pivovarit.collectors.CompletionStrategy.unordered; +import static com.pivovarit.collectors.Dispatcher.Configuration.initial; import static java.util.Collections.emptySet; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.collectingAndThen; @@ -101,7 +102,7 @@ public Set characteristics() { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); - return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor)); + return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(initial().withExecutor(executor))); } static Collector> streaming(Function mapper, Executor executor, int parallelism) { @@ -109,7 +110,7 @@ public Set characteristics() { requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); - return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor, parallelism)); + return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism))); } static Collector> streamingOrdered(Function mapper) { @@ -129,7 +130,7 @@ public Set characteristics() { requireNonNull(executor, "executor can't be null"); requireNonNull(mapper, "mapper can't be null"); - return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor)); + return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(initial().withExecutor(executor))); } static Collector> streamingOrdered(Function mapper, Executor executor, @@ -138,7 +139,7 @@ public Set characteristics() { requireNonNull(mapper, "mapper can't be null"); requireValidParallelism(parallelism); - return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor, parallelism)); + return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism))); } static final class BatchingCollectors { @@ -180,14 +181,14 @@ private BatchingCollectors() { mapper, ordered(), emptySet(), - Dispatcher.from(executor, parallelism))); + Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)))); } else { return partitioned(list, parallelism) .collect(collectingAndThen(new ParallelStreamCollector<>( batching(mapper), ordered(), emptySet(), - Dispatcher.from(executor, parallelism)), + Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism))), s -> s.flatMap(Collection::stream))); } });