Skip to content

Commit

Permalink
MSM - fixed bug in reduction phase (ingonyama-zk#549)
Browse files Browse the repository at this point in the history
This PR fixes a bug in the iterative reduction algorithm.
There were unsynchronized threads reading and writing to the same
addresses that caused MSM to fail a small percentage of the time - this is fixed now.
  • Loading branch information
HadarIngonyama committed Jun 30, 2024
1 parent f812f07 commit 4fef542
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions icicle/src/msm/msm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ namespace msm {
__global__ void single_stage_multi_reduction_kernel(
const P* v,
P* v_r,
unsigned orig_block_size,
unsigned block_size,
unsigned write_stride,
unsigned buckets_per_bm,
Expand All @@ -107,11 +108,11 @@ namespace msm {
// only for write_phase=1 because of its read pattern.
const int shifted_block_id = write_phase ? block_id + (block_id + step) / step : block_id;
const int block_tid = shifted_tid % jump;
const unsigned read_ind = block_size * shifted_block_id + block_tid;
const unsigned read_ind = orig_block_size * shifted_block_id + block_tid;
const unsigned write_ind = jump * shifted_block_id + block_tid;
const unsigned v_r_key =
write_stride ? ((write_ind / buckets_per_bm) * 2 + write_phase) * write_stride + write_ind % buckets_per_bm
: write_ind;
: read_ind;
v_r[v_r_key] = v[read_ind] + v[read_ind + jump];
}

Expand Down Expand Up @@ -745,7 +746,7 @@ namespace msm {
NUM_BLOCKS = (nof_bms_in_batch + NUM_THREADS - 1) / NUM_THREADS;
big_triangle_sum_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, final_results, nof_bms_in_batch, c);
} else {
// the recursive reduction algorithm works with 2 types of reduction that can run on parallel streams
// the iterative reduction algorithm works with 2 types of reduction that can run on parallel streams
cudaStream_t stream_reduction;
cudaEvent_t event_finished_reduction;
CHK_IF_RETURN(cudaStreamCreate(&stream_reduction));
Expand All @@ -766,10 +767,10 @@ namespace msm {
const unsigned target_buckets_count = target_windows_count << target_bits_count; // new_bms*2^new_c
CHK_IF_RETURN(cudaMallocAsync(&target_buckets, sizeof(P) * target_buckets_count * batch_size, stream));
CHK_IF_RETURN(cudaMallocAsync(
&temp_buckets1, sizeof(P) * source_buckets_count / 2 * batch_size,
&temp_buckets1, sizeof(P) * source_buckets_count * batch_size,
stream)); // for type1 reduction (strided, bottom window - evens)
CHK_IF_RETURN(cudaMallocAsync(
&temp_buckets2, sizeof(P) * source_buckets_count / 2 * batch_size,
&temp_buckets2, sizeof(P) * source_buckets_count * batch_size,
stream)); // for type2 reduction (serial, top window - odds)
initialize_buckets_kernel<<<(target_buckets_count * batch_size + 255) / 256, 256>>>(
target_buckets, target_buckets_count * batch_size); // initialization is needed for the odd c case
Expand All @@ -788,9 +789,9 @@ namespace msm {
if (!is_odd_c || !is_first_iter) { // skip if c is odd and it's the first iteration
single_stage_multi_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
is_first_iter || (is_second_iter && is_odd_c) ? source_buckets : temp_buckets1,
is_last_iter ? target_buckets : temp_buckets1, 1 << (source_bits_count - j + (is_odd_c ? 1 : 0)),
is_last_iter ? 1 << target_bits_count : 0, 1 << target_bits_count, 0 /*=write_phase*/,
(1 << target_bits_count) - 1, nof_threads);
is_last_iter ? target_buckets : temp_buckets1, 1 << source_bits_count,
1 << (source_bits_count - j + (is_odd_c ? 1 : 0)), is_last_iter ? 1 << target_bits_count : 0,
1 << target_bits_count, 0 /*=write_phase*/, (1 << target_bits_count) - 1, nof_threads);
}

nof_threads = (((source_windows_count << (source_bits_count - target_bits_count)) - source_windows_count)
Expand All @@ -801,7 +802,7 @@ namespace msm {
NUM_BLOCKS = (nof_threads + NUM_THREADS - 1) / NUM_THREADS;
single_stage_multi_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream_reduction>>>(
is_first_iter ? source_buckets : temp_buckets2, is_last_iter ? target_buckets : temp_buckets2,
1 << (target_bits_count - j), is_last_iter ? 1 << target_bits_count : 0,
1 << target_bits_count, 1 << (target_bits_count - j), is_last_iter ? 1 << target_bits_count : 0,
1 << (target_bits_count - (is_odd_c ? 1 : 0)), 1 /*=write_phase*/,
(1 << (target_bits_count - (is_odd_c ? 1 : 0))) - 1, nof_threads);
}
Expand Down

0 comments on commit 4fef542

Please sign in to comment.