Skip to content

Commit

Permalink
[HUDI-7830] Add predicate filter pruning for snapshot queries in hudi…
Browse files Browse the repository at this point in the history
… related sources (apache#11396)
  • Loading branch information
vinishjail97 committed Sep 19, 2024
1 parent f71a47f commit 7a242fe
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public UTF8String getPartitionPath(InternalRow row, StructType schema) {
return UTF8String.fromString(getPartitionPath(Option.empty(), Option.empty(), Option.of(Pair.of(row, schema))));
}

private String getPartitionPath(Option<GenericRecord> record, Option<Row> row, Option<Pair<InternalRow, StructType>> internalRowStructTypePair) {
public String getPartitionPath(Option<GenericRecord> record, Option<Row> row, Option<Pair<InternalRow, StructType>> internalRowStructTypePair) {
if (getPartitionPathFields() == null) {
throw new HoodieKeyException("Unable to find field names for partition path in cfg");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ public Pair<Option<Dataset<Row>>, String> fetchNextBatch(Option<String> lastCkpt
queryInfo.getStartInstant()))
.filter(String.format("%s <= '%s'", HoodieRecord.COMMIT_TIME_METADATA_FIELD,
queryInfo.getEndInstant()));
source = queryInfo.getPredicateFilter().map(source::filter).orElse(source);
}

HoodieRecord.HoodieRecordType recordType = createRecordMerger(props).getRecordType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ public static class Config {
public static final String SNAPSHOT_LOAD_QUERY_SPLITTER_CLASS_NAME = "hoodie.deltastreamer.snapshotload.query.splitter.class.name";
}

/**
* Checkpoint returned for the SnapshotLoadQuerySplitter.
*/
public static class CheckpointWithPredicates {
String endInstant;
String predicateFilter;

public CheckpointWithPredicates(String endInstant, String predicateFilter) {
this.endInstant = endInstant;
this.predicateFilter = predicateFilter;
}

public String getEndInstant() {
return endInstant;
}

public String getPredicateFilter() {
return predicateFilter;
}
}

/**
* Constructor initializing the properties.
*
Expand All @@ -62,6 +83,15 @@ public SnapshotLoadQuerySplitter(TypedProperties properties) {
this.properties = properties;
}

/**
* Abstract method to retrieve the next checkpoint with predicates.
*
* @param df The dataset to process.
* @param beginCheckpointStr The starting checkpoint string.
* @return The next checkpoint with predicates for partitionPath etc. to optimise snapshot query.
*/
public abstract Option<CheckpointWithPredicates> getNextCheckpointWithPredicates(Dataset<Row> df, String beginCheckpointStr);

/**
* Abstract method to retrieve the next checkpoint.
*
Expand All @@ -83,8 +113,8 @@ public SnapshotLoadQuerySplitter(TypedProperties properties) {
* returning endPoint same as queryInfo.getEndInstant().
*/
public QueryInfo getNextCheckpoint(Dataset<Row> df, QueryInfo queryInfo, Option<SourceProfileSupplier> sourceProfileSupplier) {
return getNextCheckpoint(df, queryInfo.getStartInstant(), sourceProfileSupplier)
.map(checkpoint -> queryInfo.withUpdatedEndInstant(checkpoint))
return getNextCheckpointWithPredicates(df, queryInfo.getStartInstant())
.map(queryInfo::withUpdatedCheckpoint)
.orElse(queryInfo);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

package org.apache.hudi.utilities.sources.helpers;

import org.apache.hudi.common.util.Option;
import org.apache.hudi.common.util.StringUtils;
import org.apache.hudi.utilities.sources.SnapshotLoadQuerySplitter;

import java.util.Arrays;
import java.util.List;

Expand All @@ -27,12 +31,24 @@
/**
* This class is used to prepare query information for s3 and gcs incr source.
* Some of the information in this class is used for batching based on sourceLimit.
* <p>
* queryType: Incremental or Snapshot query on the hudi table
* previousInstant: instant before startInstant.
* startInstant: start instant for range query
* endInstant: end instant for range query
* predicateFilter: predicate filters on columns to prune partitions and files.
* orderColumn: colum used for ordering results eg: _hoodie_record_key can be used.
* keyColumn: column used for performing range query eg: _hoodie_commit_time > startInstant and _hoodie_commit_time <= endInstant
* limitColumn: limits the numbers of rows returned by query
* orderByColumns: (orderColumn, keyColumn)
* </p>
*/
public class QueryInfo {
private final String queryType;
private final String previousInstant;
private final String startInstant;
private final String endInstant;
private final String predicateFilter;
private final String orderColumn;
private final String keyColumn;
private final String limitColumn;
Expand All @@ -43,10 +59,32 @@ public QueryInfo(
String startInstant, String endInstant,
String orderColumn, String keyColumn,
String limitColumn) {
this(
queryType,
previousInstant,
startInstant,
endInstant,
StringUtils.EMPTY_STRING,
orderColumn,
keyColumn,
limitColumn
);
}

public QueryInfo(
String queryType,
String previousInstant,
String startInstant,
String endInstant,
String predicateFilter,
String orderColumn,
String keyColumn,
String limitColumn) {
this.queryType = queryType;
this.previousInstant = previousInstant;
this.startInstant = startInstant;
this.endInstant = endInstant;
this.predicateFilter = predicateFilter;
this.orderColumn = orderColumn;
this.keyColumn = keyColumn;
this.limitColumn = limitColumn;
Expand Down Expand Up @@ -97,6 +135,13 @@ public List<String> getOrderByColumns() {
return orderByColumns;
}

public Option<String> getPredicateFilter() {
if (!StringUtils.isNullOrEmpty(predicateFilter)) {
return Option.of(predicateFilter);
}
return Option.empty();
}

public QueryInfo withUpdatedEndInstant(String newEndInstant) {
return new QueryInfo(
this.queryType,
Expand All @@ -109,6 +154,19 @@ public QueryInfo withUpdatedEndInstant(String newEndInstant) {
);
}

public QueryInfo withUpdatedCheckpoint(SnapshotLoadQuerySplitter.CheckpointWithPredicates checkpointWithPredicates) {
return new QueryInfo(
this.queryType,
this.previousInstant,
this.startInstant,
checkpointWithPredicates.getEndInstant(),
checkpointWithPredicates.getPredicateFilter(),
this.orderColumn,
this.keyColumn,
this.limitColumn
);
}

@Override
public String toString() {
return ("Query information for Incremental Source "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,12 @@ public Pair<QueryInfo, Dataset<Row>> runSnapshotQuery(QueryInfo queryInfo, Optio
}

public Dataset<Row> applySnapshotQueryFilters(Dataset<Row> snapshot, QueryInfo snapshotQueryInfo) {
return snapshot
Dataset<Row> df = snapshot
// add filtering so that only interested records are returned.
.filter(String.format("%s >= '%s'", HoodieRecord.COMMIT_TIME_METADATA_FIELD,
snapshotQueryInfo.getStartInstant()))
.filter(String.format("%s <= '%s'", HoodieRecord.COMMIT_TIME_METADATA_FIELD,
snapshotQueryInfo.getEndInstant()));
return snapshotQueryInfo.getPredicateFilter().map(df::filter).orElse(df);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import static org.apache.hudi.common.model.WriteOperationType.BULK_INSERT;
import static org.apache.hudi.common.model.WriteOperationType.INSERT;
import static org.apache.hudi.common.model.WriteOperationType.UPSERT;
import static org.apache.hudi.common.testutils.HoodieTestUtils.DEFAULT_PARTITION_PATHS;
import static org.apache.hudi.common.testutils.HoodieTestUtils.RAW_TRIPS_TEST_NAME;
import static org.apache.hudi.testutils.Assertions.assertNoWriteErrors;
import static org.apache.hudi.utilities.sources.helpers.IncrSourceHelper.MissingCheckpointStrategy.READ_UPTO_LATEST_COMMIT;
Expand Down Expand Up @@ -414,8 +415,51 @@ public void testHoodieIncrSourceWithDataSourceOptions(HoodieTableType tableType)
}
}

@Test
public void testPartitionPruningInHoodieIncrSource()
throws IOException {
this.tableType = MERGE_ON_READ;
metaClient = getHoodieMetaClient(storageConf(), basePath());
HoodieWriteConfig writeConfig = getConfigBuilder(basePath(), metaClient)
.withArchivalConfig(HoodieArchivalConfig.newBuilder().archiveCommitsWith(10, 12).build())
.withCleanConfig(HoodieCleanConfig.newBuilder().retainCommits(9).build())
.withCompactionConfig(
HoodieCompactionConfig.newBuilder()
.withScheduleInlineCompaction(true)
.withMaxNumDeltaCommitsBeforeCompaction(1)
.build())
.withMetadataConfig(HoodieMetadataConfig.newBuilder().enable(true).build())
.build();
List<Pair<String, List<HoodieRecord>>> inserts = new ArrayList<>();
try (SparkRDDWriteClient writeClient = getHoodieWriteClient(writeConfig)) {
inserts.add(writeRecordsForPartition(writeClient, BULK_INSERT, "100", DEFAULT_PARTITION_PATHS[0]));
inserts.add(writeRecordsForPartition(writeClient, BULK_INSERT, "200", DEFAULT_PARTITION_PATHS[1]));
inserts.add(writeRecordsForPartition(writeClient, BULK_INSERT, "300", DEFAULT_PARTITION_PATHS[2]));
// Go over all possible test cases to assert behaviour.
getArgsForPartitionPruningInHoodieIncrSource().forEach(argumentsStream -> {
Object[] arguments = argumentsStream.get();
String checkpointToPullFromHoodieInstant = (String) arguments[0];
int maxRowsPerSnapshotBatch = (int) arguments[1];
String expectedCheckpointHoodieInstant = (String) arguments[2];
int expectedCount = (int) arguments[3];
int expectedRDDPartitions = (int) arguments[4];

TypedProperties extraProps = new TypedProperties();
extraProps.setProperty(TestSnapshotQuerySplitterImpl.MAX_ROWS_PER_BATCH, String.valueOf(maxRowsPerSnapshotBatch));
readAndAssert(IncrSourceHelper.MissingCheckpointStrategy.READ_UPTO_LATEST_COMMIT,
Option.ofNullable(checkpointToPullFromHoodieInstant),
expectedCount,
expectedCheckpointHoodieInstant,
Option.of(TestSnapshotQuerySplitterImpl.class.getName()),
extraProps,
Option.ofNullable(expectedRDDPartitions)
);
});
}
}

private void readAndAssert(IncrSourceHelper.MissingCheckpointStrategy missingCheckpointStrategy, Option<String> checkpointToPull, int expectedCount,
String expectedCheckpoint, Option<String> snapshotCheckPointImplClassOpt, TypedProperties extraProps) {
String expectedCheckpoint, Option<String> snapshotCheckPointImplClassOpt, TypedProperties extraProps, Option<Integer> expectedRDDPartitions) {

Properties properties = new Properties();
properties.setProperty("hoodie.streamer.source.hoodieincr.path", basePath());
Expand All @@ -435,10 +479,16 @@ private void readAndAssert(IncrSourceHelper.MissingCheckpointStrategy missingChe
assertFalse(batchCheckPoint.getKey().isPresent());
} else {
assertEquals(expectedCount, batchCheckPoint.getKey().get().count());
expectedRDDPartitions.ifPresent(rddPartitions -> assertEquals(rddPartitions, batchCheckPoint.getKey().get().rdd().getNumPartitions()));
}
assertEquals(expectedCheckpoint, batchCheckPoint.getRight());
}

private void readAndAssert(IncrSourceHelper.MissingCheckpointStrategy missingCheckpointStrategy, Option<String> checkpointToPull, int expectedCount,
String expectedCheckpoint, Option<String> snapshotCheckPointImplClassOpt, TypedProperties extraProps) {
readAndAssert(missingCheckpointStrategy, checkpointToPull, expectedCount, expectedCheckpoint, snapshotCheckPointImplClassOpt, extraProps, Option.empty());
}

private void readAndAssert(IncrSourceHelper.MissingCheckpointStrategy missingCheckpointStrategy, Option<String> checkpointToPull,
int expectedCount, String expectedCheckpoint) {
readAndAssert(missingCheckpointStrategy, checkpointToPull, expectedCount, expectedCheckpoint, Option.empty(), new TypedProperties());
Expand All @@ -460,13 +510,39 @@ private Pair<String, List<HoodieRecord>> writeRecords(SparkRDDWriteClient writeC
return Pair.of(commit, records);
}

private Pair<String, List<HoodieRecord>> writeRecordsForPartition(SparkRDDWriteClient writeClient,
WriteOperationType writeOperationType,
String commit,
String partitionPath) {
writeClient.startCommitWithTime(commit);
List<HoodieRecord> records = dataGen.generateInsertsForPartition(commit, 100, partitionPath);
JavaRDD<WriteStatus> result = writeOperationType == WriteOperationType.BULK_INSERT
? writeClient.bulkInsert(jsc().parallelize(records, 1), commit)
: writeClient.upsert(jsc().parallelize(records, 1), commit);
List<WriteStatus> statuses = result.collect();
assertNoWriteErrors(statuses);
return Pair.of(commit, records);
}

private HoodieWriteConfig.Builder getConfigBuilder(String basePath, HoodieTableMetaClient metaClient) {
return HoodieWriteConfig.newBuilder().withPath(basePath).withSchema(HoodieTestDataGenerator.TRIP_EXAMPLE_SCHEMA)
.withParallelism(2, 2).withBulkInsertParallelism(2).withFinalizeWriteParallelism(2).withDeleteParallelism(2)
.withTimelineLayoutVersion(TimelineLayoutVersion.CURR_VERSION)
.forTable(metaClient.getTableConfig().getTableName());
}

private static Stream<Arguments> getArgsForPartitionPruningInHoodieIncrSource() {
// Arguments are in order -> checkpointToPullFromHoodieInstant, maxRowsPerSnapshotBatch, expectedCheckpointHoodieInstant, expectedCount, expectedFileParallelism.
return Stream.of(
Arguments.of(null, 1, "100", 100, 1),
Arguments.of(null, 101, "200", 200, 3),
Arguments.of(null, 10001, "300", 300, 3),
Arguments.of("100", 101, "300", 200, 2),
Arguments.of("200", 101, "300", 100, 1),
Arguments.of("300", 101, "300", 0, 0)
);
}

private static class DummySchemaProvider extends SchemaProvider {

private final Schema schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.apache.hudi.utilities.sources.helpers;

import org.apache.hudi.common.config.TypedProperties;
import org.apache.hudi.common.model.HoodieRecord;
import org.apache.hudi.common.util.Option;
import org.apache.hudi.utilities.sources.SnapshotLoadQuerySplitter;
import org.apache.hudi.utilities.streamer.SourceProfileSupplier;
Expand All @@ -29,12 +28,16 @@

import java.util.List;

import static org.apache.hudi.common.model.HoodieRecord.COMMIT_TIME_METADATA_FIELD;
import static org.apache.hudi.common.model.HoodieRecord.PARTITION_PATH_METADATA_FIELD;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.min;

public class TestSnapshotQuerySplitterImpl extends SnapshotLoadQuerySplitter {

private static final String COMMIT_TIME_METADATA_FIELD = HoodieRecord.COMMIT_TIME_METADATA_FIELD;
public static final String MAX_ROWS_PER_BATCH = "test.snapshot.load.max.row.count";

/**
* Constructor initializing the properties.
Expand All @@ -51,4 +54,23 @@ public Option<String> getNextCheckpoint(Dataset<Row> df, String beginCheckpointS
.orderBy(col(COMMIT_TIME_METADATA_FIELD)).limit(1).collectAsList();
return Option.ofNullable(row.size() > 0 ? row.get(0).getAs(COMMIT_TIME_METADATA_FIELD) : null);
}

@Override
public Option<CheckpointWithPredicates> getNextCheckpointWithPredicates(Dataset<Row> df, String beginCheckpointStr) {
int maxRowsPerBatch = properties.getInteger(MAX_ROWS_PER_BATCH, 1);
List<Row> row = df.select(col(COMMIT_TIME_METADATA_FIELD)).filter(col(COMMIT_TIME_METADATA_FIELD).gt(lit(beginCheckpointStr)))
.orderBy(col(COMMIT_TIME_METADATA_FIELD)).limit(maxRowsPerBatch).collectAsList();
if (!row.isEmpty()) {
String endInstant = row.get(row.size() - 1).getAs(COMMIT_TIME_METADATA_FIELD);
List<Row> minMax =
df.filter(col(COMMIT_TIME_METADATA_FIELD).gt(lit(beginCheckpointStr)))
.filter(col(COMMIT_TIME_METADATA_FIELD).leq(endInstant))
.select(PARTITION_PATH_METADATA_FIELD).agg(min(PARTITION_PATH_METADATA_FIELD).alias("min_partition_path"), max(PARTITION_PATH_METADATA_FIELD).alias("max_partition_path"))
.collectAsList();
String partitionFilter = String.format("partition_path >= '%s' and partition_path <= '%s'", minMax.get(0).getAs("min_partition_path"), minMax.get(0).getAs("max_partition_path"));
return Option.of(new CheckpointWithPredicates(endInstant, partitionFilter));
} else {
return Option.empty();
}
}
}

0 comments on commit 7a242fe

Please sign in to comment.