Skip to content

Commit

Permalink
Merge branch 'branch-24.10' into doc-2410-cuvs_migration
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Sep 26, 2024
2 parents f776e22 + 704feb1 commit d87285b
Show file tree
Hide file tree
Showing 4 changed files with 465 additions and 133 deletions.
1 change: 1 addition & 0 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ if(BUILD_PRIMS_BENCH)
linalg/reduce_rows_by_key.cu
linalg/reduce.cu
linalg/sddmm.cu
linalg/transpose.cu
main.cpp
)

Expand Down
85 changes: 85 additions & 0 deletions cpp/bench/prims/linalg/transpose.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/transpose.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>

namespace raft::bench::linalg {

template <typename IdxT>
struct transpose_input {
IdxT rows, cols;
};

template <typename IdxT>
inline auto operator<<(std::ostream& os, const transpose_input<IdxT>& p) -> std::ostream&
{
os << p.rows << "#" << p.cols;
return os;
}

template <typename T, typename IdxT, typename Layout>
struct TransposeBench : public fixture {
TransposeBench(const transpose_input<IdxT>& p)
: params(p), in(p.rows * p.cols, stream), out(p.rows * p.cols, stream)
{
raft::random::RngState rng{1234};
raft::random::uniform(handle, rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0);
}

void run_benchmark(::benchmark::State& state) override
{
std::ostringstream label_stream;
label_stream << params;
state.SetLabel(label_stream.str());

loop_on_state(state, [this]() {
auto input_view =
raft::make_device_matrix_view<T, IdxT, Layout>(in.data(), params.rows, params.cols);
auto output_view = raft::make_device_vector_view<T, IdxT, Layout>(out.data(), params.rows);
raft::linalg::transpose(handle,
input_view.data_handle(),
output_view.data_handle(),
params.rows,
params.cols,
handle.get_stream());
});
}

private:
transpose_input<IdxT> params;
rmm::device_uvector<T> in, out;
}; // struct TransposeBench

const std::vector<transpose_input<int>> transpose_inputs_i32 =
raft::util::itertools::product<transpose_input<int>>({10, 128, 256, 512, 1024},
{10000, 100000, 1000000});

RAFT_BENCH_REGISTER((TransposeBench<float, int, raft::row_major>), "", transpose_inputs_i32);
RAFT_BENCH_REGISTER((TransposeBench<half, int, raft::row_major>), "", transpose_inputs_i32);

RAFT_BENCH_REGISTER((TransposeBench<float, int, raft::col_major>), "", transpose_inputs_i32);
RAFT_BENCH_REGISTER((TransposeBench<half, int, raft::col_major>), "", transpose_inputs_i32);

} // namespace raft::bench::linalg
67 changes: 53 additions & 14 deletions cpp/include/raft/linalg/detail/transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ template <typename IndexType, int TILE_DIM, int BLOCK_ROWS>
RAFT_KERNEL transpose_half_kernel(IndexType n_rows,
IndexType n_cols,
const half* __restrict__ in,
half* __restrict__ out)
half* __restrict__ out,
const IndexType stride_in,
const IndexType stride_out)
{
__shared__ half tile[TILE_DIM][TILE_DIM + 1];

Expand All @@ -49,7 +51,7 @@ RAFT_KERNEL transpose_half_kernel(IndexType n_rows,

for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
if (x < n_cols && (y + j) < n_rows) {
tile[threadIdx.y + j][threadIdx.x] = __ldg(&in[(y + j) * n_cols + x]);
tile[threadIdx.y + j][threadIdx.x] = __ldg(&in[(y + j) * stride_in + x]);
}
}
__syncthreads();
Expand All @@ -59,17 +61,41 @@ RAFT_KERNEL transpose_half_kernel(IndexType n_rows,

for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) {
if (x < n_rows && (y + j) < n_cols) {
out[(y + j) * n_rows + x] = tile[threadIdx.x][threadIdx.y + j];
out[(y + j) * stride_out + x] = tile[threadIdx.x][threadIdx.y + j];
}
}
__syncthreads();
}
}
}

