Skip to content

Commit

Permalink
Apply cuda::proclaim_return_type to vertex_result.cu and sampling_pos…
Browse files Browse the repository at this point in the history
…t_processing_test.cu.
  • Loading branch information
bdice committed Dec 8, 2023
1 parent 8ad9d3b commit d0507a7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
10 changes: 6 additions & 4 deletions cpp/src/mtmg/vertex_result.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <detail/graph_partition_utils.cuh>

#include <thrust/functional.h>
#include <thrust/gather.h>

namespace cugraph {
Expand Down Expand Up @@ -91,10 +92,11 @@ rmm::device_uvector<result_t> vertex_result_view_t<result_t>::gather(
auto vertex_partition =
vertex_partition_device_view_t<vertex_t, multi_gpu>(vertex_partition_view);

auto iter =
thrust::make_transform_iterator(local_vertices.begin(), [vertex_partition] __device__(auto v) {
auto iter = thrust::make_transform_iterator(
local_vertices.begin(),
cuda::proclaim_return_type<vertex_t>([vertex_partition] __device__(auto v) {
return vertex_partition.local_vertex_partition_offset_from_vertex_nocheck(v);
});
}));

thrust::gather(handle.get_thrust_policy(),
iter,
Expand All @@ -111,7 +113,7 @@ rmm::device_uvector<result_t> vertex_result_view_t<result_t>::gather(
vertex_gpu_ids.begin(),
vertex_gpu_ids.end(),
thrust::make_zip_iterator(local_vertices.begin(), vertex_pos.begin(), tmp_result.begin()),
[] __device__(int gpu) { return gpu; },
thrust::identity{},
handle.get_stream());

//
Expand Down
10 changes: 6 additions & 4 deletions cpp/tests/sampling/sampling_post_processing_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
#include <thrust/sort.h>
#include <thrust/unique.h>

#include <cuda/functional>

struct SamplingPostProcessing_Usecase {
size_t num_labels{};
size_t num_seeds_per_label{};
Expand Down Expand Up @@ -318,15 +320,15 @@ bool check_renumber_map_invariants(

auto renumbered_merged_vertex_first = thrust::make_transform_iterator(
merged_vertices.begin(),
[sorted_org_vertices =
cuda::proclaim_return_type<vertex_t>([sorted_org_vertices =
raft::device_span<vertex_t const>(sorted_org_vertices.data(), sorted_org_vertices.size()),
matching_renumbered_vertices = raft::device_span<vertex_t const>(
matching_renumbered_vertices.data(),
matching_renumbered_vertices.size())] __device__(vertex_t major) {
auto it = thrust::lower_bound(
thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), major);
return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), it)];
});
}));

thrust::reduce_by_key(handle.get_thrust_policy(),
sort_key_first,
Expand Down Expand Up @@ -1020,7 +1022,7 @@ class Tests_SamplingPostProcessing
? this_label_output_edgelist_srcs.begin()
: this_label_output_edgelist_dsts.begin()) +
old_size,
[offsets = raft::device_span<size_t const>(d_offsets.data(), d_offsets.size()),
cuda::proclaim_return_type<vertex_t>([offsets = raft::device_span<size_t const>(d_offsets.data(), d_offsets.size()),
nzd_vertices =
renumbered_and_compressed_nzd_vertices
? thrust::make_optional<raft::device_span<vertex_t const>>(
Expand All @@ -1036,7 +1038,7 @@ class Tests_SamplingPostProcessing
} else {
return base_v + static_cast<vertex_t>(idx);
}
});
}));
thrust::copy(handle.get_thrust_policy(),
renumbered_and_compressed_edgelist_minors.begin() + h_offsets[0],
renumbered_and_compressed_edgelist_minors.begin() + h_offsets.back(),
Expand Down

0 comments on commit d0507a7

Please sign in to comment.