Skip to content

Commit

Permalink
Change ColumnViewAccess usage to work with ColumnView (#1105)
Browse files Browse the repository at this point in the history
Signed-off-by: Kuhu Shukla <kuhus@nvidia.com>
  • Loading branch information
Kuhu Shukla authored Nov 18, 2020
1 parent c748b95 commit de005d3
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids;

import ai.rapids.cudf.ColumnViewAccess;
import ai.rapids.cudf.ColumnView;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVector;
import ai.rapids.cudf.Scalar;
Expand Down Expand Up @@ -305,8 +305,8 @@ public static ColumnarBatch from(Table table, DataType[] colTypes) {
/**
* This should only ever be called from an assertion.
*/
private static <T> boolean typeConversionAllowed(ColumnViewAccess<T> cv, DataType colType) {
DType dt = cv.getDataType();
private static boolean typeConversionAllowed(ColumnView cv, DataType colType) {
DType dt = cv.getType();
if (!dt.isNestedType()) {
return getRapidsType(colType).equals(dt);
}
Expand All @@ -316,27 +316,27 @@ private static <T> boolean typeConversionAllowed(ColumnViewAccess<T> cv, DataTyp
if (!(dt.equals(DType.LIST))) {
return false;
}
try (ColumnViewAccess<T> structCv = cv.getChildColumnViewAccess(0)) {
if (!(structCv.getDataType().equals(DType.STRUCT))) {
try (ColumnView structCv = cv.getChildColumnView(0)) {
if (!(structCv.getType().equals(DType.STRUCT))) {
return false;
}
if (structCv.getNumChildren() != 2) {
return false;
}
try (ColumnViewAccess<T> keyCv = structCv.getChildColumnViewAccess(0)) {
try (ColumnView keyCv = structCv.getChildColumnView(0)) {
if (!typeConversionAllowed(keyCv, mType.keyType())) {
return false;
}
}
try (ColumnViewAccess<T> valCv = structCv.getChildColumnViewAccess(1)) {
try (ColumnView valCv = structCv.getChildColumnView(1)) {
return typeConversionAllowed(valCv, mType.valueType());
}
}
} else if (colType instanceof ArrayType) {
if (!(dt.equals(DType.LIST))) {
return false;
}
try (ColumnViewAccess<T> tmp = cv.getChildColumnViewAccess(0)) {
try (ColumnView tmp = cv.getChildColumnView(0)) {
return typeConversionAllowed(tmp, ((ArrayType) colType).elementType());
}
} else if (colType instanceof StructType) {
Expand All @@ -349,7 +349,7 @@ private static <T> boolean typeConversionAllowed(ColumnViewAccess<T> cv, DataTyp
return false;
}
for (int childIndex = 0; childIndex < numChildren; childIndex++) {
try (ColumnViewAccess<T> tmp = cv.getChildColumnViewAccess(childIndex)) {
try (ColumnView tmp = cv.getChildColumnView(childIndex)) {
StructField entry = ((StructType) colType).apply(childIndex);
if (!typeConversionAllowed(tmp, entry.dataType())) {
return false;
Expand All @@ -361,8 +361,8 @@ private static <T> boolean typeConversionAllowed(ColumnViewAccess<T> cv, DataTyp
if (!(dt.equals(DType.LIST))) {
return false;
}
try (ColumnViewAccess<T> tmp = cv.getChildColumnViewAccess(0)) {
DType tmpType = tmp.getDataType();
try (ColumnView tmp = cv.getChildColumnView(0)) {
DType tmpType = tmp.getType();
return tmpType.equals(DType.INT8) || tmpType.equals(DType.UINT8);
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

package com.nvidia.spark.rapids;

import ai.rapids.cudf.ColumnViewAccess;
import ai.rapids.cudf.HostColumnVectorCore;
import ai.rapids.cudf.HostMemoryBuffer;

import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
Expand Down Expand Up @@ -127,7 +126,7 @@ public final ColumnarArray getArray(int rowId) {
if (cachedChildren[0] == null) {
// cache the child data
ArrayType at = (ArrayType) dataType();
HostColumnVectorCore data = (HostColumnVectorCore) cudfCv.getChildColumnViewAccess(0);
HostColumnVectorCore data = cudfCv.getChildColumnView(0);
cachedChildren[0] = new RapidsHostColumnVectorCore(at.elementType(), data);
}
RapidsHostColumnVectorCore data = cachedChildren[0];
Expand All @@ -141,12 +140,11 @@ public final ColumnarMap getMap(int ordinal) {
if (cachedChildren[0] == null) {
// Cache the key/value
MapType mt = (MapType) dataType();
ColumnViewAccess<HostMemoryBuffer> structHcv = cudfCv.getChildColumnViewAccess(0);
HostColumnVectorCore structHcv = cudfCv.getChildColumnView(0);
// keys
HostColumnVectorCore firstHcvCore = (HostColumnVectorCore) structHcv.getChildColumnViewAccess(0);

HostColumnVectorCore firstHcvCore = structHcv.getChildColumnView(0);
// values
HostColumnVectorCore secondHcvCore = (HostColumnVectorCore) structHcv.getChildColumnViewAccess(1);
HostColumnVectorCore secondHcvCore = structHcv.getChildColumnView(1);

cachedChildren[0] = new RapidsHostColumnVectorCore(mt.keyType(), firstHcvCore);
cachedChildren[1] = new RapidsHostColumnVectorCore(mt.valueType(), secondHcvCore);
Expand Down Expand Up @@ -180,7 +178,7 @@ public final ColumnVector getChild(int ordinal) {
StructType st = (StructType) dataType();
StructField[] fields = st.fields();
for (int i = 0; i < fields.length; i++) {
HostColumnVectorCore tmp = (HostColumnVectorCore) cudfCv.getChildColumnViewAccess(i);
HostColumnVectorCore tmp = cudfCv.getChildColumnView(i);
cachedChildren[i] = new RapidsHostColumnVectorCore(fields[i].dataType(), tmp);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class AcceleratedColumnarToRowIterator(
currentCv = Some(wip)
at = 0
total = wip.getRowCount().toInt
val byteBuffer = currentCv.get.getChildColumnViewAccess(0).getDataBuffer
val byteBuffer = currentCv.get.getChildColumnView(0).getData
baseDataAddress = byteBuffer.getAddress
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow
}
}
val expectedType = GpuColumnVector.getRapidsType(windowFunc.dataType)
if (expectedType != aggColumn.getDataType) {
if (expectedType != aggColumn.getType) {
withResource(aggColumn) { aggColumn =>
GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
}
Expand Down Expand Up @@ -272,7 +272,7 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow
}
}
val expectedType = GpuColumnVector.getRapidsType(windowFunc.dataType)
if (expectedType != aggColumn.getDataType) {
if (expectedType != aggColumn.getType) {
withResource(aggColumn) { aggColumn =>
GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{ColumnVector, DType, PadSide, Scalar, Table}
import ai.rapids.cudf.{ColumnVector, ColumnView, DType, PadSide, Scalar, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits._

Expand Down Expand Up @@ -289,7 +289,7 @@ case class GpuConcat(children: Seq[Expression]) extends GpuComplexTypeMergingExp
}
emptyStrScalar = GpuScalar.from("", StringType)
GpuColumnVector.from(ColumnVector.stringConcatenate(emptyStrScalar, nullStrScalar,
columns.toArray[ColumnVector]), dataType)
columns.toArray[ColumnView]), dataType)
} finally {
columns.safeClose()
if (emptyStrScalar != null) {
Expand Down

0 comments on commit de005d3

Please sign in to comment.