diff --git a/java/src/main/java/ai/rapids/cudf/BinaryOperable.java b/java/src/main/java/ai/rapids/cudf/BinaryOperable.java index e5e849a74c4..68213c21956 100644 --- a/java/src/main/java/ai/rapids/cudf/BinaryOperable.java +++ b/java/src/main/java/ai/rapids/cudf/BinaryOperable.java @@ -38,7 +38,7 @@ public interface BinaryOperable { * with scale=0 as scale is required. Dtype is discarded for binary operations for decimal * types in cudf as a new DType is created for output type with the new scale. */ - static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) { + static DType implicitConversion(BinaryOp op, BinaryOperable lhs, BinaryOperable rhs) { DType a = lhs.getType(); DType b = rhs.getType(); if (a.equals(DType.FLOAT64) || b.equals(DType.FLOAT64)) { @@ -86,13 +86,15 @@ static DType implicitConversion(BinaryOperable lhs, BinaryOperable rhs) { int scale = 0; if (a.typeId == DType.DTypeEnum.DECIMAL32) { if (b.typeId == DType.DTypeEnum.DECIMAL32) { - return DType.create(DType.DTypeEnum.DECIMAL32, scale); + return DType.create(DType.DTypeEnum.DECIMAL32, + ColumnView.getFixedPointOutputScale(op, lhs.getType(), rhs.getType())); } else { throw new IllegalArgumentException("Both columns must be of the same fixed_point type"); } } else if (a.typeId == DType.DTypeEnum.DECIMAL64) { if (b.typeId == DType.DTypeEnum.DECIMAL64) { - return DType.create(DType.DTypeEnum.DECIMAL64, scale); + return DType.create(DType.DTypeEnum.DECIMAL64, + ColumnView.getFixedPointOutputScale(op, lhs.getType(), rhs.getType())); } else { throw new IllegalArgumentException("Both columns must be of the same fixed_point type"); } @@ -128,7 +130,7 @@ default ColumnVector add(BinaryOperable rhs, DType outType) { * Add + operator. this + rhs */ default ColumnVector add(BinaryOperable rhs) { - return add(rhs, implicitConversion(this, rhs)); + return add(rhs, implicitConversion(BinaryOp.ADD, this, rhs)); } /** @@ -144,7 +146,7 @@ default ColumnVector sub(BinaryOperable rhs, DType outType) { * Subtract one vector from another. this - rhs */ default ColumnVector sub(BinaryOperable rhs) { - return sub(rhs, implicitConversion(this, rhs)); + return sub(rhs, implicitConversion(BinaryOp.SUB, this, rhs)); } /** @@ -160,7 +162,7 @@ default ColumnVector mul(BinaryOperable rhs, DType outType) { * Multiply two vectors together. this * rhs */ default ColumnVector mul(BinaryOperable rhs) { - return mul(rhs, implicitConversion(this, rhs)); + return mul(rhs, implicitConversion(BinaryOp.MUL, this, rhs)); } /** @@ -176,7 +178,7 @@ default ColumnVector div(BinaryOperable rhs, DType outType) { * Divide one vector by another. this / rhs */ default ColumnVector div(BinaryOperable rhs) { - return div(rhs, implicitConversion(this, rhs)); + return div(rhs, implicitConversion(BinaryOp.DIV, this, rhs)); } /** @@ -192,7 +194,7 @@ default ColumnVector trueDiv(BinaryOperable rhs, DType outType) { * (double)this / (double)rhs */ default ColumnVector trueDiv(BinaryOperable rhs) { - return trueDiv(rhs, implicitConversion(this, rhs)); + return trueDiv(rhs, implicitConversion(BinaryOp.TRUE_DIV, this, rhs)); } /** @@ -208,7 +210,7 @@ default ColumnVector floorDiv(BinaryOperable rhs, DType outType) { * Math.floor(this/rhs) */ default ColumnVector floorDiv(BinaryOperable rhs) { - return floorDiv(rhs, implicitConversion(this, rhs)); + return floorDiv(rhs, implicitConversion(BinaryOp.FLOOR_DIV, this, rhs)); } /** @@ -224,7 +226,7 @@ default ColumnVector mod(BinaryOperable rhs, DType outType) { * this % rhs */ default ColumnVector mod(BinaryOperable rhs) { - return mod(rhs, implicitConversion(this, rhs)); + return mod(rhs, implicitConversion(BinaryOp.MOD, this, rhs)); } /** @@ -240,7 +242,7 @@ default ColumnVector pow(BinaryOperable rhs, DType outType) { * Math.pow(this, rhs) */ default ColumnVector pow(BinaryOperable rhs) { - return pow(rhs, implicitConversion(this, rhs)); + return pow(rhs, implicitConversion(BinaryOp.POW, this, rhs)); } /** @@ -338,7 +340,7 @@ default ColumnVector bitAnd(BinaryOperable rhs, DType outType) { * Bit wise and (&). this & rhs */ default ColumnVector bitAnd(BinaryOperable rhs) { - return bitAnd(rhs, implicitConversion(this, rhs)); + return bitAnd(rhs, implicitConversion(BinaryOp.BITWISE_AND, this, rhs)); } /** @@ -352,7 +354,7 @@ default ColumnVector bitOr(BinaryOperable rhs, DType outType) { * Bit wise or (|). this | rhs */ default ColumnVector bitOr(BinaryOperable rhs) { - return bitOr(rhs, implicitConversion(this, rhs)); + return bitOr(rhs, implicitConversion(BinaryOp.BITWISE_OR, this, rhs)); } /** @@ -366,7 +368,7 @@ default ColumnVector bitXor(BinaryOperable rhs, DType outType) { * Bit wise xor (^). this ^ rhs */ default ColumnVector bitXor(BinaryOperable rhs) { - return bitXor(rhs, implicitConversion(this, rhs)); + return bitXor(rhs, implicitConversion(BinaryOp.BITWISE_XOR, this, rhs)); } /** @@ -380,7 +382,7 @@ default ColumnVector and(BinaryOperable rhs, DType outType) { * Logical and (&&). this && rhs */ default ColumnVector and(BinaryOperable rhs) { - return and(rhs, implicitConversion(this, rhs)); + return and(rhs, implicitConversion(BinaryOp.LOGICAL_AND, this, rhs)); } /** @@ -394,7 +396,7 @@ default ColumnVector or(BinaryOperable rhs, DType outType) { * Logical or (||). this || rhs */ default ColumnVector or(BinaryOperable rhs) { - return or(rhs, implicitConversion(this, rhs)); + return or(rhs, implicitConversion(BinaryOp.LOGICAL_OR, this, rhs)); } /** @@ -421,7 +423,7 @@ default ColumnVector shiftLeft(BinaryOperable shiftBy, DType outType) { * with this[i] << shiftBy */ default ColumnVector shiftLeft(BinaryOperable shiftBy) { - return shiftLeft(shiftBy, implicitConversion(this, shiftBy)); + return shiftLeft(shiftBy, implicitConversion(BinaryOp.SHIFT_LEFT, this, shiftBy)); } /** @@ -447,7 +449,7 @@ default ColumnVector shiftRight(BinaryOperable shiftBy, DType outType) { * with this[i] >> shiftBy */ default ColumnVector shiftRight(BinaryOperable shiftBy) { - return shiftRight(shiftBy, implicitConversion(this, shiftBy)); + return shiftRight(shiftBy, implicitConversion(BinaryOp.SHIFT_RIGHT, this, shiftBy)); } /** @@ -475,7 +477,8 @@ default ColumnVector shiftRightUnsigned(BinaryOperable shiftBy, DType outType) { * with this[i] >>> shiftBy */ default ColumnVector shiftRightUnsigned(BinaryOperable shiftBy) { - return shiftRightUnsigned(shiftBy, implicitConversion(this, shiftBy)); + return shiftRightUnsigned(shiftBy, implicitConversion(BinaryOp.SHIFT_RIGHT_UNSIGNED, this, + shiftBy)); } /** @@ -505,7 +508,7 @@ default ColumnVector arctan2(BinaryOperable xCoordinate, DType outType) { * in radians, between the positive x axis and the ray to the point (x, y) ≠ (0, 0). */ default ColumnVector arctan2(BinaryOperable xCoordinate) { - return arctan2(xCoordinate, implicitConversion(this, xCoordinate)); + return arctan2(xCoordinate, implicitConversion(BinaryOp.ATAN2, this, xCoordinate)); } /** @@ -529,7 +532,7 @@ default ColumnVector pmod(BinaryOperable rhs, DType outputType) { * */ default ColumnVector pmod(BinaryOperable rhs) { - return pmod(rhs, implicitConversion(this, rhs)); + return pmod(rhs, implicitConversion(BinaryOp.PMOD, this, rhs)); } /** @@ -557,7 +560,7 @@ default ColumnVector maxNullAware(BinaryOperable rhs, DType outType) { * Returns the max non null value. */ default ColumnVector maxNullAware(BinaryOperable rhs) { - return maxNullAware(rhs, implicitConversion(this, rhs)); + return maxNullAware(rhs, implicitConversion(BinaryOp.NULL_MAX, this, rhs)); } /** @@ -571,6 +574,7 @@ default ColumnVector minNullAware(BinaryOperable rhs, DType outType) { * Returns the min non null value. */ default ColumnVector minNullAware(BinaryOperable rhs) { - return minNullAware(rhs, implicitConversion(this, rhs)); + return minNullAware(rhs, implicitConversion(BinaryOp.NULL_MIN, this, rhs)); } + } diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index f36896a3c96..2f3f2bf80cf 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -129,6 +129,13 @@ public final long getNativeView() { return viewHandle; } + static int getFixedPointOutputScale(BinaryOp op, DType lhsType, DType rhsType) { + assert (lhsType.isDecimalType() && rhsType.isDecimalType()); + return fixedPointOutputScale(op.nativeId, lhsType.getScale(), rhsType.getScale()); + } + + private static native int fixedPointOutputScale(int op, int lhsScale, int rhsScale); + public final DType getType() { return type; } diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index e8474bda1be..0ce9d6303e4 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -60,6 +60,7 @@ #include #include #include +#include "cudf/types.hpp" #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" @@ -1026,6 +1027,18 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVV(JNIEnv *env, j CATCH_STD(env, 0); } +JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnView_fixedPointOutputScale(JNIEnv *env, jclass, + jint int_op, + jint lhs_scale, + jint rhs_scale) { + try { + // we just return the scale as the types will be the same as the lhs input + return cudf::binary_operation_fixed_point_scale(static_cast(int_op), + lhs_scale, rhs_scale); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_binaryOpVS(JNIEnv *env, jclass, jlong lhs_view, jlong rhs_ptr, jint int_op, jint out_dtype,