Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate/jit/matmul tiling 2d #1472

Merged
merged 28 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
variable declarations above loops
  • Loading branch information
louisfd committed Mar 19, 2024
commit 8b504f7dd042cd1e02cdedd5683669c3430062f2
19 changes: 11 additions & 8 deletions crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,28 @@ pub fn computation_loop(
let block_size_n: Variable = shader.config.block_size_n.into();
let elem = results.item().elem();

let lhs_sm_position = scope.create_local(Elem::UInt);
let rhs_sm_position = scope.create_local(Elem::UInt);

let registered_m = scope.create_local(elem);
let registered_n = scope.create_local(elem);

let multiplied = scope.create_local(elem);
let results_position = scope.create_local(Elem::UInt);
let results_before = scope.create_local(elem);
let results_after = scope.create_local(elem);

gpu!(
scope,
range(0u32, shader.config.block_size_k as u32, shader.unroll).for_each(
|dot_index, scope| {
// Load a subcolumn of values from lhs
let lhs_sm_position = scope.create_local(Elem::UInt);
gpu!(scope, lhs_sm_position = thread_row / 4u32);
gpu!(scope, lhs_sm_position *= block_size_k);
gpu!(scope, lhs_sm_position += dot_index);
gpu!(scope, register_m = shared_lhs[lhs_sm_position]);

// Load a subrow of values from rhs
let rhs_sm_position = scope.create_local(Elem::UInt);
gpu!(scope, rhs_sm_position = dot_index * block_size_n);
gpu!(scope, rhs_sm_position += thread_col);
gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32);
Expand All @@ -46,25 +55,19 @@ pub fn computation_loop(
scope,
range(0u32, shader.config.tile_size_n as u32, shader.unroll)
.for_each(|res_idx_n, scope| {
let registered_m = scope.create_local(elem);
let registered_n = scope.create_local(elem);
gpu!(scope, registered_m = register_m[res_idx_m]);
gpu!(scope, registered_n = register_n[res_idx_n]);

let multiplied = scope.create_local(elem);
gpu!(scope, multiplied = registered_m * registered_n);

let results_position = scope.create_local(Elem::UInt);
gpu!(
scope,
results_position =
res_idx_m * shader.config.tile_size_n
);
gpu!(scope, results_position += res_idx_n);

let results_before = scope.create_local(elem);
gpu!(scope, results_before = results[results_position]);
let results_after = scope.create_local(elem);
gpu!(scope, results_after = results_before + multiplied);

gpu!(scope, results[results_position] = results_after);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,26 @@ fn load_shared_memory_with_bound_check(
let block_size_n: Variable = shader.config.block_size_n.into();
let elem = input.item().elem();

let current = scope.create_local(Elem::UInt);
let aligned_with_shared_memory = scope.create_local(Elem::Bool);
let sm_position = scope.create_local(Elem::UInt);
let within_input = scope.create_local(Elem::Bool);
let current_with_k = scope.create_local(Elem::UInt);
let remain_at_least_1 = scope.create_local(Elem::Bool);
let read_condition = scope.create_local(Elem::Bool);
let val_vec4 = scope.create_local(shared_memory.item());

gpu!(
scope,
range(0_u32, 4u32, shader.unroll).for_each(|j, scope| {
let current = scope.create_local(Elem::UInt);
gpu!(scope, current = thread_idx_1 + j);

let aligned_with_shared_memory = scope.create_local(Elem::Bool);
gpu!(scope, aligned_with_shared_memory = current < block_size_k);

// To avoid overwriting following row in shared memory
gpu!(scope, if(aligned_with_shared_memory).then(|scope|{

// Position in shared memory
let sm_position = scope.create_local(Elem::UInt);
match input_identifier {
InputIdentifier::Lhs => {
gpu!(scope, sm_position = thread_idx_2 / 4u32);
Expand All @@ -102,53 +108,53 @@ fn load_shared_memory_with_bound_check(
}

// To pad with zeros if outside lhs
let within_input = scope.create_local(Elem::Bool);
let current_with_k = scope.create_local(Elem::UInt);
let remain_at_least_1 = scope.create_local(Elem::Bool);
let read_condition = scope.create_local(Elem::Bool);
gpu!(scope, current_with_k = current + k);
gpu!(scope, within_input = current_with_k < dim_k);
gpu!(scope, remain_at_least_1 = remain >= 1u32);
gpu!(scope, read_condition = within_input && remain_at_least_1);

gpu!(scope, if(read_condition).then(|scope| {
let position_0 = scope.create_local(Elem::UInt);
gpu!(scope, position_0 = k + current);
gpu!(scope, position_0 *= stride_1);
let tmp = scope.create_local(Elem::UInt);
gpu!(scope, tmp = thread_idx_2 * stride_2);
gpu!(scope, position_0 += tmp);
gpu!(scope, position_0 += input_offset);
let position_0 = scope.create_local(Elem::UInt);
let position_1 = scope.create_local(Elem::UInt);
let position_2 = scope.create_local(Elem::UInt);
let position_3 = scope.create_local(Elem::UInt);
gpu!(scope, position_1 = position_0 + stride_2);
gpu!(scope, position_2 = position_1 + stride_2);
gpu!(scope, position_3 = position_2 + stride_2);
let remain_n = scope.create_local(Elem::Bool);

let val_0 = scope.zero(elem);
let val_1 = scope.zero(elem);
let val_2 = scope.zero(elem);
let val_3 = scope.zero(elem);

let remain_n = scope.create_local(Elem::Bool);
gpu!(scope, position_0 = k + current);
gpu!(scope, position_0 *= stride_1);
gpu!(scope, tmp = thread_idx_2 * stride_2);
gpu!(scope, position_0 += tmp);
gpu!(scope, position_0 += input_offset);
gpu!(scope, position_1 = position_0 + stride_2);
gpu!(scope, position_2 = position_1 + stride_2);
gpu!(scope, position_3 = position_2 + stride_2);

gpu!(scope, remain_n = remain >= 4u32);
gpu!(scope, if(remain_n).then(|scope|{
gpu!(scope, val_0 = input[position_0]);
gpu!(scope, val_1 = input[position_1]);
gpu!(scope, val_2 = input[position_2]);
gpu!(scope, val_3 = input[position_3]);

}).else(|scope|{
gpu!(scope, remain_n = remain == 3u32);
gpu!(scope, if(remain_n).then(|scope|{
gpu!(scope, val_0 = input[position_0]);
gpu!(scope, val_1 = input[position_1]);
gpu!(scope, val_2 = input[position_2]);

}).else(|scope|{
gpu!(scope, remain_n = remain == 2u32);
gpu!(scope, if(remain_n).then(|scope|{
gpu!(scope, val_0 = input[position_0]);
gpu!(scope, val_1 = input[position_1]);

}).else(|scope|{
gpu!(scope, remain_n = remain == 1u32);
gpu!(scope, if(remain_n).then(|scope|{
Expand All @@ -158,12 +164,11 @@ fn load_shared_memory_with_bound_check(
}));
}));

let val_vec4 = scope.create_local(shared_memory.item());
gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3));
gpu!(scope, shared_memory[sm_position] = val_vec4);

}).else(|scope|{
let val_0 = scope.zero(elem);
let val_vec4 = scope.create_local(shared_memory.item());
gpu!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0));
gpu!(scope, shared_memory[sm_position] = val_vec4);
}));
Expand Down Expand Up @@ -206,19 +211,31 @@ fn load_shared_memory_no_bound_check(
let block_size_n: Variable = shader.config.block_size_n.into();
let elem = input.item().elem();

let current = scope.create_local(Elem::UInt);
let aligned_with_shared_memory = scope.create_local(Elem::Bool);
let sm_position = scope.create_local(Elem::UInt);

let tmp = scope.create_local(Elem::UInt);
let position_0 = scope.create_local(Elem::UInt);
let position_1 = scope.create_local(Elem::UInt);
let position_2 = scope.create_local(Elem::UInt);
let position_3 = scope.create_local(Elem::UInt);
let val_0 = scope.create_local(elem);
let val_1 = scope.create_local(elem);
let val_2 = scope.create_local(elem);
let val_3 = scope.create_local(elem);
let val_vec4 = scope.create_local(shared_memory.item());

gpu!(
scope,
range(0_u32, 4u32, shader.unroll).for_each(|j, scope| {
let current = scope.create_local(Elem::UInt);
gpu!(scope, current = thread_idx_1 + j);

let aligned_with_shared_memory = scope.create_local(Elem::Bool);
gpu!(scope, aligned_with_shared_memory = current < block_size_k);

// To avoid overwriting following row in shared memory
gpu!(scope, if(aligned_with_shared_memory).then(|scope|{

let sm_position = scope.create_local(Elem::UInt);
match input_identifier {
InputIdentifier::Lhs => {
gpu!(scope, sm_position = thread_idx_2 / 4u32);
Expand All @@ -232,30 +249,20 @@ fn load_shared_memory_no_bound_check(
}
}

let position_0 = scope.create_local(Elem::UInt);
gpu!(scope, position_0 = k + current);
gpu!(scope, position_0 *= stride_1);
let tmp = scope.create_local(Elem::UInt);
gpu!(scope, tmp = thread_idx_2 * stride_2);
gpu!(scope, position_0 += tmp);
gpu!(scope, position_0 += input_offset);
let position_1 = scope.create_local(Elem::UInt);
let position_2 = scope.create_local(Elem::UInt);
let position_3 = scope.create_local(Elem::UInt);
gpu!(scope, position_1 = position_0 + stride_2);
gpu!(scope, position_2 = position_1 + stride_2);
gpu!(scope, position_3 = position_2 + stride_2);

let val_0 = scope.create_local(elem);
let val_1 = scope.create_local(elem);
let val_2 = scope.create_local(elem);
let val_3 = scope.create_local(elem);
gpu!(scope, val_0 = input[position_0]);
gpu!(scope, val_1 = input[position_1]);
gpu!(scope, val_2 = input[position_2]);
gpu!(scope, val_3 = input[position_3]);

let val_vec4 = scope.create_local(shared_memory.item());
gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3));
gpu!(scope, shared_memory[sm_position] = val_vec4);
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ pub(crate) fn gather_shader_information(

// Shapes
let rank = Variable::Rank;
let ultimate_dim = scope.create_local(Elem::UInt);
let penultimate_dim = scope.create_local(Elem::UInt);
gpu!(scope, ultimate_dim = rank - 1u32);
gpu!(scope, penultimate_dim = rank - 2u32);
let last_dim = scope.create_local(Elem::UInt);
let second_to_last_dim = scope.create_local(Elem::UInt);
let dim_m = scope.create_local(Elem::UInt);
let dim_k = scope.create_local(Elem::UInt);
let dim_n = scope.create_local(Elem::UInt);
gpu!(scope, dim_m = shape(lhs, penultimate_dim));
gpu!(scope, dim_k = shape(lhs, ultimate_dim));
gpu!(scope, dim_n = shape(rhs, ultimate_dim));
gpu!(scope, last_dim = rank - 1u32);
gpu!(scope, second_to_last_dim = rank - 2u32);
gpu!(scope, dim_m = shape(lhs, second_to_last_dim));
gpu!(scope, dim_k = shape(lhs, last_dim));
gpu!(scope, dim_n = shape(rhs, last_dim));

// Strides
let lhs_stride_row = scope.create_local(Elem::UInt);
Expand All @@ -45,28 +45,28 @@ pub(crate) fn gather_shader_information(
let rhs_stride_col = scope.create_local(Elem::UInt);
let out_stride_row = scope.create_local(Elem::UInt);
let out_stride_col = scope.create_local(Elem::UInt);
gpu!(scope, lhs_stride_row = stride(lhs, penultimate_dim));
gpu!(scope, lhs_stride_col = stride(lhs, ultimate_dim));
gpu!(scope, rhs_stride_row = stride(rhs, penultimate_dim));
gpu!(scope, rhs_stride_col = stride(rhs, ultimate_dim));
gpu!(scope, out_stride_row = stride(out, penultimate_dim));
gpu!(scope, out_stride_col = stride(out, ultimate_dim));
gpu!(scope, lhs_stride_row = stride(lhs, second_to_last_dim));
gpu!(scope, lhs_stride_col = stride(lhs, last_dim));
gpu!(scope, rhs_stride_row = stride(rhs, second_to_last_dim));
gpu!(scope, rhs_stride_col = stride(rhs, last_dim));
gpu!(scope, out_stride_row = stride(out, second_to_last_dim));
gpu!(scope, out_stride_col = stride(out, last_dim));

// Workgroup offset
let skip_row = scope.create_local(Elem::UInt);
let skip_col = scope.create_local(Elem::UInt);
let workgroup_id_x = Variable::WorkgroupIdX;
let workgroup_id_y = Variable::WorkgroupIdY;
gpu!(scope, skip_row = workgroup_id_x);
gpu!(scope, skip_row *= block_size_m);
let skip_col = scope.create_local(Elem::UInt);
let workgroup_id_y = Variable::WorkgroupIdY;
gpu!(scope, skip_col = workgroup_id_y);
gpu!(scope, skip_col *= block_size_n);

// Position of the first element of the thread, relative to the block
let thread_row = scope.create_local(Elem::UInt);
let thread_col = scope.create_local(Elem::UInt);
gpu!(scope, thread_row = local_idx / n_threads_per_row);
gpu!(scope, thread_row *= tile_size_m);
let thread_col = scope.create_local(Elem::UInt);
gpu!(scope, thread_col = local_idx % n_threads_per_row);
gpu!(scope, thread_col *= tile_size_n);

Expand All @@ -89,30 +89,29 @@ pub(crate) fn gather_shader_information(
gpu!(scope, offset_output = offset_output * batch);

// Batch offset for the lhs & rhs matrices.
let stride_lhs = scope.create_local(Elem::UInt);
let stride_rhs = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_lhs = scope.create_local(Elem::UInt);
let shape_rhs = scope.create_local(Elem::UInt);
let tmp = scope.create_local(Elem::UInt);
let tmp_lhs = scope.create_local(Elem::UInt);
let tmp_rhs = scope.create_local(Elem::UInt);
gpu!(scope, batch_dims = rank - 2u32);
gpu!(
scope,
range(0u32, batch_dims).for_each(|b, scope| {
let stride_lhs = scope.create_local(Elem::UInt);
let stride_rhs = scope.create_local(Elem::UInt);
let stride_output = scope.create_local(Elem::UInt);
let shape_lhs = scope.create_local(Elem::UInt);
let shape_rhs = scope.create_local(Elem::UInt);

gpu!(scope, stride_lhs = stride(lhs, b));
gpu!(scope, stride_rhs = stride(rhs, b));
gpu!(scope, stride_output = stride(out, b));
gpu!(scope, shape_lhs = shape(lhs, b));
gpu!(scope, shape_rhs = shape(rhs, b));

let tmp = scope.create_local(Elem::UInt);
gpu!(scope, tmp = offset_output / stride_output);
let tmp_lhs = scope.create_local(Elem::UInt);
gpu!(scope, tmp_lhs = tmp % shape_lhs);
gpu!(scope, tmp_lhs = tmp_lhs * stride_lhs);
gpu!(scope, offset_lhs += tmp_lhs);

let tmp_rhs = scope.create_local(Elem::UInt);
gpu!(scope, tmp_rhs = tmp % shape_rhs);
gpu!(scope, tmp_rhs = tmp_rhs * stride_rhs);
gpu!(scope, offset_rhs += tmp_rhs);
Expand All @@ -134,7 +133,9 @@ pub(crate) fn gather_shader_information(
shader.config.block_size_k as u32 * shader.config.block_size_n as u32 / 4u32,
);

// Calculate exact number of loop iterations
let n_loops = scope.create_local(Elem::UInt);
let k = scope.create_local(Elem::UInt);
if shader.bounds_check_required {
let dim_k_float = scope.create_local(elem);
let block_size_k_float = scope.create_local(elem);
Expand All @@ -148,8 +149,6 @@ pub(crate) fn gather_shader_information(
gpu!(scope, n_loops = dim_k / block_size_k);
}

let k = scope.create_local(Elem::UInt);

Tiling2dState {
n_loops,
k,
Expand Down
Loading