Skip to content

Commit

Permalink
Use dedicated Configuration object for configuring Dispatcher instances
Browse files Browse the repository at this point in the history
  • Loading branch information
pivovarit committed Sep 14, 2024
1 parent 1842264 commit 0739983
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static com.pivovarit.collectors.BatchingSpliterator.batching;
import static com.pivovarit.collectors.BatchingSpliterator.partitioned;
import static com.pivovarit.collectors.Dispatcher.Configuration.*;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.supplyAsync;
Expand Down Expand Up @@ -96,7 +97,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
static <T, R> Collector<T, ?, CompletableFuture<Stream<R>>> collectingToStream(Function<T, R> mapper) {
requireNonNull(mapper, "mapper can't be null");

return new AsyncParallelCollector<>(mapper, Dispatcher.virtual(), Function.identity());
return new AsyncParallelCollector<>(mapper, Dispatcher.from(initial()), Function.identity());
}

static <T, R> Collector<T, ?, CompletableFuture<Stream<R>>> collectingToStream(Function<T, R> mapper, int parallelism) {
Expand All @@ -110,7 +111,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");

return new AsyncParallelCollector<>(mapper, Dispatcher.from(executor), Function.identity());
return new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor)), Function.identity());
}

static <T, R> Collector<T, ?, CompletableFuture<Stream<R>>> collectingToStream(Function<T, R> mapper, Executor executor, int parallelism) {
Expand All @@ -120,7 +121,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T

return parallelism == 1
? asyncCollector(mapper, executor, i -> i)
: new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), Function.identity());
: new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)), Function.identity());
}

static <T, R, RR> Collector<T, ?, CompletableFuture<RR>> collectingWithCollector(Collector<R, ?, RR> collector, Function<T, R> mapper) {
Expand All @@ -145,7 +146,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");

return new AsyncParallelCollector<>(mapper, Dispatcher.from(executor), s -> s.collect(collector));
return new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor)), s -> s.collect(collector));
}

static <T, R, RR> Collector<T, ?, CompletableFuture<RR>> collectingWithCollector(Collector<R, ?, RR> collector, Function<T, R> mapper, Executor executor, int parallelism) {
Expand All @@ -156,7 +157,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T

return parallelism == 1
? asyncCollector(mapper, executor, s -> s.collect(collector))
: new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), s -> s.collect(collector));
: new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)), s -> s.collect(collector));
}

static void requireValidParallelism(int parallelism) {
Expand Down Expand Up @@ -212,13 +213,13 @@ private BatchingCollectors() {
return list.stream()
.collect(new AsyncParallelCollector<>(
mapper,
Dispatcher.from(executor, parallelism),
Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)),
finisher));
} else {
return partitioned(list, parallelism)
.collect(new AsyncParallelCollector<>(
batching(mapper),
Dispatcher.from(executor, parallelism),
Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)),
listStream -> finisher.apply(listStream.flatMap(Collection::stream))));
}
});
Expand Down
86 changes: 46 additions & 40 deletions src/main/java/com/pivovarit/collectors/Dispatcher.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.pivovarit.collectors;

