Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PCollection data sampling for Java SDK harness #25064 #25354

Merged
merged 43 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
cd630e4
Data Sampling Java Impl
Jan 26, 2023
e530d14
comments
Jan 26, 2023
a09f5d6
add PBD id to context
rohdesamuel Jan 31, 2023
1cb08f9
merge
Jan 31, 2023
566c7b8
Add more tests and spotless
Jan 31, 2023
7395a2e
Finish Java data sampling impl with tests, adding comments
Feb 1, 2023
4c1253f
more comments, remove Payload
Feb 1, 2023
3d88254
more comments
Feb 1, 2023
5bbef91
spotless
Feb 1, 2023
6993b83
Encode in the nested context
rohdesamuel Feb 3, 2023
769902f
Update sdks/java/harness/src/main/java/org/apache/beam/fn/harness/con…
rohdesamuel Feb 7, 2023
5c06d4e
Apply suggestions from code review
rohdesamuel Feb 7, 2023
99087e8
address pr comments
Feb 9, 2023
5609e4a
give default pbd id to test context
rohdesamuel Feb 9, 2023
694c929
address spotlesscheck
rohdesamuel Feb 9, 2023
519aece
spotless apply
rohdesamuel Feb 9, 2023
8efbb15
style guide spotless apply
rohdesamuel Feb 9, 2023
cd3732d
add serviceloader
rohdesamuel Feb 9, 2023
415e3f0
change datasampling to modify the consumers and not graph for sampling
rohdesamuel Feb 10, 2023
553fc5e
remove redundant SamplerState obj
rohdesamuel Feb 10, 2023
e260eb4
spotless
rohdesamuel Feb 10, 2023
4f29308
replace mutex with atomics in output sampler to reduce contention
rohdesamuel Feb 10, 2023
8cbc6c8
spotless and fix OutputSamplerTest
rohdesamuel Feb 10, 2023
f67234f
Update sdks/java/harness/src/main/java/org/apache/beam/fn/harness/dat…
rohdesamuel Feb 13, 2023
6815330
Update sdks/java/harness/src/main/java/org/apache/beam/fn/harness/dat…
rohdesamuel Feb 13, 2023
4848725
Update sdks/java/harness/src/main/java/org/apache/beam/fn/harness/deb…
rohdesamuel Feb 13, 2023
6c16576
always init outputsampler
rohdesamuel Feb 13, 2023
822587d
add final to DataSampler in FnHarness
rohdesamuel Feb 13, 2023
fe9ab2a
spotless apply
rohdesamuel Feb 13, 2023
07d37ea
update from proto names
rohdesamuel Feb 13, 2023
5e9c4b0
spotless bugs
Feb 14, 2023
f5f97fb
Apply suggestions from code review
rohdesamuel Feb 14, 2023
c3db7c0
address pr comments
Feb 14, 2023
fce6d69
spotlessapply and add byte[] test
Feb 15, 2023
69d8bb4
validate datasampler args
Feb 15, 2023
be2ebe4
add concurrency tests
Feb 15, 2023
93af06e
Merge branch 'master' into data-sampling-java
lukecwik Feb 21, 2023
cbdbbc3
Update sdks/java/harness/src/main/java/org/apache/beam/fn/harness/deb…
rohdesamuel Feb 21, 2023
8cc8ee0
Update sdks/java/harness/src/test/java/org/apache/beam/fn/harness/deb…
rohdesamuel Feb 21, 2023
10aa7de
Update sdks/java/harness/src/test/java/org/apache/beam/fn/harness/deb…
rohdesamuel Feb 21, 2023
04c6c88
Update sdks/java/harness/src/test/java/org/apache/beam/fn/harness/deb…
rohdesamuel Feb 21, 2023
862439b
improve contention tests
Feb 22, 2023
490be4a
spotless
Feb 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Finish Java data sampling impl with tests, adding comments
  • Loading branch information
