Skip to content

Commit

Permalink
Reuse CompensatedSum object in agg collect loops (#49548)
Browse files Browse the repository at this point in the history
The new CompensatedSum is a nice DRY refactor, but had the unanticipated 
side effect of creating a lot of object allocation in the aggregation hot collection 
loop: one object per visited document, per aggregator. In some places it 
created two per-doc-per-agg (weighted avg, geo centroids, etc) since there 
were multiple compensations being maintained.

This PR moves the object creation out of the hot loop so that it is now 
created once per segment, and resets the internal state each time through 
the loop
  • Loading branch information
polyfractal committed Nov 25, 2019
1 parent 2fd58bb commit 99e3136
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);

return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
Expand All @@ -87,7 +89,8 @@ public void collect(int doc, long bucket) throws IOException {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);

kahanSummation.reset(sum, compensation);

for (int i = 0; i < valueCount; i++) {
double value = values.nextValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ public CompensatedSum add(double value) {
return add(value, NO_CORRECTION);
}

/**
* Resets the internal state to use the new value and compensation delta
*/
public void reset(double value, double delta) {
this.value = value;
this.delta = delta;
}

/**
* Increments the Kahan sum by adding two sums, and updating the correction term for reducing numeric errors.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum compensatedSum = new CompensatedSum(0, 0);
final CompensatedSum compensatedSumOfSqr = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {

@Override
Expand Down Expand Up @@ -117,11 +119,11 @@ public void collect(int doc, long bucket) throws IOException {
// which is more accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum compensatedSum = new CompensatedSum(sum, compensation);
compensatedSum.reset(sum, compensation);

double sumOfSqr = sumOfSqrs.get(bucket);
double compensationOfSqr = compensationOfSqrs.get(bucket);
CompensatedSum compensatedSumOfSqr = new CompensatedSum(sumOfSqr, compensationOfSqr);
compensatedSumOfSqr.reset(sumOfSqr, compensationOfSqr);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
}
final BigArrays bigArrays = context.bigArrays();
final MultiGeoPointValues values = valuesSource.geoPointValues(ctx);
final CompensatedSum compensatedSumLat = new CompensatedSum(0, 0);
final CompensatedSum compensatedSumLon = new CompensatedSum(0, 0);

return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
Expand All @@ -88,8 +91,8 @@ public void collect(int doc, long bucket) throws IOException {
double sumLon = lonSum.get(bucket);
double compensationLon = lonCompensations.get(bucket);

CompensatedSum compensatedSumLat = new CompensatedSum(sumLat, compensationLat);
CompensatedSum compensatedSumLon = new CompensatedSum(sumLon, compensationLon);
compensatedSumLat.reset(sumLat, compensationLat);
compensatedSumLon.reset(sumLon, compensationLon);

// update the sum
for (int i = 0; i < valueCount; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);

return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
Expand All @@ -105,7 +107,7 @@ public void collect(int doc, long bucket) throws IOException {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
kahanSummation.reset(sum, compensation);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
}
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long bucket) throws IOException {
Expand All @@ -81,7 +82,7 @@ public void collect(int doc, long bucket) throws IOException {
// accurate than naive summation.
double sum = sums.get(bucket);
double compensation = compensations.get(bucket);
CompensatedSum kahanSummation = new CompensatedSum(sum, compensation);
kahanSummation.reset(sum, compensation);

for (int i = 0; i < valuesCount; i++) {
double value = values.nextValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
private final MultiValuesSource.NumericMultiValuesSource valuesSources;

private DoubleArray weights;
private DoubleArray sums;
private DoubleArray sumCompensations;
private DoubleArray valueSums;
private DoubleArray valueCompensations;
private DoubleArray weightCompensations;
private DocValueFormat format;

Expand All @@ -60,8 +60,8 @@ class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue {
if (valuesSources != null) {
final BigArrays bigArrays = context.bigArrays();
weights = bigArrays.newDoubleArray(1, true);
sums = bigArrays.newDoubleArray(1, true);
sumCompensations = bigArrays.newDoubleArray(1, true);
valueSums = bigArrays.newDoubleArray(1, true);
valueCompensations = bigArrays.newDoubleArray(1, true);
weightCompensations = bigArrays.newDoubleArray(1, true);
}
}
Expand All @@ -80,13 +80,15 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx);
final SortedNumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), ctx);
final CompensatedSum compensatedValueSum = new CompensatedSum(0, 0);
final CompensatedSum compensatedWeightSum = new CompensatedSum(0, 0);

return new LeafBucketCollectorBase(sub, docValues) {
@Override
public void collect(int doc, long bucket) throws IOException {
weights = bigArrays.grow(weights, bucket + 1);
sums = bigArrays.grow(sums, bucket + 1);
sumCompensations = bigArrays.grow(sumCompensations, bucket + 1);
valueSums = bigArrays.grow(valueSums, bucket + 1);
valueCompensations = bigArrays.grow(valueCompensations, bucket + 1);
weightCompensations = bigArrays.grow(weightCompensations, bucket + 1);

if (docValues.advanceExact(doc) && docWeights.advanceExact(doc)) {
Expand All @@ -102,42 +104,43 @@ public void collect(int doc, long bucket) throws IOException {
final int numValues = docValues.docValueCount();
assert numValues > 0;

double valueSum = valueSums.get(bucket);
double valueCompensation = valueCompensations.get(bucket);
compensatedValueSum.reset(valueSum, valueCompensation);

double weightSum = weights.get(bucket);
double weightCompensation = weightCompensations.get(bucket);
compensatedWeightSum.reset(weightSum, weightCompensation);

for (int i = 0; i < numValues; i++) {
kahanSum(docValues.nextValue() * weight, sums, sumCompensations, bucket);
kahanSum(weight, weights, weightCompensations, bucket);
compensatedValueSum.add(docValues.nextValue() * weight);
compensatedWeightSum.add(weight);
}

valueSums.set(bucket, compensatedValueSum.value());
valueCompensations.set(bucket, compensatedValueSum.delta());
weights.set(bucket, compensatedWeightSum.value());
weightCompensations.set(bucket, compensatedWeightSum.delta());
}
}
};
}

private static void kahanSum(double value, DoubleArray values, DoubleArray compensations, long bucket) {
// Compute the sum of double values with Kahan summation algorithm which is more
// accurate than naive summation.
double sum = values.get(bucket);
double compensation = compensations.get(bucket);

CompensatedSum kahanSummation = new CompensatedSum(sum, compensation)
.add(value);

values.set(bucket, kahanSummation.value());
compensations.set(bucket, kahanSummation.delta());
}

@Override
public double metric(long owningBucketOrd) {
if (valuesSources == null || owningBucketOrd >= sums.size()) {
if (valuesSources == null || owningBucketOrd >= valueSums.size()) {
return Double.NaN;
}
return sums.get(owningBucketOrd) / weights.get(owningBucketOrd);
return valueSums.get(owningBucketOrd) / weights.get(owningBucketOrd);
}

@Override
public InternalAggregation buildAggregation(long bucket) {
if (valuesSources == null || bucket >= sums.size()) {
if (valuesSources == null || bucket >= valueSums.size()) {
return buildEmptyAggregation();
}
return new InternalWeightedAvg(name, sums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData());
return new InternalWeightedAvg(name, valueSums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData());
}

@Override
Expand All @@ -147,7 +150,7 @@ public InternalAggregation buildEmptyAggregation() {

@Override
public void doClose() {
Releasables.close(weights, sums, sumCompensations, weightCompensations);
Releasables.close(weights, valueSums, valueCompensations, weightCompensations);
}

}

0 comments on commit 99e3136

Please sign in to comment.