Skip to content

Commit

Permalink
add new typed in filter (#16039)
Browse files Browse the repository at this point in the history
changes:
* adds TypedInFilter which preserves matching sets in the native match value type
* SQL planner uses new TypedInFilter when druid.generic.useDefaultValueForNull=false (the default)
  • Loading branch information
clintropolis authored Mar 22, 2024
1 parent a70e28a commit b0a9c31
Show file tree
Hide file tree
Showing 42 changed files with 2,776 additions and 1,002 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import org.apache.druid.collections.bitmap.MutableBitmap;
import org.apache.druid.collections.bitmap.RoaringBitmapFactory;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.ByteBufferUtils;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.data.BitmapSerdeFactory;
import org.apache.druid.segment.data.GenericIndexed;
import org.apache.druid.segment.data.RoaringBitmapSerdeFactory;
import org.apache.druid.segment.index.BitmapColumnIndex;
import org.apache.druid.segment.index.IndexedUtf8ValueIndexes;
import org.apache.druid.segment.index.semantic.StringValueSetIndexes;
import org.apache.druid.segment.serde.StringUtf8ColumnIndexSupplier;
Expand Down Expand Up @@ -73,7 +73,7 @@ public static class BenchmarkState
{
@Nullable
private IndexedUtf8ValueIndexes<?> stringValueSetIndex;
private final TreeSet<ByteBuffer> values = new TreeSet<>();
private final List<ByteBuffer> values = new ArrayList<>();
private static final int START_INT = 10_000_000;

// cardinality of the dictionary. it will contain this many ints (as strings, of course), starting at START_INT,
Expand Down Expand Up @@ -122,14 +122,16 @@ public void setup()
Random r = new Random(9001);
Collections.shuffle(filterValues);
Collections.shuffle(nonFilterValues);
values.clear();
TreeSet<ByteBuffer> sortedValues = new TreeSet<>(ByteBufferUtils.utf8Comparator());
for (int i = 0; i < filterToDictionaryPercentage * dictionarySize / 100; i++) {
if (r.nextInt(100) < selectivityPercentage) {
values.add(ByteBuffer.wrap((filterValues.get(i).toString()).getBytes(StandardCharsets.UTF_8)));
sortedValues.add(ByteBuffer.wrap((filterValues.get(i).toString()).getBytes(StandardCharsets.UTF_8)));
} else {
values.add(ByteBuffer.wrap((nonFilterValues.get(i).toString()).getBytes(StandardCharsets.UTF_8)));
sortedValues.add(ByteBuffer.wrap((nonFilterValues.get(i).toString()).getBytes(StandardCharsets.UTF_8)));
}
}
values.clear();
values.addAll(sortedValues);
}

private Iterable<Integer> intGenerator()
Expand All @@ -144,6 +146,6 @@ private Iterable<Integer> intGenerator()
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public void doValueSetCheck(Blackhole blackhole, BenchmarkState state)
{
BitmapColumnIndex bitmapIndex = state.stringValueSetIndex.forSortedValuesUtf8(state.values);
blackhole.consume(state.stringValueSetIndex.forSortedValuesUtf8(state.values));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.filter.ColumnIndexSelector;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.query.filter.TypedInFilter;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.data.BitmapSerdeFactory;
import org.apache.druid.segment.data.GenericIndexed;
import org.apache.druid.segment.data.RoaringBitmapSerdeFactory;
Expand All @@ -53,8 +55,8 @@

@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@Warmup(iterations = 2)
@Measurement(iterations = 3)
public class InFilterBenchmark
{
static {
Expand All @@ -65,6 +67,8 @@ public class InFilterBenchmark

private InDimFilter inFilter;
private InDimFilter endInDimFilter;
private TypedInFilter newInFilter;
private TypedInFilter newEndInFilter;

// cardinality of the dictionary. it will contain this many ints (as strings, of course), starting at START_INT,
// even numbers only.
Expand Down Expand Up @@ -110,12 +114,29 @@ public void setup()
"dummy",
IntStream.range(START_INT, START_INT + filterSize).mapToObj(String::valueOf).collect(Collectors.toSet())
);
newInFilter = (TypedInFilter) new TypedInFilter(
"dummy",
ColumnType.STRING,
IntStream.range(START_INT, START_INT + filterSize).mapToObj(String::valueOf).collect(Collectors.toList()),
null,
null
).toFilter();
endInDimFilter = new InDimFilter(
"dummy",
IntStream.range(START_INT + dictionarySize * 2, START_INT + dictionarySize * 2 + 1)
.mapToObj(String::valueOf)
.collect(Collectors.toSet())
);

newEndInFilter = (TypedInFilter) new TypedInFilter(
"dummy",
ColumnType.STRING,
IntStream.range(START_INT + dictionarySize * 2, START_INT + dictionarySize * 2 + 1)
.mapToObj(String::valueOf)
.collect(Collectors.toList()),
null,
null
).toFilter();
}

@Benchmark
Expand All @@ -136,6 +157,24 @@ public void doFilterAtEnd(Blackhole blackhole)
blackhole.consume(bitmapIndex);
}

@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public void doFilter2(Blackhole blackhole)
{
final ImmutableBitmap bitmapIndex = Filters.computeDefaultBitmapResults(newInFilter, selector);
blackhole.consume(bitmapIndex);
}

@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public void doFilterAtEnd2(Blackhole blackhole)
{
final ImmutableBitmap bitmapIndex = Filters.computeDefaultBitmapResults(newEndInFilter, selector);
blackhole.consume(bitmapIndex);
}

private Iterable<Integer> intGenerator()
{
// i * 2 => half of these values will be present in the inFilter, half won't.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,8 @@ public void setup()
.writeValueAsString(jsonMapper.readValue((String) planResult[0], List.class))
);
}
catch (JsonProcessingException e) {
throw new RuntimeException(e);
catch (JsonProcessingException ignored) {

}

try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, ImmutableMap.of())) {
Expand All @@ -598,6 +598,9 @@ public void setup()
}
log.info("Total result row count:" + rowCounter);
}
catch (Throwable ignored) {

}
}

