For anyone like me, who has a dire need to parallelize flatMap and needs some practical solution, not only history and theory.
The simplest solution I came up with is to do flattening by hand, basically by replacing it with map + reduce(Stream::concat).
Here's an example to demonstrate how to do this:
@Test
void testParallelStream_NOT_WORKING() throws InterruptedException, ExecutionException {
    new ForkJoinPool(10).submit(() -> {
        Stream.iterate(0, i -> i + 1).limit(2)
                .parallel()
                // does not parallelize nested streams
                .flatMap(i -> generateRangeParallel(i, 100))
                .peek(i -> System.out.println(currentThread().getName() + " : generated value: i=" + i))
                .forEachOrdered(i -> System.out.println(currentThread().getName() + " : received value: i=" + i));
    }).get();
    System.out.println("done");
}
@Test
void testParallelStream_WORKING() throws InterruptedException, ExecutionException {
    new ForkJoinPool(10).submit(() -> {
        Stream.iterate(0, i -> i + 1).limit(2)
                .parallel()
                // concatenation of nested streams instead of flatMap, parallelizes ALL the items
                .map(i -> generateRangeParallel(i, 100))
                .reduce(Stream::concat).orElse(Stream.empty())
                .peek(i -> System.out.println(currentThread().getName() + " : generated value: i=" + i))
                .forEachOrdered(i -> System.out.println(currentThread().getName() + " : received value: i=" + i));
    }).get();
    System.out.println("done");
}
Stream<Integer> generateRangeParallel(int start, int num) {
    return Stream.iterate(start, i -> i + 1).limit(num).parallel();
}
// run this method with produced output to see how work was distributed
void countThreads(String strOut) {
    var res = Arrays.stream(strOut.split("\n"))
            .map(line -> line.split("\\s+"))
            .collect(Collectors.groupingBy(s -> s[0], Collectors.counting()));
    System.out.println(res);
    System.out.println("threads  : " + res.keySet().size());
    System.out.println("work     : " + res.values());
}
Stats from run on my machine:
NOT_WORKING case stats:
{ForkJoinPool-1-worker-23=100, ForkJoinPool-1-worker-5=300}
threads  : 2
work     : [100, 300]
WORKING case stats:
{ForkJoinPool-1-worker-9=16, ForkJoinPool-1-worker-23=20, ForkJoinPool-1-worker-21=36, ForkJoinPool-1-worker-31=17, ForkJoinPool-1-worker-27=177, ForkJoinPool-1-worker-13=17, ForkJoinPool-1-worker-5=21, ForkJoinPool-1-worker-19=8, ForkJoinPool-1-worker-17=21, ForkJoinPool-1-worker-3=67}
threads  : 10
work     : [16, 20, 36, 17, 177, 17, 21, 8, 21, 67]