Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dedicated Configuration object for configuring Dispatcher instances #924

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static com.pivovarit.collectors.BatchingSpliterator.batching;
import static com.pivovarit.collectors.BatchingSpliterator.partitioned;
import static com.pivovarit.collectors.Dispatcher.Configuration.*;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.supplyAsync;
Expand Down Expand Up @@ -110,7 +111,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");

return new AsyncParallelCollector<>(mapper, Dispatcher.from(executor), Function.identity());
return new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor)), Function.identity());
}

static <T, R> Collector<T, ?, CompletableFuture<Stream<R>>> collectingToStream(Function<T, R> mapper, Executor executor, int parallelism) {
Expand All @@ -120,7 +121,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T

return parallelism == 1
? asyncCollector(mapper, executor, i -> i)
: new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), Function.identity());
: new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)), Function.identity());
}

static <T, R, RR> Collector<T, ?, CompletableFuture<RR>> collectingWithCollector(Collector<R, ?, RR> collector, Function<T, R> mapper) {
Expand All @@ -145,7 +146,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");

return new AsyncParallelCollector<>(mapper, Dispatcher.from(executor), s -> s.collect(collector));
return new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor)), s -> s.collect(collector));
}

static <T, R, RR> Collector<T, ?, CompletableFuture<RR>> collectingWithCollector(Collector<R, ?, RR> collector, Function<T, R> mapper, Executor executor, int parallelism) {
Expand All @@ -156,7 +157,7 @@ private static <T> CompletableFuture<Stream<T>> combine(List<CompletableFuture<T

return parallelism == 1
? asyncCollector(mapper, executor, s -> s.collect(collector))
: new AsyncParallelCollector<>(mapper, Dispatcher.from(executor, parallelism), s -> s.collect(collector));
: new AsyncParallelCollector<>(mapper, Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)), s -> s.collect(collector));
}

static void requireValidParallelism(int parallelism) {
Expand Down Expand Up @@ -212,13 +213,13 @@ private BatchingCollectors() {
return list.stream()
.collect(new AsyncParallelCollector<>(
mapper,
Dispatcher.from(executor, parallelism),
Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)),
finisher));
} else {
return partitioned(list, parallelism)
.collect(new AsyncParallelCollector<>(
batching(mapper),
Dispatcher.from(executor, parallelism),
Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)),
listStream -> finisher.apply(listStream.flatMap(Collection::stream))));
}
});
Expand Down
86 changes: 46 additions & 40 deletions src/main/java/com/pivovarit/collectors/Dispatcher.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.pivovarit.collectors;