Sam Rohde authored and rohdesamuel committed Feb 13, 2023
commit 7395a2e661ddcfd55dc7f3529cfc989739c6501e
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Collections;
import java.util.EnumMap;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
Expand Down Expand Up @@ -256,9 +257,12 @@ public static void main(

// Add any graph modifications.
List<ProcessBundleDescriptorModifier> modifierList = new ArrayList<>();
List<String> experimentList = options.as(ExperimentalOptions.class).getExperiments();
Optional<List<String>> experimentList =
Optional.ofNullable(options.as(ExperimentalOptions.class).getExperiments());

if (experimentList != null && experimentList.contains(ENABLE_DATA_SAMPLING_EXPERIMENT)) {
// If data sampling is enabled, then modify the graph to add any DataSampling Operations.
if (experimentList.isPresent()
&& experimentList.get().contains(ENABLE_DATA_SAMPLING_EXPERIMENT)) {
modifierList.add(new DataSamplingDescriptorModifier());
}

Expand Down Expand Up @@ -357,6 +361,7 @@ private BeamFnApi.ProcessBundleDescriptor loadDescriptor(String id) {
handlers.put(
InstructionRequest.RequestCase.HARNESS_MONITORING_INFOS,
processWideHandler::harnessMonitoringInfos);
handlers.put(InstructionRequest.RequestCase.SAMPLE, dataSampler::handleDataSampleRequest);

JvmInitializers.runBeforeProcessing(options);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,65 @@
*/
package org.apache.beam.fn.harness.debug;

import avro.shaded.com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.SampleDataResponse.ElementList;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.SampledElement;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString;

/**
* The DataSampler is a global (per SDK Harness) object that facilitates taking and returning
* samples to the Runner Harness. The class is thread-safe with respect to executing
* ProcessBundleDescriptors. Meaning, different threads executing different PBDs can sample
* simultaneously, even if computing the same logical PCollection.
*/
public class DataSampler {

/** Creates a DataSampler to sample every 10 elements while keeping a maximum of 10 in memory. */
public DataSampler() {}

/**
* @param maxSamples Sets the maximum number of samples held in memory at once.
* @param sampleEveryN Sets how often to sample.
*/
public DataSampler(int maxSamples, int sampleEveryN) {
rohdesamuel marked this conversation as resolved.
Show resolved Hide resolved
this.maxSamples = maxSamples;
this.sampleEveryN = sampleEveryN;
}

public static Set<String> EMPTY = new HashSet<>();

// Maximum number of elements in buffer.
private int maxSamples = 10;
rohdesamuel marked this conversation as resolved.
Show resolved Hide resolved

// Sampling rate.
private int sampleEveryN = 1000;
rohdesamuel marked this conversation as resolved.
Show resolved Hide resolved

private final Map<String, Map<String, OutputSampler<?>>> outputSamplers = new HashMap<>();

// The fully-qualified type is: Map[ProcessBundleDescriptorId, [PCollectionId, OutputSampler]].
// The DataSampler object lives on the same level of the FnHarness. This means that many threads
// can and will
// access this simultaneously. However, ProcessBundleDescriptors are unique per thread, so only
rohdesamuel marked this conversation as resolved.
Show resolved Hide resolved
// synchronization
// is needed on the outermost map.
private final Map<String, Map<String, OutputSampler<?>>> outputSamplers =
new ConcurrentHashMap<>();

/**
* Creates and returns a class to sample the given PCollection in the given
* ProcessBundleDescriptor. Uses the given coder encode samples as bytes when responding to a
* SampleDataRequest.
*
rohdesamuel marked this conversation as resolved.
Show resolved Hide resolved
* @param processBundleDescriptorId The PBD to sample from.
* @param pcollectionId The PCollection to take intermittent samples from.
* @param coder The coder associated with the PCollection. Coder may be from a nested context.
* @return the OutputSampler corresponding to the unique PBD and PCollection.
* @param <T> The type of element contained in the PCollection.
*/
public <T> OutputSampler<T> sampleOutput(
String processBundleDescriptorId, String pcollectionId, Coder<T> coder) {
outputSamplers.putIfAbsent(processBundleDescriptorId, new HashMap<>());
Expand All @@ -54,12 +86,50 @@ public <T> OutputSampler<T> sampleOutput(
return (OutputSampler<T>) samplers.get(pcollectionId);
}

public Map<String, List<byte[]>> samples() {
return samplesFor(EMPTY, EMPTY);
/**
* Returns all collected samples. Thread-safe.
*
* @param request The instruction request from the FnApi. Filters based on the given
* SampleDataRequest.
* @return Returns all collected samples.
*/
public BeamFnApi.InstructionResponse.Builder handleDataSampleRequest(
rohdesamuel marked this conversation as resolved.
Show resolved Hide resolved
BeamFnApi.InstructionRequest request) {
BeamFnApi.SampleDataRequest sampleDataRequest = request.getSample();

Map<String, List<byte[]>> responseSamples =
samplesFor(
ImmutableSet.copyOf(sampleDataRequest.getProcessBundleDescriptorIdsList()),
rohdesamuel marked this conversation as resolved.
Show resolved Hide resolved
ImmutableSet.copyOf(sampleDataRequest.getPcollectionIdsList()));

BeamFnApi.SampleDataResponse.Builder response = BeamFnApi.SampleDataResponse.newBuilder();
for (String pcollectionId : responseSamples.keySet()) {
ElementList.Builder elementList = ElementList.newBuilder();
for (byte[] sample : responseSamples.get(pcollectionId)) {
elementList.addElements(
SampledElement.newBuilder().setElement(ByteString.copyFrom(sample)).build());
}
response.putElementSamples(pcollectionId, elementList.build());
}

return BeamFnApi.InstructionResponse.newBuilder().setSample(response);
}

/**
* Returns a map from PCollection to its samples. Samples are filtered on
* ProcessBundleDescriptorIds and PCollections. Thread-safe.
*
* @param descriptors PCollections under each PBD id will be unioned. If empty, allows all
* descriptors.
* @param pcollections Filters all PCollections on this set. If empty, allows all PCollections.
* @return a map from PCollection to its samples.
*/
public Map<String, List<byte[]>> samplesFor(Set<String> descriptors, Set<String> pcollections) {
rohdesamuel marked this conversation as resolved.
Show resolved Hide resolved
Map<String, List<byte[]>> samples = new HashMap<>();

// Safe to iterate as the ConcurrentHashMap will return each element at most once and will not
// throw
// ConcurrentModificationException.
outputSamplers.forEach(
(descriptorId, samplers) -> {
if (!descriptors.isEmpty() && !descriptors.contains(descriptorId)) {
Expand All @@ -80,11 +150,24 @@ public Map<String, List<byte[]>> samplesFor(Set<String> descriptors, Set<String>
return samples;
}

/** @return samples from all PBDs and all PCollections. */
public Map<String, List<byte[]>> allSamples() {
return samplesFor(ImmutableSet.of(), ImmutableSet.of());
}

/**
* @param descriptors PBDs to filter on.
* @return samples only from the given descriptors.
*/
public Map<String, List<byte[]>> samplesForDescriptors(Set<String> descriptors) {
return samplesFor(descriptors, EMPTY);
return samplesFor(descriptors, ImmutableSet.of());
}

/**
* @param pcollections PCollection ids to filter on.
* @return samples only from the given PCollections.
*/
public Map<String, List<byte[]>> samplesForPCollections(Set<String> pcollections) {
return samplesFor(EMPTY, pcollections);
return samplesFor(ImmutableSet.of(), pcollections);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.*;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand All @@ -49,17 +51,27 @@ public byte[] encodeString(String s) throws IOException {
return stream.toByteArray();
}

/**
* Smoke test that a samples show in the output map.
*
* @throws Exception
*/
@Test
public void testSingleOutput() throws Exception {
DataSampler sampler = new DataSampler();

VarIntCoder coder = VarIntCoder.of();
sampler.sampleOutput("descriptor-id", "pcollection-id", coder).sample(1);

Map<String, List<byte[]>> samples = sampler.samples();
Map<String, List<byte[]>> samples = sampler.allSamples();
assertThat(samples.get("pcollection-id"), contains(encodeInt(1)));
}

/**
* Test that sampling multiple PCollections under the same descriptor is OK.
*
* @throws Exception
*/
@Test
public void testMultipleOutputs() throws Exception {
DataSampler sampler = new DataSampler();
Expand All @@ -68,11 +80,16 @@ public void testMultipleOutputs() throws Exception {
sampler.sampleOutput("descriptor-id", "pcollection-id-1", coder).sample(1);
sampler.sampleOutput("descriptor-id", "pcollection-id-2", coder).sample(2);

Map<String, List<byte[]>> samples = sampler.samples();
Map<String, List<byte[]>> samples = sampler.allSamples();
assertThat(samples.get("pcollection-id-1"), contains(encodeInt(1)));
assertThat(samples.get("pcollection-id-2"), contains(encodeInt(2)));
}

/**
* Test that the response contains samples from the same PCollection across descriptors.
*
* @throws Exception
*/
@Test
public void testMultipleDescriptors() throws Exception {
DataSampler sampler = new DataSampler();
Expand All @@ -81,10 +98,15 @@ public void testMultipleDescriptors() throws Exception {
sampler.sampleOutput("descriptor-id-1", "pcollection-id", coder).sample(1);
sampler.sampleOutput("descriptor-id-2", "pcollection-id", coder).sample(2);

Map<String, List<byte[]>> samples = sampler.samples();
Map<String, List<byte[]>> samples = sampler.allSamples();
assertThat(samples.get("pcollection-id"), contains(encodeInt(1), encodeInt(2)));
}

/**
* Test that samples can be filtered based on ProcessBundleDescriptor id.
*
* @throws Exception
*/
@Test
public void testFiltersSingleDescriptorId() throws Exception {
DataSampler sampler = new DataSampler(10, 10);
Expand All @@ -101,6 +123,11 @@ public void testFiltersSingleDescriptorId() throws Exception {
assertThat(samples.get("2"), contains(encodeString("a2")));
}

/**
* Test that samples are unioned based on ProcessBundleDescriptor id.
*
* @throws Exception
*/
@Test
public void testFiltersMultipleDescriptorId() throws Exception {
DataSampler sampler = new DataSampler(10, 10);
Expand All @@ -116,6 +143,11 @@ public void testFiltersMultipleDescriptorId() throws Exception {
assertThat(samples.get("2"), contains(encodeString("a2"), encodeString("b2")));
}

/**
* Test that samples can be filtered based on PCollection id.
*
* @throws Exception
*/
@Test
public void testFiltersSinglePCollectionId() throws Exception {
DataSampler sampler = new DataSampler(10, 10);
Expand All @@ -131,14 +163,6 @@ public void testFiltersSinglePCollectionId() throws Exception {
assertThat(samples.get("1"), containsInAnyOrder(encodeString("a1"), encodeString("b1")));
}

Map<String, List<byte[]>> singletonSample(String pcollectionId, byte[] element) {
Map<String, List<byte[]>> ret = new HashMap<>();
List<byte[]> list = new ArrayList<>();
list.add(element);
ret.put(pcollectionId, list);
return ret;
}

void generateStringSamples(DataSampler sampler) {
StringUtf8Coder coder = StringUtf8Coder.of();
sampler.sampleOutput("a", "1", coder).sample("a1");
Expand All @@ -147,11 +171,17 @@ void generateStringSamples(DataSampler sampler) {
sampler.sampleOutput("b", "2", coder).sample("b2");
}

/**
* Test that samples can be filtered both on PCollection and ProcessBundleDescriptor id.
*
* @throws Exception
*/
@Test
public void testFiltersDescriptorAndPCollectionIds() throws Exception {
List<String> descriptorIds = ImmutableList.of("a", "b");
List<String> pcollectionIds = ImmutableList.of("1", "2");

// Try all combinations for descriptor and PCollection ids.
for (String descriptorId : descriptorIds) {
for (String pcollectionId : pcollectionIds) {
DataSampler sampler = new DataSampler(10, 10);
Expand All @@ -166,4 +196,46 @@ public void testFiltersDescriptorAndPCollectionIds() throws Exception {
}
}
}

/**
* Test that the DataSampler can respond with the correct samples with filters.
*
* @throws Exception
*/
@Test
public void testMakesCorrectResponse() throws Exception {
DataSampler dataSampler = new DataSampler();
generateStringSamples(dataSampler);

// SampleDataRequest that filters on PCollection=1 and PBD ids = "a" or "b".
BeamFnApi.InstructionRequest request =
BeamFnApi.InstructionRequest.newBuilder()
.setSample(
BeamFnApi.SampleDataRequest.newBuilder()
.addPcollectionIds("1")
.addProcessBundleDescriptorIds("a")
.addProcessBundleDescriptorIds("b")
.build())
.build();
BeamFnApi.InstructionResponse actual = dataSampler.handleDataSampleRequest(request).build();
BeamFnApi.InstructionResponse expected =
BeamFnApi.InstructionResponse.newBuilder()
.setSample(
BeamFnApi.SampleDataResponse.newBuilder()
.putElementSamples(
"1",
BeamFnApi.SampleDataResponse.ElementList.newBuilder()
.addElements(
BeamFnApi.SampledElement.newBuilder()
.setElement(ByteString.copyFrom(encodeString("a1")))
.build())
.addElements(
BeamFnApi.SampledElement.newBuilder()
.setElement(ByteString.copyFrom(encodeString("b1")))
.build())
.build())
.build())
.build();
assertThat(actual, equalTo(expected));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public void testCreatingAndProcessingWithSampling() throws Exception {
.withPipeline(Pipeline.create());
Coder<String> rehydratedCoder = (Coder<String>) rehydratedComponents.getCoder("coder-id");

Map<String, List<byte[]>> samples = dataSampler.samples();
Map<String, List<byte[]>> samples = dataSampler.allSamples();
assertThat(samples.keySet(), contains("inputTarget"));

// Ensure that the value was sampled.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void testSamplesFirstN() throws Exception {
expected.add(encodeInt(i));
}

Map<String, List<byte[]>> samples = sampler.samples();
Map<String, List<byte[]>> samples = sampler.allSamples();
assertThat(samples.get("pcollection-id"), containsInAnyOrder(expected.toArray()));
}

Expand All @@ -79,7 +79,7 @@ public void testActsLikeCircularBuffer() throws Exception {
expected.add(encodeInt(79));
expected.add(encodeInt(99));

Map<String, List<byte[]>> samples = sampler.samples();
Map<String, List<byte[]>> samples = sampler.allSamples();
assertThat(samples.get("pcollection-id"), containsInAnyOrder(expected.toArray()));
}
}