Skip to content

Commit

Permalink
Merge branch 'branch-24.10' into cagra-bf
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong authored Sep 18, 2024
2 parents 59a3898 + 2071841 commit 1064895
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
10 changes: 5 additions & 5 deletions cpp/include/raft/sparse/op/detail/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ struct TupleComp {
* @param vals vals array from coo matrix
* @param stream: cuda stream to use
*/
template <typename T>
void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t stream)
template <typename T, typename IdxT = int>
void coo_sort(IdxT m, IdxT n, IdxT nnz, IdxT* rows, IdxT* cols, T* vals, cudaStream_t stream)
{
auto coo_indices = thrust::make_zip_iterator(thrust::make_tuple(rows, cols));

Expand All @@ -83,10 +83,10 @@ void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t
* @param in: COO to sort by row
* @param stream: the cuda stream to use
*/
template <typename T>
void coo_sort(COO<T>* const in, cudaStream_t stream)
template <typename T, typename IdxT = int>
void coo_sort(COO<T, IdxT>* const in, cudaStream_t stream)
{
coo_sort<T>(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream);
coo_sort<T, IdxT>(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream);
}

/**
Expand Down
14 changes: 7 additions & 7 deletions cpp/include/raft/sparse/op/sort.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,8 +37,8 @@ namespace op {
* @param vals vals array from coo matrix
* @param stream: cuda stream to use
*/
template <typename T>
void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t stream)
template <typename T, typename IdxT = int>
void coo_sort(IdxT m, IdxT n, IdxT nnz, IdxT* rows, IdxT* cols, T* vals, cudaStream_t stream)
{
detail::coo_sort(m, n, nnz, rows, cols, vals, stream);
}
Expand All @@ -49,10 +49,10 @@ void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t
* @param in: COO to sort by row
* @param stream: the cuda stream to use
*/
template <typename T>
void coo_sort(COO<T>* const in, cudaStream_t stream)
template <typename T, typename IdxT = int>
void coo_sort(COO<T, IdxT>* const in, cudaStream_t stream)
{
coo_sort<T>(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream);
coo_sort<T, IdxT>(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream);
}

/**
Expand All @@ -75,4 +75,4 @@ void coo_sort_by_weight(
}; // end NAMESPACE sparse
}; // end NAMESPACE raft

#endif
#endif

0 comments on commit 1064895

Please sign in to comment.