private StringEncodingStrategy getStringEncodingStrategy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

package org.apache.druid.benchmark.query;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand All @@ -28,6 +31,8 @@
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.ExpressionProcessing;
Expand Down Expand Up @@ -198,7 +203,17 @@ public String getFormatString()
"SELECT SUM(long1) FROM foo WHERE string5 LIKE '%1%' AND string1 = '1000'",
"SELECT SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.string5') LIKE '%1%' AND JSON_VALUE(nested, '$.nesteder.string1') = '1000'",
"SELECT SUM(long1) FROM foo WHERE string1 = '1000' AND string5 LIKE '%1%'",
"SELECT SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.string1') = '1000' AND JSON_VALUE(nested, '$.nesteder.string5') LIKE '%1%'"
"SELECT SUM(JSON_VALUE(nested, '$.long1' RETURNING BIGINT)) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.string1') = '1000' AND JSON_VALUE(nested, '$.nesteder.string5') LIKE '%1%'",
//48,49 bigger in
"SELECT long2 FROM foo WHERE long2 IN (1, 19, 21, 23, 25, 26, 46, 50, 51, 55, 60, 61, 66, 68, 69, 70, 77, 88, 90, 92, 93, 94, 95, 100, 101, 102, 104, 109, 111, 113, 114, 115, 120, 121, 122, 134, 135, 136, 140, 142, 150, 155, 170, 172, 173, 174, 180, 181, 190, 199, 200, 201, 202, 203, 204)",
"SELECT JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) IN (1, 19, 21, 23, 25, 26, 46, 50, 51, 55, 60, 61, 66, 68, 69, 70, 77, 88, 90, 92, 93, 94, 95, 100, 101, 102, 104, 109, 111, 113, 114, 115, 120, 121, 122, 134, 135, 136, 140, 142, 150, 155, 170, 172, 173, 174, 180, 181, 190, 199, 200, 201, 202, 203, 204)",
//50, 51 bigger in group
"SELECT long2 FROM foo WHERE long2 IN (1, 19, 21, 23, 25, 26, 46, 50, 51, 55, 60, 61, 66, 68, 69, 70, 77, 88, 90, 92, 93, 94, 95, 100, 101, 102, 104, 109, 111, 113, 114, 115, 120, 121, 122, 134, 135, 136, 140, 142, 150, 155, 170, 172, 173, 174, 180, 181, 190, 199, 200, 201, 202, 203, 204) GROUP BY 1",
"SELECT JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) IN (1, 19, 21, 23, 25, 26, 46, 50, 51, 55, 60, 61, 66, 68, 69, 70, 77, 88, 90, 92, 93, 94, 95, 100, 101, 102, 104, 109, 111, 113, 114, 115, 120, 121, 122, 134, 135, 136, 140, 142, 150, 155, 170, 172, 173, 174, 180, 181, 190, 199, 200, 201, 202, 203, 204) GROUP BY 1",
"SELECT long2 FROM foo WHERE double3 IN (1.0, 19.0, 21.0, 23.0, 25.0, 26.0, 46.0, 50.0, 51.0, 55.0, 60.0, 61.0, 66.0, 68.0, 69.0, 70.0, 77.0, 88.0, 90.0, 92.0, 93.0, 94.0, 95.0, 100.0, 101.0, 102.0, 104.0, 109.0, 111.0, 113.0, 114.0, 115.0, 120.0, 121.0, 122.0, 134.0, 135.0, 136.0, 140.0, 142.0, 150.0, 155.0, 170.0, 172.0, 173.0, 174.0, 180.0, 181.0, 190.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0)",
"SELECT JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE) IN (1.0, 19.0, 21.0, 23.0, 25.0, 26.0, 46.0, 50.0, 51.0, 55.0, 60.0, 61.0, 66.0, 68.0, 69.0, 70.0, 77.0, 88.0, 90.0, 92.0, 93.0, 94.0, 95.0, 100.0, 101.0, 102.0, 104.0, 109.0, 111.0, 113.0, 114.0, 115.0, 120.0, 121.0, 122.0, 134.0, 135.0, 136.0, 140.0, 142.0, 150.0, 155.0, 170.0, 172.0, 173.0, 174.0, 180.0, 181.0, 190.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0)",
"SELECT long2 FROM foo WHERE double3 IN (1.0, 19.0, 21.0, 23.0, 25.0, 26.0, 46.0, 50.0, 51.0, 55.0, 60.0, 61.0, 66.0, 68.0, 69.0, 70.0, 77.0, 88.0, 90.0, 92.0, 93.0, 94.0, 95.0, 100.0, 101.0, 102.0, 104.0, 109.0, 111.0, 113.0, 114.0, 115.0, 120.0, 121.0, 122.0, 134.0, 135.0, 136.0, 140.0, 142.0, 150.0, 155.0, 170.0, 172.0, 173.0, 174.0, 180.0, 181.0, 190.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0) GROUP BY 1",
"SELECT JSON_VALUE(nested, '$.nesteder.long2' RETURNING BIGINT) FROM foo WHERE JSON_VALUE(nested, '$.nesteder.double3' RETURNING DOUBLE) IN (1.0, 19.0, 21.0, 23.0, 25.0, 26.0, 46.0, 50.0, 51.0, 55.0, 60.0, 61.0, 66.0, 68.0, 69.0, 70.0, 77.0, 88.0, 90.0, 92.0, 93.0, 94.0, 95.0, 100.0, 101.0, 102.0, 104.0, 109.0, 111.0, 113.0, 114.0, 115.0, 120.0, 121.0, 122.0, 134.0, 135.0, 136.0, 140.0, 142.0, 150.0, 155.0, 170.0, 172.0, 173.0, 174.0, 180.0, 181.0, 190.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0) GROUP BY 1"
);