import java.util.Objects;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
Expand All @@ -9,6 +10,7 @@
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
Expand All @@ -31,41 +33,27 @@ final class Dispatcher<T> {

private volatile boolean shortCircuited = false;

private Dispatcher() {
this.executor = defaultExecutorService();
this.limiter = null;
}

private Dispatcher(Executor executor, int permits) {
requireValidExecutor(executor);
this.executor = executor;
this.limiter = new Semaphore(permits);
}

private Dispatcher(int permits) {
this.executor = defaultExecutorService();
this.limiter = new Semaphore(permits);
}

private Dispatcher(Executor executor) {
this.executor = executor;
this.limiter = null;
}
private Dispatcher(Configuration configuration) {
if (configuration.maxParallelism == null) {
this.limiter = null;
} else {
requireValidMaxParallelism(configuration.maxParallelism);
this.limiter = new Semaphore(configuration.maxParallelism);
}

static <T> Dispatcher<T> from(Executor executor) {
return new Dispatcher<>(executor);
this.executor = Objects.requireNonNullElseGet(requireValidExecutor(configuration.executor), Dispatcher::defaultExecutorService);
}

static <T> Dispatcher<T> from(Executor executor, int permits) {
return new Dispatcher<>(executor, permits);
static <T> Dispatcher<T> virtual() {
return new Dispatcher<>(Configuration.initial());
}

static <T> Dispatcher<T> virtual() {
return new Dispatcher<>();
static <T> Dispatcher<T> virtual(int maxParallelism) {
return new Dispatcher<>(Configuration.initial().withMaxParallelism(maxParallelism));
}

static <T> Dispatcher<T> virtual(int permits) {
return new Dispatcher<>(permits);
static <T> Dispatcher<T> from(Configuration configuration) {
return new Dispatcher<>(configuration);
}

void start() {
Expand All @@ -82,17 +70,15 @@ void start() {
}
Runnable task;
if ((task = workingQueue.take()) != POISON_PILL) {
retry(() -> {
executor.execute(() -> {
try {
task.run();
} finally {
if (limiter != null) {
limiter.release();
}
retry(() -> executor.execute(() -> {
try {
task.run();
} finally {
if (limiter != null) {
limiter.release();
}
});
});
}
}));
} else {
break;
}
Expand Down Expand Up @@ -150,6 +136,19 @@ private static Function<Throwable, Void> shortcircuit(InterruptibleCompletableFu
};
}

record Configuration(Executor executor, Integer maxParallelism, ThreadFactory dispatcherFactory) {

public static Configuration initial() {
return new Configuration(null, null, null);
}
public Configuration withExecutor(Executor executor) {
return new Configuration(executor, this.maxParallelism, this.dispatcherFactory);
}

public Configuration withMaxParallelism(int permits) {
return new Configuration(this.executor, permits, this.dispatcherFactory);
}
}
static final class InterruptibleCompletableFuture<T> extends CompletableFuture<T> {

private volatile FutureTask<?> backingTask;
Expand All @@ -165,13 +164,19 @@ public boolean cancel(boolean mayInterruptIfRunning) {
}
return super.cancel(mayInterruptIfRunning);
}
}

}
private static ExecutorService defaultExecutorService() {
return Executors.newVirtualThreadPerTaskExecutor();
}

private static void requireValidExecutor(Executor executor) {
static void requireValidMaxParallelism(int maxParallelism) {
if (maxParallelism < 1) {
throw new IllegalArgumentException("Max parallelism can't be lower than 1");
}
}

private static Executor requireValidExecutor(Executor executor) {
if (executor instanceof ThreadPoolExecutor tpe) {
switch (tpe.getRejectedExecutionHandler()) {
case ThreadPoolExecutor.DiscardPolicy __ ->
Expand All @@ -183,6 +188,7 @@ private static void requireValidExecutor(Executor executor) {
}
}
}
return executor;
}

private static void retry(Runnable runnable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.pivovarit.collectors.BatchingSpliterator.partitioned;
import static com.pivovarit.collectors.CompletionStrategy.ordered;
import static com.pivovarit.collectors.CompletionStrategy.unordered;
import static com.pivovarit.collectors.Dispatcher.Configuration.initial;
import static java.util.Collections.emptySet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.collectingAndThen;
Expand Down Expand Up @@ -101,15 +102,15 @@ public Set<Characteristics> characteristics() {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");

return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor));
return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(initial().withExecutor(executor)));
}

static <T, R> Collector<T, ?, Stream<R>> streaming(Function<T, R> mapper, Executor executor, int parallelism) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);

return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor, parallelism));
return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)));
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper) {
Expand All @@ -129,7 +130,7 @@ public Set<Characteristics> characteristics() {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");

return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor));
return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(initial().withExecutor(executor)));
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor,
Expand All @@ -138,7 +139,7 @@ public Set<Characteristics> characteristics() {
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);

return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor, parallelism));
return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)));
}

static final class BatchingCollectors {
Expand Down Expand Up @@ -180,14 +181,14 @@ private BatchingCollectors() {
mapper,
ordered(),
emptySet(),
Dispatcher.from(executor, parallelism)));
Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism))));
} else {
return partitioned(list, parallelism)
.collect(collectingAndThen(new ParallelStreamCollector<>(
batching(mapper),
ordered(),
emptySet(),
Dispatcher.from(executor, parallelism)),
Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism))),
s -> s.flatMap(Collection::stream)));
}
});
Expand Down

0 comments on commit 0739983

Please sign in to comment.