Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming collectors for stable performance #536

Merged
merged 2 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions avro-builder/builder-spi/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies {
implementation "org.apache.logging.log4j:log4j-api:2.17.1"
implementation "commons-io:commons-io:2.11.0"
implementation "jakarta.json:jakarta.json-api:2.0.1"
implementation "com.pivovarit:parallel-collectors:2.5.0"

testImplementation "org.apache.avro:avro:1.9.2"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@

import com.linkedin.avroutil1.builder.operations.Operation;
import com.linkedin.avroutil1.builder.operations.OperationContext;
import com.linkedin.avroutil1.builder.util.StreamUtil;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**
* context for running a set of {@link com.linkedin.avroutil1.builder.plugins.BuilderPlugin}s
*/
public class BuilderPluginContext {

private static final Logger LOGGER = LoggerFactory.getLogger(BuilderPluginContext.class);

private final List<Operation> operations = new ArrayList<>(1);
private volatile boolean sealed = false;
private final OperationContext operationContext;
Expand All @@ -43,12 +48,16 @@ public void run() throws Exception {
//"seal" any internal state to prevent plugins from trying to do weird things during execution
sealed = true;

operations.parallelStream().forEach(op -> {
int operationCount = operations.stream().collect(StreamUtil.toParallelStream(op -> {
try {
op.run(operationContext);
} catch (Exception e) {
throw new IllegalStateException("Exception running operation", e);
}
});

return 1;
}, 2)).reduce(0, Integer::sum);

LOGGER.info("Executed {} operations for builder plugins", operationCount);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2024 LinkedIn Corp.
* Licensed under the BSD 2-Clause License (the "License").
* See License in the project root for license information.
*/

package com.linkedin.avroutil1.builder.util;

import com.pivovarit.collectors.ParallelCollectors;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Stream;


/**
* Utilities for dealing with java streams.
*/
public final class StreamUtil {

/**
* An (effectively) unbounded {@link ExecutorService} used for parallel processing. This is kept unbounded to avoid
* deadlocks caused when using {@link #toParallelStream(Function, int)} recursively. Callers are supposed to set
* sane values for parallelism to avoid spawning a crazy number of concurrent threads.
*/
private static final ExecutorService WORK_EXECUTOR =
new ThreadPoolExecutor(0, Integer.MAX_VALUE, 60, TimeUnit.SECONDS, new SynchronousQueue<>());

private StreamUtil() {
// Disallow external instantiation.
}

/**
* A convenience {@link Collector} used for executing parallel computations on a custom {@link Executor}
* and returning a {@link Stream} instance returning results as they arrive.
* <p>
* For the parallelism of 1, the stream is executed by the calling thread.
*
* @param mapper a transformation to be performed in parallel
* @param parallelism the max parallelism level
* @param <T> the type of the collected elements
* @param <R> the result returned by {@code mapper}
*
* @return a {@code Collector} which collects all processed elements into a {@code Stream} in parallel.
*/
public static <T, R> Collector<T, ?, Stream<R>> toParallelStream(Function<T, R> mapper, int parallelism) {
return ParallelCollectors.parallelToStream(mapper, WORK_EXECUTOR, parallelism);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.linkedin.avroutil1.builder.operations.OperationContext;
import com.linkedin.avroutil1.builder.operations.SchemaSet;
import com.linkedin.avroutil1.builder.operations.codegen.CodeGenOpConfig;
import com.linkedin.avroutil1.builder.util.StreamUtil;
import com.linkedin.avroutil1.builder.plugins.BuilderPlugin;
import com.linkedin.avroutil1.builder.plugins.BuilderPluginContext;
import com.linkedin.avroutil1.codegen.SpecificRecordClassGenerator;
Expand Down Expand Up @@ -101,14 +102,14 @@ private void generateCode(OperationContext opContext) {
AvroJavaStringRepresentation.fromJson(config.getMethodStringRepresentation().toString()),
config.getMinAvroVersion(), config.isUtf8EncodingPutByIndexEnabled());
final SpecificRecordClassGenerator generator = new SpecificRecordClassGenerator();
List<JavaFile> generatedClasses = allNamedSchemas.parallelStream().map(namedSchema -> {
List<JavaFile> generatedClasses = allNamedSchemas.stream().collect(StreamUtil.toParallelStream(namedSchema -> {
try {
// Top level schema
return generator.generateSpecificClass(namedSchema, generationConfig);
} catch (Exception e) {
throw new RuntimeException("failed to generate class for " + namedSchema.getFullName(), e);
}
}).collect(Collectors.toList());
}, 10)).collect(Collectors.toList());
long genEnd = System.currentTimeMillis();
LOGGER.info("Generated {} java source files in {} millis", generatedClasses.size(), genEnd - genStart);

Expand All @@ -129,15 +130,15 @@ private void writeJavaFilesToDisk(Collection<JavaFile> javaFiles, Path outputFol
long writeStart = System.currentTimeMillis();

// write out the files we generated
int filesWritten = javaFiles.parallelStream().map(javaFile -> {
int filesWritten = javaFiles.stream().collect(StreamUtil.toParallelStream(javaFile -> {
try {
javaFile.writeToPath(outputFolderPath);
} catch (Exception e) {
throw new IllegalStateException("while writing file " + javaFile.typeSpec.name, e);
}

return 1;
}).reduce(0, Integer::sum);
}, 10)).reduce(0, Integer::sum);

long writeEnd = System.currentTimeMillis();
LOGGER.info("Wrote out {} generated java source files under {} in {} millis", filesWritten, outputFolderPath,
Expand Down
Loading