In this article, I will show you how to use the ForkJoinPool
, which hasn’t received significant dissemination among Java developers.
ForkJoinPool
is one of the ExecutorService
’s implementations. It is used in CompletableFuture
and Stream API. It was designed to simplify parallelism for recursive tasks by breaking the single task into independent ones until they are small enough to be executed asynchronously. With this class, you can perform a significantly large amount of tasks in a small number of threads.
Java has a common ForkJoinPool
implementation that can be created through the static commonPool()
method:
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
ForkJoinTask
has its own Runnable
and Callable
implementations. These implementations are called RecursiveAction
and RecursiveTask
respectively. Each of them has an abstract compute()
method, which must be implemented during implementation. RecursiveTask<T>#compute()
returns generic <T> value, RecursiveAction#compute()
returns void. Both of these implementations are inherited from the abstract ForkJoinTask
class.
To start a task in ForkJoinPool
, use the method T invoke(ForkJoinTask<T> task)
:
ForkJoinTask forkJoinTask = new ForkJoinTaskImpl(...);
forkJoinPool.invoke(forkJoinTask);
In addition to the compute()
method, ForkJoinTask
has the following methods: fork()
and join()
. In terms of usage, the ForkJoinTask#join()
is similar to the Thread#join()
. But in the case of a fork-join, the thread may not actually fall asleep, it will rather switch to another task. This strategy is called “work stealing”, it allows more efficient use of a limited number of threads.
Let's look at the simple example of the ForkJoinTask
implementation for calculating the Fibonacci numbers:
public class FibonacciTask extends RecursiveTask<Integer> {
private final int n;
public FibonacciTask(int n) {
this.n = n;
}
@Override
public Integer compute() {
if (n < 2) return n;
FibonacciTask f1 = new FibonacciTask(n - 1);
FibonacciTask f2 = new FibonacciTask(n - 2);
f1.fork();
return f2.compute() + f1.join();
}
}
In this example, FibonacciTask
implements the compute()
method, which creates additional FibonacciTask
instances and forks them. The join()
method asks the current thread to wait until results are returned by the forked methods.
Let's take a look at the more complex example of ForkJoinTask
- the Merge Sort. It is based on the "Divide and Conquer" principle, old as the world. We need to divide source problems into subtasks, solve them recursively and combine the results:
For this task we will use RecursiveAction
as the ForkJoinTask
implementation, since we don't need a return value. We will add int[] arr
as a parameter to the constructor and instantiate the class's field:
public class MergeSortAction extends RecursiveAction {
private final int[] arr;
public MergeSortAction(int[] arr) {
this.arr = arr;
}
@Override
public void compute() {
...
}
private void merge(int[] left, int[] right) {
...
}
}
Next, we implement the merge()
method. It will accept 2 arrays - left
and right
parts. We need to merge these unsorted arrays by comparing elements and assigning the smallest element to the appropriate arr index. When one of the arrays will be empty, the while
loop will break, but the data from another non-empty array part still needs to be merged. To do this we have two additional while loops to get all the data from both arrays:
private void merge(int[] left, int[] right) {
int i = 0, j = 0, k = 0;
while (i < left.length && j < right.length) {
if (left[i] < right[j])
arr[k++] = left[i++];
else
arr[k++] = right[j++];
}
while (i < left.length) {
arr[k++] = left[i++];
}
while (j < right.length) {
arr[k++] = right[j++];
}
}
Next, we will write the compute()
method to recursively divide the original array and pass the results to the merge()
method. To do this we need to calculate the middle index of the array. Then we divide the original array into two parts: left
and right
. To fill them we copy original data by calling System.arraycopy(Object src, int srcPos, Object dest, int destPos, int length)
:
@Override
public void compute() {
if (arr.length < 2) return;
int mid = arr.length / 2;
int[] left = new int[mid];
System.arraycopy(arr, 0, left, 0, mid);
int[] right = new int[arr.length - mid];
System.arraycopy(arr, mid, right, 0, arr.length - mid);
...
}
Now, we have two separate arrays. Let’s divide them recursively by creating new MergeSortAction
tasks and run them asynchronously by passing them into invokeAll()
method. To sort and combine two arrays we use merge()
method:
@Override
public void compute() {
if (arr.length < 2) return;
int mid = arr.length / 2;
int[] left = new int[mid];
System.arraycopy(arr, 0, left, 0, mid);
int[] right = new int[arr.length - mid];
System.arraycopy(arr, mid, right, 0, arr.length - mid);
invokeAll(new MergeSortAction(left), new MergeSortAction(right));
merge(left, right);
}
We’ve finished our parallel Merge Sort. Let’s test it and compare performance with the non-parallel version. I used ThreadLocalRandom
to generate random numbers and ZonedDateTime
to calculate the execution time:
class MergeSortTest {
private final List<MergeSort> mergeSortImpls =
Arrays.asList(new MergeSortImpl(), new ParallelMergeSortImpl());
@Test
void sort() {
for (MergeSort mergeSort : mergeSortImpls) {
int[] arr = IntStream
.range(0, 100_000_000)
.map(i -> ThreadLocalRandom.current().nextInt())
.toArray();
ZonedDateTime now = ZonedDateTime.now();
mergeSort.sort(arr);
System.out.printf("%s exec time: %dms\n",
mergeSort.getClass().getSimpleName(),
ChronoUnit.MILLIS.between(now, ZonedDateTime.now()));
assertTrue(isSorted(arr));
}
}
private boolean isSorted(int[] arr) {
for (int i = 0; i < arr.length - 1; i++) {
if (arr[i] > arr[i + 1])
return false;
}
return true;
}
}
There are the execution results for 100 million of numbers on 8-core CPU:
In this article, I showed you an example of how to use Fork/Join Framework. I hope you now have a basic idea of how to speed up your applications. The source code is available over on GitHub.