Skip to content

Commit

Permalink
Enable Strings as a supported type for GpuColumnarToRow transitions (N…
Browse files Browse the repository at this point in the history
…VIDIA#5998)

* Enable Strings as a supported type for GpuColumnarToRow transitions

Signed-off-by: Ahmed Hussein (amahussein) <a@ahussein.me>

* fix pr-feedback

Signed-off-by: Ahmed Hussein (amahussein) <a@ahussein.me>

Co-authored-by: Mike Wilson <hyperbolic2346@users.noreply.github.com>
  • Loading branch information
amahussein and hyperbolic2346 authored Jul 27, 2022
1 parent 5ba525c commit 452e7ba
Show file tree
Hide file tree
Showing 7 changed files with 1,013 additions and 150 deletions.
45 changes: 45 additions & 0 deletions integration_tests/src/main/python/row_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,48 @@ def test_host_columnar_transition(spark_tmp_path, data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: spark.read.parquet(data_path).filter("a IS NOT NULL"),
conf={ 'spark.rapids.sql.exec.FileSourceScanExec' : 'false'})

# This is one of the most basic tests where we verify that we can
# move data onto and off of the GPU when the schema is variable width (no nulls).
def test_row_conversions_var_width_basic():
def do_it(spark):
schema = StructType([StructField("col_00_int", IntegerType(), nullable=False),
StructField("col_01_str", StringType(), nullable=False),
StructField("col_02_int", IntegerType(), nullable=False),
StructField("col_03_str", StringType(), nullable=False)])
df = spark.createDataFrame([(1, "string_val_00", 2, "string_val_01"),
(3, "string_val_10", 4, "string_val_11")],
schema=schema).selectExpr("*", "col_00_int as 1st_column")
return df
assert_gpu_and_cpu_are_equal_collect(lambda spark : do_it(spark))

# This is one of the tests where we verify that we can move data onto and off of the GPU when the
# schema is variable width. Note that the supported variable width types (i.e., string)
# are scattered so that the test covers packing, which is where columns are reordered for smaller
# data size by placing columns with the same alignment requirements next to each other.
def test_row_conversions_var_width():
gens = [["a", byte_gen], ["b", short_gen], ["c", int_gen], ["d", long_gen],
["e", float_gen], ["f", double_gen], ["g", boolean_gen], ["h", string_gen],
["i", timestamp_gen], ["j", date_gen], ["k", string_gen], ["l", decimal_gen_64bit],
["m", decimal_gen_32bit], ["n", string_gen]]
assert_gpu_and_cpu_are_equal_collect(
lambda spark : gen_df(spark, gens).selectExpr("*", "a as a_again"))

def test_row_conversions_var_width_wide():
gens = [["a{}".format(i), ByteGen(nullable=True)] for i in range(10)] + \
[["b{}".format(i), ShortGen(nullable=True)] for i in range(10)] + \
[["c{}".format(i), IntegerGen(nullable=True)] for i in range(10)] + \
[["d{}".format(i), LongGen(nullable=True)] for i in range(10)] + \
[["e{}".format(i), FloatGen(nullable=True)] for i in range(10)] + \
[["f{}".format(i), DoubleGen(nullable=True)] for i in range(10)] + \
[["g{}".format(i), StringGen(nullable=True)] for i in range(5)] + \
[["h{}".format(i), BooleanGen(nullable=True)] for i in range(10)] + \
[["i{}".format(i), StringGen(nullable=True)] for i in range(5)] + \
[["j{}".format(i), TimestampGen(nullable=True)] for i in range(10)] + \
[["k{}".format(i), DateGen(nullable=True)] for i in range(10)] + \
[["l{}".format(i), DecimalGen(precision=12, scale=2, nullable=True)] for i in range(10)] + \
[["m{}".format(i), DecimalGen(precision=7, scale=3, nullable=True)] for i in range(10)]
def do_it(spark):
df=gen_df(spark, gens, length=1).selectExpr("*", "a0 as a_again")
return df
assert_gpu_and_cpu_are_equal_collect(do_it)
85 changes: 41 additions & 44 deletions sql-plugin/src/main/java/com/nvidia/spark/rapids/CudfUnsafeRow.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -44,26 +44,7 @@
* UnsafeRow works.
*/
public final class CudfUnsafeRow extends InternalRow {
public static int alignOffset(int offset, int alignment) {
return (offset + alignment - 1) & -alignment;
}

public static int calculateBitSetWidthInBytes(int numFields) {
return (numFields + 7)/ 8;
}

public static int getRowSizeEstimate(Attribute[] attributes) {
// This needs to match what is in cudf and what is in the constructor.
int offset = 0;
for (Attribute attr : attributes) {
int length = GpuColumnVector.getNonNestedRapidsType(attr.dataType()).getSizeInBytes();
offset = alignOffset(offset, length);
offset += length;
}
int bitSetWidthInBytes = calculateBitSetWidthInBytes(attributes.length);
// Each row is 64-bit aligned
return alignOffset(offset + bitSetWidthInBytes, 8);
}

//////////////////////////////////////////////////////////////////////////////
// Private fields and methods
Expand All @@ -75,15 +56,15 @@ public static int getRowSizeEstimate(Attribute[] attributes) {
private long address;

/**
* For each column the starting location to read from. The index to the is the position in
* the row bytes, not the user faceing ordinal.
* For each column the starting location to read from. The index is the position in
* the row bytes, not the user facing ordinal.
*/
private int[] startOffsets;

/**
* At what point validity data starts.
* At what point validity data starts from the beginning of a row's data.
*/
private int fixedWidthSizeInBytes;
private int validityOffsetInBytes;

/**
* The size of this row's backing data, in bytes.
Expand All @@ -95,6 +76,8 @@ public static int getRowSizeEstimate(Attribute[] attributes) {
*/
private int[] remapping;

private boolean variableWidthSchema;

/**
* Get the address where a field is stored.
* @param ordinal the user facing ordinal.
Expand Down Expand Up @@ -131,17 +114,11 @@ private void assertIndexIsValid(int index) {
* backing row.
*/
public CudfUnsafeRow(Attribute[] attributes, int[] remapping) {
int offset = 0;
startOffsets = new int[attributes.length];
for (int i = 0; i < attributes.length; i++) {
Attribute attr = attributes[i];
int length = GpuColumnVector.getNonNestedRapidsType(attr.dataType()).getSizeInBytes();
assert length > 0 : "Only fixed width types are currently supported.";
offset = alignOffset(offset, length);
startOffsets[i] = offset;
offset += length;
}
fixedWidthSizeInBytes = offset;
JCudfUtil.RowOffsetsCalculator jCudfBuilder =
JCudfUtil.getRowOffsetsCalculator(attributes, startOffsets);
this.validityOffsetInBytes = jCudfBuilder.getValidityBytesOffset();
this.variableWidthSchema = jCudfBuilder.hasVarSizeData();
this.remapping = remapping;
assert startOffsets.length == remapping.length;
}
Expand Down Expand Up @@ -183,7 +160,7 @@ public boolean isNullAt(int ordinal) {
assertIndexIsValid(i);
int validByteIndex = i / 8;
int validBitIndex = i % 8;
byte b = Platform.getByte(null, address + fixedWidthSizeInBytes + validByteIndex);
byte b = Platform.getByte(null, address + validityOffsetInBytes + validByteIndex);
return ((1 << validBitIndex) & b) == 0;
}

Expand All @@ -193,9 +170,9 @@ public void setNullAt(int ordinal) {
assertIndexIsValid(i);
int validByteIndex = i / 8;
int validBitIndex = i % 8;
byte b = Platform.getByte(null, address + fixedWidthSizeInBytes + validByteIndex);
byte b = Platform.getByte(null, address + validityOffsetInBytes + validByteIndex);
b = (byte)((b & ~(1 << validBitIndex)) & 0xFF);
Platform.putByte(null, address + fixedWidthSizeInBytes + validByteIndex, b);
Platform.putByte(null, address + validityOffsetInBytes + validByteIndex, b);
}

@Override
Expand Down Expand Up @@ -253,12 +230,15 @@ public Decimal getDecimal(int ordinal, int precision, int scale) {

@Override
public UTF8String getUTF8String(int ordinal) {
// if (isNullAt(ordinal)) return null;
// final long offsetAndSize = getLong(ordinal);
// final int offset = (int) (offsetAndSize >> 32);
// final int size = (int) offsetAndSize;
// return UTF8String.fromAddress(null, address + offset, size);
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
if (isNullAt(ordinal)) {
return null;
}
final long columnOffset = getFieldAddressFromOrdinal(ordinal);
// data format for the fixed-width portion of variable-width data is 4 bytes of offset from the
// start of the row followed by 4 bytes of length.
final int offset = Platform.getInt(null, columnOffset);
final int size = Platform.getInt(null, columnOffset + 4);
return UTF8String.fromAddress(null, address + offset, size);
}

@Override
Expand Down Expand Up @@ -397,4 +377,21 @@ public boolean anyNull() {
throw new IllegalArgumentException("NOT IMPLEMENTED YET");
// return BitSetMethods.anySet(baseObject, address, bitSetWidthInBytes / 8);
}
}

public boolean isVariableWidthSchema() {
return variableWidthSchema;
}

public int getValidityOffsetInBytes() {
return validityOffsetInBytes;
}

/**
* Calculates the offset of the variable width section.
* This can be used to get the offset of the variable-width data. Note that the data-offset is 1-byte aligned.
* @return Total bytes used by the fixed width offsets and the validity bytes without row-alignment.
*/
public int getFixedWidthInBytes() {
return getValidityOffsetInBytes() + JCudfUtil.calculateBitSetWidthInBytes(numFields());
}
}
Loading

0 comments on commit 452e7ba

Please sign in to comment.