diff --git a/src/main/cpp/src/CastStringJni.cpp b/src/main/cpp/src/CastStringJni.cpp index f18e0de956..df3b67d8ca 100644 --- a/src/main/cpp/src/CastStringJni.cpp +++ b/src/main/cpp/src/CastStringJni.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "cast_string.hpp" #include "cudf_jni_apis.hpp" @@ -111,4 +112,51 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_fromDecimal } CATCH_CAST_EXCEPTION(env, 0); } + + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_changeRadix(JNIEnv* env, + jclass, + jlong input_column, + jint fromRadix, + jint toRadix) +{ + JNI_NULL_CHECK(env, input_column, "input column is null", 0); + + try { + cudf::jni::auto_set_device(env); + + cudf::column_view cv{*reinterpret_cast(input_column)}; + auto const uint64_cv = [&] { + switch (fromRadix) { + case 10: { + return spark_rapids_jni::string_to_integer( + cudf::data_type(cudf::type_id::UINT64), + cv, + JNI_FALSE, + JNI_TRUE, + cudf::get_default_stream()); + } break; + case 16: { + return cudf::strings::hex_to_integers(cv, cudf::data_type(cudf::type_id::UINT64)); + } + } + return std::unique_ptr(nullptr); + }(); + + std::unique_ptr result_col = [&] { + switch (toRadix) { + case 16: { + return cudf::strings::integers_to_hex(*uint64_cv); + } break; + case 10: { + return cudf::strings::from_integers(*uint64_cv); + } break; + } + return std::unique_ptr(nullptr); + }(); + + return cudf::jni::release_as_jlong(result_col); + } + CATCH_CAST_EXCEPTION(env, 0); +} } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java b/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java index 4368d6e098..d2ed6a1ef5 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java @@ -82,7 +82,7 @@ public static ColumnVector toDecimal(ColumnView cv, boolean ansiMode, boolean st /** * Convert a decimal column to a string column. - * + * * @param cv the column data to process * @return the converted column */ @@ -102,10 +102,16 @@ public static ColumnVector toFloat(ColumnView cv, boolean ansiMode, DType type) return new ColumnVector(toFloat(cv.getNativeView(), ansiMode, type.getTypeId().getNativeId())); } + + public static ColumnVector changeRadix(ColumnView cv, int fromRadix, int toRadix) { + return new ColumnVector(changeRadix(cv.getNativeView(), fromRadix, toRadix)); + } + private static native long toInteger(long nativeColumnView, boolean ansi_enabled, boolean strip, int dtype); private static native long toDecimal(long nativeColumnView, boolean ansi_enabled, boolean strip, int precision, int scale); private static native long toFloat(long nativeColumnView, boolean ansi_enabled, int dtype); private static native long fromDecimal(long nativeColumnView); + private static native long changeRadix(long nativeColumnView, int fromRadix, int toRadix); } \ No newline at end of file