@Param({"5000000"})
Expand Down Expand Up @@ -271,7 +286,15 @@ public String getFormatString()
"44",
"45",
"46",
"47"
"47",
"48",
"49",
"50",
"51",
"52",
"53",
"54",
"55"
})
private String query;

Expand Down Expand Up @@ -386,8 +409,41 @@ public void setup()
QUERIES.get(Integer.parseInt(query))
);
}
catch (Throwable ignored) {
// the show must go on
catch (Throwable ex) {
log.warn(ex, "failed to sanity check");
}

final String sql = QUERIES.get(Integer.parseInt(query));
final ObjectMapper jsonMapper = CalciteTests.getJsonMapper();
try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, "EXPLAIN PLAN FOR " + sql, ImmutableMap.of("useNativeQueryExplain", true))) {
final PlannerResult plannerResult = planner.plan();
final Sequence<Object[]> resultSequence = plannerResult.run().getResults();
final Object[] planResult = resultSequence.toList().get(0);
log.info("Native query plan:\n" +
jsonMapper.writerWithDefaultPrettyPrinter()
.writeValueAsString(jsonMapper.readValue((String) planResult[0], List.class))
);
}
catch (JsonMappingException e) {
throw new RuntimeException(e);
}
catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

try (final DruidPlanner planner = plannerFactory.createPlannerForTesting(engine, sql, ImmutableMap.of())) {
final PlannerResult plannerResult = planner.plan();
final Sequence<Object[]> resultSequence = plannerResult.run().getResults();
final Yielder<Object[]> yielder = Yielders.each(resultSequence);
int rowCounter = 0;
while (!yielder.isDone()) {
rowCounter++;
yielder.next(yielder.get());
}
log.info("Total result row count:" + rowCounter);
}
catch (Throwable ex) {
log.warn(ex, "failed to count rows");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
import java.util.Objects;
import java.util.Set;

/**
* Recommended to use {@link RangeFilter} instead
*/
public class BoundDimFilter extends AbstractOptimizableDimFilter implements DimFilter
{
private final String dimension;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
@JsonSubTypes.Type(name = "range", value = RangeFilter.class),
@JsonSubTypes.Type(name = "isfalse", value = IsFalseDimFilter.class),
@JsonSubTypes.Type(name = "istrue", value = IsTrueDimFilter.class),
@JsonSubTypes.Type(name = "arrayContainsElement", value = ArrayContainsElementFilter.class)
@JsonSubTypes.Type(name = "arrayContainsElement", value = ArrayContainsElementFilter.class),
@JsonSubTypes.Type(name = "inType", value = TypedInFilter.class)
})
public interface DimFilter extends Cacheable
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public class DimFilterUtils
static final byte RANGE_CACHE_ID = 0x14;
static final byte IS_FILTER_BOOLEAN_FILTER_CACHE_ID = 0x15;
static final byte ARRAY_CONTAINS_CACHE_ID = 0x16;
static final byte TYPED_IN_CACHE_ID = 0x17;


