Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Java: Support struct scalar [skip ci] #8327

Merged
merged 8 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion java/src/main/java/ai/rapids/cudf/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.slf4j.LoggerFactory;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Objects;
Expand Down Expand Up @@ -373,6 +372,83 @@ public static Scalar listFromColumnView(ColumnView list) {
return new Scalar(DType.LIST, makeListScalar(list.getNativeView(), true));
}

/**
* Creates a null scalar of struct type.
*
* @param elementTypes data types of children in the struct
* @return a null scalar of struct type
*/
public static Scalar structFromNull(HostColumnVector.DataType... elementTypes) {
ColumnVector[] children = new ColumnVector[elementTypes.length];
firestarman marked this conversation as resolved.
Show resolved Hide resolved
long[] childHandles = new long[elementTypes.length];
try {
for (int i = 0; i < elementTypes.length; i++) {
// Build column vector having single null value rather than empty column vector,
// because struct scalar requires row count of children columns == 1.
children[i] = buildNullColumnVector(elementTypes[i]);
childHandles[i] = children[i].getNativeView();
}
return new Scalar(DType.STRUCT, makeStructScalar(childHandles, false));
} finally {
// close all empty children
for (ColumnVector child : children) {
// We closed all created ColumnViews when we hit null. Therefore we exit the loop.
if (child == null) break;
// make sure the close process is exception-free
try {
child.close();
} catch (Exception ignored) {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}

/**
* Creates a scalar of struct from a ColumnView.
*
* @param columns children columns of struct
* @return a Struct scalar
*/
public static Scalar structFromColumnViews(ColumnView... columns) {
if (columns == null || columns.length == 0) {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalArgumentException("......");
revans2 marked this conversation as resolved.
Show resolved Hide resolved
}
long[] columnHandles = new long[columns.length];
for (int i = 0; i < columns.length; i++) {
columnHandles[i] = columns[i].getNativeView();
}
return new Scalar(DType.STRUCT, makeStructScalar(columnHandles, true));
}

/**
* Build column vector of single row who holds a null value
*
* @param hostType host data type of null column vector
* @return the null vector
*/
private static ColumnVector buildNullColumnVector(HostColumnVector.DataType hostType) {
DType dt = hostType.getType();
if (!dt.isNestedType()) {
try (HostColumnVector.Builder builder = HostColumnVector.builder(dt, 1)) {
builder.appendNull();
try (HostColumnVector hcv = builder.build()) {
return hcv.copyToDevice();
}
}
} else if (dt.typeId == DType.DTypeEnum.LIST) {
try (HostColumnVector hcv = HostColumnVector.fromLists(hostType, Arrays.asList())) {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
return hcv.copyToDevice();
}
} else if (dt.typeId == DType.DTypeEnum.STRUCT) {
try (HostColumnVector hcv = HostColumnVector.fromStructs(hostType,
revans2 marked this conversation as resolved.
Show resolved Hide resolved
new HostColumnVector.StructData(new Object[hostType.getNumChildren()]))) {
return hcv.copyToDevice();
}
} else {
throw new IllegalArgumentException("Unsupported data type: " + hostType);
}
}

private static native void closeScalar(long scalarHandle);
private static native boolean isScalarValid(long scalarHandle);
private static native byte getByte(long scalarHandle);
Expand All @@ -383,6 +459,7 @@ public static Scalar listFromColumnView(ColumnView list) {
private static native double getDouble(long scalarHandle);
private static native byte[] getUTF8(long scalarHandle);
private static native long getListAsColumnView(long scalarHandle);
private static native long[] getChildrenFromStructScalar(long scalarHandle);
private static native long makeBool8Scalar(boolean isValid, boolean value);
private static native long makeInt8Scalar(byte value, boolean isValid);
private static native long makeUint8Scalar(byte value, boolean isValid);
Expand All @@ -402,6 +479,7 @@ public static Scalar listFromColumnView(ColumnView list) {
private static native long makeDecimal32Scalar(int value, int scale, boolean isValid);
private static native long makeDecimal64Scalar(long value, int scale, boolean isValid);
private static native long makeListScalar(long viewHandle, boolean isValid);
private static native long makeStructScalar(long[] viewHandles, boolean isValid);


Scalar(DType type, long scalarHandle) {
Expand Down Expand Up @@ -539,6 +617,36 @@ public ColumnView getListAsColumnView() {
return new ColumnView(getListAsColumnView(getScalarHandle()));
}

/**
* Fetches views of children columns from struct scalar.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*
* @return array of column views refer to children of struct scalar
*/
public ColumnView[] getChildrenFromStructScalar() {
assert DType.STRUCT.equals(type) : "Cannot get table for the vector of type " + type;

long[] childHandles = getChildrenFromStructScalar(getScalarHandle());
ColumnView[] children = new ColumnView[childHandles.length];
try {
for (int i = 0; i < children.length; i++) {
children[i] = new ColumnView(childHandles[i]);
}
} catch (Exception ex) {
// close all created ColumnViews if exception thrown
for (ColumnView child : children) {
// We closed all created ColumnViews when we hit null. Therefore we exit the loop.
if (child == null) break;
// make sure the close process is exception-free
try {
child.close();
} catch (Exception ignore) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has the same problem as up above we need to not just eat the exceptions but add them to the suppression list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

}
}
throw ex;
}
return children;
}

@Override
public ColumnVector binaryOp(BinaryOp op, BinaryOperable rhs, DType outType) {
if (rhs instanceof ColumnView) {
Expand Down
37 changes: 37 additions & 0 deletions java/src/main/native/src/ScalarJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_getListAsColumnView(JNIEnv *e
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Scalar_getChildrenFromStructScalar(JNIEnv *env, jclass,
jlong scalar_handle) {
JNI_NULL_CHECK(env, scalar_handle, "scalar handle is null", 0);
try {
cudf::jni::auto_set_device(env);
const auto s = reinterpret_cast<cudf::struct_scalar*>(scalar_handle);
const cudf::table_view& table = s->view();
cudf::jni::native_jlongArray column_handles(env, table.num_columns());
revans2 marked this conversation as resolved.
Show resolved Hide resolved
for (int i = 0; i < table.num_columns(); i++) {
column_handles[i] = reinterpret_cast<jlong>(new cudf::column_view(table.column(i)));
}
return column_handles.get_jArray();
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeBool8Scalar(JNIEnv *env, jclass,
jboolean value,
jboolean is_valid) {
Expand Down Expand Up @@ -477,4 +493,25 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeListScalar(JNIEnv *env, j
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Scalar_makeStructScalar(JNIEnv *env, jclass,
jlongArray handles,
jboolean is_valid) {

JNI_NULL_CHECK(env, handles, "native view handles are null", 0)
try {
cudf::jni::auto_set_device(env);
std::unique_ptr<cudf::column_view> ret;
cudf::jni::native_jpointerArray<cudf::column_view> column_pointers(env, handles);
std::vector<cudf::column_view> columns;
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
std::transform(column_pointers.data(),
column_pointers.data() + column_pointers.size(),
std::back_inserter(columns),
[](auto const& col_ptr) { return *col_ptr; });
auto s = std::make_unique<cudf::struct_scalar>(
cudf::host_span<cudf::column_view const>{columns}, is_valid);
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
return reinterpret_cast<jlong>(s.release());
}
CATCH_STD(env, 0);
}

} // extern "C"
112 changes: 111 additions & 1 deletion java/src/test/java/ai/rapids/cudf/ScalarTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public void testNull() {
}
}

// list scalar
// create elementType for nested types
HostColumnVector.DataType hDataType;
if (DType.EMPTY.equals(type)) {
continue;
Expand All @@ -84,6 +84,8 @@ public void testNull() {
// list of non nested type
hDataType = new BasicType(true, type);
}

// test list scalar with elementType(`type`)
try (Scalar s = Scalar.listFromNull(hDataType);
ColumnView listCv = s.getListAsColumnView()) {
assertFalse(s.isValid(), "null validity for " + type);
Expand All @@ -99,6 +101,23 @@ public void testNull() {
}
}
}

// test struct scalar with elementType(`type`)
try (Scalar s = Scalar.structFromNull(hDataType, hDataType, hDataType)) {
assertFalse(s.isValid(), "null validity for " + type);
assertEquals(DType.STRUCT, s.getType());

ColumnView[] children = s.getChildrenFromStructScalar();
try {
for (ColumnView child : children) {
assertEquals(hDataType.getType(), child.getType());
assertEquals(1L, child.getRowCount());
assertEquals(1L, child.getNullCount());
}
} finally {
for (ColumnView child : children) child.close();
}
}
}
}

Expand Down Expand Up @@ -287,4 +306,95 @@ public void testList() {
}
}
}

@Test
public void testStruct() {
try (ColumnVector col0 = ColumnVector.fromInts(1);
ColumnVector col1 = ColumnVector.fromBoxedDoubles(1.2);
ColumnVector col2 = ColumnVector.fromStrings("a");
ColumnVector col3 = ColumnVector.fromDecimals(BigDecimal.TEN);
ColumnVector col4 = ColumnVector.daysFromInts(10);
ColumnVector col5 = ColumnVector.durationSecondsFromLongs(12345L);
Scalar s = Scalar.structFromColumnViews(col0, col1, col2, col3, col4, col5, col0, col1)) {
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
ColumnView[] children = s.getChildrenFromStructScalar();
try {
assertColumnsAreEqual(col0, children[0]);
assertColumnsAreEqual(col1, children[1]);
assertColumnsAreEqual(col2, children[2]);
assertColumnsAreEqual(col3, children[3]);
assertColumnsAreEqual(col4, children[4]);
assertColumnsAreEqual(col5, children[5]);
assertColumnsAreEqual(col0, children[6]);
assertColumnsAreEqual(col1, children[7]);
} finally {
for (ColumnView child : children) child.close();
}
}

// test Struct Scalar with null members
try (ColumnVector col0 = ColumnVector.fromInts(1);
ColumnVector col1 = ColumnVector.fromBoxedDoubles((Double) null);
ColumnVector col2 = ColumnVector.fromStrings((String) null);
Scalar s1 = Scalar.structFromColumnViews(col0, col1, col2);
Scalar s2 = Scalar.structFromColumnViews(col1, col2)) {
ColumnView[] children = s1.getChildrenFromStructScalar();
try {
assertColumnsAreEqual(col0, children[0]);
assertColumnsAreEqual(col1, children[1]);
assertColumnsAreEqual(col2, children[2]);
} finally {
for (ColumnView child : children) child.close();
}

ColumnView[] children2 = s2.getChildrenFromStructScalar();
try {
assertColumnsAreEqual(col1, children2[0]);
assertColumnsAreEqual(col2, children2[1]);
} finally {
for (ColumnView child : children2) child.close();
}
}

// test Struct Scalar with single column
try (ColumnVector col0 = ColumnVector.fromInts(1234);
Scalar s = Scalar.structFromColumnViews(col0)) {
ColumnView[] children = s.getChildrenFromStructScalar();
try {
assertColumnsAreEqual(col0, children[0]);
} finally {
children[0].close();
}
}

// test Struct Scalar with nested types
HostColumnVector.DataType listType = new HostColumnVector.ListType(true,
new HostColumnVector.BasicType(true, DType.INT32));
HostColumnVector.DataType structType = new HostColumnVector.StructType(true,
new HostColumnVector.BasicType(true, DType.INT32),
new HostColumnVector.BasicType(true, DType.INT64));
HostColumnVector.DataType nestedStructType = new HostColumnVector.StructType(true,
new HostColumnVector.BasicType(true, DType.STRING),
listType, structType);
try (ColumnVector strCol = ColumnVector.fromStrings("AAAAAA");
ColumnVector listCol = ColumnVector.fromLists(listType, Arrays.asList(1, 2, 3, 4, 5));
ColumnVector structCol = ColumnVector.fromStructs(structType,
new HostColumnVector.StructData(1, -1L));
ColumnVector nestedStructCol = ColumnVector.fromStructs(nestedStructType,
new HostColumnVector.StructData(null,
Arrays.asList(1, 2, null),
new HostColumnVector.StructData(null, 10L)));
Scalar s = Scalar.structFromColumnViews(strCol, listCol, structCol, nestedStructCol)) {
assertEquals(DType.STRUCT, s.getType());
assertTrue(s.isValid());
ColumnView[] children = s.getChildrenFromStructScalar();
try {
assertColumnsAreEqual(strCol, children[0]);
assertColumnsAreEqual(listCol, children[1]);
assertColumnsAreEqual(structCol, children[2]);
assertColumnsAreEqual(nestedStructCol, children[3]);
} finally {
for (ColumnView child : children) child.close();
}
}
}
}