Skip to content

Commit

Permalink
Support creating decimal vectors from scalar (#6723)
Browse files Browse the repository at this point in the history
Support creating decimal vectors from scalar through peforming in_place_fill with underlying data type of fixed-point columns. Related tests are provided in java package.
  • Loading branch information
sperlingxx authored Nov 17, 2020
1 parent bee3229 commit 42fe218
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
- PR #6328 Java and JNI bindings for getMapValue/map_lookup
- PR #6371 Use ColumnViewAccess on Host side
- PR #6297 cuDF Python Scalars
- PR #6723 Support creating decimal vectors from scalar

## Improvements

Expand Down
27 changes: 25 additions & 2 deletions cpp/include/cudf/column/column_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ class mutable_column_view : public detail::column_view_base {
operator column_view() const;

private:
friend mutable_column_view logical_cast(mutable_column_view const& input, data_type type);

std::vector<mutable_column_view> mutable_children;
};

Expand All @@ -555,8 +557,7 @@ size_type count_descendants(column_view parent);
* duration count. However, an INT32 column cannot be logically cast to INT64 as the sizes differ,
* nor can an INT32 columm be logically cast to a FLOAT32 since what the bits represent differs.
*
* The validity of the conversion can be checked with `cudf::is_logically_castable()`. For other
* conversions between fixed-width types which require a copy, see `cudf::cast()`.
* The validity of the conversion can be checked with `cudf::is_logically_castable()`.
*
* @throws cudf::logic_error if the specified cast is not possible, i.e.,
* `is_logically_castable(input.type(), type)` is false.
Expand All @@ -567,4 +568,26 @@ size_type count_descendants(column_view parent);
*/
column_view logical_cast(column_view const& input, data_type type);

/**
* @brief Zero-copy cast between types with the same underlying representation.
*
* This is similar to `reinterpret_cast` or `bit_cast` in that it gives a view of the same raw bits
* as a different type. Unlike `reinterpret_cast` however, this cast is only allowed on types that
* have the same width and underlying representation. For example, the way timestamp types are laid
* out in memory is equivalent to an integer representing a duration since a fixed epoch; logically
* casting to the same integer type (INT32 for days, INT64 for others) results in a raw view of the
* duration count. However, an INT32 column cannot be logically cast to INT64 as the sizes differ,
* nor can an INT32 columm be logically cast to a FLOAT32 since what the bits represent differs.
*
* The validity of the conversion can be checked with `cudf::is_logically_castable()`.
*
* @throws cudf::logic_error if the specified cast is not possible, i.e.,
* `is_logically_castable(input.type(), type)` is false.
*
* @param input The `mutable_column_view` to cast from
* @param type The `data_type` to cast to
* @return New `mutable_column_view` wrapping the same data as `input` but cast to `type`
*/
mutable_column_view logical_cast(mutable_column_view const& input, data_type type);

} // namespace cudf
3 changes: 3 additions & 0 deletions cpp/include/cudf/utilities/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,9 @@ MAP_CASTABLE_TYPES(cudf::duration_s, cudf::duration_s::rep);
MAP_CASTABLE_TYPES(cudf::duration_ms, cudf::duration_ms::rep);
MAP_CASTABLE_TYPES(cudf::duration_us, cudf::duration_us::rep);
MAP_CASTABLE_TYPES(cudf::duration_ns, cudf::duration_ns::rep);
// Allow cast between decimals and integer representation
MAP_CASTABLE_TYPES(numeric::decimal32, numeric::decimal32::rep);
MAP_CASTABLE_TYPES(numeric::decimal64, numeric::decimal64::rep);

template <typename FromType>
struct is_logically_castable_to_impl {
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/column/column_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,16 @@ column_view logical_cast(column_view const& input, data_type type)
input._children};
}

