Skip to content

Commit

Permalink
Support Decimal DIV changes in cudf (#7527)
Browse files Browse the repository at this point in the history
@codereport is making changes to the way `DIV` will behave for fixed-point types #7435. This PR contains Java changes to support those changes. 

Note: This is a draft until #7435 is merged

Authors:
  - Raza Jafri (@razajafri)

Approvers:
  - MithunR (@mythrocks)
  - Jason Lowe (@jlowe)
  - Gera Shegalov (@gerashegalov)

URL: #7527
  • Loading branch information
razajafri authored Mar 16, 2021
1 parent c1c60ba commit 2b2c0d2
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 23 deletions.
50 changes: 27 additions & 23 deletions java/src/main/java/ai/rapids/cudf/BinaryOperable.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public interface BinaryOperable {
* with scale=0 as scale is required. Dtype is discarded for binary operations for decimal
* types in cudf as a new DType is created for output type with the new scale.
*/
static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) {
static DType implicitConversion(BinaryOp op, BinaryOperable lhs, BinaryOperable rhs) {
DType a = lhs.getType();
DType b = rhs.getType();
if (a.equals(DType.FLOAT64) || b.equals(DType.FLOAT64)) {
Expand Down Expand Up @@ -86,13 +86,15 @@ static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) {
int scale = 0;
if (a.typeId == DType.DTypeEnum.DECIMAL32) {
if (b.typeId == DType.DTypeEnum.DECIMAL32) {
return DType.create(DType.DTypeEnum.DECIMAL32, scale);
return DType.create(DType.DTypeEnum.DECIMAL32,
ColumnView.getFixedPointOutputScale(op, lhs.getType(), rhs.getType()));
} else {
throw new IllegalArgumentException("Both columns must be of the same fixed_point type");
}
} else if (a.typeId == DType.DTypeEnum.DECIMAL64) {
if (b.typeId == DType.DTypeEnum.DECIMAL64) {
return DType.create(DType.DTypeEnum.DECIMAL64, scale);
return DType.create(DType.DTypeEnum.DECIMAL64,
ColumnView.getFixedPointOutputScale(op, lhs.getType(), rhs.getType()));
} else {
throw new IllegalArgumentException("Both columns must be of the same fixed_point type");
}
Expand Down Expand Up @@ -128,7 +130,7 @@ default ColumnVector add(BinaryOperable rhs, DType outType) {
* Add + operator. this + rhs
*/
default ColumnVector add(BinaryOperable rhs) {
return add(rhs, implicitConversion(this, rhs));
return add(rhs, implicitConversion(BinaryOp.ADD, this, rhs));
}

/**
Expand All @@ -144,7 +146,7 @@ default ColumnVector sub(BinaryOperable rhs, DType outType) {
* Subtract one vector from another. this - rhs
*/
default ColumnVector sub(BinaryOperable rhs) {
return sub(rhs, implicitConversion(this, rhs));
return sub(rhs, implicitConversion(BinaryOp.SUB, this, rhs));
}

/**
Expand All @@ -160,7 +162,7 @@ default ColumnVector mul(BinaryOperable rhs, DType outType) {
* Multiply two vectors together. this * rhs
*/
default ColumnVector mul(BinaryOperable rhs) {
return mul(rhs, implicitConversion(this, rhs));
return mul(rhs, implicitConversion(BinaryOp.MUL, this, rhs));
}

/**
Expand All @@ -176,7 +178,7 @@ default ColumnVector div(BinaryOperable rhs, DType outType) {
* Divide one vector by another. this / rhs
*/
default ColumnVector div(BinaryOperable rhs) {
return div(rhs, implicitConversion(this, rhs));
return div(rhs, implicitConversion(BinaryOp.DIV, this, rhs));
}

/**
Expand All @@ -192,7 +194,7 @@ default ColumnVector trueDiv(BinaryOperable rhs, DType outType) {
* (double)this / (double)rhs
*/
default ColumnVector trueDiv(BinaryOperable rhs) {
return trueDiv(rhs, implicitConversion(this, rhs));
return trueDiv(rhs, implicitConversion(BinaryOp.TRUE_DIV, this, rhs));
}

/**
Expand All @@ -208,7 +210,7 @@ default ColumnVector floorDiv(BinaryOperable rhs, DType outType) {
* Math.floor(this/rhs)
*/
default ColumnVector floorDiv(BinaryOperable rhs) {
return floorDiv(rhs, implicitConversion(this, rhs));
return floorDiv(rhs, implicitConversion(BinaryOp.FLOOR_DIV, this, rhs));
}

/**
Expand All @@ -224,7 +226,7 @@ default ColumnVector mod(BinaryOperable rhs, DType outType) {
* this % rhs
*/
default ColumnVector mod(BinaryOperable rhs) {
return mod(rhs, implicitConversion(this, rhs));
return mod(rhs, implicitConversion(BinaryOp.MOD, this, rhs));
}

/**
Expand All @@ -240,7 +242,7 @@ default ColumnVector pow(BinaryOperable rhs, DType outType) {
* Math.pow(this, rhs)
*/
default ColumnVector pow(BinaryOperable rhs) {
return pow(rhs, implicitConversion(this, rhs));
return pow(rhs, implicitConversion(BinaryOp.POW, this, rhs));
}

/**
Expand Down Expand Up @@ -338,7 +340,7 @@ default ColumnVector bitAnd(BinaryOperable rhs, DType outType) {
* Bit wise and (&). this & rhs
*/
default ColumnVector bitAnd(BinaryOperable rhs) {
return bitAnd(rhs, implicitConversion(this, rhs));
return bitAnd(rhs, implicitConversion(BinaryOp.BITWISE_AND, this, rhs));
}

/**
Expand All @@ -352,7 +354,7 @@ default ColumnVector bitOr(BinaryOperable rhs, DType outType) {
* Bit wise or (|). this | rhs
*/
default ColumnVector bitOr(BinaryOperable rhs) {
return bitOr(rhs, implicitConversion(this, rhs));
return bitOr(rhs, implicitConversion(BinaryOp.BITWISE_OR, this, rhs));
}

/**
Expand All @@ -366,7 +368,7 @@ default ColumnVector bitXor(BinaryOperable rhs, DType outType) {
* Bit wise xor (^). this ^ rhs
*/
default ColumnVector bitXor(BinaryOperable rhs) {
return bitXor(rhs, implicitConversion(this, rhs));
return bitXor(rhs, implicitConversion(BinaryOp.BITWISE_XOR, this, rhs));
}

/**
Expand All @@ -380,7 +382,7 @@ default ColumnVector and(BinaryOperable rhs, DType outType) {
* Logical and (&&). this && rhs
*/
default ColumnVector and(BinaryOperable rhs) {
return and(rhs, implicitConversion(this, rhs));
return and(rhs, implicitConversion(BinaryOp.LOGICAL_AND, this, rhs));
}

/**
Expand All @@ -394,7 +396,7 @@ default ColumnVector or(BinaryOperable rhs, DType outType) {
* Logical or (||). this || rhs
*/
default ColumnVector or(BinaryOperable rhs) {
return or(rhs, implicitConversion(this, rhs));
return or(rhs, implicitConversion(BinaryOp.LOGICAL_OR, this, rhs));
}

/**
Expand All @@ -421,7 +423,7 @@ default ColumnVector shiftLeft(BinaryOperable shiftBy, DType outType) {
* with this[i] << shiftBy
*/
default ColumnVector shiftLeft(BinaryOperable shiftBy) {
return shiftLeft(shiftBy, implicitConversion(this, shiftBy));
return shiftLeft(shiftBy, implicitConversion(BinaryOp.SHIFT_LEFT, this, shiftBy));
}

/**
Expand All @@ -447,7 +449,7 @@ default ColumnVector shiftRight(BinaryOperable shiftBy, DType outType) {
* with this[i] >> shiftBy
*/
default ColumnVector shiftRight(BinaryOperable shiftBy) {
return shiftRight(shiftBy, implicitConversion(this, shiftBy));
return shiftRight(shiftBy, implicitConversion(BinaryOp.SHIFT_RIGHT, this, shiftBy));
}

/**
Expand Down Expand Up @@ -475,7 +477,8 @@ default ColumnVector shiftRightUnsigned(BinaryOperable shiftBy, DType outType) {
* with this[i] >>> shiftBy
*/
default ColumnVector shiftRightUnsigned(BinaryOperable shiftBy) {
return shiftRightUnsigned(shiftBy, implicitConversion(this, shiftBy));
return shiftRightUnsigned(shiftBy, implicitConversion(BinaryOp.SHIFT_RIGHT_UNSIGNED, this,
shiftBy));
}

/**
Expand Down Expand Up @@ -505,7 +508,7 @@ default ColumnVector arctan2(BinaryOperable xCoordinate, DType outType) {
* in radians, between the positive x axis and the ray to the point (x, y) ≠ (0, 0).
*/
default ColumnVector arctan2(BinaryOperable xCoordinate) {
return arctan2(xCoordinate, implicitConversion(this, xCoordinate));
return arctan2(xCoordinate, implicitConversion(BinaryOp.ATAN2, this, xCoordinate));
}

/**
Expand All @@ -529,7 +532,7 @@ default ColumnVector pmod(BinaryOperable rhs, DType outputType) {
*
*/
default ColumnVector pmod(BinaryOperable rhs) {
return pmod(rhs, implicitConversion(this, rhs));
return pmod(rhs, implicitConversion(BinaryOp.PMOD, this, rhs));
}

/**
Expand Down Expand Up @@ -557,7 +560,7 @@ default ColumnVector maxNullAware(BinaryOperable rhs, DType outType) {
* Returns the max non null value.
*/
default ColumnVector maxNullAware(BinaryOperable rhs) {
return maxNullAware(rhs, implicitConversion(this, rhs));
return maxNullAware(rhs, implicitConversion(BinaryOp.NULL_MAX, this, rhs));
}

/**
Expand All @@ -571,6 +574,7 @@ default ColumnVector minNullAware(BinaryOperable rhs, DType outType) {
* Returns the min non null value.
*/
default ColumnVector minNullAware(BinaryOperable rhs) {
return minNullAware(rhs, implicitConversion(this, rhs));
return minNullAware(rhs, implicitConversion(BinaryOp.NULL_MIN, this, rhs));
}

}
7 changes: 7 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ public final long getNativeView() {
return viewHandle;
}

static int getFixedPointOutputScale(BinaryOp op, DType lhsType, DType rhsType) {
assert (lhsType.isDecimalType() && rhsType.isDecimalType());
return fixedPointOutputScale(op.nativeId, lhsType.getScale(), rhsType.getScale());
}

private static native int fixedPointOutputScale(int op, int lhsScale, int rhsScale);

public final DType getType() {
return type;
}
Expand Down
13 changes: 13 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/structs/structs_column_view.hpp>
#include <map_lookup.hpp>
#include "cudf/types.hpp"

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
Expand Down Expand Up @@ -1026,6 +1027,18 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, j
CATCH_STD(env, 0);
}

JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnView_fixedPointOutputScale(JNIEnv *env, jclass,
jint int_op,
jint lhs_scale,
jint rhs_scale) {
try {
// we just return the scale as the types will be the same as the lhs input
return cudf::binary_operation_fixed_point_scale(static_cast<cudf::binary_operator>(int_op),
lhs_scale, rhs_scale);
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVS(JNIEnv *env, jclass,
jlong lhs_view, jlong rhs_ptr,
jint int_op, jint out_dtype,
Expand Down

0 comments on commit 2b2c0d2

Please sign in to comment.