public static final byte STRING_SEPARATOR = (byte) 0xFF;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.common.collect.ForwardingSortedSet;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.Sets;
Expand All @@ -42,7 +43,6 @@
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.ByteBufferUtils;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Comparators;
Expand All @@ -69,12 +69,20 @@
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

/**
* Approximately like the SQL 'IN' filter, with the main difference being that this will match NULL values if contained
* in the values list instead of ignoring them.
* <p>
* This filter specifies all match values as a sorted string set; matching against other column types must incur the
* cost of converting values to check for matches. For the most part, {@link TypedInFilter} should be used instead.
*/
public class InDimFilter extends AbstractOptimizableDimFilter implements Filter
{
/**
Expand All @@ -84,7 +92,7 @@ public class InDimFilter extends AbstractOptimizableDimFilter implements Filter
*/
private final ValuesSet values;
// Computed eagerly, not lazily, because lazy computations would block all processing threads for a given query.
private final SortedSet<ByteBuffer> valuesUtf8;
private final List<ByteBuffer> valuesUtf8;
private final String dimension;
@Nullable
private final ExtractionFn extractionFn;
Expand Down Expand Up @@ -806,9 +814,9 @@ public static ValuesSet copyOf(final Collection<String> values)
return copyOf(values.iterator());
}

public SortedSet<ByteBuffer> toUtf8()
public List<ByteBuffer> toUtf8()
{
final TreeSet<ByteBuffer> valuesUtf8 = new TreeSet<>(ByteBufferUtils.utf8Comparator());
final List<ByteBuffer> valuesUtf8 = Lists.newArrayListWithCapacity(values.size());

for (final String value : values) {
if (value == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import java.util.Set;

/**
*
* Recommended to use {@link EqualityFilter} or {@link NullFilter} instead
*/
public class SelectorDimFilter extends AbstractOptimizableDimFilter implements DimFilter
{
Expand Down
Loading

0 comments on commit b0a9c31

Please sign in to comment.