/**
* @brief Transposes a matrix stored in row-major order.
*
* This function transposes a matrix of half-precision floating-point numbers (`half`).
* Both the input (`in`) and output (`out`) matrices are assumed to be stored in row-major order.
*
* @tparam IndexType The type used for indexing the matrix dimensions (e.g., int).
* @param handle The RAFT resource handle which contains resources.
* @param n_rows The number of rows in the input matrix.
* @param n_cols The number of columns in the input matrix.
* @param in Pointer to the input matrix in row-major order.
* @param out Pointer to the output matrix in row-major order, where the transposed matrix will be
* stored.
* @param stride_in The stride (number of elements between consecutive rows) for the input matrix.
* Default is 1, which means the input matrix is contiguous in memory.
* @param stride_out The stride (number of elements between consecutive rows) for the output matrix.
* Default is 1, which means the output matrix is contiguous in memory.
*/

template <typename IndexType>
void transpose_half(
raft::resources const& handle, IndexType n_rows, IndexType n_cols, const half* in, half* out)
void transpose_half(raft::resources const& handle,
IndexType n_rows,
IndexType n_cols,
const half* in,
half* out,
const IndexType stride_in = 1,
const IndexType stride_out = 1)
{
if (n_cols == 0 || n_rows == 0) return;
auto stream = resource::get_cuda_stream(handle);
Expand Down Expand Up @@ -100,8 +126,13 @@ void transpose_half(

dim3 grids(adjusted_grid_x, adjusted_grid_y);

transpose_half_kernel<IndexType, block_dim_x, block_dim_y>
<<<grids, blocks, 0, stream>>>(n_rows, n_cols, in, out);
if (stride_in > 1 || stride_out > 1) {
transpose_half_kernel<IndexType, block_dim_x, block_dim_y>
<<<grids, blocks, 0, stream>>>(n_rows, n_cols, in, out, stride_in, stride_out);
} else {
transpose_half_kernel<IndexType, block_dim_x, block_dim_y>
<<<grids, blocks, 0, stream>>>(n_rows, n_cols, in, out, n_cols, n_rows);
}

RAFT_CUDA_TRY(cudaPeekAtLastError());
}
Expand All @@ -118,7 +149,7 @@ void transpose(raft::resources const& handle,
int out_n_cols = n_rows;

if constexpr (std::is_same_v<math_t, half>) {
transpose_half(handle, out_n_rows, out_n_cols, in, out);
transpose_half(handle, n_cols, n_rows, in, out);
} else {
cublasHandle_t cublas_h = resource::get_cublas_handle(handle);
RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream));
Expand Down Expand Up @@ -195,9 +226,13 @@ void transpose_row_major_impl(
raft::mdspan<half, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<half, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
auto out_n_rows = in.extent(1);
auto out_n_cols = in.extent(0);
transpose_half<IndexType>(handle, out_n_cols, out_n_rows, in.data_handle(), out.data_handle());
transpose_half<IndexType>(handle,
in.extent(0),
in.extent(1),
in.data_handle(),
out.data_handle(),
in.stride(0),
out.stride(0));
}

template <typename T, typename IndexType, typename LayoutPolicy, typename AccessorPolicy>
Expand Down Expand Up @@ -233,9 +268,13 @@ void transpose_col_major_impl(
raft::mdspan<half, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> in,
raft::mdspan<half, raft::matrix_extent<IndexType>, LayoutPolicy, AccessorPolicy> out)
{
auto out_n_rows = in.extent(1);
auto out_n_cols = in.extent(0);
transpose_half<IndexType>(handle, out_n_rows, out_n_cols, in.data_handle(), out.data_handle());
transpose_half<IndexType>(handle,
in.extent(1),
in.extent(0),
in.data_handle(),
out.data_handle(),
in.stride(1),
out.stride(1));
}

}; // end namespace detail
Expand Down
Loading

0 comments on commit d87285b

Please sign in to comment.