mutable_column_view logical_cast(mutable_column_view const& input, data_type type)
{
CUDF_EXPECTS(is_logically_castable(input._type, type), "types are not logically castable");
return mutable_column_view{type,
input._size,
const_cast<void*>(input._data),
const_cast<cudf::bitmask_type*>(input._null_mask),
input._null_count,
input._offset,
input.mutable_children};
}

} // namespace cudf
16 changes: 14 additions & 2 deletions cpp/src/filling/fill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cudf/dictionary/dictionary_factories.hpp>
#include <cudf/filling.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/detail/fill.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>
Expand Down Expand Up @@ -59,11 +60,22 @@ struct in_place_fill_range_dispatch {
cudf::mutable_column_view& destination;

template <typename T>
std::enable_if_t<cudf::is_fixed_width<T>(), void> operator()(cudf::size_type begin,
std::enable_if_t<cudf::is_fixed_width<T>() && not cudf::is_fixed_point<T>(), void> operator()(
cudf::size_type begin, cudf::size_type end, cudaStream_t stream = 0)
{
in_place_fill<T>(destination, begin, end, value, stream);
}

template <typename T>
std::enable_if_t<cudf::is_fixed_point<T>(), void> operator()(cudf::size_type begin,
cudf::size_type end,
cudaStream_t stream = 0)
{
in_place_fill<T>(destination, begin, end, value, stream);
auto unscaled = static_cast<cudf::fixed_point_scalar<T> const&>(value).value();
using RepType = typename T::rep;
auto s = cudf::numeric_scalar<RepType>(unscaled, value.is_valid());
auto view = cudf::logical_cast(destination, s.type());
in_place_fill<RepType>(view, begin, end, s, stream);
}

template <typename T>
Expand Down
21 changes: 16 additions & 5 deletions cpp/tests/column/column_view_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ struct rep_type_impl<T, std::enable_if_t<cudf::is_duration<T>()>> {
using type = typename T::rep;
};

template <typename T>
struct rep_type_impl<T, std::enable_if_t<cudf::is_fixed_point<T>()>> {
using type = typename T::rep;
};

template <typename T>
using rep_type_t = typename rep_type_impl<T>::type;

Expand All @@ -54,23 +59,29 @@ struct ColumnViewAllTypesTests : public cudf::test::BaseFixture {
TYPED_TEST_CASE(ColumnViewAllTypesTests, cudf::test::FixedWidthTypes);

template <typename FromType, typename ToType, typename Iterator>
void do_logical_cast(cudf::column_view const& input, Iterator begin, Iterator end)
void do_logical_cast(cudf::column_view const& column_view, Iterator begin, Iterator end)
{
auto mutable_column_view = reinterpret_cast<cudf::mutable_column_view const&>(column_view);
if (std::is_same<FromType, ToType>::value) {
// Cast to same type
auto output = cudf::logical_cast(input, input.type());
cudf::test::expect_columns_equal(output, input);
auto output = cudf::logical_cast(column_view, column_view.type());
auto output1 = cudf::logical_cast(mutable_column_view, mutable_column_view.type());
cudf::test::expect_columns_equal(output, column_view);
cudf::test::expect_columns_equal(output1, mutable_column_view);
} else if (std::is_same<rep_type_t<FromType>, ToType>::value ||
std::is_same<FromType, rep_type_t<ToType>>::value) {
// Cast integer to timestamp or vice versa
cudf::data_type type{cudf::type_to_id<ToType>()};
auto output = cudf::logical_cast(input, type);
auto output = cudf::logical_cast(column_view, type);
auto output1 = cudf::logical_cast(mutable_column_view, type);
cudf::test::fixed_width_column_wrapper<ToType, cudf::size_type> expected(begin, end);
cudf::test::expect_columns_equal(output, expected);
cudf::test::expect_columns_equal(output1, expected);
} else {
// Other casts not allowed
cudf::data_type type{cudf::type_to_id<ToType>()};
EXPECT_THROW(cudf::logical_cast(input, type), cudf::logic_error);
EXPECT_THROW(cudf::logical_cast(column_view, type), cudf::logic_error);
EXPECT_THROW(cudf::logical_cast(mutable_column_view, type), cudf::logic_error);
}
}

Expand Down
8 changes: 3 additions & 5 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,14 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromScalar(JNIEnv *env,
try {
cudf::jni::auto_set_device(env);
auto scalar_val = reinterpret_cast<cudf::scalar const *>(j_scalar);
auto dtype = scalar_val->type();
cudf::mask_state mask_state =
const auto dtype = scalar_val->type();
const auto mask_state =
scalar_val->is_valid() ? cudf::mask_state::UNALLOCATED : cudf::mask_state::ALL_NULL;
std::unique_ptr<cudf::column> col;
if (row_count == 0) {
col = cudf::make_empty_column(dtype);
} else if (cudf::is_fixed_width(dtype)) {
col = cudf::make_fixed_width_column(dtype, row_count, mask_state);
auto mut_view = col->mutable_view();
cudf::fill_in_place(mut_view, 0, row_count, *scalar_val);
col = cudf::make_column_from_scalar(*scalar_val, row_count);
} else if (dtype.id() == cudf::type_id::STRING) {
// create a string column of all empty strings to fill (cheapest string column to create)
auto offsets = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32}, row_count + 1,
Expand Down
28 changes: 28 additions & 0 deletions java/src/test/java/ai/rapids/cudf/DecimalColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,32 @@ public void testAppendVector() {
}
}
}

@Test
public void testColumnVectorFromScalar() {
try (Scalar s = Scalar.fromDecimal(-3, 1233456)) {
try (ColumnVector cv = ColumnVector.fromScalar(s, 10)) {
assertEquals(s.getType(), cv.getDataType());
assertEquals(10L, cv.getRowCount());
try (HostColumnVector hcv = cv.copyToHost()) {
for (int i = 0; i < cv.getRowCount(); i++) {
assertEquals(s.getInt(), hcv.getInt(i));
assertEquals(s.getBigDecimal(), hcv.getBigDecimal(i));
}
}
}
}
try (Scalar s = Scalar.fromDecimal(-6, 123456789098L)) {
try (ColumnVector cv = ColumnVector.fromScalar(s, 10)) {
assertEquals(s.getType(), cv.getDataType());
assertEquals(10L, cv.getRowCount());
try (HostColumnVector hcv = cv.copyToHost()) {
for (int i = 0; i < cv.getRowCount(); i++) {
assertEquals(s.getLong(), hcv.getLong(i));
assertEquals(s.getBigDecimal(), hcv.getBigDecimal(i));
}
}
}
}
}
}

0 comments on commit 42fe218

Please sign in to comment.