From 306c47ca1ef17f7bc62a249693a96aab8c48d608 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Thu, 8 Feb 2024 12:00:24 -0600 Subject: [PATCH] JNI JSON read with DataSource and infered schema, along with basic java nested Schema JSON reads (#14954) This adds in support for some more JSON reading functionality. It allows us to infer the JSON schema using a DataSource as the input. It also adds in support for using a nested Schema when parsing JSON. Authors: - Robert (Bobby) Evans (https://github.com/revans2) Approvers: - Jason Lowe (https://github.com/jlowe) URL: https://github.com/rapidsai/cudf/pull/14954 --- java/src/main/java/ai/rapids/cudf/Schema.java | 269 +++++++++++++++-- java/src/main/java/ai/rapids/cudf/Table.java | 205 ++++++++++++- .../java/ai/rapids/cudf/TableWithMeta.java | 97 +++++- java/src/main/native/src/TableJni.cpp | 277 ++++++++++++------ .../test/java/ai/rapids/cudf/TableTest.java | 144 ++++++++- 5 files changed, 845 insertions(+), 147 deletions(-) diff --git a/java/src/main/java/ai/rapids/cudf/Schema.java b/java/src/main/java/ai/rapids/cudf/Schema.java index 79e66cb608e..c8571dd841c 100644 --- a/java/src/main/java/ai/rapids/cudf/Schema.java +++ b/java/src/main/java/ai/rapids/cudf/Schema.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,78 +26,285 @@ */ public class Schema { public static final Schema INFERRED = new Schema(); - private final List names; - private final List types; - private Schema(List names, List types) { - this.names = new ArrayList<>(names); - this.types = new ArrayList<>(types); + private final DType topLevelType; + private final List childNames; + private final List childSchemas; + private boolean flattened = false; + private String[] flattenedNames; + private DType[] flattenedTypes; + private int[] flattenedCounts; + + private Schema(DType topLevelType, + List childNames, + List childSchemas) { + this.topLevelType = topLevelType; + this.childNames = childNames; + this.childSchemas = childSchemas; } /** * Inferred schema. */ private Schema() { - names = null; - types = null; + topLevelType = null; + childNames = null; + childSchemas = null; + } + + /** + * Get the schema of a child element. Note that an inferred schema will have no children. + * @param i the index of the child to read. + * @return the new Schema + * @throws IndexOutOfBoundsException if the index is not in the range of children. + */ + public Schema getChild(int i) { + if (childSchemas == null) { + throw new IndexOutOfBoundsException("There are 0 children in this schema"); + } + return childSchemas.get(i); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(topLevelType); + if (topLevelType == DType.STRUCT) { + sb.append("{"); + if (childNames != null) { + for (int i = 0; i < childNames.size(); i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(childNames.get(i)); + sb.append(": "); + sb.append(childSchemas.get(i)); + } + } + sb.append("}"); + } else if (topLevelType == DType.LIST) { + sb.append("["); + if (childNames != null) { + for (int i = 0; i < childNames.size(); i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(childSchemas.get(i)); + } + } + sb.append("]"); + } + return sb.toString(); + } + + private void flattenIfNeeded() { + if (!flattened) { + int flatLen = flattenedLength(0); + if (flatLen == 0) { + flattenedNames = null; + flattenedTypes = null; + flattenedCounts = null; + } else { + String[] names = new String[flatLen]; + DType[] types = new DType[flatLen]; + int[] counts = new int[flatLen]; + collectFlattened(names, types, counts, 0); + flattenedNames = names; + flattenedTypes = types; + flattenedCounts = counts; + } + flattened = true; + } + } + + private int flattenedLength(int startingLength) { + if (childSchemas != null) { + for (Schema child: childSchemas) { + startingLength++; + startingLength = child.flattenedLength(startingLength); + } + } + return startingLength; + } + + private int collectFlattened(String[] names, DType[] types, int[] counts, int offset) { + if (childSchemas != null) { + for (int i = 0; i < childSchemas.size(); i++) { + Schema child = childSchemas.get(i); + names[offset] = childNames.get(i); + types[offset] = child.topLevelType; + if (child.childNames != null) { + counts[offset] = child.childNames.size(); + } else { + counts[offset] = 0; + } + offset++; + offset = this.childSchemas.get(i).collectFlattened(names, types, counts, offset); + } + } + return offset; } public static Builder builder() { - return new Builder(); + return new Builder(DType.STRUCT); + } + + public String[] getFlattenedColumnNames() { + flattenIfNeeded(); + return flattenedNames; } public String[] getColumnNames() { - if (names == null) { + if (childNames == null) { return null; } - return names.toArray(new String[names.size()]); + return childNames.toArray(new String[childNames.size()]); + } + + public boolean isNested() { + return childSchemas != null && childSchemas.size() > 0; + } + + /** + * This is really for a top level struct schema where it is nested, but + * for things like CSV we care that it does not have any children that are also + * nested. + */ + public boolean hasNestedChildren() { + if (childSchemas != null) { + for (Schema child: childSchemas) { + if (child.isNested()) { + return true; + } + } + } + return false; } - int[] getTypeIds() { - if (types == null) { + int[] getFlattenedTypeIds() { + flattenIfNeeded(); + if (flattenedTypes == null) { return null; } - int[] ret = new int[types.size()]; - for (int i = 0; i < types.size(); i++) { - ret[i] = types.get(i).getTypeId().nativeId; + int[] ret = new int[flattenedTypes.length]; + for (int i = 0; i < flattenedTypes.length; i++) { + ret[i] = flattenedTypes[i].getTypeId().nativeId; } return ret; } - int[] getTypeScales() { - if (types == null) { + int[] getFlattenedTypeScales() { + flattenIfNeeded(); + if (flattenedTypes == null) { return null; } - int[] ret = new int[types.size()]; - for (int i = 0; i < types.size(); i++) { - ret[i] = types.get(i).getScale(); + int[] ret = new int[flattenedTypes.length]; + for (int i = 0; i < flattenedTypes.length; i++) { + ret[i] = flattenedTypes[i].getScale(); } return ret; } - DType[] getTypes() { - if (types == null) { + DType[] getFlattenedTypes() { + flattenIfNeeded(); + return flattenedTypes; + } + + public DType[] getChildTypes() { + if (childSchemas == null) { return null; } - DType[] ret = new DType[types.size()]; - for (int i = 0; i < types.size(); i++) { - ret[i] = types.get(i); + DType[] ret = new DType[childSchemas.size()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = childSchemas.get(i).topLevelType; } return ret; } + int[] getFlattenedNumChildren() { + flattenIfNeeded(); + return flattenedCounts; + } + + public DType getType() { + return topLevelType; + } + + /** + * Check to see if the schema includes a struct at all. + * @return true if this or any one of its descendants contains a struct, else false. + */ + public boolean isStructOrHasStructDescendant() { + if (DType.STRUCT == topLevelType) { + return true; + } else if (DType.LIST == topLevelType) { + return childSchemas.stream().anyMatch(Schema::isStructOrHasStructDescendant); + } + return false; + } + public static class Builder { - private final List names = new ArrayList<>(); - private final List types = new ArrayList<>(); + private final DType topLevelType; + private final List names; + private final List types; - public Builder column(DType type, String name) { - types.add(type); + private Builder(DType topLevelType) { + this.topLevelType = topLevelType; + if (topLevelType == DType.STRUCT || topLevelType == DType.LIST) { + // There can be children + names = new ArrayList<>(); + types = new ArrayList<>(); + } else { + names = null; + types = null; + } + } + + /** + * Add a new column + * @param type the type of column to add + * @param name the name of the column to add (Ignored for list types) + * @return the builder for the new column. This should really only be used when the type + * passed in is a LIST or a STRUCT. + */ + public Builder addColumn(DType type, String name) { + if (names == null) { + throw new IllegalStateException("A column of type " + topLevelType + + " cannot have children"); + } + if (topLevelType == DType.LIST && names.size() > 0) { + throw new IllegalStateException("A LIST column can only have one child"); + } + if (names.contains(name)) { + throw new IllegalStateException("Cannot add duplicate names to a schema"); + } + Builder ret = new Builder(type); + types.add(ret); names.add(name); + return ret; + } + + /** + * Adds a single column to the current schema. addColumn is preferred as it can be used + * to support nested types. + * @param type the type of the column. + * @param name the name of the column. + * @return this for chaining. + */ + public Builder column(DType type, String name) { + addColumn(type, name); return this; } public Schema build() { - return new Schema(names, types); + List children = null; + if (types != null) { + children = new ArrayList<>(types.size()); + for (Builder b: types) { + children.add(b.build()); + } + } + return new Schema(topLevelType, names, children); } } } diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index ecf2e860351..9a790c8518b 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -246,7 +246,7 @@ private static native long[] readCSVFromDataSource(String[] columnNames, /** * read JSON data and return a pointer to a TableWithMeta object. */ - private static native long readJSON(String[] columnNames, + private static native long readJSON(int[] numChildren, String[] columnNames, int[] dTypeIds, int[] dTypeScales, String filePath, long address, long length, boolean dayFirst, boolean lines, @@ -254,7 +254,7 @@ private static native long readJSON(String[] columnNames, boolean normalizeSingleQuotes, boolean mixedTypesAsStrings) throws CudfException; - private static native long readJSONFromDataSource(String[] columnNames, + private static native long readJSONFromDataSource(int[] numChildren, String[] columnNames, int[] dTypeIds, int[] dTypeScales, boolean dayFirst, boolean lines, boolean recoverWithNulls, @@ -262,6 +262,11 @@ private static native long readJSONFromDataSource(String[] columnNames, boolean mixedTypesAsStrings, long dsHandle) throws CudfException; + private static native long readAndInferJSONFromDataSource(boolean dayFirst, boolean lines, + boolean recoverWithNulls, + boolean normalizeSingleQuotes, + boolean mixedTypesAsStrings, + long dsHandle) throws CudfException; private static native long readAndInferJSON(long address, long length, boolean dayFirst, boolean lines, boolean recoverWithNulls, boolean normalizeSingleQuotes, boolean mixedTypesAsStrings) throws CudfException; @@ -808,8 +813,11 @@ public static Table readCSV(Schema schema, File path) { * @return the file parsed as a table on the GPU. */ public static Table readCSV(Schema schema, CSVOptions opts, File path) { + if (schema.hasNestedChildren()) { + throw new IllegalArgumentException("CSV does not support nested types"); + } return new Table( - readCSV(schema.getColumnNames(), schema.getTypeIds(), schema.getTypeScales(), + readCSV(schema.getFlattenedColumnNames(), schema.getFlattenedTypeIds(), schema.getFlattenedTypeScales(), opts.getIncludeColumnNames(), path.getAbsolutePath(), 0, 0, opts.getHeaderRow(), @@ -890,7 +898,10 @@ public static Table readCSV(Schema schema, CSVOptions opts, HostMemoryBuffer buf assert len > 0; assert len <= buffer.getLength() - offset; assert offset >= 0 && offset < buffer.length; - return new Table(readCSV(schema.getColumnNames(), schema.getTypeIds(), schema.getTypeScales(), + if (schema.hasNestedChildren()) { + throw new IllegalArgumentException("CSV does not support nested types"); + } + return new Table(readCSV(schema.getFlattenedColumnNames(), schema.getFlattenedTypeIds(), schema.getFlattenedTypeScales(), opts.getIncludeColumnNames(), null, buffer.getAddress() + offset, len, opts.getHeaderRow(), @@ -906,9 +917,12 @@ public static Table readCSV(Schema schema, CSVOptions opts, HostMemoryBuffer buf public static Table readCSV(Schema schema, CSVOptions opts, DataSource ds) { long dsHandle = DataSourceHelper.createWrapperDataSource(ds); try { - return new Table(readCSVFromDataSource(schema.getColumnNames(), - schema.getTypeIds(), - schema.getTypeScales(), + if (schema.hasNestedChildren()) { + throw new IllegalArgumentException("CSV does not support nested types"); + } + return new Table(readCSVFromDataSource(schema.getFlattenedColumnNames(), + schema.getFlattenedTypeIds(), + schema.getFlattenedTypeScales(), opts.getIncludeColumnNames(), opts.getHeaderRow(), opts.getDelim(), @@ -1043,6 +1057,134 @@ public static Table readJSON(Schema schema, JSONOptions opts, byte[] buffer) { return readJSON(schema, opts, buffer, 0, buffer.length); } + private static class DidViewChange { + ColumnVector changeWasNeeded = null; + boolean noChangeNeeded = false; + + public static DidViewChange yes(ColumnVector cv) { + DidViewChange ret = new DidViewChange(); + ret.changeWasNeeded = cv; + return ret; + } + + public static DidViewChange no() { + DidViewChange ret = new DidViewChange(); + ret.noChangeNeeded = true; + return ret; + } + } + + private static DidViewChange gatherJSONColumns(Schema schema, TableWithMeta.NestedChildren children, + ColumnView cv) { + // We need to do this recursively to be sure it all matches as expected. + // If we run into problems where the data types don't match, we are not + // going to fix up the data types. We are only going to reorder the columns. + if (schema.getType() == DType.STRUCT) { + if (cv.getType() != DType.STRUCT) { + // The types don't match so just return the input unchanged... + return DidViewChange.no(); + } else { + String[] foundNames = children.getNames(); + HashMap indices = new HashMap<>(); + for (int i = 0; i < foundNames.length; i++) { + indices.put(foundNames[i], i); + } + // We might need to rearrange the columns to match what we want. + DType[] types = schema.getChildTypes(); + String[] neededNames = schema.getColumnNames(); + ColumnView[] columns = new ColumnView[neededNames.length]; + try { + boolean somethingChanged = false; + if (columns.length != foundNames.length) { + somethingChanged = true; + } + for (int i = 0; i < columns.length; i++) { + String neededColumnName = neededNames[i]; + Integer index = indices.get(neededColumnName); + if (index != null) { + if (schema.getChild(i).isStructOrHasStructDescendant()) { + ColumnView child = cv.getChildColumnView(index); + boolean shouldCloseChild = true; + try { + if (index != i) { + somethingChanged = true; + } + DidViewChange childResult = gatherJSONColumns(schema.getChild(i), + children.getChild(index), child); + if (childResult.noChangeNeeded) { + shouldCloseChild = false; + columns[i] = child; + } else { + somethingChanged = true; + columns[i] = childResult.changeWasNeeded; + } + } finally { + if (shouldCloseChild) { + child.close(); + } + } + } else { + if (index != i) { + somethingChanged = true; + } + columns[i] = cv.getChildColumnView(index); + } + } else { + somethingChanged = true; + try (Scalar s = Scalar.fromNull(types[i])) { + columns[i] = ColumnVector.fromScalar(s, (int) cv.getRowCount()); + } + } + } + if (somethingChanged) { + try (ColumnView ret = new ColumnView(cv.type, cv.rows, Optional.of(cv.nullCount), + cv.getValid(), null, columns)) { + return DidViewChange.yes(ret.copyToColumnVector()); + } + } else { + return DidViewChange.no(); + } + } finally { + for (ColumnView c: columns) { + if (c != null) { + c.close(); + } + } + } + } + } else if (schema.getType() == DType.LIST && cv.getType() == DType.LIST) { + if (schema.isStructOrHasStructDescendant()) { + String [] childNames = children.getNames(); + if (childNames.length == 2 && + "offsets".equals(childNames[0]) && + "element".equals(childNames[1])) { + try (ColumnView child = cv.getChildColumnView(0)){ + DidViewChange listResult = gatherJSONColumns(schema.getChild(0), + children.getChild(1), child); + if (listResult.noChangeNeeded) { + return DidViewChange.no(); + } else { + try (ColumnView listView = new ColumnView(cv.type, cv.rows, + Optional.of(cv.nullCount), cv.getValid(), cv.getOffsets(), + new ColumnView[]{listResult.changeWasNeeded})) { + return DidViewChange.yes(listView.copyToColumnVector()); + } finally { + listResult.changeWasNeeded.close(); + } + } + } + } + } + // Nothing to change so just return the input, but we need to inc a ref count to really + // make it work, so for now we are going to turn it into a ColumnVector. + return DidViewChange.no(); + } else { + // Nothing to change so just return the input, but we need to inc a ref count to really + // make it work, so for now we are going to turn it into a ColumnVector. + return DidViewChange.no(); + } + } + private static Table gatherJSONColumns(Schema schema, TableWithMeta twm) { String[] neededColumns = schema.getColumnNames(); if (neededColumns == null || neededColumns.length == 0) { @@ -1054,14 +1196,24 @@ private static Table gatherJSONColumns(Schema schema, TableWithMeta twm) { indices.put(foundNames[i], i); } // We might need to rearrange the columns to match what we want. - DType[] types = schema.getTypes(); + DType[] types = schema.getChildTypes(); ColumnVector[] columns = new ColumnVector[neededColumns.length]; try (Table tbl = twm.releaseTable()) { for (int i = 0; i < columns.length; i++) { String neededColumnName = neededColumns[i]; Integer index = indices.get(neededColumnName); if (index != null) { - columns[i] = tbl.getColumn(index).incRefCount(); + if (schema.getChild(i).isStructOrHasStructDescendant()) { + DidViewChange gathered = gatherJSONColumns(schema.getChild(i), twm.getChild(index), + tbl.getColumn(index)); + if (gathered.noChangeNeeded) { + columns[i] = tbl.getColumn(index).incRefCount(); + } else { + columns[i] = gathered.changeWasNeeded; + } + } else { + columns[i] = tbl.getColumn(index).incRefCount(); + } } else { try (Scalar s = Scalar.fromNull(types[i])) { columns[i] = ColumnVector.fromScalar(s, (int)tbl.getRowCount()); @@ -1088,7 +1240,8 @@ private static Table gatherJSONColumns(Schema schema, TableWithMeta twm) { */ public static Table readJSON(Schema schema, JSONOptions opts, File path) { try (TableWithMeta twm = new TableWithMeta( - readJSON(schema.getColumnNames(), schema.getTypeIds(), schema.getTypeScales(), + readJSON(schema.getFlattenedNumChildren(), schema.getFlattenedColumnNames(), + schema.getFlattenedTypeIds(), schema.getFlattenedTypeScales(), path.getAbsolutePath(), 0, 0, opts.isDayFirst(), opts.isLines(), opts.isRecoverWithNull(), @@ -1150,6 +1303,26 @@ public static TableWithMeta readJSON(JSONOptions opts, HostMemoryBuffer buffer, opts.isMixedTypesAsStrings())); } + /** + * Read JSON formatted data and infer the column names and schema. + * @param opts various JSON parsing options. + * @return the data parsed as a table on the GPU and the metadata for the table returned. + */ + public static TableWithMeta readAndInferJSON(JSONOptions opts, DataSource ds) { + long dsHandle = DataSourceHelper.createWrapperDataSource(ds); + try { + TableWithMeta twm = new TableWithMeta(readAndInferJSONFromDataSource(opts.isDayFirst(), + opts.isLines(), + opts.isRecoverWithNull(), + opts.isNormalizeSingleQuotes(), + opts.isMixedTypesAsStrings(), + dsHandle)); + return twm; + } finally { + DataSourceHelper.destroyWrapperDataSource(dsHandle); + } + } + /** * Read JSON formatted data. * @param schema the schema of the data. You may use Schema.INFERRED to infer the schema. @@ -1167,8 +1340,9 @@ public static Table readJSON(Schema schema, JSONOptions opts, HostMemoryBuffer b assert len > 0; assert len <= buffer.length - offset; assert offset >= 0 && offset < buffer.length; - try (TableWithMeta twm = new TableWithMeta(readJSON(schema.getColumnNames(), - schema.getTypeIds(), schema.getTypeScales(), null, + try (TableWithMeta twm = new TableWithMeta(readJSON( + schema.getFlattenedNumChildren(), schema.getFlattenedColumnNames(), + schema.getFlattenedTypeIds(), schema.getFlattenedTypeScales(), null, buffer.getAddress() + offset, len, opts.isDayFirst(), opts.isLines(), opts.isRecoverWithNull(), opts.isNormalizeSingleQuotes(), opts.isMixedTypesAsStrings()))) { @@ -1185,9 +1359,10 @@ public static Table readJSON(Schema schema, JSONOptions opts, HostMemoryBuffer b */ public static Table readJSON(Schema schema, JSONOptions opts, DataSource ds) { long dsHandle = DataSourceHelper.createWrapperDataSource(ds); - try (TableWithMeta twm = new TableWithMeta(readJSONFromDataSource(schema.getColumnNames(), - schema.getTypeIds(), schema.getTypeScales(), opts.isDayFirst(), opts.isLines(), - opts.isRecoverWithNull(), opts.isNormalizeSingleQuotes(), opts.isMixedTypesAsStrings(), dsHandle))) { + try (TableWithMeta twm = new TableWithMeta(readJSONFromDataSource(schema.getFlattenedNumChildren(), + schema.getFlattenedColumnNames(), schema.getFlattenedTypeIds(), schema.getFlattenedTypeScales(), opts.isDayFirst(), + opts.isLines(), opts.isRecoverWithNull(), opts.isNormalizeSingleQuotes(), + opts.isMixedTypesAsStrings(), dsHandle))) { return gatherJSONColumns(schema, twm); } finally { DataSourceHelper.destroyWrapperDataSource(dsHandle); diff --git a/java/src/main/java/ai/rapids/cudf/TableWithMeta.java b/java/src/main/java/ai/rapids/cudf/TableWithMeta.java index b6b8ad6bc28..040fa68f01e 100644 --- a/java/src/main/java/ai/rapids/cudf/TableWithMeta.java +++ b/java/src/main/java/ai/rapids/cudf/TableWithMeta.java @@ -1,6 +1,6 @@ /* * - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,12 +19,56 @@ package ai.rapids.cudf; +import java.util.Arrays; + /** * A table along with some metadata about the table. This is typically returned when * reading data from an input file where the metadata can be important. */ public class TableWithMeta implements AutoCloseable { private long handle; + private NestedChildren children = null; + + public static class NestedChildren { + private final String[] names; + private final NestedChildren[] children; + + private NestedChildren(String[] names, NestedChildren[] children) { + this.names = names; + this.children = children; + } + + public String[] getNames() { + return names; + } + + public NestedChildren getChild(int i) { + return children[i]; + } + public boolean isChildNested(int i) { + return (getChild(i) != null); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("{"); + if (names != null) { + for (int i = 0; i < names.length; i++) { + if (i != 0) { + sb.append(", "); + } + sb.append(names[i]); + sb.append(": "); + if (children != null) { + sb.append(children[i]); + } + } + } + sb.append("}"); + return sb.toString(); + } + } TableWithMeta(long handle) { this.handle = handle; @@ -43,12 +87,57 @@ public Table releaseTable() { } } + private static class ChildAndOffset { + public NestedChildren child; + public int newOffset; + } + + private ChildAndOffset unflatten(int startOffset, String[] flatNames, int[] flatCounts) { + ChildAndOffset ret = new ChildAndOffset(); + int length = flatCounts[startOffset]; + if (length == 0) { + ret.newOffset = startOffset + 1; + return ret; + } else { + String[] names = new String[length]; + NestedChildren[] children = new NestedChildren[length]; + int currentOffset = startOffset + 1; + for (int i = 0; i < length; i++) { + names[i] = flatNames[currentOffset]; + ChildAndOffset tmp = unflatten(currentOffset, flatNames, flatCounts); + children[i] = tmp.child; + currentOffset = tmp.newOffset; + } + ret.newOffset = currentOffset; + ret.child = new NestedChildren(names, children); + return ret; + } + } + + NestedChildren getChildren() { + if (children == null) { + int[] flatCount = getFlattenedChildCounts(handle); + String[] flatNames = getFlattenedColumnNames(handle); + ChildAndOffset tmp = unflatten(0, flatNames, flatCount); + children = tmp.child; + } + return children; + } + /** * Get the names of the top level columns. In the future new APIs can be added to get * names of child columns. */ public String[] getColumnNames() { - return getColumnNames(handle); + return getChildren().getNames(); + } + + public NestedChildren getChild(int i) { + return getChildren().getChild(i); + } + + public boolean isChildNested(int i) { + return getChildren().isChildNested(i); } @Override @@ -63,5 +152,7 @@ public void close() { private static native long[] releaseTable(long handle); - private static native String[] getColumnNames(long handle); + private static native String[] getFlattenedColumnNames(long handle); + + private static native int[] getFlattenedChildCounts(long handle); } diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index cef18b245e7..1d6f1332b06 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -925,6 +925,49 @@ cudf::table_view remove_validity_if_needed(cudf::table_view *input_table_view) { return cudf::table_view(views); } +cudf::io::schema_element read_schema_element(int &index, + cudf::jni::native_jintArray const &children, + cudf::jni::native_jstringArray const &names, + cudf::jni::native_jintArray const &types, + cudf::jni::native_jintArray const &scales) { + auto d_type = cudf::data_type{static_cast(types[index]), scales[index]}; + if (d_type.id() == cudf::type_id::STRUCT || d_type.id() == cudf::type_id::LIST) { + std::map child_elems; + int num_children = children[index]; + // go to the next entry, so recursion can parse it. + index++; + for (int i = 0; i < num_children; i++) { + child_elems.insert( + std::pair{names.get(index).get(), + cudf::jni::read_schema_element(index, children, names, types, scales)}); + } + return cudf::io::schema_element{d_type, std::move(child_elems)}; + } else { + if (children[index] != 0) { + throw std::invalid_argument("found children for a type that should have none"); + } + // go to the next entry before returning... + index++; + return cudf::io::schema_element{d_type, {}}; + } +} + +void append_flattened_child_counts(cudf::io::column_name_info const &info, + std::vector &counts) { + counts.push_back(info.children.size()); + for (cudf::io::column_name_info const &child : info.children) { + append_flattened_child_counts(child, counts); + } +} + +void append_flattened_child_names(cudf::io::column_name_info const &info, + std::vector &names) { + names.push_back(info.name); + for (cudf::io::column_name_info const &child : info.children) { + append_flattened_child_names(child, names); + } +} + } // namespace } // namespace jni @@ -1148,14 +1191,12 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSVFromDataSource( cudf::jni::native_jintArray n_types(env, j_types); cudf::jni::native_jintArray n_scales(env, j_scales); if (n_types.is_null() != n_scales.is_null()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match null", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and scales must match null", NULL); } std::vector data_types; if (!n_types.is_null()) { if (n_types.size() != n_scales.size()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match size", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and scales must match size", NULL); } data_types.reserve(n_types.size()); std::transform(n_types.begin(), n_types.end(), n_scales.begin(), @@ -1207,11 +1248,10 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV( JNI_NULL_CHECK(env, inputfilepath, "input file or buffer must be supplied", NULL); read_buffer = false; } else if (inputfilepath != NULL) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "cannot pass in both a buffer and an inputfilepath", NULL); } else if (buffer_length <= 0) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "An empty buffer is not supported", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "An empty buffer is not supported", NULL); } try { @@ -1220,14 +1260,12 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV( cudf::jni::native_jintArray n_types(env, j_types); cudf::jni::native_jintArray n_scales(env, j_scales); if (n_types.is_null() != n_scales.is_null()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match null", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and scales must match null", NULL); } std::vector data_types; if (!n_types.is_null()) { if (n_types.size() != n_scales.size()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match size", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and scales must match size", NULL); } data_types.reserve(n_types.size()); std::transform(n_types.begin(), n_types.end(), n_scales.begin(), @@ -1238,8 +1276,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readCSV( cudf::jni::native_jstring filename(env, inputfilepath); if (!read_buffer && filename.is_empty()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "inputfilepath can't be empty", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "inputfilepath can't be empty", NULL); } cudf::jni::native_jstringArray n_null_values(env, null_values); @@ -1390,13 +1427,43 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_Table_endWriteCSVToBuffer(JNIEnv *env CATCH_STD(env, ); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readAndInferJSONFromDataSource( + JNIEnv *env, jclass, jboolean day_first, jboolean lines, jboolean recover_with_null, + jboolean normalize_single_quotes, jboolean mixed_types_as_string, jlong ds_handle) { + + JNI_NULL_CHECK(env, ds_handle, "no data source handle given", 0); + + try { + cudf::jni::auto_set_device(env); + auto ds = reinterpret_cast(ds_handle); + cudf::io::source_info source{ds}; + + auto const recovery_mode = recover_with_null ? + cudf::io::json_recovery_mode_t::RECOVER_WITH_NULL : + cudf::io::json_recovery_mode_t::FAIL; + cudf::io::json_reader_options_builder opts = + cudf::io::json_reader_options::builder(source) + .dayfirst(static_cast(day_first)) + .lines(static_cast(lines)) + .recovery_mode(recovery_mode) + .normalize_single_quotes(static_cast(normalize_single_quotes)) + .mixed_types_as_string(mixed_types_as_string); + + auto result = + std::make_unique(cudf::io::read_json(opts.build())); + + return reinterpret_cast(result.release()); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readAndInferJSON( JNIEnv *env, jclass, jlong buffer, jlong buffer_length, jboolean day_first, jboolean lines, jboolean recover_with_null, jboolean normalize_single_quotes, jboolean mixed_types_as_string) { JNI_NULL_CHECK(env, buffer, "buffer cannot be null", 0); if (buffer_length <= 0) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "An empty buffer is not supported", 0); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "An empty buffer is not supported", 0); } try { @@ -1434,19 +1501,48 @@ JNIEXPORT void JNICALL Java_ai_rapids_cudf_TableWithMeta_close(JNIEnv *env, jcla CATCH_STD(env, ); } -JNIEXPORT jobjectArray JNICALL Java_ai_rapids_cudf_TableWithMeta_getColumnNames(JNIEnv *env, jclass, - jlong handle) { +JNIEXPORT jintArray JNICALL +Java_ai_rapids_cudf_TableWithMeta_getFlattenedChildCounts(JNIEnv *env, jclass, jlong handle) { JNI_NULL_CHECK(env, handle, "handle is null", nullptr); try { cudf::jni::auto_set_device(env); auto ptr = reinterpret_cast(handle); - auto length = ptr->metadata.schema_info.size(); + std::vector counts; + counts.push_back(ptr->metadata.schema_info.size()); + for (cudf::io::column_name_info const &child : ptr->metadata.schema_info) { + cudf::jni::append_flattened_child_counts(child, counts); + } + + auto length = counts.size(); + cudf::jni::native_jintArray ret(env, length); + for (size_t i = 0; i < length; i++) { + ret[i] = counts[i]; + } + ret.commit(); + return ret.get_jArray(); + } + CATCH_STD(env, nullptr); +} + +JNIEXPORT jobjectArray JNICALL +Java_ai_rapids_cudf_TableWithMeta_getFlattenedColumnNames(JNIEnv *env, jclass, jlong handle) { + JNI_NULL_CHECK(env, handle, "handle is null", nullptr); + + try { + cudf::jni::auto_set_device(env); + auto ptr = reinterpret_cast(handle); + std::vector names; + names.push_back("ROOT"); + for (cudf::io::column_name_info const &child : ptr->metadata.schema_info) { + cudf::jni::append_flattened_child_names(child, names); + } + + auto length = names.size(); auto ret = static_cast( env->NewObjectArray(length, env->FindClass("java/lang/String"), nullptr)); for (size_t i = 0; i < length; i++) { - env->SetObjectArrayElement(ret, i, - env->NewStringUTF(ptr->metadata.schema_info[i].name.c_str())); + env->SetObjectArrayElement(ret, i, env->NewStringUTF(names[i].c_str())); } return ret; @@ -1471,8 +1567,8 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_TableWithMeta_releaseTable(JNIE } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSONFromDataSource( - JNIEnv *env, jclass, jobjectArray col_names, jintArray j_types, jintArray j_scales, - jboolean day_first, jboolean lines, jboolean recover_with_null, + JNIEnv *env, jclass, jintArray j_num_children, jobjectArray col_names, jintArray j_types, + jintArray j_scales, jboolean day_first, jboolean lines, jboolean recover_with_null, jboolean normalize_single_quotes, jboolean mixed_types_as_string, jlong ds_handle) { JNI_NULL_CHECK(env, ds_handle, "no data source handle given", 0); @@ -1482,21 +1578,15 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSONFromDataSource( cudf::jni::native_jstringArray n_col_names(env, col_names); cudf::jni::native_jintArray n_types(env, j_types); cudf::jni::native_jintArray n_scales(env, j_scales); + cudf::jni::native_jintArray n_children(env, j_num_children); if (n_types.is_null() != n_scales.is_null()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match null", - 0); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and scales must match null", 0); } - std::vector data_types; - if (!n_types.is_null()) { - if (n_types.size() != n_scales.size()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match size", - 0); - } - data_types.reserve(n_types.size()); - std::transform(n_types.begin(), n_types.end(), n_scales.begin(), - std::back_inserter(data_types), [](auto const &type, auto const &scale) { - return cudf::data_type{static_cast(type), scale}; - }); + if (n_types.is_null() != n_col_names.is_null()) { + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and names must match null", 0); + } + if (n_types.is_null() != n_children.is_null()) { + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and num children must match null", 0); } auto ds = reinterpret_cast(ds_handle); @@ -1513,20 +1603,26 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSONFromDataSource( .normalize_single_quotes(static_cast(normalize_single_quotes)) .mixed_types_as_string(mixed_types_as_string); - if (!n_col_names.is_null() && data_types.size() > 0) { + if (!n_types.is_null()) { + if (n_types.size() != n_scales.size()) { + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and scales must match size", 0); + } if (n_col_names.size() != n_types.size()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", - "types and column names must match size", 0); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and column names must match size", + 0); + } + if (n_children.size() != n_types.size()) { + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and num children must match size", + 0); } - std::map map; - - auto col_names_vec = n_col_names.as_cpp_vector(); - std::transform(col_names_vec.begin(), col_names_vec.end(), data_types.begin(), - std::inserter(map, map.end()), - [](std::string a, cudf::data_type b) { return std::make_pair(a, b); }); - opts.dtypes(map); - } else if (data_types.size() > 0) { + std::map data_types; + int at = 0; + while (at < n_types.size()) { + data_types.insert(std::pair{ + n_col_names.get(at).get(), + cudf::jni::read_schema_element(at, n_children, n_col_names, n_types, n_scales)}); + } opts.dtypes(data_types); } else { // should infer the types @@ -1541,19 +1637,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSONFromDataSource( } JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSON( - JNIEnv *env, jclass, jobjectArray col_names, jintArray j_types, jintArray j_scales, - jstring inputfilepath, jlong buffer, jlong buffer_length, jboolean day_first, jboolean lines, - jboolean recover_with_null, jboolean normalize_single_quotes, jboolean mixed_types_as_string) { + JNIEnv *env, jclass, jintArray j_num_children, jobjectArray col_names, jintArray j_types, + jintArray j_scales, jstring inputfilepath, jlong buffer, jlong buffer_length, + jboolean day_first, jboolean lines, jboolean recover_with_null, + jboolean normalize_single_quotes, jboolean mixed_types_as_string) { bool read_buffer = true; if (buffer == 0) { JNI_NULL_CHECK(env, inputfilepath, "input file or buffer must be supplied", 0); read_buffer = false; } else if (inputfilepath != NULL) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "cannot pass in both a buffer and an inputfilepath", 0); } else if (buffer_length <= 0) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "An empty buffer is not supported", 0); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "An empty buffer is not supported", 0); } try { @@ -1561,26 +1658,20 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSON( cudf::jni::native_jstringArray n_col_names(env, col_names); cudf::jni::native_jintArray n_types(env, j_types); cudf::jni::native_jintArray n_scales(env, j_scales); + cudf::jni::native_jintArray n_children(env, j_num_children); if (n_types.is_null() != n_scales.is_null()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match null", - 0); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and scales must match null", 0); } - std::vector data_types; - if (!n_types.is_null()) { - if (n_types.size() != n_scales.size()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "types and scales must match size", - 0); - } - data_types.reserve(n_types.size()); - std::transform(n_types.begin(), n_types.end(), n_scales.begin(), - std::back_inserter(data_types), [](auto const &type, auto const &scale) { - return cudf::data_type{static_cast(type), scale}; - }); + if (n_types.is_null() != n_col_names.is_null()) { + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and names must match null", 0); + } + if (n_types.is_null() != n_children.is_null()) { + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and num children must match null", 0); } cudf::jni::native_jstring filename(env, inputfilepath); if (!read_buffer && filename.is_empty()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "inputfilepath can't be empty", 0); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "inputfilepath can't be empty", 0); } auto source = read_buffer ? cudf::io::source_info{reinterpret_cast(buffer), @@ -1598,20 +1689,26 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_readJSON( .normalize_single_quotes(static_cast(normalize_single_quotes)) .mixed_types_as_string(mixed_types_as_string); - if (!n_col_names.is_null() && data_types.size() > 0) { + if (!n_types.is_null()) { + if (n_types.size() != n_scales.size()) { + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and scales must match size", 0); + } if (n_col_names.size() != n_types.size()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", - "types and column names must match size", 0); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and column names must match size", + 0); + } + if (n_children.size() != n_types.size()) { + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "types and num children must match size", + 0); } - std::map map; - - auto col_names_vec = n_col_names.as_cpp_vector(); - std::transform(col_names_vec.begin(), col_names_vec.end(), data_types.begin(), - std::inserter(map, map.end()), - [](std::string a, cudf::data_type b) { return std::make_pair(a, b); }); - opts.dtypes(map); - } else if (data_types.size() > 0) { + std::map data_types; + int at = 0; + while (at < n_types.size()) { + data_types.insert(std::pair{ + n_col_names.get(at).get(), + cudf::jni::read_schema_element(at, n_children, n_col_names, n_types, n_scales)}); + } opts.dtypes(data_types); } else { // should infer the types @@ -1665,19 +1762,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readParquet( JNI_NULL_CHECK(env, inputfilepath, "input file or buffer must be supplied", NULL); read_buffer = false; } else if (inputfilepath != NULL) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "cannot pass in both a buffer and an inputfilepath", NULL); } else if (buffer_length <= 0) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "An empty buffer is not supported", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "An empty buffer is not supported", NULL); } try { cudf::jni::auto_set_device(env); cudf::jni::native_jstring filename(env, inputfilepath); if (!read_buffer && filename.is_empty()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "inputfilepath can't be empty", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "inputfilepath can't be empty", NULL); } cudf::jni::native_jstringArray n_filter_col_names(env, filter_col_names); @@ -1731,19 +1826,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readAvro(JNIEnv *env, jcl if (!read_buffer) { JNI_NULL_CHECK(env, inputfilepath, "input file or buffer must be supplied", NULL); } else if (inputfilepath != NULL) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "cannot pass in both a buffer and an inputfilepath", NULL); } else if (buffer_length <= 0) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "An empty buffer is not supported", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "An empty buffer is not supported", NULL); } try { cudf::jni::auto_set_device(env); cudf::jni::native_jstring filename(env, inputfilepath); if (!read_buffer && filename.is_empty()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "inputfilepath can't be empty", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "inputfilepath can't be empty", NULL); } cudf::jni::native_jstringArray n_filter_col_names(env, filter_col_names); @@ -1942,19 +2035,17 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_readORC( JNI_NULL_CHECK(env, inputfilepath, "input file or buffer must be supplied", NULL); read_buffer = false; } else if (inputfilepath != NULL) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "cannot pass in both a buffer and an inputfilepath", NULL); } else if (buffer_length <= 0) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "An empty buffer is not supported", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "An empty buffer is not supported", NULL); } try { cudf::jni::auto_set_device(env); cudf::jni::native_jstring filename(env, inputfilepath); if (!read_buffer && filename.is_empty()) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "inputfilepath can't be empty", - NULL); + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "inputfilepath can't be empty", NULL); } cudf::jni::native_jstringArray n_filter_col_names(env, filter_col_names); @@ -3187,7 +3278,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_dropDuplicates(JNIEnv *en case 2: return cudf::duplicate_keep_option::KEEP_LAST; case 3: return cudf::duplicate_keep_option::KEEP_NONE; default: - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Invalid `keep` option", + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Invalid `keep` option", cudf::duplicate_keep_option::KEEP_ANY); } }(); @@ -3384,7 +3475,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rollingWindowAggregate( cudf::jni::native_jbooleanArray unbounded_following{env, j_unbounded_following}; if (not valid_window_parameters(values, agg_instances, min_periods, preceding, following)) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Number of aggregation columns must match number of agg ops, and window-specs", nullptr); } @@ -3459,7 +3550,7 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_rangeRollingWindowAggrega cudf::jni::native_jpointerArray following(env, j_following); if (not valid_window_parameters(values, agg_instances, min_periods, preceding, following)) { - JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", + JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Number of aggregation columns must match number of agg ops, and window-specs", nullptr); } diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index f1c4d0803a3..76f127eae77 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -33,6 +33,7 @@ import com.google.common.base.Charsets; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import org.apache.avro.SchemaBuilder; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.parquet.hadoop.ParquetFileReader; @@ -53,7 +54,6 @@ import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; -import java.util.stream.IntStream; import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; import static ai.rapids.cudf.AssertUtils.assertPartialColumnsAreEqual; @@ -75,6 +75,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class TableTest extends CudfTestBase { + private static final HostMemoryAllocator hostMemoryAllocator = DefaultHostMemoryAllocator.get(); private static final File TEST_PARQUET_FILE = TestUtils.getResourceAsFile("acq.parquet"); @@ -348,6 +349,139 @@ void testReadSingleQuotesJSONFile() throws IOException { } } + private static final byte[] NESTED_JSON_DATA_BUFFER = ("{\"a\":{\"c\":\"C1\"}}\n" + + "{\"a\":{\"c\":\"C2\", \"b\":\"B2\"}}\n" + + "{\"d\":[1,2,3]}\n" + + "{\"e\": [{\"g\": 1}, {\"f\": 2}, {\"f\": 3, \"g\": 4}], \"d\":[]}").getBytes(StandardCharsets.UTF_8); + + @Test + void testReadJSONNestedTypes() { + Schema.Builder root = Schema.builder(); + Schema.Builder a = root.addColumn(DType.STRUCT, "a"); + a.addColumn(DType.STRING, "b"); + a.addColumn(DType.STRING, "c"); + a.addColumn(DType.STRING, "missing"); + Schema.Builder d = root.addColumn(DType.LIST, "d"); + d.addColumn(DType.INT64, "ignored"); + root.addColumn(DType.INT64, "also_missing"); + Schema.Builder e = root.addColumn(DType.LIST, "e"); + Schema.Builder eChild = e.addColumn(DType.STRUCT, "ignored"); + eChild.addColumn(DType.INT64, "f"); + eChild.addColumn(DType.STRING, "missing_in_list"); + eChild.addColumn(DType.INT64, "g"); + Schema schema = root.build(); + JSONOptions opts = JSONOptions.builder() + .withLines(true) + .build(); + StructType aStruct = new StructType(true, + new BasicType(true, DType.STRING), + new BasicType(true, DType.STRING), + new BasicType(true, DType.STRING)); + ListType dList = new ListType(true, new BasicType(true, DType.INT64)); + StructType eChildStruct = new StructType(true, + new BasicType(true, DType.INT64), + new BasicType(true, DType.STRING), + new BasicType(true, DType.INT64)); + ListType eList = new ListType(true, eChildStruct); + try (Table expected = new Table.TestBuilder() + .column(aStruct, + new StructData(null, "C1", null), + new StructData("B2", "C2", null), + null, + null) + .column(dList, + null, + null, + Arrays.asList(1L,2L,3L), + new ArrayList()) + .column((Long)null, null, null, null) // also_missing + .column(eList, + null, + null, + null, + Arrays.asList(new StructData(null, null, 1L), new StructData(2L, null, null), new StructData(3L, null, 4L))) + .build(); + Table table = Table.readJSON(schema, opts, NESTED_JSON_DATA_BUFFER)) { + assertTablesAreEqual(expected, table); + } + } + + @Test + void testReadJSONNestedTypesVerySmallChanges() { + Schema.Builder root = Schema.builder(); + Schema.Builder e = root.addColumn(DType.LIST, "e"); + Schema.Builder eChild = e.addColumn(DType.STRUCT, "ignored"); + eChild.addColumn(DType.INT64, "g"); + eChild.addColumn(DType.INT64, "f"); + Schema schema = root.build(); + JSONOptions opts = JSONOptions.builder() + .withLines(true) + .build(); + StructType eChildStruct = new StructType(true, + new BasicType(true, DType.INT64), + new BasicType(true, DType.INT64)); + ListType eList = new ListType(true, eChildStruct); + try (Table expected = new Table.TestBuilder() + .column(eList, + null, + null, + null, + Arrays.asList(new StructData(1L, null), new StructData(null, 2L), new StructData(4L, 3L))) + .build(); + Table table = Table.readJSON(schema, opts, NESTED_JSON_DATA_BUFFER)) { + assertTablesAreEqual(expected, table); + } + } + + @Test + void testReadJSONNestedTypesDataSource() { + Schema.Builder root = Schema.builder(); + Schema.Builder a = root.addColumn(DType.STRUCT, "a"); + a.addColumn(DType.STRING, "b"); + a.addColumn(DType.STRING, "c"); + a.addColumn(DType.STRING, "missing"); + Schema.Builder d = root.addColumn(DType.LIST, "d"); + d.addColumn(DType.INT64, "ignored"); + root.addColumn(DType.INT64, "also_missing"); + Schema.Builder e = root.addColumn(DType.LIST, "e"); + Schema.Builder eChild = e.addColumn(DType.STRUCT, "ignored"); + eChild.addColumn(DType.INT64, "g"); + Schema schema = root.build(); + JSONOptions opts = JSONOptions.builder() + .withLines(true) + .build(); + StructType aStruct = new StructType(true, + new BasicType(true, DType.STRING), + new BasicType(true, DType.STRING), + new BasicType(true, DType.STRING)); + ListType dList = new ListType(true, new BasicType(true, DType.INT64)); + StructType eChildStruct = new StructType(true, + new BasicType(true, DType.INT64)); + ListType eList = new ListType(true, eChildStruct); + try (Table expected = new Table.TestBuilder() + .column(aStruct, + new StructData(null, "C1", null), + new StructData("B2", "C2", null), + null, + null) + .column(dList, + null, + null, + Arrays.asList(1L,2L,3L), + new ArrayList()) + .column((Long)null, null, null, null) // also_missing + .column(eList, + null, + null, + null, + Arrays.asList(new StructData(1L), new StructData((Long)null), new StructData(4L))) + .build(); + MultiBufferDataSource source = sourceFrom(NESTED_JSON_DATA_BUFFER); + Table table = Table.readJSON(schema, opts, source)) { + assertTablesAreEqual(expected, table); + } + } + void testReadMixedType2JSONFileFeatureDisabled() { Schema schema = Schema.builder() .column(DType.STRING, "a") @@ -870,7 +1004,7 @@ private void testWriteCSVToFileImpl(char fieldDelim, boolean includeHeader, .column(DType.STRING, "str") .build(); CSVWriterOptions writeOptions = CSVWriterOptions.builder() - .withColumnNames(schema.getColumnNames()) + .withColumnNames(schema.getFlattenedColumnNames()) .withIncludeHeader(includeHeader) .withFieldDelimiter((byte)fieldDelim) .withRowDelimiter("\n") @@ -922,7 +1056,7 @@ private void testWriteUnquotedCSVToFileImpl(char fieldDelim) throws IOException .column(DType.STRING, "str") .build(); CSVWriterOptions writeOptions = CSVWriterOptions.builder() - .withColumnNames(schema.getColumnNames()) + .withColumnNames(schema.getFlattenedColumnNames()) .withIncludeHeader(false) .withFieldDelimiter((byte)fieldDelim) .withRowDelimiter("\n") @@ -966,7 +1100,7 @@ private void testChunkedCSVWriterUnquotedImpl(char fieldDelim) throws IOExceptio .column(DType.STRING, "str") .build(); CSVWriterOptions writeOptions = CSVWriterOptions.builder() - .withColumnNames(schema.getColumnNames()) + .withColumnNames(schema.getFlattenedColumnNames()) .withIncludeHeader(false) .withFieldDelimiter((byte)fieldDelim) .withRowDelimiter("\n") @@ -1020,7 +1154,7 @@ private void testChunkedCSVWriterImpl(char fieldDelim, boolean includeHeader, .column(DType.STRING, "str") .build(); CSVWriterOptions writeOptions = CSVWriterOptions.builder() - .withColumnNames(schema.getColumnNames()) + .withColumnNames(schema.getFlattenedColumnNames()) .withIncludeHeader(includeHeader) .withFieldDelimiter((byte)fieldDelim) .withRowDelimiter("\n")