import java.util.Objects;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
Expand All @@ -9,6 +10,7 @@
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
Expand All @@ -31,41 +33,27 @@ final class Dispatcher<T> {

private volatile boolean shortCircuited = false;

private Dispatcher() {
this.executor = defaultExecutorService();
this.limiter = null;
}

private Dispatcher(Executor executor, int permits) {
requireValidExecutor(executor);
this.executor = executor;
this.limiter = new Semaphore(permits);
}

private Dispatcher(int permits) {
this.executor = defaultExecutorService();
this.limiter = new Semaphore(permits);
}

private Dispatcher(Executor executor) {
this.executor = executor;
this.limiter = null;
}
private Dispatcher(Configuration configuration) {
if (configuration.maxParallelism == null) {
this.limiter = null;
} else {
requireValidMaxParallelism(configuration.maxParallelism);
this.limiter = new Semaphore(configuration.maxParallelism);
}

static <T> Dispatcher<T> from(Executor executor) {
return new Dispatcher<>(executor);
this.executor = Objects.requireNonNullElseGet(requireValidExecutor(configuration.executor), Dispatcher::defaultExecutorService);
}

static <T> Dispatcher<T> from(Executor executor, int permits) {
return new Dispatcher<>(executor, permits);
static <T> Dispatcher<T> virtual() {
return new Dispatcher<>(Configuration.initial());
}

static <T> Dispatcher<T> virtual() {
return new Dispatcher<>();
static <T> Dispatcher<T> virtual(int maxParallelism) {
return new Dispatcher<>(Configuration.initial().withMaxParallelism(maxParallelism));
}

static <T> Dispatcher<T> virtual(int permits) {
return new Dispatcher<>(permits);
static <T> Dispatcher<T> from(Configuration configuration) {
return new Dispatcher<>(configuration);
}

void start() {
Expand All @@ -82,17 +70,15 @@ void start() {
}
Runnable task;
if ((task = workingQueue.take()) != POISON_PILL) {
retry(() -> {
executor.execute(() -> {
try {
task.run();
} finally {
if (limiter != null) {
limiter.release();
}
retry(() -> executor.execute(() -> {
try {
task.run();
} finally {
if (limiter != null) {
limiter.release();
}
});
});
}
}));
} else {
break;
}
Expand Down Expand Up @@ -150,6 +136,19 @@ private static Function<Throwable, Void> shortcircuit(InterruptibleCompletableFu
};
}

record Configuration(Executor executor, Integer maxParallelism, ThreadFactory dispatcherFactory) {

public static Configuration initial() {
return new Configuration(null, null, null);
}
public Configuration withExecutor(Executor executor) {
return new Configuration(executor, this.maxParallelism, this.dispatcherFactory);
}

public Configuration withMaxParallelism(int permits) {
return new Configuration(this.executor, permits, this.dispatcherFactory);
}
}
static final class InterruptibleCompletableFuture<T> extends CompletableFuture<T> {

private volatile FutureTask<?> backingTask;
Expand All @@ -165,13 +164,19 @@ public boolean cancel(boolean mayInterruptIfRunning) {
}
return super.cancel(mayInterruptIfRunning);
}
}

}
private static ExecutorService defaultExecutorService() {
return Executors.newVirtualThreadPerTaskExecutor();
}

private static void requireValidExecutor(Executor executor) {
static void requireValidMaxParallelism(int maxParallelism) {
if (maxParallelism < 1) {
throw new IllegalArgumentException("Max parallelism can't be lower than 1");
}
}

private static Executor requireValidExecutor(Executor executor) {
if (executor instanceof ThreadPoolExecutor tpe) {
switch (tpe.getRejectedExecutionHandler()) {
case ThreadPoolExecutor.DiscardPolicy __ ->
Expand All @@ -183,6 +188,7 @@ private static void requireValidExecutor(Executor executor) {
}
}
}
return executor;
}

private static void retry(Runnable runnable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.pivovarit.collectors.BatchingSpliterator.partitioned;
import static com.pivovarit.collectors.CompletionStrategy.ordered;
import static com.pivovarit.collectors.CompletionStrategy.unordered;
import static com.pivovarit.collectors.Dispatcher.Configuration.initial;
import static java.util.Collections.emptySet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.collectingAndThen;
Expand Down Expand Up @@ -101,15 +102,15 @@ public Set<Characteristics> characteristics() {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");

return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor));
return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(initial().withExecutor(executor)));
}

static <T, R> Collector<T, ?, Stream<R>> streaming(Function<T, R> mapper, Executor executor, int parallelism) {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);

return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(executor, parallelism));
return new ParallelStreamCollector<>(mapper, unordered(), UNORDERED, Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)));
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper) {
Expand All @@ -129,7 +130,7 @@ public Set<Characteristics> characteristics() {
requireNonNull(executor, "executor can't be null");
requireNonNull(mapper, "mapper can't be null");

return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor));
return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(initial().withExecutor(executor)));
}

static <T, R> Collector<T, ?, Stream<R>> streamingOrdered(Function<T, R> mapper, Executor executor,
Expand All @@ -138,7 +139,7 @@ public Set<Characteristics> characteristics() {
requireNonNull(mapper, "mapper can't be null");
requireValidParallelism(parallelism);

return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(executor, parallelism));
return new ParallelStreamCollector<>(mapper, ordered(), emptySet(), Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism)));
}

static final class BatchingCollectors {
Expand Down Expand Up @@ -180,14 +181,14 @@ private BatchingCollectors() {
mapper,
ordered(),
emptySet(),
Dispatcher.from(executor, parallelism)));
Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism))));
} else {
return partitioned(list, parallelism)
.collect(collectingAndThen(new ParallelStreamCollector<>(
batching(mapper),
ordered(),
emptySet(),
Dispatcher.from(executor, parallelism)),
Dispatcher.from(initial().withExecutor(executor).withMaxParallelism(parallelism))),
s -> s.flatMap(Collection::stream)));
}
});
Expand Down